TEMPORAL FUSION TRANSFORMER

Kaynak:https://twitter.com/jeande_d

Bu yazımda sizlere Google’ın yakın zamanda geliştirmiş olduğu Temporal Fusion Transformer (TFT) mimarisini açıklamak ve Python’da örnek bir veri seti üzerinden uygulama yapmak istiyorum.

Gün geçmiyor ki makine öğrenme algoritmalarının bir yenisi çıkmasın ve bir ihtiyaca merhem olmasın. Bugün sizlere tanıtımını yapacağım TFT derin öğrenme mimarisi, zaman serileri üzerine güven aralıklarıyla tahmin üretebilen bir mimari/yöntemdir. Bu mimarinin, benim için önemli olmasında iki büyük neden vardır. Bunların birincisi, güven aralıklı tahminler elde edebiliyor olmamız; ikincisi ise yorumlanabilen bir derin öğrenme algoritması olmasıdır.

Bildiğiniz üzere klasik derin öğrenme ağlarında, hangi parametrenin (özniteliğin) ne kadar önemli olduğunu (önem düzeyini) ve modelin neden bu tahmini yaptığını anlayamıyorduk. Tahminler, bizler için kapalı kutu oluyordu. Açıklanabilir Yapay Zeka (Explainable AI) ihtiyacı işte tam bu noktada devreye girdi. TFT mimarisi ile beraber, derin ağlarda eğitilmiş modelin hangi parametrenin ne derece önemli olduğunu ve Shap değerlerini öğrenebiliyoruz. Böylelikle, bir zaman serisinde modelin bu tahmini neden yaptığını ilgili kişilere (müşterilerimize ya da iş birimine) açıklayabiliyor ve modelin karar alma surecini yorumlayabiliyoruz. Gerekirse sistematik hataları düzeltebiliyoruz. Birçok insan, bu mimarinin yorumlanabilirlik gücüyle neden bu kadar önemli olduğunu anlatmaya çalıştığımı anlayacaktır.

TFT ile klasik zaman serisi tahmin yöntemleri arasındaki temel fark; TFT derin öğrenme mimarisi, farklı tiplerde verilerle eğitilebilmektedir. TFT modeli statik olarak, değişkenin zamanla değişebilen ve bilinemeyecek değişkenleri modele özellik olarak gönderebilmemize izin vermesidir. DeepAR gibi bu tarz değişkenleri kabul eden mimariler son zamanlarda yaygınlaşmaya başlasa da, bu denli detay verebiliyor olmak ve derin öğrenme mimarisinin yorumlanabilirliğe açık olması bu mimarinin, diğer mimarilerden bir adım öne çıkmasını sağlıyor. Çoklu zaman serisi ve çok adımlı tahmin yapmamıza izin vermesi de diğer bir güzel tarafı.

Temporal Fusion Transformer (TFT) mimarisi, bir tür dikkat odaklı (attention-based) derin öğrenme mimarisidir. Bu yöntemin faydasını, model çıktısının yorumlanmasında görüyor olacağız. Geleneksel derin öğrenme mimarileri, hedef değişken için çok da ilgili olmayan değişkenlere gereksiz önem gösterebilir. Dikkat temelli değişken seçimi, modelin genelleştirmesini iyileştirebilir. Transformer; dikkat mekanizmasına dayanan ve diziyi etkili bir şekilde hesaplayabilen, tekrarlayan sinir ağlarını tamamen ortadan kaldıran yeni bir kodlayıcı kod çözücü modelidir. Dikkat mekanizması; insan beyninin çalıştığı şekilde önemsiz gördüklerini filtrelerken, belirli bilgi parçalarına nasıl odaklanabileceğini taklit etmesi üzerinde çalışmaktadır. Örneğin; çok gürültülü bir ortamda insan beyni istemsiz olarak, karşılıklı konuştuğu kişinin dudaklarına odaklanıp, ne söylediğini anlamaya/tahmin etmeye çalışmaktadır.

TFT ayrıca, gereksiz bileşenleri bastırmak ve ilgili özelliklerin akıllıca seçilmesi için performans sağlayan geçit katmanı kullanır.

Kaynak: https://arxiv.org/pdf/1912.09363.pdf

