Flamingo — Intuitively and Exhaustively Explained
The Architecture Behind Modern Visual Language Modeling
In this article we’ll discuss Flamingo, a landmark paper in “multimodal modeling”.
First we’ll define “multimodal models” as a class of machine learning models capable of understanding numerous types of data. We’ll then briefly explore landmark papers in image classification and text generation, then describe how Flamingo combined these technologies to achieve state of the art performance in use cases containing both images and text.
By the end of this article you’ll have a thorough understanding of how Flamingo achieved state-of-the-art performance, paving the way for today’s advanced A.I. systems like GPT-4 and Google Gemini.
Who is this useful for? Anyone interested in natural language processing, computer vision, or multimodal modeling.
How advanced is this post? This is an intermediate post that assumes some basic knowledge of machine learning.
Pre-requisites: None, but I included some relevant reference material throughout the article should you be confused about a specific topic.
There is also a curated list of resources at the end of the article for related reading.
Multimodal Modeling Before Flamingo
“Vision-Language modeling” is what most people think of when they think of “multimodal modeling”. Before we get into the nuts and bolts, let’s define these two ideas:
Multimodal modeling is an umbrella term for any machine learning model that deals with multiple “modalities”. You can think of a modality as a type of data; things like text, images, tables, and audio are all considered by data scientists as different “modalities”. Multimodal models are models that can somehow work with multiple modalities.
Vision-language modeling is probably the most popular form of multimodality. machine learning models that can, in some way, do tasks that require the simultaneous understanding of both images and text.
In reality, vision-language modeling is a catchall term for a broad class of problems:
Visual Question Answering: Given an image and a textual question about that image, generate a response
Captioning: Given an image, describe the content of the image textually
Visual Dialogue: Hold a coherent and organic conversation that contains both images and text
Image Classification: Given an image, categorize that image into one of a fixed set of predefined textual classes.
Before Flamingo, highly specific models were state-of-the-art on highly specific multimodal tasks:
KAT was state of the art on the OKVQA dataset, a visual question answering dataset
a Good Embedding Is All You Need? was state of the art on VQAv2, a different visual question answering dataset
SimVLM was state of the art on COCO, an image captioning dataset
VIOLET was state of the art on MSVDQA, a visual question answering dataset focused on video
One of the big claims from the CLIP paper (a landmark paper we’ll cover in the next section) was that this highly specialized performance doesn’t scale well to real-world situations. If you have a model that’s good at one dataset, but can’t do a similar task in a similar dataset, is it really a good model?
CLIP tackled this problem before flamingo, but CLIP style models only work in classifying images. The idea behind Flamingo was to bridge the gap between CLIP style models, which have a robust understanding of the content of images, and the textual understanding and generative ability of language models, creating a system that could be used to converse about both text and images robustly and flexibly.
The Precursors to Flamingo
Before we talk about Flamingo itself, it’s critical to understand the two key modeling strategies it inherits; CLIP and Decoder Transformers:
CLIP
The whole idea of CLIP was to create a general-purpose image classifier that could be used in a variety of cases without any further training. To achieve this, the authors of the CLIP paper used a strategy called “contrastive learning”.
Contrastive learning is a subtle re-framing of the problem of image classification. Instead of learning “This label goes with this image, and this label goes with that image”, contrastive learning says “This label is closer in similarity to this image, and this label is closer in similarity to that image”. This subtle change in thinking opened up a whole new approach to representing images that have been widely used ever since.
CLIP employs two components to build this idea of “closeness”, an image encoder and a text encoder. These both learn, in unison through the training process, to jointly align pairs of images and text in a high-dimensional space. By putting similar images and text in similar spots, and different images and text in different spots, CLIP style models learn to make decisions about which text belongs with which image.
The important part for us is the image encoder. For CLIP to be successful, it has to train the image encoder to understand things like dogs, clothes, and skateboards within an image so that those images can be placed in the right spot with the right text. In other words, CLIP Image encoders are really good at distilling an image down into its general meaning, a quality Flamingo employs to achieve visual-language modeling.
That was a super quick run-through, feel free to refer to my article on CLIP for more information:
Decoder-Only Transformers
So that was CLIP, the technology Flamingo uses to understand images. Decoder-only transformers are what Flamingo uses to understand text.
From a super high level, you can think of language models as a big stack of blocks. The purpose of each of these blocks is to refine a representation of the input text, block by block, and then use that representation to predict what text should follow the input.
The thing that makes each of these blocks within the transformer special is their use of “attention”. Attention is a form of modeling where the representation of multiple words in the input are combined together to create an abstract and highly contextualized representation.
This is done by feeding the attention mechanism three inputs: A “Query”, “Key” and “Value”. Don’t get hung up on the names, we’ll build a more intuitive understanding of all this later, I just want to share with you the high level workings of the attention mechanism.
The query and key, which are derived from the input to the attention mechanism, get multiplied together to create what I like to call the “attention matrix”
Then, the attention matrix is used as a filter to transform the value matrix into the final output.
So, in essence, the attention mechanism uses some inputs to filter other inputs. We’ll cover attention more thoroughly later in this article, for now, just understanding the vibe from a high level is sufficient. If you do want to dig in deeper, check out my article on transformers, or my article on GPT a popular decoder-only transformer:
The end goal of a decoder-only transformer is to use the attention mechanisms, within the numerous blocks of the model, to understand the input, and figure out what the next output should be. By being able to guess the next output well, language models can construct an output by guessing one word at a time.
Flamingo in a Nutshell
Now that we understand the essence of CLIP style image encoding (which turns an image into a vector that conveys the image's general meaning) and transformer style language models (which use attention to iteratively output words), we can start digging into Flamingo.
At its highest level, flamingo consists of four key components:
A Vision Encoder, which re-represents images into their general meaning
A Perceiver Resampler, A system that combines the information from a variable number of images into a fixed number of features (allowing the model to understand things like video or a series of images)
A Language Model, a pre-trained decoder style transformer like GPT3 or llama. The flamingo paper used chinchilla.
Gated Cross Attention, allowing flamingo to slowly learn to inject image information into the language model throughout the training process.
Flamingo uses these systems to understand an arbitrary input sequence of images and text to generate textual output. Let’s break down each of these components, one by one, to build a complete understanding of how Flamingo functions.
The Vision Encoder
Flamingo uses a CLIP style image encoder (the one we previously discussed) to encode images. This is a common strategy used in numerous multimodal architectures that have to do with images.
The idea is, instead of Flamingo needing to learn about images from scratch, it can employ the high quality summarizations from a pre-trained CLIP image encoder. Thus, Flamingo offloads a lot of the work of understanding images, and instead only has to reason about image distillations.
Flamingo doesn’t actually use CLIP, but the NFNet F6 model. For our purposes, the only conceptual difference is that NFNet produces summarizations about sub-regions of an image rather than the whole image. This makes it easier for Flamingo to understand subtleties within the image. NFNet also does a lot of other cool stuff, I might cover it in a future article, but for our purposes, this is more of a line item. Conceptually, NFNet is like a fancy version of CLIP.
The Preceiver Resampler
In wanting to create a flexible and robust multimodal model, the authors of Flamingo created a system that was good at handling both images and video. Video data is a difficult type of data to do machine learning on; there’s a lot of information in even small video files, and extracting the right information efficiently can be computationally expensive and difficult.
Flamingo addresses the problems of video with the “Perceiver Resampler”. The perceiver resampler can be thought of as a summarization system that compresses an arbitrarily long video down into a fixed set of descriptive tokens. It’s not conceptually difficult, but there’s a lot of moving parts. Let’s look at it from a high level, then a more nuts-and-bolts lower level
The Preceiver Resampler — High Level
Conceptually, you can think of the perceiver resampler as a filter; it takes in a fixed length of predefined tokens and uses input images extracted from video to filter those tokens. Regardless of the number of images in an input, the same fixed number of tokens come out of the output.
From a high level, the perceiver resampler fits into the greater Flamingo architecture in the following way:
The images are extracted from the prompt. In their place, an
<image>
token is placed in the text so that the model knows where the image came from.The output from the perceiver resampler is used to incrementally filter the internal state of the LLM throughout various layers, ultimately allowing the LLM to converse about the images.
We’ll cover how cross-attention interweaves the image representation into the LLM in later sections. For now, let’s zoom into the perceiver resampler and see how it works.
The Perceiver Resampler — Nuts and Bolts
To understand the perceiver resampler in more detail, let’s work through its components step-by-step.
First, the input image, or sequence of images, is passed through a vision encoder. This summarizes the content of the image in a way that’s easy to interpret for ML systems. This is our NFNet image encoder we discussed in a previous section.
The attention mechanism (which the image encodings are ultimately fed into), tends to shuffle inputs around and thus lose track of where a particular piece of information was in an input sequence. As a result, it’s customary to add a time encoding, which embeds the time of an input into the value of the input itself.
Flamingo uses a learned time vector for each frame in the input. During training, Flamingo has spots for 8 input frames total. These are added to the features extracted from the vision encoder.
I was a bit surprised by this. I thought the whole point of Flamingo was to be generalizable; limiting video input to eight frames seemed like kind of a silly design choice. Apparently, though, the model is robust to interpolating between time embeddings to fit in more frames as necessary. So, if you want to add more frames, just make new time tokens by interpolating between the eight trained ones.
Although our model was trained with a fixed number of 8 frames, at inference time, we input 30 frames at 3 FPS. This is achieved by linearly interpolating the learnt temporal position embedding of the Perceiver Resampler at inference time. — The Flamingo Paper.
Another quick note, you might think “Hey, if we need to add information about time, then why not location? If attention mixes up our input, wouldn’t it be useful to say something like ‘this information came from the top right of an image’?” The answer to that is absolutely, but we don’t have to add positional information because our vision encoder does already:
Note that we only use temporal encodings and no explicit spatial grid position encodings; we did not observe improvements from the latter. This rationale behind is likely that CNNs, such as our NFNet encoder, are known to implicitly include spatial information — The Flamingo Paper.
Now that we have extracted features from each image, and we’ve added all necessary information about time, we can use attention to filter out the right information from the image via the learned tokens.
Let’s follow some data through the attention mechanism to get a thorough idea of how it functions.
Step 1) Flattening
The features extracted from the images are of shape [T, S, d]
where T
is the number of images, S
is the number of spatial grids, and d
is the length of the feature vectors. In most machine learning contexts, “tokens” are vectors of some length. d
is sometimes referred to as the “internal dimension”, as it’s the size of the token vectors within the model (note: in this article d
is depicted as length 6, but in reality, the internal dimension of modern models is very large. on the order of hundreds or thousands. So, usually, these vectors are much, much longer than is depicted).
Before passing through attention, these tokens get flattened along the space and time dimensions; so T
and S
become a dimension of length T * S
, resulting in a two-dimensional matrix of shape [T * S, d]
.
Keep in mind, while this might appear like we’re shuffling the image data around:
This operation is done consistently across successive runs, so while flattening mixes up the order of space and time, it does so in the same way for any given input.
We don’t really need the order to be preserved. Recall that spatial information is automatically encoded by the image encoder, and time information was added to encode time information in each token. As long as things are done consistently, the order of the inputs doesn’t really matter.
Step 2) Creating the Key and Value
Now that the information from our images is properly processed, it’s time to pass them into the attention mechanism. Recall that, from a high level, the whole idea of the perceived resampler is to use images to filter a fixed number of tokens.
This “filtering” idea gets done with matrix multiplication within the attention mechanism, which is essentially just matrix multiplication.
The whole idea of this attention mechanism is that a fixed number of tokens are extracted from a variable sequence of images, so it makes sense that some clever matrix manipulation would be required to get everything working right. That’s why there are a bunch of arrows before the attention mechanism in the perceiver resampler.
For Flamingo, the flattened features from the image, labeled Xf
, are concatenated with a set of learned tokens, X
of shape [R,d]
to construct the “key” and “value” inputs. In this case, R
represents an arbitrary fixed number of tokens. The “query” is simply the learned tokens X
.
Recall that d
is the internal dimension of the model.
One question you might be wondering is “Why are the learned tokens appended to the image information?”, I think this is so the perceiver resampler can more intimately control the information that comes out of the attention mechanism. Keep that in mind as we go through the next section.
Step 3) Running Through the Attention Mechanism
Now that we have our Query, Key, and Value defined we can run them through attention. Attention is just two matrix multiplication operations. First, the Query and Key are multiplied together to construct the attention matrix
Then the attention matrix is multiplied by the value to construct the final output.
et voilà, the attention mechanism in the perceiver resampler has now extracted information from all images into a fixed number of output tokens, which are of size [R, D]
, where r
is the arbitrary number of fixed learned tokens and d
is the, also arbitrary, internal model dimension.
Step 4) Constructing the Final Output
After the attention mechanism a “skip” connection is applied. Basically, the attention mechanism is great at extracting subtle features, but it mixes everything around like crazy. It’s useful to allow some of the older, simpler structure from before the attention mechanism to be present. Thus, the learned tokens from before the attention mechanism are added to the output of the attention mechanism, allowing some older and simpler information to be present.
The output of the skip connection is passed through a feed forward network, which is your classic prototypical neural network. Another skip connection is applied for the same reason as before.
Generally, with machine learning systems, it helps to do the same thing in a few passes so that complex operations can be done incrementally over numerous layers. In the perceiver resampler, this information extraction using attention, skip, and feedforward is done a few times. The output tokens from one iteration are fed right back into the input of another; hence the X num_layers
in the diagram above.
Alright, so that’s the perceived resampler. Using a set of learned tokens (which are just vectors of numbers that can change throughout the training process), multiple stacks of attention allow the perceiver resampler to extract information from a sequence of input images. In the next section we’ll discuss how that information is fed into a language model so that the language model can converse about images.
Combining Visual and Textual Information With Gated Cross Attention
Cudos for hanging in there, hopefully you’re finding this as fascinating as I am. It’s been about six months since you started reading this article; recall that the whole point is to get a pre-trained language model to be able to understand and converse about images.
Keep reading with a 7-day free trial
Subscribe to Intuitively and Exhaustively Explained to keep reading this post and get 7 days of free access to the full post archives.