DistributionHead#

class DistributionHead(tag, distribution, in_features, out_features, loss_weight=1.0, metrics=None)[source]#

Bases: BaseHead

Methods

calculate_loss

forward

get_outputs

reset

Attributes

Parameters:
  • tag (str) –

  • distribution (Distribution) –

  • in_features (int) –

  • out_features (int) –

  • loss_weight (float) –

  • metrics (Metric | list[torchmetrics.metric.Metric] | dict[str, torchmetrics.metric.Metric]) –