Head#
- class Head(tag, output_module, loss_fn, loss_weight=1.0, metrics=None)[source]#
Bases:
BaseHead
Methods
calculate_loss
forward
get_outputs
reset
Attributes
- Parameters:
tag (str) –
output_module (Module) –
loss_fn (Callable[[Tensor, Tensor], Tensor]) –
loss_weight (float) –
metrics (Metric | list[torchmetrics.metric.Metric] | dict[str, torchmetrics.metric.Metric]) –