from collections import defaultdict
from typing import Any, Callable
import pytorch_lightning as pl
import torch
import torch.distributions
import torch.nn as nn
from torchmetrics import Metric, MetricCollection
from .util import merge_dicts
[docs]class MetricModule(nn.Module):
    def __init__(
        self, tag: str, metrics: Metric | list[Metric] | dict[str, Metric]
    ):
        """Module for metrics. It's a wrapper of MetricCollections for training,
        validation and test stage. It's used in BaseHead class. In general,
        it's not used directly by users. """
        super().__init__()
        self.tag = tag
        if tag.startswith('head.'):
            self.head_tag = tag
            self.label_tag = f'label.{tag[5:]}'
        else:
            self.head_tag = f'head.{tag}'
            self.label_tag = f'label.{tag}'
        metrics = MetricCollection(metrics)
        self._metric_dict = nn.ModuleDict(
            {
                '_train': metrics.clone(prefix=f'train/{tag}.'),
                '_val': metrics.clone(prefix=f'val/{tag}.'),
                '_test': metrics.clone(prefix=f'test/{tag}.'),
            }
        )
    def forward(self, outputs, batch, stage):
        return self._metric_dict[f'_{stage}'](
            outputs[self.head_tag], batch[self.label_tag]
        )
    def compute(self, stage):
        return self._metric_dict[f'_{stage}'].compute()
    def update(self, outputs, batch, stage):
        self._metric_dict[f'_{stage}'].update(
            outputs[self.head_tag], batch[self.label_tag]
        )
    def reset(self, stage):
        self._metric_dict[f'_{stage}'].reset() 
[docs]class BaseHead(nn.Module):
    _SPECIAL_ATTRIBUTES = ('tag', 'loss_weight', 'metrics')
    def __init__(self):
        """Base class of all Head classes. Head is a module that takes the
        output of the last layer of body and produces the output of the model.
        It also calculates the loss and metrics. User have to define following
        attributes:
        * tag: str
            Tag for a head. Prefix 'head.' is added automatically.
        * loss_weight: float
            Loss weight for loss calculations.
        * metrics: Metric | list[Metric] | dict[str, Metric]
            Metrics for training, validation and test stage.
        User have to define following methods:
        * forward(self, inputs: Any) -> torch.Tensor
            Forward method of the head.
        * get_outputs(self) -> dict[str, Any]
            Get outputs of the head. It produces a dictionary of outputs from
            internal state of the head.
        * reset(self)
            Reset the internal states of the head.
        * calculate_loss(self, outputs: dict[str, Any], batch: dict[str, Any]) -> torch.Tensor
            Calculate loss of the head.
        """
        super().__init__()
        self._tag = None
        self._loss_weight = None
        self._metrics = None
    def __setattr__(self, name, value):
        if name in BaseHead._SPECIAL_ATTRIBUTES:
            return object.__setattr__(self, name, value)
        else:
            return super().__setattr__(name, value)
    @property
    def tag(self) -> str:
        """Tag for a head. Prefix 'head.' is added automatically."""
        if self._tag is None:
            raise NotImplementedError(f'Define {self.__class__.__name__}.tag')
        else:
            return self._tag
    @tag.setter
    def tag(self, value: str):
        if not isinstance(value, str):
            raise TypeError(f'Invalid type for "tag": {type(value)}')
        if not value.startswith('head.'):
            value = f'head.{value}'
        self._tag = value
    @property
    def metrics(self) -> MetricModule:
        if self._metrics is None:
            raise NotImplementedError(
                f'Define {self.__class__.__name__}.metrics'
            )
        else:
            return self._metrics
    @metrics.setter
    def metrics(self, value: Metric | list[Metric] | dict[str, Metric]):
        metric_module = MetricModule(tag=self.tag, metrics=value)
        self._metrics = metric_module
    @property
    def has_metrics(self):
        return self._metrics is not None
    @property
    def loss_weight(self) -> float:
        """Loss weight for loss calculations."""
        if self._loss_weight is None:
            self._loss_weight = 1.0
        return self._loss_weight
    @loss_weight.setter
    def loss_weight(self, value: float):
        if not isinstance(value, (float, int)):
            raise TypeError(f'Invalid type for "loss_weight": {type(value)}')
        elif value < 0:
            raise ValueError('loss_weight < 0')
        self._loss_weight = value
    @property
    def label_tag(self) -> str:
        """Tag of target label. If the tag of head is "head.my_tag" then
        label_tag is "label.my_tag".
        """
        return f'label.{self.tag[5:]}'
    def forward(self, inputs: Any) -> torch.Tensor:
        raise NotImplementedError(f'Define {self.__class__.__name__}.forward()')
    def get_outputs(self) -> dict[str, Any]:
        raise NotImplementedError(
            f'Define {self.__class__.__name__}.get_outputs()'
        )
    def reset(self):
        raise NotImplementedError(f'Define {self.__class__.__name__}.reset()')
    def calculate_loss(
        self, outputs: dict[str, Any], batch: dict[str, Any]
    ) -> torch.Tensor:
        raise NotImplementedError(
            f'Define {self.__class__.__name__}.calculate_loss()'
        ) 
