Source code for deep_time_series.core

from collections import defaultdict
from typing import Any, Callable

import pytorch_lightning as pl
import torch
import torch.distributions
import torch.nn as nn
from torchmetrics import Metric, MetricCollection

from .util import merge_dicts


[docs]class MetricModule(nn.Module): def __init__( self, tag: str, metrics: Metric | list[Metric] | dict[str, Metric] ): super().__init__() self.tag = tag if tag.startswith('head.'): self.head_tag = tag self.label_tag = f'label.{tag[5:]}' else: self.head_tag = f'head.{tag}' self.label_tag = f'label.{tag}' metrics = MetricCollection(metrics) self._metric_dict = nn.ModuleDict( { '_train': metrics.clone(prefix=f'train/{tag}.'), '_val': metrics.clone(prefix=f'val/{tag}.'), '_test': metrics.clone(prefix=f'test/{tag}.'), } )
[docs] def forward(self, outputs, batch, stage): return self._metric_dict[f'_{stage}']( outputs[self.head_tag], batch[self.label_tag] )
[docs] def compute(self, stage): return self._metric_dict[f'_{stage}'].compute()
[docs] def update(self, outputs, batch, stage): self._metric_dict[f'_{stage}'].update( outputs[self.head_tag], batch[self.label_tag] )
[docs] def reset(self, stage): self._metric_dict[f'_{stage}'].reset()
[docs]class BaseHead(nn.Module): _SPECIAL_ATTRIBUTES = ('tag', 'loss_weight', 'metrics') def __init__(self): """Base class of all Head classes.""" super().__init__() self._tag = None self._loss_weight = None self._metrics = None def __setattr__(self, name, value): if name in BaseHead._SPECIAL_ATTRIBUTES: return object.__setattr__(self, name, value) else: return super().__setattr__(name, value) @property def tag(self) -> str: """Tag for a head. Prefix 'head.' is added automatically.""" if self._tag is None: raise NotImplementedError(f'Define {self.__class__.__name__}.tag') else: return self._tag @tag.setter def tag(self, value: str): if not isinstance(value, str): raise TypeError(f'Invalid type for "tag": {type(value)}') if not value.startswith('head.'): value = f'head.{value}' self._tag = value @property def metrics(self) -> MetricModule: if self._metrics is None: raise NotImplementedError( f'Define {self.__class__.__name__}.metrics' ) else: return self._metrics @metrics.setter def metrics(self, value: Metric | list[Metric] | dict[str, Metric]): metric_module = MetricModule(tag=self.tag, metrics=value) self._metrics = metric_module @property def has_metrics(self): return self._metrics is not None @property def loss_weight(self) -> float: """Loss weight for loss calculations.""" if self._loss_weight is None: self._loss_weight = 1.0 return self._loss_weight @loss_weight.setter def loss_weight(self, value: float): if not isinstance(value, (float, int)): raise TypeError(f'Invalid type for "loss_weight": {type(value)}') elif value < 0: raise ValueError('loss_weight < 0') self._loss_weight = value @property def label_tag(self) -> str: """Tag of target label. If the tag of head is "head.my_tag" then label_tag is "label.my_tag". """ return f'label.{self.tag[5:]}'
[docs] def forward(self, inputs: Any) -> torch.Tensor: raise NotImplementedError(f'Define {self.__class__.__name__}.forward()')
[docs] def get_outputs(self) -> dict[str, Any]: raise NotImplementedError( f'Define {self.__class__.__name__}.get_outputs()' )
[docs] def reset(self): raise NotImplementedError(f'Define {self.__class__.__name__}.reset()')
[docs] def calculate_loss( self, outputs: dict[str, Any], batch: dict[str, Any] ) -> torch.Tensor: raise NotImplementedError( f'Define {self.__class__.__name__}.calculate_loss()' )
[docs]class DistributionHead(BaseHead): def __init__( self, tag: str, distribution: torch.distributions.Distribution, in_features: int, out_features: int, loss_weight: float = 1.0, metrics: Metric | list[Metric] | dict[str, Metric] = None, ): super().__init__() self.tag = tag self.loss_weight = loss_weight if metrics is not None: self.metrics = metrics linears = {} transforms = {} for k, v in distribution.arg_constraints.items(): # 'logits' is prefered. if k == 'probs': continue linears[k] = nn.Linear(in_features, out_features) transforms[k] = torch.distributions.transform_to(v) self.distribution = distribution self.linears = nn.ModuleDict(linears) self.transforms = transforms self._outputs = defaultdict(list)
[docs] def forward(self, x): kwargs = { k: self.transforms[k](layer(x)) for k, layer in self.linears.items() } m = self.distribution(**kwargs) y = m.sample() for k, v in kwargs.items(): self._outputs[f'{self.tag}.{k}'].append(v) self._outputs[self.tag].append(y) return y
[docs] def get_outputs(self): outputs = {} for k, v in self._outputs.items(): outputs[k] = torch.cat(v, dim=1) return outputs
[docs] def reset(self): self._outputs = defaultdict(list)
[docs] def calculate_loss(self, outputs, batch) -> torch.Tensor: kwargs = {k: outputs[f'{self.tag}.{k}'] for k in self.linears.keys()} m = self.distribution(**kwargs) return -torch.mean(m.log_prob(batch[self.label_tag]))
[docs]class ForecastingModule(pl.LightningModule): _SPECIAL_ATTRIBUTES = ( 'encoding_length', 'decoding_length', 'head', 'heads', ) def __init__(self): """Base class of all forecasting modules.""" super().__init__() self._encoding_length = None self._decoding_length = None self._heads = None def __setattr__(self, name, value): if name in ForecastingModule._SPECIAL_ATTRIBUTES: return object.__setattr__(self, name, value) else: return super().__setattr__(name, value) @property def encoding_length(self) -> int: """Encoding length.""" if self._encoding_length is None: raise NotImplementedError( f'Define {self.__class__.__name__}.encoding_length' ) else: return self._encoding_length @encoding_length.setter def encoding_length(self, value: int): if not isinstance(value, int): raise TypeError( f'Invalid type for "encoding_length": {type(value)}' ) elif value <= 0: raise ValueError('Encoding length <= 0.') self._encoding_length = value @property def decoding_length(self) -> int: if self._decoding_length is None: raise NotImplementedError( f'Define {self.__class__.__name__}.decoding_length' ) else: return self._decoding_length @decoding_length.setter def decoding_length(self, value: int): if not isinstance(value, int): raise TypeError( f'Invalid type for "decoding_length": {type(value)}' ) elif value <= 0: raise ValueError('Decoding length <= 0.') self._decoding_length = value @property def heads(self) -> list[BaseHead]: if self._heads is None: raise NotImplementedError(f'Define {self.__class__.__name__}.heads') else: return self._heads @heads.setter def heads(self, heads: list[BaseHead]): if not isinstance(heads, list): raise TypeError(f'Invalid type for "heads". {type(heads)}') elif not all(isinstance(head, BaseHead) for head in heads): raise TypeError( f'Invalid type for "heads". {[type(v) for v in heads]}' ) self._heads = nn.ModuleList(heads) @property def head(self) -> BaseHead: if self._heads is None: raise NotImplementedError(f'Define {self.__class__.__name__}.heads') elif len(self._heads) != 1: raise Exception('Multi-head model cannot use head.') else: return self._heads[0] @head.setter def head(self, head: BaseHead): if not isinstance(head, BaseHead): raise TypeError(f'Invalid type for "heads". {type(head)}') self.heads = [head]
[docs] def encode(self, inputs: dict[str, Any]) -> dict[str, Any]: raise NotImplementedError(f'Define {self.__class__.__name__}.encode()')
[docs] def decode_train(self, inputs: dict[str, Any]) -> dict[str, Any]: return self.decode_eval(inputs)
[docs] def decode_eval(self, inputs: dict[str, Any]) -> dict[str, Any]: NotImplementedError(f'Define {self.__class__.__name__}.decode()')
[docs] def decode(self, inputs): if self.training: return self.decode_train(inputs) else: return self.decode_eval(inputs)
[docs] def forward(self, inputs: dict[str, Any]) -> dict[str, Any]: encoder_outputs = self.encode(inputs) decoder_inputs = merge_dicts([inputs, encoder_outputs]) outputs = self.decode(decoder_inputs) return outputs
[docs] def make_chunk_specs(self): pass
[docs] def calculate_loss( self, outputs: dict[str, Any], batch: dict[str, Any] ) -> dict[str, Any]: loss = 0 for head in self.heads: loss += head.loss_weight * head.calculate_loss(outputs, batch) return loss
[docs] def forward_metrics( self, outputs: dict[str, Any], batch: dict[str, Any], stage: str, ) -> dict[str, Any]: metrics = {} for head in self.heads: if not head.has_metrics: continue # It's a dictionary's update method. # Don't confuse with update of TorchMetric. metrics.update( # __call__ of TorchMetric. head.metrics(outputs=outputs, batch=batch, stage=stage) ) return metrics
[docs] def update_metrics( self, outputs: dict[str, Any], batch: dict[str, Any], stage: str, ) -> None: for head in self.heads: if not head.has_metrics: continue head.metrics.update(outputs=outputs, batch=batch, stage=stage)
[docs] def compute_metrics(self, stage: str) -> dict[str, Any]: metrics = {} for head in self.heads: if not head.has_metrics: continue metrics.update(head.metrics.compute(stage=stage)) return metrics
[docs] def reset_metrics(self, stage: str) -> None: for head in self.heads: if not head.has_metrics: continue head.metrics.reset(stage=stage)
[docs] def training_step( self, batch: dict[str, Any], batch_idx: int ) -> dict[str, Any]: outputs = self(batch) loss = self.calculate_loss(outputs, batch) # Update and evaluate metric. metrics = self.forward_metrics(outputs, batch, stage='train') self.log('train/loss', loss) # Log instant metrics. self.log_dict(metrics) return loss
[docs] def on_train_epoch_end(self) -> None: # Don't log epoch averaged metrics. Just reset the states. self.reset_metrics(stage='train')
[docs] def validation_step( self, batch: dict[str, Any], batch_idx: int, dataloader_idx: int = 0, ): outputs = self(batch) loss = self.calculate_loss(outputs, batch) # Don't log metrics yet. self.update_metrics(outputs, batch, stage='val') # loss will be epoch averaged. self.log('val/loss', loss)
[docs] def on_validation_epoch_end(self) -> None: metrics = self.compute_metrics(stage='val') self.log_dict(metrics) self.reset_metrics(stage='val')
[docs] def test_step( self, batch: dict[str, Any], batch_idx: int, dataloader_idx: int = 0, ): outputs = self(batch) loss = self.calculate_loss(outputs, batch) # Don't log metrics yet. self.update_metrics(outputs, batch, stage='test') # loss will be epoch averaged. self.log('test/loss', loss)
[docs] def on_test_epoch_end(self) -> None: metrics = self.compute_metrics(stage='test') self.log_dict(metrics) self.reset_metrics(stage='test')