Trainer
- class plugins.train.trainer.distributed.Trainer(model: ModelBase, config: TrainConfig)
Bases:
TrainerDistributed training with torch.nn.DataParallel
- Parameters:
model (ModelBase) – The model that will be running this trainer
config (TrainConfig) – The Training Configuration options
Methods Summary
Obtain a standard random sampler
register_loss(loss)Registers the selected loss functions to the underlying model nn.module
train_batch(inputs, targets, optimizer, meta)Run a single forward and backwards pass through the model for a single batch
Methods Documentation
- get_sampler() type[RandomSampler]
Obtain a standard random sampler
- Return type:
The Random sampler
- register_loss(loss: LossCollator) None
Registers the selected loss functions to the underlying model nn.module
- Parameters:
loss (LossCollator) – The configured loss functions
- Return type:
None
- train_batch(inputs: list[torch.Tensor], targets: list[torch.Tensor], optimizer: Optimizer, meta: BatchMeta) list[BatchLoss]
Run a single forward and backwards pass through the model for a single batch
- Parameters:
inputs (list[torch.Tensor]) – The batch of input image tensors to the model of length(num inputs)
targets (list[torch.Tensor]) – List of len (num_outputs) of target images in shape (batch_size, num_inputs, height, width, 3) at all model output sizes as float32 0.0 - 1.0 range
optimizer (Optimizer) – The configured Optimizer to use
meta (BatchMeta) – The meta information for the batch
- Return type:
The loss for each input to the model in order (A, B, …)