Ayrıca; LSTM sequence to sequence encode / decoders kullandığı için kısa dönemli ilişkileri özetler ve LSTM blokları arasındaki ilişkiyi tanımlarken, uzun dönemli ilişkilerin yakalanması için dikkat ağlarına bırakır. LSTM encoder / decoder kullanılmasının sebebi ise, context – embedding yaratmaktır çünkü mimari, farklı tipte girdi alabilmeye izin vermektedir. Daha önce belirttiğim gibi; TFT mimarisi şu an bilinenler, gelecekte bilinecek olanlar ve geçmişten bildiklerimiz gibi çeşitli girdiler alabiliyordu. Bilinen girdiler encoder’da işlenirken, gelecekte bilinemeyecek olanlar decoder’da işlenir (encoder; çok boyutlu veriyi az boyuta, decoder ise az boyutlu verinin boyutunu arttırmak için kullanılır).

 

UYGULAMA

Bu çalışmada, kütüphane geliştiricilerin kullandıkları veri setini kullanacağız. Bu veri seti Kaggle’daki Stallion verisidir. Açıkçası bu veri setini değil, başka bir veri seti kullanmayı arzu ediyordum; fakat örnek olarak kullandıkları veri setinde TFT mimarisinin artı yönleri çok güzel açıklandığından bu veri seti ve kodlar üzerinden uygulamaya devam etmeye karar verdim. Ayrıca, çalışmanın kodlarına bu link üzerinden erişebilirsiniz. Bu çalışmada, PyTorch’un kendi sitesinde örnek olarak gösterdikleri kodlar kullanılmıştır.

import os
import copy
import torch
import warnings
import numpy as np
import pandas as pd
from pathlib import Path
import pytorch_lightning as pl
from pytorch_forecasting.data import GroupNormalizer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_forecasting.data.examples import get_stallion_data
from pytorch_forecasting.metrics import SMAPE, PoissonLoss, QuantileLoss
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_forecasting import Baseline, TemporalFusionTransformer, 
TimeSeriesDataSet
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import 
optimize_hyperparameters
warnings.filterwarnings("ignore")

Gerekli kütüphaneleri çağırdıktan sonra verimizi yükleyelim ve göstermiş oldukları manipülasyonları yapalım. Zaman serileri tahmininde sizlerin de bildiği üzere, özel takvim günleri çok önemlidir. Örneğin; babalar gününde saat / takım elbise, sevgililer gününde ise çiçek satışlarının artması gibi. Talep tahminleme modeli geliştirilirken, veri bilimciler özel takvim günlerini mutlaka modellerinde yer vermelidir. Bu uygulamada da SKU bazında talep tahminleme yapılacağı için ilgili olan özel takvim verileri modele belirtilmiştir. Daha önce de açıkladığım gibi, TFT modeli bu tarz verileri kullanabilmektedir.

data = get_stallion_data()
# add time index
data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
data["time_idx"] -= data["time_idx"].min()
# add additional features
data["month"] = data.date.dt.month.astype(str).astype("category") # categories have be strings
data["log_volume"] = np.log(data.volume + 1e-8)
data["avg_volume_by_sku"] = data.groupby(["time_idx", "sku"], 
observed=True).volume.transform("mean")
data["avg_volume_by_agency"] = data.groupby(["time_idx", "agency"], 
observed=True).volume.transform("mean")
# we want to encode special days as one variable and thus need to first reverse one-hot encoding
special_days = [
 "easter_day",
 "good_friday",
 "new_year",
 "christmas",
 "labor_day",
 "independence_day",
 "revolution_day_memorial",
 "regional_games",
 "fifa_u_17_world_cup",
 "football_gold_cup",
 "beer_capital",
 "music_fest",
]
data[special_days] = data[special_days].apply(lambda x: x.map({0: "-", 1: x.name})).astype("category")
data.sample(10, random_state=521)

Yukarıdaki manipülasyonlardan da görüldüğü üzere; veri setinde zaman (aylık), sku ve agency bazında hacmin ortalaması alınmaktadır. Fakat şu unutulmamalıdır ki; gelecekte veri setini eğitim ve validasyon olarak ikiye böldüğümüzde, validasyonda yazan ortalama değerlerini bilemiyor olacağız. Tam bu noktada TFT mimarisinin en beğendiğim özelliği devreye giriyor. Biz bu mimaride, gelecekte (tahmin yapacağımız anda) bu bilgiyi bilemiyor olacağız. Bu bilgiyi kullanabilmek için TFT mimarisi ‘time_varying_unknown_reals’ argümanı almaktadır. Yani, zamanla değişen ve gelecekte bilinmeyecek bilgiler olarak bu değişkeni modelimize belirtmiş oluyoruz.

