ModifiedResNet
- class lib.model.networks.clip.ModifiedResNet(input_resolution: int, width: int, layer_config: tuple[int, int, int, int], output_dim: int, heads: int, name='ModifiedResNet')
Bases:
objectA ResNet class that is similar to torchvision’s but contains the following changes:
There are now 3 “stem” convolutions as opposed to 1, with an average pool instead of a max pool.
Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
The final pooling layer is a QKV attention instead of an average pool
- Parameters:
input_resolution (int) – The input resolution of the model. Default is 224.
width (int) – The width of the model. Default is 64.
layer_config (list) – A list containing the number of Bottleneck blocks for each layer.
output_dim (int) – The output dimension of the model.
heads (int) – The number of heads for the QKV attention.
name (str) – The name of the model. Default is “ModifiedResNet”.
Methods Summary
__call__()Implements the forward pass of the ModifiedResNet model.
Methods Documentation
- __call__() Model
Implements the forward pass of the ModifiedResNet model.
- Returns:
The modified resnet model.
- Return type:
keras.models.Model
- __call__() Model
Implements the forward pass of the ModifiedResNet model.
- Returns:
The modified resnet model.
- Return type:
keras.models.Model