TrainerBase
- class plugins.train.trainer.base.TrainerBase(model: ModelBase, config: TrainConfig)
Bases:
ABCA trainer plugin interface. It must implement the method “train_batch” which takes an input of inputs to the model and target images for model output. It returns loss per side
- Parameters:
model (ModelBase) – The model plugin
config (TrainConfig) – The Training Configuration options
Methods Summary
Override to set the sampler that the Torch DataLoader should use
register_loss(loss)Registers the selected loss functions to the underlying model nn.module
train_batch(inputs, targets, optimizer, meta)Override to run a single forward and backwards pass through the model for a single batch
Methods Documentation
- abstractmethod get_sampler() type[RandomSampler | DistributedSampler]
Override to set the sampler that the Torch DataLoader should use
- Return type:
The sampler that the torch DataLoader should use
- register_loss(loss: LossCollator) None
Registers the selected loss functions to the underlying model nn.module
- Parameters:
loss (LossCollator) – The configured loss functions
- Return type:
None
- abstractmethod train_batch(inputs: list[torch.Tensor], targets: list[torch.Tensor], optimizer: Optimizer, meta: BatchMeta) list[BatchLoss]
Override to run a single forward and backwards pass through the model for a single batch
- Parameters:
inputs (list[torch.Tensor]) – The batch of input image tensors to the model of length(num inputs)
targets (list[torch.Tensor]) – List of len (num_outputs) of target images in shape (batch_size, num_inputs, height, width, 3) at all model output sizes as float32 0.0 - 1.0 range
optimizer (Optimizer) – The configured Optimizer to use
meta (BatchMeta) – The meta information for the batch
- Return type:
The loss for each input to the model in order (A, B, …)
- batch_size
The batch size for each iteration to be trained through the model.
- config
Training configuration options
- abstractmethod get_sampler() type[RandomSampler | DistributedSampler]
Override to set the sampler that the Torch DataLoader should use
- Return type:
The sampler that the torch DataLoader should use
- loss_func: LossCollator
The selected loss functions for the model
- model
The model plugin to train the batch on
- register_loss(loss: LossCollator) None
Registers the selected loss functions to the underlying model nn.module
- Parameters:
loss (LossCollator) – The configured loss functions
- Return type:
None
- sampler
The data sampler that the data loader should use
- abstractmethod train_batch(inputs: list[torch.Tensor], targets: list[torch.Tensor], optimizer: Optimizer, meta: BatchMeta) list[BatchLoss]
Override to run a single forward and backwards pass through the model for a single batch
- Parameters:
inputs (list[torch.Tensor]) – The batch of input image tensors to the model of length(num inputs)
targets (list[torch.Tensor]) – List of len (num_outputs) of target images in shape (batch_size, num_inputs, height, width, 3) at all model output sizes as float32 0.0 - 1.0 range
optimizer (Optimizer) – The configured Optimizer to use
meta (BatchMeta) – The meta information for the batch
- Return type:
The loss for each input to the model in order (A, B, …)