max_prediction_length = 6
max_encoder_length = 24
training_cutoff = data["time_idx"].max() - max_prediction_length
training = TimeSeriesDataSet(
 data[lambda x: x.time_idx <= training_cutoff],
 time_idx="time_idx",
 target="volume",
 group_ids=["agency", "sku"],
 min_encoder_length=max_encoder_length // 2, # keep encoder length long 
(as it is in the validation set)
 max_encoder_length=max_encoder_length,
 min_prediction_length=1,
 max_prediction_length=max_prediction_length,
 static_categoricals=["agency", "sku"],
 static_reals=["avg_population_2017",
"avg_yearly_household_income_2017"],
 time_varying_known_categoricals=["special_days", "month"],
 variable_groups={"special_days": special_days}, # group of categorical 
variables can be treated as one variable
 time_varying_known_reals=["time_idx", "price_regular",
"discount_in_percent"],
 time_varying_unknown_categoricals=[],
 time_varying_unknown_reals=[
 "volume",
 "log_volume",
 "industry_volume",
 "soda_volume",
 "avg_max_temp",
 "avg_volume_by_agency",
 "avg_volume_by_sku",
 ],
 target_normalizer=GroupNormalizer(
 groups=["agency", "sku"], transformation="softplus"
 ), # use softplus and normalize by group
 add_relative_time_idx=True,
 add_target_scales=True,
 add_encoder_length=True,
)
# create validation set (predict=True) which means to predict the last 
max_prediction_length points in time
# for each series
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True,
stop_randomization=True)
# create dataloaders for model
batch_size = 128 # set this between 32 to 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size,
num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size
* 10, num_workers=0)

Yukarıda gösterildiği gibi; modelin gelecekte bilemeyeceği değişkenleri (avg_volume_by_sku), bileceğimiz statik değerleri (avg_population_2017), statik kategorik değişkenleri (gelecekte bileceğimiz özel takvim verisi) ve zamanla değişebilecek ama T anında bildiğimiz (time_idx, price_regular, discount_in_percent) değişkenleri nasıl öğrenmesi / görmesi gerektiğini belirtiyoruz. Kodda gösterildiği üzere, 6 aylık bir tahmin yapmamız gerektiği belirtilmektedir. 6 aylık tahmin için ise en fazla son 24 ayın bilgilerini, LSTM encoder’a göndermesi gerektiğini modele belirtiyoruz ve son 6 ayı validasyon veri seti olarak training_cutoff olarak belirtiyoruz. Tahminlerimizi, ‘agency’ ve ‘sku’ kırılımında istediğimiz için veri setini gruplandırması gerektiğini belirtiyoruz.

 

ÖNERİLEN ÖĞRENME ORANINI BULMA

# configure network and trainer
pl.seed_everything(42)
trainer = pl.Trainer(
 gpus=0,
 # clipping gradients is a hyperparameter and important to prevent 
divergance
 # of the gradient for recurrent neural networks
 gradient_clip_val=0.1,
)
tft = TemporalFusionTransformer.from_dataset(
 training,
 # not meaningful for finding the learning rate but otherwise very 
important
 learning_rate=0.03,
 hidden_size=16, # most important hyperparameter apart from learning 
rate
 # number of attention heads. Set to up to 4 for large datasets
 attention_head_size=1,
 dropout=0.1, # between 0.1 and 0.3 are good values
 hidden_continuous_size=8, # set to <= hidden_size
 output_size=7, # 7 quantiles by default
 loss=QuantileLoss(),
 # reduce learning rate if no improvement in validation loss after x 
epochs
 reduce_on_plateau_patience=4,
)
# find optimal learning rate
res = trainer.tuner.lr_find(
 tft,
 train_dataloaders=train_dataloader,
 val_dataloaders=val_dataloader,
 max_lr=10.0,
 min_lr=1e-6,
)
print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

Bu kodun çıktısıyla önerilen öğrenme oranını bulabilirsiniz.

 

MODEL EĞİTİMİ

