This post is about the basics of the attention mechanism in sequence-to-sequence models. The attention mechanism is a foundation for highly performant NLP models like the Transformer.
What is a sequence-to-sequence model?
A sequence-to-sequence model consists of an encoder and decoder. The encoder encodes a variable-length source sequence into something that the decoder decodes into a variable-length target sequence.
They are often used for tasks like language translation, and question-answering.
Non-attentional RNN sequence-to-sequence models
To make more clear the motivations behind the design choices of the attention mechanism, I think that it is helpful to examine non-attentional RNN sequence-to-sequence models and their shortcomings.
In a non-attentional RNN sequence-to-sequence model, the encoder and decoder are n-layered RNNs (ex. 4-layered LSTMs), with the encoder RNN possibly being bidirectional. The encoder encodes the variable-length input sequence into a single fixed-length context vector (this is just the encoder’s last hidden state) for the decoder.
Problems
- The performance of such a model degrades quickly as the length of the input sequence increases. It is difficult for the encoder to effectively compress a long input sequence into a single fixed-length context vector. It is like trying to tell a complicated story with just a few words. So, what if we encoded the sequence into multiple context vectors?
- Consider a human performing a language translation task. When they try to output the next translated word, they don’t need to pay attention to every word in the source sentence. Intuitively, this is more efficient than looking at the whole sentence, especially for long sentences. Non-attentional RNN sequence-to-sequence models have no mechanism like this. So why don’t we introduce one?
Attentional sequence-to-sequence models
Attentional sequence-to-sequence models (as described by the papers cited in the post, and that use global attention) address the aforementioned problems with non-attentional ones, with the following changes:
- We get a variable-length sequence of context vectors for each input sequence, instead of a single fixed-length context vector - for each of its time steps, the decoder gets a separate context vector which contains information on the entire sequence of encoder hidden states.
- A new attention mechanism that reflects how much attention should be paid to various parts of the input when decoding at a particular time step. This is used in creating the context vectors.
For (hopefully) easier communication, we will use these definitions (chosen to be consistent with the symbols in the 2015 Bahdanau paper that introduced the attention mechanism) from now on:
- $h_j$: the hidden state of the encoder at time $j$
- $c_i$: context vector that is input to the decoder at time step $i$
- $s_i$: the hidden state of the decoder at time $i$
- $y_i$: the target word at time step $i$
Determining the context vectors
Each decoding time step $s_i$ is associated with a context vector $c_i$, which gives the decoder all the source sequence information it needs to generate the target word for step $i$. How should we determine $c_i$?
It seems that $c_i$ should package information about encoder hidden states. So should $c_i$ just be a sum of encoder hidden states? Actually, we can do something better than that.
Notice that the encoder time steps $h_1, h_2, … h_{T_{encoder}}$ may not be equally important for determining a target word. For example, consider an English-to-Spanish translation task in which the input is “how are you” and the expected output is “cómo estás”. When generating the word “cómo”, the decoder only needs to pay attention to the first input word “how”, rather than the entire input sentence.
Imagine a much longer sequence to translate (ex. 100+ words). Without being selective about which inputs to pay attention to, the decoder/translator can be easily overwhelmed - it can be much harder to teach it to translate effectively, and it can be much slower. So it is crucial for the decoder/translator to be selectively attentive to the encoder hidden states.
So, how can we be more selectively attentive to $h_1, h_2 … h_{T_{encoder}}$ during a decoding step?
To begin, we can weight each of those encoder hidden states when calculating the context vector $c_i$ for decoding step $i$, such that:
- $c_i = \sum_{j=1}^{T_{encoder}} \alpha_{ij} h_{j}$
- where $\alpha_{ij}$ is a scalar weighting term that represents how much $s_i$ should pay attention to $h_j$. The sum of these terms should be 1, across all $j$, so we can use a softmax. Then $c_i$ is effectively a weighted average of the encoder hidden states.
- $\alpha_{ij} = \frac{exp(e_{ij})}{\sum_{k=1}^{T_x} exp(e_{ik})}$
- this is the softmax that represents the weight on the encoder hidden state $h_j$, when considering decoder hidden state $s_i$
- $e_{ij}$ is a scalar alignment score representing how well $h_j$ matches with $s_i$
As an alternative view, the context vector $c_i$ can be formulated as:
$c_i = \mathrm{f}(s_i, h) = \mathrm{softmax}(\mathrm{align}(s_i, h))h$
- where $\mathrm{align}(s_i, h)$ outputs a vector of alignment scores between $s_i$ and each element of $h$
Determining the alignments between target and source sequence
How should we weight the encoder hidden states for each decoder step? Specifically, what is the function that determines the alignment score $e_{ij}$, that represents how well $s_i$ matches with $h_j$?
Here are some proposed alignment functions from various papers.
1. Additive
- $a(s_{i}, h_j) = v_a^T \tanh(W_a[s_{i}; h_j])$
- where $v_a$ and $W_a$ are weight matrices to be learned
- This processes the hidden states through a network with a hidden dense layer (with $\tanh$ as its activation function)
- The Luong paper states that their implementation of this (compared to general and dot-product) “does not yield good performances and more analysis should be done to understand the reason”
- Referenced in:
- “Neural Machine Translation by Jointly Learning to Align and Translate” (Bahdanau et al., 2015)
- “Effective Approaches to Attention-based Neural Machine Translation” (Luong et al., 2015)
- (Note that this is referenced as “concat” in the Luong paper. But I’m referring to it as just “additive”, because Tensorflow does)
2. General
- $a(s_{i}, h_j) = s_{i}^T W_a h_j$
- where $W_a$ is a weight matrix to be learned
- According to the paper’s authors, it works better than dot product for local attention in their experiments
- Referenced in “Effective Approaches to Attention-based Neural Machine Translation” (Luong et al., 2015)
3. Dot-product
- $a(s_{i}, h_j) = s_{i}^T h_j$
- According to the paper’s authors, it works well for global attention in their experiments
- Note that this is much faster than additive attention. Considering that we have to make $T_{encoder} \times T_{decoder}$ alignments where $T_x$ represents hidden state sequence length of $x$, this is important.
- Referenced in “Effective Approaches to Attention-based Neural Machine Translation” (Luong et al., 2015)
4. Scaled dot-product
- $a(s_{i}, h_j) = \frac{s_{i}^T h_j}{\sqrt{d_h}}$
- where $d_h$ is the dimension of $h_j$
- The paper mentions that additive and dot-product attention perform similarly for small values of $d_h$, but additive outperforms dot-product for larger values. They guess that the dot products become large as $d_h$ becomes large, which push the softmax into regions of very small gradients. To counter that effect, they decided to scale the dot products by $\frac{1}{\sqrt{d_h}}$; this prevents the variance from increasing as $d_h$ increases (refer to the laws of variances to derive this)
- Referenced in “Attenion is All You Need” (Vaswani et al., 2017)
It seems like additive, dot-product and scaled dot-product are most popular; currently, Tensorflow only has attention layers for those.
Aside: interpreting the alignment functions
This section explores some questions I had/have about the alignment functions. Why do they work? When might we want to use one over another?
The dot product
Why does the dot product work as an alignment function?
Recall what the alignment function is supposed to do. It should output a larger value when the given encoder hidden state $s_i$ is more “relevant” for generating the next target word for the given decoder state $h_j$, and a smaller value when it is less so. So “relevance” in this context is some distance or inverse-distance/similarity measure in the abstract spaces of the hidden states.
Recall that the dot product for vectors $a$ and $b$ is $a \cdot b = ||a|| ||b|| \cos{\theta}$, where $\theta$ is the angle between them. Visualize how it changes with different values of $\theta$ (with fixed magnitudes of $a$ and $b$) from $0°$ to $180°$:
- $>0$ and largest when $\theta = 0°$ ($a$ and $b$ are in the same direction)
- $0$ when $\theta = 90°$ ($a$ and $b$ are orthogonal)
- $<0$ and smallest when $\theta = 180°$ ($a$ and $b$ are in opposite directions)
From the above, you can see that the dot product outputs a larger value when its vectors are more “similar”, and a smaller value when they are less so. So, the dot product can be interpreted as a measure of similarity, which can function as a measure of “relevance” in the abstract spaces of the encoder/decoder hidden states. And as previously mentioned, measuring “relevance” is what we want the alignment function to do. So from this perspective, it makes sense to me that the dot product works as an alignment function.
Dot product versus cosine
Why not use a simple cosine function ($\cos{\theta}$), also known as the normalized dot product, instead of the dot product ($||a|| ||b|| \cos{\theta}$) to measure similarity? When might we find it useful to use the magnitudes of $a$ and $b$?
The vector magnitudes in the dot product could allow it to capture certain additional information like frequency. So that might be useful for certain tasks. But admittedly, I currently don’t know exactly why the authors of the papers cited in this post chose dot product over cosine for their language translation tasks.
Alignment functions that multiply the hidden states by learnable weights
What could be the motivation for multiplying the encoder/decoder hidden state vectors by learnable weights, as in the additive ($v_a^T \tanh(W_a[s_{i}; h_j])$) and general ($s_{i}^T W_a h_j$) alignment functions?
- Thinking that the encoder/decoder hidden states might be represented in very different ways. Intuitively, it seems that it would be hard to determine similarity of things that are represented very differently, at least without first changing their representations. So by multiplying those states by the learnable weights, we are hoping to transform them into new spaces in which determining similarity is easier.
The new spaces that the weights transform the hidden states into could be much smaller in dimensionality, and thus computationally cheaper to work with (although of course there is a non-trivial computational cost for transforming them into these spaces).
The dimensions of the encoder and decoder hidden states don’t match. Multiplying by another matrix can transform one of them into a new state with compatible dimensionality.
Determining the next target word
The context vector will be used to help determine the next target word.
The specifics of the various ways in which that happens are out of scope for this post. I don’t think that they are core to understanding the basics of the attention mechanism. But you can read the cited papers for information on those.
Conclusion
In this post, we explored how the attention mechanism addresses some shortcomings of non-attentional sequence-to-sequence models to better handle longer sequences.
In another post, we will see how the attention mechanism is used in the transformer model.
Acknowledgments
- “Sequence to Sequence Learning with Neural Networks” (Sutskever et al., 2014)
- “Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation” (Cho et al., 2014)
- “Neural Machine Translation by Jointly Learning to Align and Translate” (Bahdanau et al., 2015)
- “Effective Approaches to Attention-based Neural Machine Translation” (Luong et al., 2015)
- “Attenion is All You Need” (Vaswani et al., 2017)