DistributionHead#
- class DistributionHead(tag, distribution, in_features, out_features, loss_weight=1.0, metrics=None)[source]#
Bases:
BaseHeadMethods
calculate_lossforwardget_outputsresetAttributes
- Parameters:
tag (str) –
distribution (Distribution) –
in_features (int) –
out_features (int) –
loss_weight (float) –
metrics (Metric | list[torchmetrics.metric.Metric] | dict[str, torchmetrics.metric.Metric]) –