In this article, we will discuss some of the limitations of Encoder-Decoder models which act as a motivation for the development of Attention Mechanism. After that, we will talk about the concepts of Attention Models and their application in Machine translation.
Note:Attention mechanism is a slightly advanced topic and requires an understanding of Encoder-Decoder and Long Short Term Memory Models. Kindly refer to my previous articles on Encoder-Decoder Models andLSTM
Before we discuss the concepts of Attention models, we will start by revisiting the task of machine translation using Encoder-Decoder models.
Citation Note:The content and the structure of this article is based on my understanding of the deep learning lectures from One-Fourth Labs — PadhAI .
Let’s take an example of translating text from Hindi to English using the seq2seq model which uses an encoder-decoder architecture. The underlying model in the encoder-decoder architecture can be anything from RNN to LSTM.
At a high level, encoder reads the entire sentence only once and encodes all the information from the previously hidden representations and previous inputs into an encoded vector. Then the decoder at each time-step uses this embedding to produce a new word.
The problem in this approach is that encoder reads the entire sentence only once and it has to remember everything and converts that sentence to an encoded vector. For longer sentences, the encoder will not be able to remember the starting parts of the sequence resulting in the loss of information.
Is this how humans translate a sentence?
Do you think that the entire input sequence (or sentence) is important at every time-step during encoding?. Can we place special emphasis on certain words rather than giving equal importance to all the words?. Attention Mechanism is developed to address these challenges.
We, humans, try to translate each word in the output by focusing only on certain words in the input. At each time-step, we take only relevant information from the long sentences and then translate that particular word. Ideally, at each time-step, we should feed only the relevant information (encodings of the relevant information) to the decoder for the translation.
How do we know which of the words are important or we need to give more attention to?. For now, assume that we have an oracle to tell us which words to focus on at a given time-step t . By taking the oracle’s help can we design a better architecture so that we can feed relevant information to the decoder?.
So for each input word, we assign a weight α (ranges between 0–1) that represents the importance of that word for the output at the time-step ‘ t’ . For example, α 12 represents the importance of the first input word on the output word at the second time-step. To generalize, the representation α jt represents the weight associated with the jᵗʰ input word at the tᵗʰ time-step.
For example, at time-step 2, we could just take a weighted average of the corresponding word representations along with the weights α jt and feed it into the decoder. In this scenario, we are not feeding the complete encoded vector into the decoder, rather the weighted representation of the words. In effect, we are giving more importance or attention to the important words based on the weights given by oracle. (Thanks to oracle!)
Intuitively this approach should work better than the vanilla version of encoder-decoder architecture because we are not overloading the decoder with irrelevant information.
Don’t be fooled, in reality, there is no oracle. If there is no oracle then how do we learn the weights?.
Notations:From now on we will refer the decoder state at the tᵗʰ time-step as St and encoder state at the jᵗʰ time-step as hⱼ .
The parameter α jt has to be learned from the data. To enable this we define a function,
The function to calculate the intermediate parameter ( e jt) takes two parameters. Let’s discuss what are those parameters. At the tᵗʰ time-step, we are trying to find out how important is the jᵗʰ word, so the function to compute the weights should depend on the vector representation of the word itself (i.e… hⱼ) and the decoder state up to that particular time step (i.e…St-₁).
The weight e jt captures the importance of the jᵗʰ input word for decoding the tᵗʰ output word. Using the softmax function, we can normalize these weights to get our parameter α jt (ranges between 0–1).
The parameter α jt denotes the probability of focusing on the jᵗʰ word to produce the tᵗʰ output word.