Vasudev's visions

Vision Transformer ...1

This covers the idea behind building Pali Gemma Model from scratch in Python. I learnt a lot along the process, some of the key concepts around Rotary Embeddings, KVCache, etc. Big thanks to Umair Jamil's youtube, without which I'd be rendered with much less help!

If you're new to this, well don't worry, I'm here to help :)

Assuming that you have a basic understanding of Transformer model; a Vision transformer is just a Transformer with images. Use cases revolve around object segmentation, object detection. Consider your gf sends you her image, and you're a schizophrenic like me, you'll pay attention to each and every detail very attentively. Now when a computer wants to do that, it can use Vision Transformers. They outperform CNNs in computation & accuracy by a margin larger than the Wall of China.

So, a Vision language model is something that learns from both images and texts, instead of only texts as in normal Transformers. Let's start from the basic structure.
So PaliGemma has 2 parts:

  1. SigLip Vision Encoder
  2. Gemma Text Encoder

gemma architecture Here we'll be only covering the SigLip Vision Encoder as we have coded it.
It consists of multiple different chunks tied together. We frist convert the images to input embeddings. Let's go one by one:

def forward(self, pixel_values:torch.FloatTensor) -> torch.Tensor:
        _, _, height, width = pixel_values.shape

        patch_embeds = self.patch_embedding(pixel_values)
        embeddings = patch_embeds.flatten(2)        # 
        embeddings = embeddings.transpose(1,2)

        # attach positional encoding to flattened layer
        embeddings = embeddings + self.position_embedding(self.position_ids)

        return embeddings

We define the basic elements of attention as:

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

self.embed_dim = config.hidden_size
        self.self_attn = SigLipAttention(config)        # self attention
        self.layer_norm1 = nn.LayerNorm(...)
        self.mlp = SigLipMLP(config)        # multi layer perceptron
        self.layer_norm2 = nn.LayerNorm(...)

This covers our first part of scrolling through Pali Gemma.

There's More...

This was the first of 'n' parts of a series I'm writing that covers the basics about Vision Language models, considering Pali Gemma as a reference. So, watch out :)