Trainer

class plugins.train.trainer.distributed.Trainer(model: ModelBase, config: TrainConfig)

Bases: Trainer

Distributed training with torch.nn.DataParallel

Parameters:
  • model (ModelBase) – The model that will be running this trainer

  • config (TrainConfig) – The Training Configuration options

Methods Summary

get_sampler()

Obtain a standard random sampler

register_loss(loss)

Registers the selected loss functions to the underlying model nn.module

train_batch(inputs, targets, optimizer, meta)

Run a single forward and backwards pass through the model for a single batch

Methods Documentation

get_sampler() type[RandomSampler]

Obtain a standard random sampler

Return type:

The Random sampler

register_loss(loss: LossCollator) None

Registers the selected loss functions to the underlying model nn.module

Parameters:

loss (LossCollator) – The configured loss functions

Return type:

None

train_batch(inputs: list[torch.Tensor], targets: list[torch.Tensor], optimizer: Optimizer, meta: BatchMeta) list[BatchLoss]

Run a single forward and backwards pass through the model for a single batch

Parameters:
  • inputs (list[torch.Tensor]) – The batch of input image tensors to the model of length(num inputs)

  • targets (list[torch.Tensor]) – List of len (num_outputs) of target images in shape (batch_size, num_inputs, height, width, 3) at all model output sizes as float32 0.0 - 1.0 range

  • optimizer (Optimizer) – The configured Optimizer to use

  • meta (BatchMeta) – The meta information for the batch

Return type:

The loss for each input to the model in order (A, B, …)