Source code for deep_time_series.model.single_shot_transformer

import numpy as np
import torch
import torch.nn as nn

from ..chunk import DecodingChunkSpec, EncodingChunkSpec, LabelChunkSpec
from ..core import ForecastingModule, Head
from ..layer import PositionalEncoding


[docs]class SingleShotTransformer(ForecastingModule): def __init__( self, encoding_length, decoding_length, target_names, nontarget_names, d_model, n_heads, n_layers, dim_feedforward=None, dropout_rate=0.0, lr=1e-3, optimizer: torch.optim.Optimizer = torch.optim.Adam, optimizer_options=None, loss_fn=None, metrics=None, head=None, ): super().__init__() self.save_hyperparameters() self.encoding_length = encoding_length self.decoding_length = decoding_length if optimizer_options is None: self.hparams.optimizer_options = {} if loss_fn is None: loss_fn = nn.MSELoss() if dim_feedforward is None: dim_feedforward = 4 * d_model n_targets = len(target_names) n_nontargets = len(nontarget_names) n_features = n_nontargets + n_targets self.use_nontargets = n_nontargets > 0 self.encoder_d_matching_layer = nn.Linear( in_features=n_features, out_features=d_model, ) if self.use_nontargets: self.decoder_d_matching_layer = nn.Linear( in_features=n_nontargets, out_features=d_model, ) self.positional_encoding = PositionalEncoding( d_model=d_model, max_len=max(encoding_length, decoding_length), ) encoder_layer = nn.TransformerEncoderLayer( d_model, n_heads, dim_feedforward=dim_feedforward, dropout=dropout_rate, batch_first=True, ) self.encoder = nn.TransformerEncoder(encoder_layer, n_layers) decoder_layer = nn.TransformerDecoderLayer( d_model, n_heads, dim_feedforward=dim_feedforward, dropout=dropout_rate, batch_first=True, ) self.decoder = nn.TransformerDecoder(decoder_layer, n_layers) if head is not None: self.head = head else: self.head = Head( tag='targets', output_module=nn.Linear(d_model, n_targets), loss_fn=loss_fn, metrics=metrics, ) def encode(self, inputs): # L: encoding length. # all_input: (B, L, F). if self.use_nontargets: x = torch.cat( [inputs['encoding.targets'], inputs['encoding.nontargets']], dim=2, ) else: x = inputs['encoding.targets'] x = self.encoder_d_matching_layer(x) x = self.positional_encoding(x) # (B, L, d_model). memory = self.encoder(x) return {'memory': memory} def decode(self, inputs): # L: decoding_length memory = inputs['memory'] if self.use_nontargets: x = inputs['decoding.nontargets'] x = self.decoder_d_matching_layer(x) else: # Same device will be used automatically. x = torch.zeros_like(memory) x = self.positional_encoding(x) # (B, L, d_model). tgt_mask = self.generate_square_subsequent_mask(x.size(1)) x = self.decoder(tgt=x, memory=memory, tgt_mask=tgt_mask) # (B, L, n_outputs). self.head.reset() self.head(x) outputs = self.head.get_outputs() return outputs
[docs] def generate_square_subsequent_mask(self, sz): r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). """ return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1).to( self.device )
def make_chunk_specs(self): E = self.encoding_length D = self.decoding_length chunk_specs = [ EncodingChunkSpec( tag='targets', names=self.hparams.target_names, range_=(0, E), dtype=np.float32, ), LabelChunkSpec( tag='targets', names=self.hparams.target_names, range_=(E, E + D), dtype=np.float32, ), ] if self.use_nontargets: chunk_specs += [ EncodingChunkSpec( tag='nontargets', names=self.hparams.nontarget_names, range_=(0, E), dtype=np.float32, ), DecodingChunkSpec( tag='nontargets', names=self.hparams.nontarget_names, range_=(E, E + D), dtype=np.float32, ), ] return chunk_specs def configure_optimizers(self): return self.hparams.optimizer( self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_options, )