[docs]class Head(BaseHead):
    def __init__(
        self,
        tag: str,
        output_module: nn.Module,
        loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
        loss_weight: float = 1.0,
        metrics: Metric | list[Metric] | dict[str, Metric] = None,
    ):
        super().__init__()
        self.tag = tag
        self.output_module = output_module
        self.loss_fn = loss_fn
        self.loss_weight = loss_weight
        if metrics is not None:
            self.metrics = metrics
        self._ys = []
    def forward(self, inputs: Any) -> torch.Tensor:
        y = self.output_module(inputs)
        self._ys.append(y)
        return y
    def get_outputs(self):
        return {self.tag: torch.cat(self._ys, dim=1)}
    def reset(self):
        self._ys = []
    def calculate_loss(
        self, outputs: dict[str, Any], batch: dict[str, Any]
    ) -> torch.Tensor:
        return self.loss_fn(outputs[self.tag], batch[self.label_tag]) 
[docs]class DistributionHead(BaseHead):
    def __init__(
        self,
        tag: str,
        distribution: torch.distributions.Distribution,
        in_features: int,
        out_features: int,
        loss_weight: float = 1.0,
        metrics: Metric | list[Metric] | dict[str, Metric] = None,
    ):
        super().__init__()
        self.tag = tag
        self.loss_weight = loss_weight
        if metrics is not None:
            self.metrics = metrics
        linears = {}
        transforms = {}
        for k, v in distribution.arg_constraints.items():
            # 'logits' is prefered.
            if k == 'probs':
                continue
            linears[k] = nn.Linear(in_features, out_features)
            transforms[k] = torch.distributions.transform_to(v)
        self.distribution = distribution
        self.linears = nn.ModuleDict(linears)
        self.transforms = transforms
        self._outputs = defaultdict(list)
    def forward(self, x):
        kwargs = {
            k: self.transforms[k](layer(x)) for k, layer in self.linears.items()
        }
        m = self.distribution(**kwargs)
        y = m.sample()
        for k, v in kwargs.items():
            self._outputs[f'{self.tag}.{k}'].append(v)
        self._outputs[self.tag].append(y)
        return y
    def get_outputs(self):
        outputs = {}
        for k, v in self._outputs.items():
            outputs[k] = torch.cat(v, dim=1)
        return outputs
    def reset(self):
        self._outputs = defaultdict(list)
    def calculate_loss(self, outputs, batch) -> torch.Tensor:
        kwargs = {k: outputs[f'{self.tag}.{k}'] for k in self.linears.keys()}
        m = self.distribution(**kwargs)
        return -torch.mean(m.log_prob(batch[self.label_tag])) 
