TrainerBase

class plugins.train.trainer.base.TrainerBase(model: ModelBase, config: TrainConfig)

Bases: ABC

A trainer plugin interface. It must implement the method “train_batch” which takes an input of inputs to the model and target images for model output. It returns loss per side

Parameters:
  • model (ModelBase) – The model plugin

  • config (TrainConfig) – The Training Configuration options

Methods Summary

get_sampler()

Override to set the sampler that the Torch DataLoader should use

register_loss(loss)

Registers the selected loss functions to the underlying model nn.module

train_batch(inputs, targets, optimizer, meta)

Override to run a single forward and backwards pass through the model for a single batch

Methods Documentation

abstractmethod get_sampler() type[RandomSampler | DistributedSampler]

Override to set the sampler that the Torch DataLoader should use

Return type:

The sampler that the torch DataLoader should use

register_loss(loss: LossCollator) None

Registers the selected loss functions to the underlying model nn.module

Parameters:

loss (LossCollator) – The configured loss functions

Return type:

None

abstractmethod train_batch(inputs: list[torch.Tensor], targets: list[torch.Tensor], optimizer: Optimizer, meta: BatchMeta) list[BatchLoss]

Override to run a single forward and backwards pass through the model for a single batch

Parameters:
  • inputs (list[torch.Tensor]) – The batch of input image tensors to the model of length(num inputs)

  • targets (list[torch.Tensor]) – List of len (num_outputs) of target images in shape (batch_size, num_inputs, height, width, 3) at all model output sizes as float32 0.0 - 1.0 range

  • optimizer (Optimizer) – The configured Optimizer to use

  • meta (BatchMeta) – The meta information for the batch

Return type:

The loss for each input to the model in order (A, B, …)

batch_size

The batch size for each iteration to be trained through the model.

config

Training configuration options

abstractmethod get_sampler() type[RandomSampler | DistributedSampler]

Override to set the sampler that the Torch DataLoader should use

Return type:

The sampler that the torch DataLoader should use

loss_func: LossCollator

The selected loss functions for the model

model

The model plugin to train the batch on

register_loss(loss: LossCollator) None

Registers the selected loss functions to the underlying model nn.module

Parameters:

loss (LossCollator) – The configured loss functions

Return type:

None

sampler

The data sampler that the data loader should use

abstractmethod train_batch(inputs: list[torch.Tensor], targets: list[torch.Tensor], optimizer: Optimizer, meta: BatchMeta) list[BatchLoss]

Override to run a single forward and backwards pass through the model for a single batch

Parameters:
  • inputs (list[torch.Tensor]) – The batch of input image tensors to the model of length(num inputs)

  • targets (list[torch.Tensor]) – List of len (num_outputs) of target images in shape (batch_size, num_inputs, height, width, 3) at all model output sizes as float32 0.0 - 1.0 range

  • optimizer (Optimizer) – The configured Optimizer to use

  • meta (BatchMeta) – The meta information for the batch

Return type:

The loss for each input to the model in order (A, B, …)