Key to practical deep learning — transfer learning
-In natural language processing (NLP)
-Contextualized word representations
-Current state of the art
-Branching Attention — proposed classifier head
Experiments and results
-Benefits of head-only training
-Full IMDB, other datasets
-What did not work, possible improvements
Applying deep neural networks to unstructured data, most commonly images, can yield impressive results. Academic papers and prototypes aside, mass-market cars understand their surroundings in real-time , by analyzing video feeds. In some areas, cars drive fully autonomously . In these cases, and many others, deep learning is used to transform raw numbers of the images’ individual pixels into some level of understanding of the scene.
While understanding images comes naturally to humans, for a computer it is no easy feat. It takes a lot of computations and enormous amounts of data to teach a machine how to see the world somewhat similarly to how humans do . In practical terms, it also takes a lot of time and money. The example above? It was trained using more than 14 million manually annotated images ( ImageNet , then COCO ). Yet, there are many practical applications in all kinds of niches — countless industrial ones, and many used by NGOs — e.g. to recognize individual whales by their pictures or detect deforestation using satellite imagery .
Key to using deep learning without huge datasets and budgets? Re-purposing already trained models, and the knowledge they contain; more formally: transfer learning. Computer vision has undergone a revolution of sorts when in the mid-2010s it became standard practice to use models pre-trained on ImageNet for all kinds of applications .
The pre-training task is simple: for each image, decide what is the main object in it by choosing one of the (20,000+) pre-defined categories. To solve it, the neural networks learn how to extract meaningful representations of images from raw pixel data. First, convolutional layers learn to detect simple features, such as edges and corners. Next layers use that information to recognize more complex shapes — perhaps wings and beaks when looking at images of birds. The final layers can recognize specific types of objects — e.g. birds, cars or planes of various types.
We can use the pre-trained networks to compute representations of images and train our models to use these representations. Alternatively, we can remove several last layers of the pre-trained network, replace them with new, randomly-initialized ones, and train the whole network to perform the task at hand. This way it is possible to quickly build prototype systems, using datasets with only a few thousand, or a few hundred labeled images. For example, 400 training samples were sufficient to achieve 85% accuracy in a 4-category classification of disorders present in OCT images of retinas — a very different kind of images than the pictures of animals and vehicles the original network was trained on. With some additional effort, e.g. for classification tasks, the data requirements can be lowered further. Taking it to the extreme, classifying images into categories without seeing any examples of them, is also possible (zero-shot learning, e.g. DeViSE ).
Automated processing of natural language is a challenging problem. Systems performing various tasks in this domain have existed for several decades, but until relatively recently they were predominantly rule-based. These systems can perform very well (example: categorizing questions into 50 categories with 97.2% accuracy ) and their behavior can be easily understood and debugged. Unfortunately, since they are based on manually-developed systems of rules and knowledge bases, their development is labor-intensive. Their complexity grows immensely with expected functionality, so they are generally only applicable to well-defined, narrow tasks. Furthermore, their operation can be disrupted by typos or grammatical errors.
The motivation for developing statistical, data-driven NLP systems is clear. In many cases, it is easier to gather a dataset relevant to the task at hand, than it is to develop a massive set of rules. It translates into lower costs and a shorter time of system development. Additionally, it provides an opportunity for the system to perform better with time, as more data is gathered. Besides, a lot of text data is freely available — in forms of books, online publications, etc. — and arguably contains most of humanity’s knowledge. Making use of it, directly, as a knowledge base could enable the construction of new, very capable systems.
Building a system able to fully understand natural language is an extremely hard, unsolved problem. Resolving linguistic ambiguities often requires context, knowledge of idioms, detecting sarcasm, and even general knowledge and human “common sense” (see Winograd Schema Challenge). Nevertheless, the problem is well worth pursuing — because of immediate, direct applications, as well as a possibility of getting closer to understanding general intelligence.
The first statistical methods in NLP started much simpler. They represented text documents as counts of words they contain (representation called bag-of-words). Improvements of this approach suggested using frequencies instead of counts, and often frequencies in relation to how common a given word is generally ( TF-IDF ). These methods completely disregard word order, so only use a small part of available information. Using n-grams (sequences of n words) instead of individual words, with similar further processing, is a viable way of incorporating word order information. Unfortunately, using high n values is unfeasible, as the number of possible word combinations grows exponentially with n.
The count-based representations described above remained the foundation of state-of-the-art statistical solutions for a long time. Because of their simplicity and computational efficiency, for some applications, they remain important baselines. For example, for topic classification and sentiment analysis, trying Naive Bayes or SVM classifiers, based on bag-of-words or bi-gram features, might be a good idea (in case of the full IMDB task described below, it achieves 90–91% accuracy, compared to 95% with ULMFiT — but is orders of magnitude faster).
The first wave of changes came in 2013, with the advent of Word2Vec . Models from this category produced a numerical representation (called “word vector”, or “embedding”) for each word in their vocabulary. They were built as shallow neural networks, trained on short sequences from a large corpus of freely available text. Representations of each word were based on its usual context and already captured a lot of the language’s semantics. Famously, analogies between words correspond to arithmetic relationships between their vectors , e.g.
king — man + woman ~= queen .
Each word in the vocabulary was represented as a fixed-length vector (e.g. 300 numbers) — set of coordinates of a point in the embedding space. Since the vocabulary size (number of unique words) is usually orders of magnitude higher, the representation is distributed — words and concepts are represented by certain combinations of coordinates, rather than a single element of the vector.
As part of a complete system, word embeddings would be used to encode each word of a document, producing a variable-length sequence of vectors (length corresponding to the number of words). Many algorithms using these representations were developed — from averaging all the word vectors and training an SVM classifier on the results, to passing the sequence through recurrent or convolutional networks.
Word embeddings have one central problem. When used to encode words of a document, they encode each word individually — ignoring the context. The vector representing root will be the same in:
Words with more than one possible meaning are a problem for most practical applications of NLP. For some applications, e.g. sentiment analysis, syntactic ambiguity has to be resolved as well. This can range from simple problems (was a phrase negated?) to very complex, sometimes impossible ones (was the phrase used ironically?).
Let us consider a practical scenario. We want to build a system which will estimate public sentiment towards brands, by classifying tweets about them as positive or negative, and then calculate a fraction of positive ones. We need training data — we might download a number of relevant tweets about various brands, and then manually label them (or crowdsource the task).
If we use word embeddings and feed the encoded text to a neural network, we will be able to train it to perform the classification task — distinguishing the positive messages from negative ones. But because word embeddings do not capture context, just individual words, our network has to learn the whole structure of the language, and even higher-level concepts like irony or sarcasm, at the same time! All based on our precious, manually-labeled data. This is clearly inefficient.
To help neural networks “understand” the structure of languages, researchers have developed several network architectures and training algorithms, which compute contextualized representations of words. The key difference compared to word embedding-based approaches: this understanding of context is trained on unlabeled, freely available text from a given language (or multiple languages). Data is often taken from Wikipedia ,public domain books, or just scraping the Internet.
In 2018–2019 they started a “revolution” of sorts, analogous to the one in computer vision several years earlier. Deep neural networks are commonly used to compute rich, contextualized representations of text — as opposed to the context-unaware word vectors. These representations can then be used by another neural network (e.g. ELMO ). More commonly, the pre-trained network has its last layer replaced with a set of different ones, designed for the task at hand (often called a downstream task ). The weights of the new layers are initialized randomly and trained using the labeled data to perform the downstream task. The process is much easier since the hardest part — language understanding — has mostly been done.
The initial training, meant to train the network’s “understanding” of text is based on unlabeled data. Labels — things for the model to predict — are automatically generated from the data itself. Training with such labels is often referred to as unsupervised pre-training.
The most common approaches to unsupervised pre-training are:
Approaches to tokenization— splitting text into its basic chunks — also vary. The simplest, but sub-optimal strategy would be to split them every time a space character is encountered. Word-based tokenization is used in some cases, but usually includes rules specific to each language (e.g. splitting “don’t” into “do” and “n’t”). Sub-word tokenization , (e.g. SentencePiece ) is perhaps the most common — it is derived from data, based on the frequency of particular character sequences, often treating white space as yet another character. It is also worth noting, that character-level language models are also sometimes used in practice (e.g. FLAIR ).
Two groups of network architectures dominate the space:
This article describes a novel network architecture for classification models working on text documents — a modification of ULMFiT . Using that paper’s nomenclature, a different classifier head is proposed.
Training an ULMFiT model from scratch consists of 3 steps:
The network architecture used in the ULMFiT paper is depicted below. It contains:
The classifier head performs Concat Pooling , then passes its results through one fully connected layer and the output layer.
Concat polingis a simple concatenation of 3 elements:
The described architecture is meant to directly address two main problems with using average-pooling, and/or max-pooling to aggregate text representations generated by recurrent language models.
Therefore, the classifier head contains two “branches”, each answering one question:
Both are implemented as simple, fully connected neural networks. Numbers and sizes of their layers are additional hyper-parameters, the choice of which is discussed in the “Results” section. The networks are applied independently to each sequence element. Values returned by the ATT branch — a scalar (eⱼ) for each sequence element — are then passed through a Softmax function to obtain proper weights (aⱼ) for a weighted average. A weighted average of vectors ( bⱼ ) returned by AGG becomes the final representation of the sequence (C) .
Diagram of the “Branching Attention” classifier head:
note: if the AGG branch is skipped, and ATT only has the output layer, the whole aggregation reduces to dot-product attention , with a single, trainable query. It does not produce particularly good results, as discussed in the next section.
Most experiments were performed using a popular sentiment classification dataset: IMDB . It contains 50,000 movie reviews, written by users of the imdb.com website. They are labeled as positive if the user’s rating of the movie was 7 or above, negative for 4 or lower. The classes are balanced, there are equal numbers of positive and negative reviews for each movie. There are no more than 30 reviews per movie. Document lengths vary greatly — many are relatively short, but 2% are longer than 1000 words. It is also worth noting, that the labels are somewhat “noisy”, e.g. there are some positive reviews labeled as negative.
To simulate a small dataset situation, yet make statistically significant comparisons between architectures, the experiments were conducted in the following way: models were trained on 1000-element samples of the IMDB training dataset, and evaluated on the whole test set. This was repeated 20 times for each architecture and hyper-parameter set. The training-related hyper-parameters and a single dropout multiplier (as recommended by ULMFiT authors) were tuned separately for Branching Attention and Concat Pooling heads.
Below, the results of the proposed architecture (Branching Attention) are compared with the baseline (Concat Pooling). Two additional variants are provided, as a minimal ablation study:
While by no means ground-breaking, the results above seem encouraging. Without changing the encoder — the most important part of the model — the number of incorrectly classified samples was reduced by over 10%. The proposed architecture seems to make better use of the encoder’s representation of text.
In the best configuration presented above, the Branching Attention head had 30% fewer parameterscompared to the original ULMFiT architecture. In other configurations, performing only slightly worse, the number of parameters decreased by up to 85%.
It is worth noting that removing either branch of the Branching Attention head causes a significant drop in performance — even below the baseline levels. While not shown here for brevity, reducing the depth of either branch (to a single layer) also results in a small, but consistent, drop in performance. The results support using the complete network architecture, with both branches, as previously described.
The distributions of accuracy scores from individual experiment runs are shown below, as a box plot. It is meant to visualize the variance of results for each configuration, resulting from different sampling of the training dataset and different random initialization.
Worth noting: the single, negative outlier present for all four architectures, denoted as a dot, comes from training on the same sample of the training dataset — presumably of inferior quality. Likely, if training was repeated with the same dataset each time (only changing the random seed), the variance of accuracy for each configuration would be much lower. However, the use of different dataset samples was meant to ensure that any conclusions are more general — not specific to a particular dataset. Overall, the improvement of mean accuracy cannot be reasonably explained as random noise.
What other benefits can we get from using a classifier head with a trainable aggregation, instead of a fixed one? We might not need to modify the parameters of the encoder at all. According to ULMFiT , we should first optimize the classifier head, and then gradually add layers of the encoder to the optimizer’s scope — this approach was described in the previous section. However, training exclusively the classifier’s head has important practical benefits, as discussed in the next section. The table below summarizes results obtained this way — only optimizing parameters of the classifier’s head.
Somewhat surprisingly, for the full Branching Attention head, the results did not deteriorate at all. Training classifiers this way can be a viable, practical approach , especially for relatively small datasets . As expected, the gap between the proposed architecture and Concat Pooling has increased.
The variant without the AGG branch performed above expectations — better than when the whole network was optimized (in the previous section). It might indicate that the training process when optimizing the full network could be improved upon — but such attempts were not successful.
Why is it important that the proposed architecture performs much better than the baseline when only training the classifier head? There are several practical advantages to building systems this way. Training is faster and requires less memory. A less obvious benefit lies in the low number of parameters that are unique to the model being trained.
New use cases open in scenarios where many classifiers operate on texts from the same domain, e.g. “tweets” or “news articles” for a given language. It becomes very cheap to train, store and run a new model:
Where can it be useful? There are many possible scenarios, examples might include:
By analyzing attention weights in a neural network we can understand it better — we can see which parts of the input are relevant to the task at hand. Interesting visualizations have been demonstrated e.g. in machine translation and image caption generation .
With Branching Attention, in some circumstances, we can take it a step further. If the classification problem is binary, we can achieve quite good results by setting the last (or only) dimension of the AGG branch to 1. Effectively, for each input token, we will obtain a single, scalar value, and the final decision will be a weighted average of these values (after a simple transformation). In the case of sentiment analysis, we can show “local” sentiment after processing each token, as well as importance scores of specific parts of the document .
See the example below, and check out the interactive demo . In the demo, it is possible to visualize both weights and sentiment at the same time (as in the example), just the weights or just the sentiment. It is also possible to test the model on a different piece of text.
Opacityof the color behind each token means the attention weight associated with it. Hue denotes the value of a feature (sentiment) calculated for that token, considering its left context , in a red-to-green scale.
A few observations:
To verify if the proposed architecture performs well generally, or just for the particular dataset type and size, several other experiments were conducted.
When training on the full IMDB dataset, in its default train/test split, the results are approximately equal to those of a Concat Pooling-based classifier (albeit Branching Attention seems to be less sensitive to the choice of training-related hyper-parameters). Intuitively, the large training dataset might contain enough information to modify the encoder in such a way, that its output will be meaningfully aggregated by average- and max-pooling.
When training on the full IMDB dataset, but optimizing only the classifier head, Branching Attention performs significantly better than the baseline. As expected, it performs worse than when the whole network is optimized. Nevertheless, this approach can be useful in some cases, as discussed in the “Benefits of head-only training” section.
Using an older version of the code , a system based on Branching Attention was entered into the PolEval 2019 contest — task 6. Its goal was to prepare a system assessing which tweets, written in Polish, might contain cyber-bullying content. A crowd-labeled training set of ~10,000 example tweets was provided. It was highly imbalanced — with an overwhelming majority of examples not containing cyber-bullying.
The system used a sub-word tokenization mechanism — SentencePiece — to deal with misspellings common on Twitter. This type of tokenization is generally beneficial for Polish — as it is a fusional language. Its abundant pre- and post-fixes dramatically increase the number of unique words — which is a problem with traditional, word-based tokenization approaches. The language model was trained on Polish Wikipedia and then fine-tuned on a large, unlabelled dataset of Polish-language tweets.
In the contest, the system came in 3rd( proceedings , page 106), behind an ULMFiT — like system with a larger, improved language model (similar to, and by co-authors of, MultiFiT ). However, it ranked just ahead of a BERT-based system.
Performance of Branching Attention and Concat Pooling was then compared in the same way as in the “IMDB sample” experiment, with hyper-parameters tuned on the IMDB dataset . Both models were trained 20 times, each time on a different sub-sample of the training dataset, and with different random initialization. To directly compare models’ performance and avoid issues with setting a detection threshold, the Average Precision metric was chosen.
In a completely different context — different language, tokenization strategy, different (more complex) task, without further hyper-parameter optimization, Branching Attention provided a modest, but noticeable performance improvement . It can be considered a successful test of the proposed architecture.
Another verification of the method’s results was performed using the Fact-Checking Questions dataset. The dataset was extracted from “Qatar Living” — a community question-answering forum. Each document — first post in a topic — consists of a category, a subject line, and the post’s content. The fields were concatenated together, with delimiters between them. The task: assigning a category to each question: “factual”, “opinion/advice” or “social”.
In an experimental setup identical to the one described above, with the same hyper-parameters, the obtained results were:
Once again, the network using Branching Attention performed significantly better— despite both networks sharing the same pre-trained encoder.
Several attempts to improve the network’s performance were made:
Code and the fine-tuned language model, needed to reproduce this article’s results, are provided.