Collate

class lib.training.data.collate.Collate(input_size: int, output_sizes: tuple[int, ...], color_order: T.Literal['bgr', 'rgb'], config: TrainConfig, landmarks: LandmarkMatcher | None)

Bases: object

Collation function for processing a batch of samples into input and output tensors applying augmentation

Parameters:
  • input_size (int) – The pixel size of the model input

  • output_sizes (tuple[int, ...]) – The pixel sizes of the model output

  • color_order (T.Literal['bgr', 'rgb']) – The color order that the model expects

  • config (TrainConfig) – The training configuration for the model

  • landmarks (LandmarkMatcher | None) – The landmark matching object for the (A and B) sides of the model if warp_to_landmarks is enabled otherwise None

Methods Summary

__call__(data)

Prepare the loaded samples for feeding the model, creating targets and applying augmentation

Methods Documentation

__call__(data: list[tuple[tuple[npt.NDArray[np.uint8], int], ...]]) tuple[list[torch.Tensor], list[torch.Tensor], BatchMeta]

Prepare the loaded samples for feeding the model, creating targets and applying augmentation

Parameters:

data (list[tuple[tuple[npt.NDArray[np.uint8], int], ...]]) – Batch of data tuples with the loaded stacked image and masks from each loader in the first position and the image file index for each item in the batch in the 2nd

Returns:

  • feed – list of len (num_inputs) tensors of shape(batch_size, H, W, C) inputs for the model

  • targets – List of len (num_outputs) of target images in shape (batch_size, num_inputs, height, width, 3) at all model output sizes as float32 0.0 - 1.0 range

  • meta – The meta information for the batch

Return type:

tuple[list[torch.Tensor], list[torch.Tensor], BatchMeta]

__call__(data: list[tuple[tuple[npt.NDArray[np.uint8], int], ...]]) tuple[list[torch.Tensor], list[torch.Tensor], BatchMeta]

Prepare the loaded samples for feeding the model, creating targets and applying augmentation

Parameters:

data (list[tuple[tuple[npt.NDArray[np.uint8], int], ...]]) – Batch of data tuples with the loaded stacked image and masks from each loader in the first position and the image file index for each item in the batch in the 2nd

Returns:

  • feed – list of len (num_inputs) tensors of shape(batch_size, H, W, C) inputs for the model

  • targets – List of len (num_outputs) of target images in shape (batch_size, num_inputs, height, width, 3) at all model output sizes as float32 0.0 - 1.0 range

  • meta – The meta information for the batch

Return type:

tuple[list[torch.Tensor], list[torch.Tensor], BatchMeta]