BatchLoss

class lib.training.loss.BatchLoss(unweighted: list[dict[str, Tensor]], weighted: list[dict[str, Tensor]], mask: Tensor | None = None)

Bases: object

Dataclass for holding Loss values for a batch of data

Attributes Summary

mask

The loss scalar for the mask for each item in the batch if learn_mask is selected otherwise None.

total

The total single weighted loss scalar for all items in the batch for backprop

Methods Summary

to_cpu()

Detaches all contained loss values and moves them to CPU

Attributes Documentation

Parameters:
  • unweighted (list[dict[str, Tensor]])

  • weighted (list[dict[str, Tensor]])

  • mask (Tensor | None)

mask: Tensor | None = None

The loss scalar for the mask for each item in the batch if learn_mask is selected otherwise None. Default: None

total

The total single weighted loss scalar for all items in the batch for backprop

Methods Documentation

to_cpu() Self

Detaches all contained loss values and moves them to CPU

Return type:

This object with all tensors detached and moved to CPU

mask: Tensor | None = None

The loss scalar for the mask for each item in the batch if learn_mask is selected otherwise None. Default: None

to_cpu() Self

Detaches all contained loss values and moves them to CPU

Return type:

This object with all tensors detached and moved to CPU

property total: Tensor

The total single weighted loss scalar for all items in the batch for backprop

unweighted: list[dict[str, Tensor]] = <dataclasses._MISSING_TYPE object>

For each side output, the unweighted loss scalars for each function for each item in the batch

weighted: list[dict[str, Tensor]] = <dataclasses._MISSING_TYPE object>

For each side output, the weighted loss scalars for each function for each item in the batch