BaseHead#
- class BaseHead[source]#
Bases:
Module
Base class of all Head classes. Head is a module that takes the output of the last layer of body and produces the output of the model. It also calculates the loss and metrics. User have to define following attributes:
- tag: str
Tag for a head. Prefix ‘head.’ is added automatically.
- loss_weight: float
Loss weight for loss calculations.
- metrics: Metric | list[Metric] | dict[str, Metric]
Metrics for training, validation and test stage.
User have to define following methods:
- forward(self, inputs: Any) -> torch.Tensor
Forward method of the head.
- get_outputs(self) -> dict[str, Any]
Get outputs of the head. It produces a dictionary of outputs from internal state of the head.
- reset(self)
Reset the internal states of the head.
- calculate_loss(self, outputs: dict[str, Any], batch: dict[str, Any]) -> torch.Tensor
Calculate loss of the head.
Methods
calculate_loss
forward
get_outputs
reset
Attributes
has_metrics
Tag of target label.
Loss weight for loss calculations.
metrics
Tag for a head.
- property label_tag: str#
Tag of target label. If the tag of head is “head.my_tag” then label_tag is “label.my_tag”.
- property loss_weight: float#
Loss weight for loss calculations.
- property tag: str#
Tag for a head. Prefix ‘head.’ is added automatically.