[docs]class ForecastingModule(pl.LightningModule):
    _SPECIAL_ATTRIBUTES = (
        'encoding_length',
        'decoding_length',
        'head',
        'heads',
    )
    def __init__(self):
        """Base class of all forecasting modules."""
        super().__init__()
        self._encoding_length = None
        self._decoding_length = None
        self._heads = None
    def __setattr__(self, name, value):
        if name in ForecastingModule._SPECIAL_ATTRIBUTES:
            return object.__setattr__(self, name, value)
        else:
            return super().__setattr__(name, value)
    @property
    def encoding_length(self) -> int:
        """Encoding length."""
        if self._encoding_length is None:
            raise NotImplementedError(
                f'Define {self.__class__.__name__}.encoding_length'
            )
        else:
            return self._encoding_length
    @encoding_length.setter
    def encoding_length(self, value: int):
        if not isinstance(value, int):
            raise TypeError(
                f'Invalid type for "encoding_length": {type(value)}'
            )
        elif value <= 0:
            raise ValueError('Encoding length <= 0.')
        self._encoding_length = value
    @property
    def decoding_length(self) -> int:
        if self._decoding_length is None:
            raise NotImplementedError(
                f'Define {self.__class__.__name__}.decoding_length'
            )
        else:
            return self._decoding_length
    @decoding_length.setter
    def decoding_length(self, value: int):
        if not isinstance(value, int):
            raise TypeError(
                f'Invalid type for "decoding_length": {type(value)}'
            )
        elif value <= 0:
            raise ValueError('Decoding length <= 0.')
        self._decoding_length = value
    @property
    def heads(self) -> list[BaseHead]:
        if self._heads is None:
            raise NotImplementedError(f'Define {self.__class__.__name__}.heads')
        else:
            return self._heads
    @heads.setter
    def heads(self, heads: list[BaseHead]):
        if not isinstance(heads, list):
            raise TypeError(f'Invalid type for "heads". {type(heads)}')
        elif not all(isinstance(head, BaseHead) for head in heads):
            raise TypeError(
                f'Invalid type for "heads". {[type(v) for v in heads]}'
            )
        self._heads = nn.ModuleList(heads)
    @property
    def head(self) -> BaseHead:
        if self._heads is None:
            raise NotImplementedError(f'Define {self.__class__.__name__}.heads')
        elif len(self._heads) != 1:
            raise Exception('Multi-head model cannot use head.')
        else:
            return self._heads[0]
    @head.setter
    def head(self, head: BaseHead):
        if not isinstance(head, BaseHead):
            raise TypeError(f'Invalid type for "heads". {type(head)}')
        self.heads = [head]
    def encode(self, inputs: dict[str, Any]) -> dict[str, Any]:
        raise NotImplementedError(f'Define {self.__class__.__name__}.encode()')
    def decode_train(self, inputs: dict[str, Any]) -> dict[str, Any]:
        return self.decode_eval(inputs)
    def decode_eval(self, inputs: dict[str, Any]) -> dict[str, Any]:
        NotImplementedError(f'Define {self.__class__.__name__}.decode()')
    def decode(self, inputs):
        if self.training:
            return self.decode_train(inputs)
        else:
            return self.decode_eval(inputs)
    def forward(self, inputs: dict[str, Any]) -> dict[str, Any]:
        encoder_outputs = self.encode(inputs)
        decoder_inputs = merge_dicts([inputs, encoder_outputs])
        outputs = self.decode(decoder_inputs)
        return outputs
    def make_chunk_specs(self):
        pass
    def calculate_loss(
        self, outputs: dict[str, Any], batch: dict[str, Any]
    ) -> dict[str, Any]:
        loss = 0
        for head in self.heads:
            loss += head.loss_weight * head.calculate_loss(outputs, batch)
        return loss
    def forward_metrics(
        self,
        outputs: dict[str, Any],
        batch: dict[str, Any],
        stage: str,
    ) -> dict[str, Any]:
        metrics = {}
        for head in self.heads:
            if not head.has_metrics:
                continue
            # It's a dictionary's update method.
            # Don't confuse with update of TorchMetric.
            metrics.update(
                # __call__ of TorchMetric.
                head.metrics(outputs=outputs, batch=batch, stage=stage)
            )
        return metrics
    def update_metrics(
        self,
        outputs: dict[str, Any],
        batch: dict[str, Any],
        stage: str,
    ) -> None:
        for head in self.heads:
            if not head.has_metrics:
                continue
            head.metrics.update(outputs=outputs, batch=batch, stage=stage)
    def compute_metrics(self, stage: str) -> dict[str, Any]:
        metrics = {}
        for head in self.heads:
            if not head.has_metrics:
                continue
            metrics.update(head.metrics.compute(stage=stage))
        return metrics
    def reset_metrics(self, stage: str) -> None:
        for head in self.heads:
            if not head.has_metrics:
                continue
            head.metrics.reset(stage=stage)
    def training_step(
        self, batch: dict[str, Any], batch_idx: int
    ) -> dict[str, Any]:
        outputs = self(batch)
        loss = self.calculate_loss(outputs, batch)
        # Update and evaluate metric.
        metrics = self.forward_metrics(outputs, batch, stage='train')
        self.log('train/loss', loss)
        # Log instant metrics.
        self.log_dict(metrics)
        return loss
    def on_train_epoch_end(self) -> None:
        # Don't log epoch averaged metrics. Just reset the states.
        self.reset_metrics(stage='train')
    def validation_step(
        self,
        batch: dict[str, Any],
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        outputs = self(batch)
        loss = self.calculate_loss(outputs, batch)
        # Don't log metrics yet.
        self.update_metrics(outputs, batch, stage='val')
        # loss will be epoch averaged.
        self.log('val/loss', loss)
    def on_validation_epoch_end(self) -> None:
        metrics = self.compute_metrics(stage='val')
        self.log_dict(metrics)
        self.reset_metrics(stage='val')
    def test_step(
        self,
        batch: dict[str, Any],
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        outputs = self(batch)
        loss = self.calculate_loss(outputs, batch)
        # Don't log metrics yet.
        self.update_metrics(outputs, batch, stage='test')
        # loss will be epoch averaged.
        self.log('test/loss', loss)
    def on_test_epoch_end(self) -> None:
        metrics = self.compute_metrics(stage='test')
        self.log_dict(metrics)
        self.reset_metrics(stage='test')