BaseHead#
- class BaseHead[source]#
Bases:
ModuleBase class of all Head classes. Head is a module that takes the output of the last layer of body and produces the output of the model. It also calculates the loss and metrics. User have to define following attributes:
- tag: str
Tag for a head. Prefix ‘head.’ is added automatically.
- loss_weight: float
Loss weight for loss calculations.
- metrics: Metric | list[Metric] | dict[str, Metric]
Metrics for training, validation and test stage.
User have to define following methods:
- forward(self, inputs: Any) -> torch.Tensor
Forward method of the head.
- get_outputs(self) -> dict[str, Any]
Get outputs of the head. It produces a dictionary of outputs from internal state of the head.
- reset(self)
Reset the internal states of the head.
- calculate_loss(self, outputs: dict[str, Any], batch: dict[str, Any]) -> torch.Tensor
Calculate loss of the head.
Methods
calculate_lossforwardget_outputsresetAttributes
has_metricsTag of target label.
Loss weight for loss calculations.
metricsTag for a head.
- property label_tag: str#
Tag of target label. If the tag of head is “head.my_tag” then label_tag is “label.my_tag”.
- property loss_weight: float#
Loss weight for loss calculations.
- property tag: str#
Tag for a head. Prefix ‘head.’ is added automatically.