BaseHead#

class BaseHead[source]#

Bases: Module

Base class of all Head classes. Head is a module that takes the output of the last layer of body and produces the output of the model. It also calculates the loss and metrics. User have to define following attributes:

  • tag: str

    Tag for a head. Prefix ‘head.’ is added automatically.

  • loss_weight: float

    Loss weight for loss calculations.

  • metrics: Metric | list[Metric] | dict[str, Metric]

    Metrics for training, validation and test stage.

User have to define following methods:

  • forward(self, inputs: Any) -> torch.Tensor

    Forward method of the head.

  • get_outputs(self) -> dict[str, Any]

    Get outputs of the head. It produces a dictionary of outputs from internal state of the head.

  • reset(self)

    Reset the internal states of the head.

  • calculate_loss(self, outputs: dict[str, Any], batch: dict[str, Any]) -> torch.Tensor

    Calculate loss of the head.

Methods

calculate_loss

forward

get_outputs

reset

Attributes

has_metrics

label_tag

Tag of target label.

loss_weight

Loss weight for loss calculations.

metrics

tag

Tag for a head.

property label_tag: str#

Tag of target label. If the tag of head is “head.my_tag” then label_tag is “label.my_tag”.

property loss_weight: float#

Loss weight for loss calculations.

property tag: str#

Tag for a head. Prefix ‘head.’ is added automatically.