Source code for deep_time_series.model.mlp

import numpy as np
import torch
import torch.nn as nn

from ..chunk import DecodingChunkSpec, EncodingChunkSpec, LabelChunkSpec
from ..core import ForecastingModule, Head
from ..layer import Unsqueesze


[docs]class MLP(ForecastingModule): def __init__( self, hidden_size, encoding_length, decoding_length, target_names, nontarget_names, n_hidden_layers, activation=nn.ELU, dropout_rate=0.0, lr=1e-3, optimizer: torch.optim.Optimizer = torch.optim.Adam, optimizer_options=None, loss_fn=None, metrics=None, head=None, ): super().__init__() self.save_hyperparameters() self.encoding_length = encoding_length self.decoding_length = decoding_length if optimizer_options is None: self.hparams.optimizer_options = {} if loss_fn is None: loss_fn = nn.MSELoss() n_outputs = len(target_names) n_features = len(nontarget_names) + n_outputs self.use_nontargets = n_outputs != n_features layers = [ nn.Flatten(), nn.Linear(n_features * encoding_length, hidden_size), activation(), ] for i in range(n_hidden_layers): if dropout_rate > 1e-6: layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.Linear(hidden_size, hidden_size)) layers.append(activation()) if dropout_rate > 1e-6: layers.append(nn.Dropout(p=dropout_rate)) layers.append(Unsqueesze(1)) self.body = nn.Sequential(*layers) if head is not None: self.head = head else: self.head = Head( tag='targets', output_module=nn.Linear(hidden_size, n_outputs), loss_fn=loss_fn, metrics=metrics, ) def encode(self, inputs): # (B, L, F). if self.use_nontargets: x = torch.cat( [inputs['encoding.targets'], inputs['encoding.nontargets']], dim=2, ) else: x = inputs['encoding.targets'] return {'x': x} def decode(self, inputs): # F = T + C. # (B, L, F) x = inputs['x'] # (B, L, C). if self.use_nontargets: c = inputs['decoding.nontargets'] B = x.size(0) EL = self.encoding_length self.head.reset() for i in range(self.decoding_length): # (B, L*F) # x = x.view(B, -1) # (B, 1, n_outputs). y = self.body(x) # y = y.unsqueeze(1) y = self.head(y) if i + 1 == self.decoding_length: break # (B, 1, F). if self.use_nontargets: z = torch.cat([y, c[:, i : i + 1, :]], dim=2) else: z = y # z = y # (B, EL, F). x = x.view(B, EL, -1) # (B, L, F). x = torch.cat([x[:, 1:, :], z], dim=1) outputs = self.head.get_outputs() return outputs def make_chunk_specs(self): E = self.encoding_length D = self.decoding_length chunk_specs = [ EncodingChunkSpec( tag='targets', names=self.hparams.target_names, range_=(0, E), dtype=np.float32, ), LabelChunkSpec( tag='targets', names=self.hparams.target_names, range_=(E, E + D), dtype=np.float32, ), ] if self.use_nontargets: chunk_specs += [ EncodingChunkSpec( tag='nontargets', names=self.hparams.nontarget_names, range_=(1, E + 1), dtype=np.float32, ), DecodingChunkSpec( tag='nontargets', names=self.hparams.nontarget_names, range_=(E + 1, E + D), dtype=np.float32, ), ] return chunk_specs def configure_optimizers(self): return self.hparams.optimizer( self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_options, )