DistributionHead#
- class DistributionHead(tag, distribution, in_features, out_features, loss_weight=1.0, metrics=None)[source]#
Bases:
BaseHead
Methods
calculate_loss
forward
get_outputs
reset
Attributes
- Parameters:
tag (str) –
distribution (Distribution) –
in_features (int) –
out_features (int) –
loss_weight (float) –
metrics (Metric | list[torchmetrics.metric.Metric] | dict[str, torchmetrics.metric.Metric]) –