Overview

Focalnet Architecture

Goal: Explore how to use a Focalnet to extract patch features in place of our standard feature extraction conv stack.

Baseline Here is our current standard patch feature extraction network. It has three average pooling layers and the final spatial resolution is 1/8 of the original size before being flattened. There are very few parameters in here in total - a tiny feature extractor.

We want to explore replacing our normal patch feature extractor with a Focalnet (https://arxiv.org/pdf/2203.11926.pdf).

The Focalnet is comprised of a stack of BasicLayers. Each BasicLayer is made of several FocalBlocks, and each FocalBlock contains multiple FocalModulation layers. In other words:

PatchEmbed

BasicLayer x 4

FocalNetBlock x 2

FocalModulation

PatchEmbed

GlobalAvgPool

The BasicLayer is made of several FocalBlocks, and each FocalBlock contains multiple FocalModulation layers.

Since we are replacing our patch feature extractor with the Focalnet, it makes sense to compare how they transform the data. A focalnet with 4 BasicLayers will do 3 downsamples just like the regular patch feature extraction, but due to the large downsampling we do in the first patch embedding, the final feature maps are much smaller at 2x2. That said, there is way more computation in each BasicLayer than there is in the whole standard feature extractor combined.

One major difference is that the regular patch feature extractor uses some form of pooling to downsample, which has no learnable parameters. The Focalnet uses a Patch Embedding layer that does have learnable parameter.

BasicLayer

After splitting the patches up into subpatches (whose size is determined by the focalnet_patch_size parameter) and doing the initial patch embedding, we have the input to the first BasicLayer. We pass this through a few FocalBlocks and a downsample with another PatchEmbedding layer.

for block in self.blocks: x = block(x, height, width, channels) x, height_o, width_o, channels_o = self.downsample(x)

FocalBlock

The FocalBlock is essentially a FocalModulation layer with a skip connection, a few LayerNorms, and an MLP.

shortcut = x x = self.norm1(x) x = self.modulation(x)
x = shortcut + x
x = self.norm2(x)
x = x + self.mlp(x)

FocalModulation

Where all the magic occurs.

See this diagram yoinked from the paper showing the difference between focal modulation and self-attention

Focalnet Parameters

At the highest level, there are a few things that determine that architecture of the focalnet:

depth - list which determines number of BasicLayers and their depths. The tiny Focalnet default is [2, 2, 6, 2] . This means there are 4 BasicLayers having a depth of 2, 2, 6, and 2 respectively.

embed_dim - the starting number of feature maps, which doubles after each BasicLayer

patch_size - the rate to downsample the input image. Since the patch size divides the input image size, bigger values result in more downsampling

Embedding Dimension

Patch embedding dimension controls the number of feature maps we ultimately create. As with a normal stack of convolutions, the spatial resolution gets smaller and the channel dimension / number of filters gets larger in deeper layers. However, instead of using max or average pooling to downsample, we use use a Patch Embedding layer.

The following image shows how the embedding dimension, aka the channel dimension, of our feature maps doubles after each BasicLayer.

Patch Size

Now let’s look at the patch size. The following image shows how the patch size determines the spatial resolution for the whole Focalnet pipeline. We start off with 64x64 images and default to a patch size of 4, we have 16x16 feature maps going into the first BasicLayer. Since we downpool after each BasicLayer and a standard FocalNet has 4 BasicLayers, we end up with tiny feature maps - 2x2.

This might seem bad, but notice the last step. We a global average pool over the spatial dimension, erasing the spatial dimension entirely - all the features now resides in the channel dimension. This entirely composes the patch features passed to the dense layers we use to predict the joint coordinates

Depth

Deeper/thinner focalnets [3, 3, 16, 3] instead of [2, 2, 6, 2] with embedding dim 64 instead of 96 have slightly better performance, less parameters and less flops, but higher time cost due to the increased number of sequential blocks.

In a single BasicLayer, the feature maps go through multiple FocalModulation blocks without changing shape. It’s only at the last step of the BasicLayer that the feature maps get downsampled. The downsample is done via a PatchEmbedding (fancy conv) and the embedding dimension also doubles, taking the feature maps from (Ph, Pw, d) to (Ph, Pw, 2*d)

Assume a downsample cuts the spatial resolution in half and doubles the embedding dimension. The total number of parameters in the feature maps gets cut in half after every downsample since the output feature maps (Ph/2 * Pw/2 * 2d) are half the size of the input (Ph*Ph*d)

Let’s look at what happens to the resolution of the feature maps with a Focalnet with depths=[2, 2, 6, 2]. We have 4 BasicLayers, and because we downsample after every BasicLayer besdies the last, we have 3 downsamples. The final feature maps will be 1/8th of the size of the original feature maps in spatial resolution, and 8x deeper in the channel dimension. This means that, focal layers equal, each subsequent BasicLayer is half the size of the one before it.

PatchEmbed

This is where the original patch embedding comes in. This is the first chance to “tokenize” the image into local regions. The discretization of groups of pixels in the input image into nonoverlapping patches that happens in the first. This is useful because by tokenizing the regions, you are creating two classifications - global and local The patch embedding is really just a convolution with a certain kernel size and stride such that there is less overlap than a standard 3x3 conv with stride of 1.

If you are not doing anything fancy and using the conv embedding, the PatchEmbed is just doing this: self.proj = layers.Conv2D(embed_dim, kernel_size=patch_size, strides=patch_size)

Notice how the stide and kerenl size are the same. That means when we convolve the kernel with a part of the image, the next convolution will have a completely different set of pixels. Compare this with something like a 3x3 conv with a strid of 1, and which has overlap between the feature maps. Since we are moving the conv kernel clean over all the pixels in the last convolution, each value in the output feature map represents a distinct part of the input image.

This is an assumption to make, and there’s also a way to make it so that there is overlap, and the authors point out this improves performance slightly, though it is not the default.

With a patch size of say 4, that means the first 4x4 pixels in the top left will be

FocalModulation

Here is what the FocalModulation is actually doing.

Similar to how a vision transformer takes in flattened image patches and projects them into queries, keys, and values, Focalnets takes in an input and projects it into queries, context, and gates. The main difference is that in self attention, each query is getting multiplied by every key, making the attention value calculation quadratic with input size. In the focalnet, the context is run through a series of convs at various receptive fields and gated before getting multiplied element-wise with the queries.

In other words, as we pass the context through the FocalLayers and do the projection from feature maps to queries, context, and gates, we are saying that the fetures have the information to tell you, for a given patch, which other patches affect it, and which other patches it affects, as well as how.

MLP

Why use an MLP, aka dense layers? It’s the most straightforward thing to try. Before you start trying to get fancy with convolutions and other things with various inductive biases, just wire everything together and see what it can learn.

As the FocalModulation layer is actually the lowermost level of the Focalnet, this projection is getting repeated many times in the course of a single forward pass. The basic idea is still shared with self attention: each location within a feature map contains information about other locations, and we should use that information to our advantage.

Implementation wise, you would hope that various FocalLayers with a FocalModulation block would be able to be run in parallel, since they take the same input in (the context tensor) and apply independent operations (depth wise conv at different kernel sizes) and then sum the results. In practice, I have no idea how this works gpu threading wise.

Intuitions from the paper

“Compared to pooling [100, 35], depth-wise convolution is learnable and structure-aware. In contrast to regular convolution, it is channel-wise and thus computationally much cheaper.”

“feature maps obtained via hierarchical contextualization are condensed into a modulator. In an image, the relation between a visual token (query) and its surrounding contexts often depends on the content itself. For example, the model might rely on local fine-grained features for encoding the queries of salient visual objects, but mainly global coarse-grained features for the queries of background scenes. Based on this intuition, we use a gating mechanism to control how much to aggregate from different levels for each query. Specifically, we use a linear layer to obtain a spatial- and level-aware gating weights G = fg(X) ∈ R H×W×(L+1). Then, we perform a weighted sum through an element-wise multiplication to obtain a single feature map Z out which has the same size as the input X,”

Passing the context vector through depthwise convolutional layers with different kernel sizes gives hierarchical contexts –> why do we keep updating and passing the same context vector through the different kernel sized depthwise conv layers, as opposed to passing the single context vector to the four different depthwise conv layers independently and concatenating or averaging them