This post explores the basics of vanilla recurrent neural networks (RNNs) and long short-term memory networks (LSTMs).
What neural architecture can effectively handle sequences?
Consider that we want a neural network to do a sequence-processing task, such as:
- generating human names
- image captioning
- sentiment classification of text
What kind of neural architecture would be suited for these types of problems?
The problem with dense neural networks
A vanilla dense neural network would not elegantly handle sequences, for the following reasons:
- It cannot handle variable-length inputs and outputs
- The input/output sequence sizes would be limited to the number of input/output neurons
- Information from other parts of the sequence are not persisted
- But in many sequence-processing problems, it is helpful to get context from other parts of the sequence
Recurrent neural networks (RNNs)
A recurrent neural network is composed of a sequence of connected cells, each representing some index or “time” $t$ in the sequence to be processed/generated. The trainable parameters of the cells are shared, so you can also think of an RNN as having one cell that is recurrently connected to itself $t$ times.
This structure allows it to accept and output sequences of any length, and persist information from other parts of the sequence. For those reasons, it is better-suited than dense neural networks for processing sequences.
Aside: processing the input with an embedding layer
This section is just a very high-level overview of an embedding layer in case you are not aware of it. I think that it is good to know about because it is so often used in text processing. Before being passed to the RNN layer, the network’s input is often first processed by an embedding layer.
What are the problems with the input representation?
The input tokens are often represented as one-hot vectors, which can get very high-dimensional with a large dictionary of tokens. It’s not efficient to work with these very sparse, high-dimensional vectors. You might run out of memory, and it’s much harder to learn with higher-dimensional features (see the curse of dimensionality).
The one-hot representation doesn’t allow the network to easily determine similarity and relationships between the tokens. But being able to do that is extremely useful. For example, the network can generalize better if it interprets similar tokens (ex. synonyms, nouns, verbs, etc.) similarly.
How does the embedding layer address those problems?
The embedding layer address these problems by transforming those inputs into dense, lower-dimensional representations. The transformation is learned so that useful relationships can effectively be determined from the resulting representations.
Note that although a dense layer can do this too, an embedding layer does this much more efficiently.
Vanilla RNNs
In this section I’ll describe the details of a vanilla RNN layer, which you can think of as a single RNN cell that is recurrently connected to itself a number of times according to the input sequence length.
RNN cell details
Each cell for time $t$ in a vanilla RNN can have two inputs and two outputs (see the RNN diagram above):
Inputs:
- $h_{t-1}$ from the cell at time $t-1$
- $x_t$ from outside the network
Outputs:
- $h_t$ to the cell at time $t+1$
- $h_t = \tanh(W_{x}x_t + W_{h}h_{t-1} + b_h)$
- $o_t$ to outside the network
- $o_t = softmax(W_{o}h_t + b_o)$
- $h_t$ to the cell at time $t+1$
Each cell shares the same trainable parameters:
- $W_{x}$ representing the weights on the external input $x_t$
- $W_{h}$ and $b_h$ representing the weights/bias on the internal input $h_{t-1}$
- $W_{o}$ and $b_o$ representing the weights/bias for the output $o_t$
How does backpropagation with RNNs work?
The total loss for the RNN is the sum of the losses for each cell’s exonetwork output.
We calculate and apply the gradients of this loss with respect to the trainable parameters. Remember that those parameters are shared across all cells.
I want to give you an idea of what the gradient calculation looks like to make the weaknesses of RNNs more apparent.
Example gradient of the loss at time step $t$ with respect to $W_h$:
$\frac{\partial{L_t}}{\partial{W_h}} = \frac{\partial{L_{t}}}{\partial{W_{h_t}}} + \frac{\partial{L_{t}}}{\partial{W_{h_{t-1}}}} + \frac{\partial{L_{t}}}{\partial{W_{h_{t-2}}}} + …$
This is the sum of the gradients of the loss at time step $t$ with respect to the $W_{h_i}$, for each time step $i$ before $t$. $W_h$ is shared across all cells, so $W_{h_t} = W_{h_{i}}$, but each $W_{h_i}$ is treated like a separate parameter for gradient calculations.
Expanding it out, we get:
$\frac{\partial{L_t}}{\partial{W_h}} = \frac{\partial{L_{t}}}{\partial{h_t}}\frac{\partial{h_t}}{\partial{W_{h_t}}} + \frac{\partial{L_{t}}}{\partial{h_t}}\frac{\partial{h_t}}{\partial{h_{t-1}}}\frac{\partial{h_{t-1}}}{\partial{W_{h_{t-1}}}} + \frac{\partial{L_{t}}}{\partial{h_t}}\frac{\partial{h_t}}{\partial{h_{t-1}}}\frac{\partial{h_{t-1}}}{\partial{h_{t-2}}}\frac{\partial{h_{t-2}}}{\partial{W_{h_{t-2}}}} + …$
The gradient contribution from earlier time steps have more terms in the form $\frac{\partial{h_t}}{\partial{h_{t-1}}}$. Let’s expand that term.
- Recall that $h_t = \tanh(W_{x}x_t + W_{h}h_{t-1} + b_h)$
- Let $k = W_{x}x_t + W_{h}h_{t-1} + b_h$
- $\frac{\partial{h_t}}{\partial{h_{t-1}}} = \frac{\partial{h_t}}{\partial{k}}\frac{\partial{k}}{\partial{h_{t-1}}} = \frac{\partial{(\tanh(k))}}{\partial{k}}W_h$
- Note that $\frac{\partial{(\tanh(k))}}{\partial{k}} = 1 - \tanh^2(k)$ which outputs a value between $0$ and $1$
So the gradient contribution from earlier time steps consists of a lot more multiplications of $\frac{\partial{(\tanh(k))}}{\partial{k}}W_h$ which means that it can much more easily vanish/explode!
Problems with vanilla RNNs
From the previous section on backpropagation, you can see that the gradient contribution from earlier parts of the sequence can much more easily vanish/explode.
This can lead to complete failure to learn or difficulty in learning long-term dependencies.
Bidirectional RNNs
Bidirectional RNNs not only process the sequence from start to end, but also from end to start. This might be helpful in some cases in which there are dependencies from later parts of the sequence as well as from earlier parts.
For simplicity, this post only includes technical details of unidirectional RNNs.
Long short-term memory networks (LSTMs)
A long short-term memory network (LSTM) is a type of RNN that tries to alleviate the issue of vanilla RNNs being bad at processing longer sequences. It adds something called a cell state that is maintained over time, and has the ability to add/remove information to/from it.
Cell details
Inputs
The cell for time $t$ has the following inputs:
- $h_{t-1}$ from the cell at time $t-1$
- $C_{t-1}$ from the cell at time $t-1$
- $x_t$ from outside the network
Cell gates and outputs
- Forget gate
- $f_t = sigmoid(W_f \cdot [h_{t-1}, x_t] + b_f)$
- This will decide what information we are going to throw away from the incoming cell state $C_{t-1}$
- $f_t = sigmoid(W_f \cdot [h_{t-1}, x_t] + b_f)$
- Input gate
- $i_t = sigmoid(W_i \cdot [h_{t-1}, x_t] + b_i)$
- This will decide what new information we are going to store in the cell state $C_t$
- $i_t = sigmoid(W_i \cdot [h_{t-1}, x_t] + b_i)$
- Update cell state
- $\bar{C_t} = tanh(W_c \cdot [h_{t-1}, x_t] + b_c)$
- This represents the candidate new information to store in the cell state
- $C_t = f_t*C_{t-1} + i_{t}*\bar{C_t}$
- This is output to the cell for time $t+1$
- $\bar{C_t} = tanh(W_c \cdot [h_{t-1}, x_t] + b_c)$
- Output gate
- $o_t = sigmoid(W_o \cdot [h_{t-1}, x_t] + b_o)$
- $h_t = o_t * tanh(C_t)$
- This is the hidden state that is output 1) as input to the cell at time $t+1$, and 2) to outside the network
Why do we use sigmoid in the forget and input gates?
Since the sigmoid function outputs a number between 0 and 1, it acts as an intuitive gatekeeper. 0 would mean “let nothing through” and 1 would mean “let everything through”.
Problems with LSTMs
- Although it works better than vanilla RNNs for long-term dependencies, it still does not work well for very long-term dependencies.
- Like with vanilla RNNs, sequences are processed serially rather than in parallel which is slower
Addressing problems with LSTMs
The attention mechanism and the transformer architecture address some of the problems with LSTMs. But those are out of scope for this post.