ModifiedResNet

class lib.model.networks.clip.ModifiedResNet(input_resolution: int, width: int, layer_config: tuple[int, int, int, int], output_dim: int, heads: int, name='ModifiedResNet')

Bases: object

A ResNet class that is similar to torchvision’s but contains the following changes:

  • There are now 3 “stem” convolutions as opposed to 1, with an average pool instead of a max pool.

  • Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1

  • The final pooling layer is a QKV attention instead of an average pool

Parameters:
  • input_resolution (int) – The input resolution of the model. Default is 224.

  • width (int) – The width of the model. Default is 64.

  • layer_config (list) – A list containing the number of Bottleneck blocks for each layer.

  • output_dim (int) – The output dimension of the model.

  • heads (int) – The number of heads for the QKV attention.

  • name (str) – The name of the model. Default is “ModifiedResNet”.

Methods Summary

__call__()

Implements the forward pass of the ModifiedResNet model.

Methods Documentation

__call__() Model

Implements the forward pass of the ModifiedResNet model.

Returns:

The modified resnet model.

Return type:

keras.models.Model

__call__() Model

Implements the forward pass of the ModifiedResNet model.

Returns:

The modified resnet model.

Return type:

keras.models.Model