Collate
- class lib.training.data.collate.Collate(input_size: int, output_sizes: tuple[int, ...], color_order: T.Literal['bgr', 'rgb'], config: TrainConfig, landmarks: LandmarkMatcher | None)
Bases:
objectCollation function for processing a batch of samples into input and output tensors applying augmentation
- Parameters:
input_size (int) – The pixel size of the model input
output_sizes (tuple[int, ...]) – The pixel sizes of the model output
color_order (T.Literal['bgr', 'rgb']) – The color order that the model expects
config (TrainConfig) – The training configuration for the model
landmarks (LandmarkMatcher | None) – The landmark matching object for the (A and B) sides of the model if warp_to_landmarks is enabled otherwise
None
Methods Summary
__call__(data)Prepare the loaded samples for feeding the model, creating targets and applying augmentation
Methods Documentation
- __call__(data: list[tuple[tuple[npt.NDArray[np.uint8], int], ...]]) tuple[list[torch.Tensor], list[torch.Tensor], BatchMeta]
Prepare the loaded samples for feeding the model, creating targets and applying augmentation
- Parameters:
data (list[tuple[tuple[npt.NDArray[np.uint8], int], ...]]) – Batch of data tuples with the loaded stacked image and masks from each loader in the first position and the image file index for each item in the batch in the 2nd
- Returns:
feed – list of len (num_inputs) tensors of shape(batch_size, H, W, C) inputs for the model
targets – List of len (num_outputs) of target images in shape (batch_size, num_inputs, height, width, 3) at all model output sizes as float32 0.0 - 1.0 range
meta – The meta information for the batch
- Return type:
tuple[list[torch.Tensor], list[torch.Tensor], BatchMeta]
- __call__(data: list[tuple[tuple[npt.NDArray[np.uint8], int], ...]]) tuple[list[torch.Tensor], list[torch.Tensor], BatchMeta]
Prepare the loaded samples for feeding the model, creating targets and applying augmentation
- Parameters:
data (list[tuple[tuple[npt.NDArray[np.uint8], int], ...]]) – Batch of data tuples with the loaded stacked image and masks from each loader in the first position and the image file index for each item in the batch in the 2nd
- Returns:
feed – list of len (num_inputs) tensors of shape(batch_size, H, W, C) inputs for the model
targets – List of len (num_outputs) of target images in shape (batch_size, num_inputs, height, width, 3) at all model output sizes as float32 0.0 - 1.0 range
meta – The meta information for the batch
- Return type:
tuple[list[torch.Tensor], list[torch.Tensor], BatchMeta]