# configure network and trainer
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4,
patience=10, verbose=False, mode="min")
lr_logger = LearningRateMonitor() # log the learning rate
logger = TensorBoardLogger("lightning_logs") # logging results to a 
tensorboard
trainer = pl.Trainer(
 max_epochs=30,
 gpus=0,
 weights_summary="top",
 gradient_clip_val=0.1,
 limit_train_batches=30, # coment in for training, running valiation 
every 30 batches
 # fast_dev_run=True, # comment in to check that networkor dataset has 
no serious bugs
 callbacks=[lr_logger, early_stop_callback],
 logger=logger,
)
tft = TemporalFusionTransformer.from_dataset(
 training,
 learning_rate=0.03,
 hidden_size=16,
 attention_head_size=1,
 dropout=0.1,
 hidden_continuous_size=8,
 output_size=7, # 7 quantiles by default
 loss=QuantileLoss(),
 log_interval=10, # uncomment for learning rate finder and otherwise, 
e.g. to 10 for logging every 10 batches
 reduce_on_plateau_patience=4,
)
# fit network
trainer.fit(
 tft,
 train_dataloaders=train_dataloader,
 val_dataloaders=val_dataloader,
)

Yukarıdaki çıktıda mimarimizin özetini görebilmekteyiz.

 

EN İYİ PARAMETRELERİ BULMA

import pickle
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import
optimize_hyperparameters
# create study
study = optimize_hyperparameters(
 train_dataloader,
 val_dataloader,
 model_path="optuna_test",
 n_trials=200,
 max_epochs=50,
 gradient_clip_val_range=(0.01, 1.0),
 hidden_size_range=(8, 128),
 hidden_continuous_size_range=(8, 128),
 attention_head_size_range=(1, 4),
 learning_rate_range=(0.001, 0.1),
 dropout_range=(0.1, 0.3),
 trainer_kwargs=dict(limit_train_batches=30),
 reduce_on_plateau_patience=4,
 use_learning_rate_finder=False, # use Optuna to find ideal learning 
rate or use in-built learning rate finder
)
# save study results - also we can resume tuning at a later point in time
with open("test_study.pkl", "wb") as fout:
 pickle.dump(study, fout)
# show best hyperparameters
print(study.best_trial.params)

Kodları ile en uygun parametreleri bulabiliriz. Şimdi, validasyon setimizin hata metriğini ve validasyon setine yaptığımız güven aralıklı tahminlerimizi görelim.

# load the best model according to the validation loss
# (given that we use early stopping, this is not necessarily the last epoch)
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
# calcualte mean absolute error on validation set
actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])
predictions = best_tft.predict(val_dataloader)
(actuals - predictions).abs().mean()
# Ortalama mutlak hatamız 254 olduğunu görüyoruz.
# raw predictions are a dictionary from which all kind of information 
including quantiles can be extracted
raw_predictions, x = best_tft.predict(val_dataloader, mode="raw",
return_x=True)
for idx in range(10): # plot 10 examples
 best_tft.plot_prediction(x, raw_predictions, idx=idx,
add_loss_to_title=True);

Validasyon setinden alınan 3 farklı SKU’nun tahminleri şekildeki gibi olmuştur.

 

PARAMETRE ÖNEM GRAFİĞİ

interpretation = best_tft.interpret_output(raw_predictions, reduction="sum")
best_tft.plot_interpretation(interpretation)

Parametre önem grafiğini şekildeki gibi görebilmekteyiz.

 

SONUÇ

TFT mimarisi, biz veri bilimciler için acı noktalarımıza dokunan faydalı bir mimari olduğu çok net olarak görülmektedir. Derin öğrenme mimarileriyle geliştirdiğimiz tahminleri yorumlayabilmek ve özniteliklerimizin önem düzeyleri hakkında bilgi sahibi olmak büyük bir lüks. En azından şimdilik.

Dünya literatürünü incelediğimizde; açıklanabilir yapay zekanın gün geçtikçe öneminin artması ve istatistik teorileri tabanlı algoritmaların ön plana çıkması, ilerideki hayatımızda veri okur-yazarlığımızı hep üst seviyede tutmamız gerektiğinin zorunlu olduğunu anlıyoruz. Kısa da olsa tanıtımını yaptığım bu mimarinin, günlük iş hayatınızda faydalı olması dileğiyle.

Saygılarımla,
Varsayımlarınızın sağlanması dileğiyle,
Veri ile kalın, Hoşça kalın…

Utku Kubilay ÇINAR – Data Scientist @ Alghanim Industries

 

Kapak Görseli: Suzanne D. Williams on Unsplash

 

Yazar Hakkında
Toplam 21 yazı
Utku Kubilay ÇINAR
Utku Kubilay ÇINAR
YTÜ - Doktora - Veri Bilimi - Alghanim Industries - Data Scientist
Yorumlar (Yorum yapılmamış)

Bir yanıt yazın

E-posta adresiniz yayınlanmayacak. Gerekli alanlar * ile işaretlenmişlerdir

×

Bir Şeyler Ara