Explainable, data-efficient text classification

Improving ULMFiT with the right kind of attention

In this article you can find:

  • an introduction describing core ideas, history and applications of Transfer Learning,
  • a review of recent developments in natural language processing — including the ULMFiT algorithm (utilizing pre-trained language models based on recurrent neural networks),
  • a novel network architecture — modification of ULMFiT — which enables more accurate text classification, especially if trained on small datasets; its performance is evaluated in several scenarios; it came in 3rd in the Automatic Cyber-bullying Detection contest (PolEval 2019 Task 6–1)
  • a way of training groups of extremely lightweight (<100KB) classifiers meant to process documents from similar domains, as well as a way to deploy them efficiently,
  • an interactive demo of attention visualization showing not only which parts of the text are important to the classifier, but also how it understands them.

Table of contents:

Key to practical deep learning — transfer learning


-In natural language processing (NLP)

-Contextualized word representations

-Current state of the art

Proposed architecture


-Branching Attention — proposed classifier head

Experiments and results



-Benefits of head-only training

-Attention visualization

-Full IMDB, other datasets

-What did not work, possible improvements

Key to practical deep learning — transfer learning

In computer vision

Applying deep neural networks to unstructured data, most commonly images, can yield impressive results. Academic papers and prototypes aside, mass-market cars understand their surroundings in real-time , by analyzing video feeds. In some areas, cars drive fully autonomously . In these cases, and many others, deep learning is used to transform raw numbers of the images’ individual pixels into some level of understanding of the scene.

Figure 1: Open-source example of real-time object detection: neural network returns bounding boxes and classes of all detected objects. Source: YOLO v3 object detection

While understanding images comes naturally to humans, for a computer it is no easy feat. It takes a lot of computations and enormous amounts of data to teach a machine how to see the world somewhat similarly to how humans do . In practical terms, it also takes a lot of time and money. The example above? It was trained using more than 14 million manually annotated images ( ImageNet , then COCO ). Yet, there are many practical applications in all kinds of niches — countless industrial ones, and many used by NGOs — e.g. to recognize individual whales by their pictures or detect deforestation using satellite imagery .

Key to using deep learning without huge datasets and budgets? Re-purposing already trained models, and the knowledge they contain; more formally: transfer learning. Computer vision has undergone a revolution of sorts when in the mid-2010s it became standard practice to use models pre-trained on ImageNet for all kinds of applications .

The pre-training task is simple: for each image, decide what is the main object in it by choosing one of the (20,000+) pre-defined categories. To solve it, the neural networks learn how to extract meaningful representations of images from raw pixel data. First, convolutional layers learn to detect simple features, such as edges and corners. Next layers use that information to recognize more complex shapes — perhaps wings and beaks when looking at images of birds. The final layers can recognize specific types of objects — e.g. birds, cars or planes of various types.

We can use the pre-trained networks to compute representations of images and train our models to use these representations. Alternatively, we can remove several last layers of the pre-trained network, replace them with new, randomly-initialized ones, and train the whole network to perform the task at hand. This way it is possible to quickly build prototype systems, using datasets with only a few thousand, or a few hundred labeled images. For example, 400 training samples were sufficient to achieve 85% accuracy in a 4-category classification of disorders present in OCT images of retinas — a very different kind of images than the pictures of animals and vehicles the original network was trained on. With some additional effort, e.g. for classification tasks, the data requirements can be lowered further. Taking it to the extreme, classifying images into categories without seeing any examples of them, is also possible (zero-shot learning, e.g. DeViSE ).

In natural language processing (NLP)

Automated processing of natural language is a challenging problem. Systems performing various tasks in this domain have existed for several decades, but until relatively recently they were predominantly rule-based. These systems can perform very well (example: categorizing questions into 50 categories with 97.2% accuracy ) and their behavior can be easily understood and debugged. Unfortunately, since they are based on manually-developed systems of rules and knowledge bases, their development is labor-intensive. Their complexity grows immensely with expected functionality, so they are generally only applicable to well-defined, narrow tasks. Furthermore, their operation can be disrupted by typos or grammatical errors.

The motivation for developing statistical, data-driven NLP systems is clear. In many cases, it is easier to gather a dataset relevant to the task at hand, than it is to develop a massive set of rules. It translates into lower costs and a shorter time of system development. Additionally, it provides an opportunity for the system to perform better with time, as more data is gathered. Besides, a lot of text data is freely available — in forms of books, online publications, etc. — and arguably contains most of humanity’s knowledge. Making use of it, directly, as a knowledge base could enable the construction of new, very capable systems.

Building a system able to fully understand natural language is an extremely hard, unsolved problem. Resolving linguistic ambiguities often requires context, knowledge of idioms, detecting sarcasm, and even general knowledge and human “common sense” (see Winograd Schema Challenge). Nevertheless, the problem is well worth pursuing — because of immediate, direct applications, as well as a possibility of getting closer to understanding general intelligence.

The first statistical methods in NLP started much simpler. They represented text documents as counts of words they contain (representation called bag-of-words). Improvements of this approach suggested using frequencies instead of counts, and often frequencies in relation to how common a given word is generally ( TF-IDF ). These methods completely disregard word order, so only use a small part of available information. Using n-grams (sequences of n words) instead of individual words, with similar further processing, is a viable way of incorporating word order information. Unfortunately, using high n values is unfeasible, as the number of possible word combinations grows exponentially with n.

The count-based representations described above remained the foundation of state-of-the-art statistical solutions for a long time. Because of their simplicity and computational efficiency, for some applications, they remain important baselines. For example, for topic classification and sentiment analysis, trying Naive Bayes or SVM classifiers, based on bag-of-words or bi-gram features, might be a good idea (in case of the full IMDB task described below, it achieves 90–91% accuracy, compared to 95% with ULMFiT — but is orders of magnitude faster).

The first wave of changes came in 2013, with the advent of Word2Vec . Models from this category produced a numerical representation (called “word vector”, or “embedding”) for each word in their vocabulary. They were built as shallow neural networks, trained on short sequences from a large corpus of freely available text. Representations of each word were based on its usual context and already captured a lot of the language’s semantics. Famously, analogies between words correspond to arithmetic relationships between their vectors , e.g. king — man + woman ~= queen .

Figure 2: Semantic relationships in word embedding space. Coordinates of points marked “queen”, “woman”, “man” and “king” are the word vectors/embeddings. Source: Towards Understanding Linear Word Analogies , blog post

Each word in the vocabulary was represented as a fixed-length vector (e.g. 300 numbers) — set of coordinates of a point in the embedding space. Since the vocabulary size (number of unique words) is usually orders of magnitude higher, the representation is distributed — words and concepts are represented by certain combinations of coordinates, rather than a single element of the vector.

As part of a complete system, word embeddings would be used to encode each word of a document, producing a variable-length sequence of vectors (length corresponding to the number of words). Many algorithms using these representations were developed — from averaging all the word vectors and training an SVM classifier on the results, to passing the sequence through recurrent or convolutional networks.

Word embeddings have one central problem. When used to encode words of a document, they encode each word individually — ignoring the context. The vector representing root will be the same in:

  • root of a tree
  • square root of two

Words with more than one possible meaning are a problem for most practical applications of NLP. For some applications, e.g. sentiment analysis, syntactic ambiguity has to be resolved as well. This can range from simple problems (was a phrase negated?) to very complex, sometimes impossible ones (was the phrase used ironically?).

Let us consider a practical scenario. We want to build a system which will estimate public sentiment towards brands, by classifying tweets about them as positive or negative, and then calculate a fraction of positive ones. We need training data — we might download a number of relevant tweets about various brands, and then manually label them (or crowdsource the task).

If we use word embeddings and feed the encoded text to a neural network, we will be able to train it to perform the classification task — distinguishing the positive messages from negative ones. But because word embeddings do not capture context, just individual words, our network has to learn the whole structure of the language, and even higher-level concepts like irony or sarcasm, at the same time! All based on our precious, manually-labeled data. This is clearly inefficient.

Contextualized word representations

To help neural networks “understand” the structure of languages, researchers have developed several network architectures and training algorithms, which compute contextualized representations of words. The key difference compared to word embedding-based approaches: this understanding of context is trained on unlabeled, freely available text from a given language (or multiple languages). Data is often taken from Wikipedia ,public domain books, or just scraping the Internet.

In 2018–2019 they started a “revolution” of sorts, analogous to the one in computer vision several years earlier. Deep neural networks are commonly used to compute rich, contextualized representations of text — as opposed to the context-unaware word vectors. These representations can then be used by another neural network (e.g. ELMO ). More commonly, the pre-trained network has its last layer replaced with a set of different ones, designed for the task at hand (often called a downstream task ). The weights of the new layers are initialized randomly and trained using the labeled data to perform the downstream task. The process is much easier since the hardest part — language understanding — has mostly been done.

The initial training, meant to train the network’s “understanding” of text is based on unlabeled data. Labels — things for the model to predict — are automatically generated from the data itself. Training with such labels is often referred to as unsupervised pre-training.

Current state of the art

The most common approaches to unsupervised pre-training are:

  • masked language modeling — predicting a few deleted words from the remaining context (e.g. BERT )
  • language modeling — predicting the next word given its predecessors (e.g. GPT-2 , ULMFiT )
  • replaced token detection — recent, but very promising strategy, where some words are replaced by words generated by a separate, auxiliary language model, and the pre-training task is to recognize them (e.g. ELECTRA )
Figure 3: Two most popular pre-training objectives. Source: ELECTRA

Approaches to tokenization— splitting text into its basic chunks — also vary. The simplest, but sub-optimal strategy would be to split them every time a space character is encountered. Word-based tokenization is used in some cases, but usually includes rules specific to each language (e.g. splitting “don’t” into “do” and “n’t”). Sub-word tokenization , (e.g. SentencePiece ) is perhaps the most common — it is derived from data, based on the frequency of particular character sequences, often treating white space as yet another character. It is also worth noting, that character-level language models are also sometimes used in practice (e.g. FLAIR ).

Two groups of network architectures dominate the space:

  1. Transformer networks — based on the self-attention mechanism
    Most notable examples of Transformer networks include Google’s BERT (+ its many variants) and OpenAI’s GPT-2 . They helped achieve state-of-the-art results in many, if not most, NLP problems. Training is very easily parallelizable, and when interpreting a token they can use both tokens before and after the one being interpreted.
    There are many variants of the attention mechanism, summarized here ; they generally aggregate information from different parts of the input as a weighted average of these parts, with various trainable modules computing the weights.
    To find implementations of many architectures, pre-trained models, tutorials and more see HuggingFace Transformers .
  2. Recurrent networks (and hybrids of recurrent + other architectures).
    ULMFiT is perhaps the most popular in this group, others worth mentioning include MultiFiT , SHA-RNN and recently open-sourced Mogrifier LSTM .
    While training is harder to parallelize, for that very reason models tend to be smaller and train more efficiently — there is no way to accelerate training by just throwing more compute at the problem. All the recurrent networks mentioned above are possible to train from scratch on a desktop computer with a single GPU. They should fine-tune for a downstream task relatively quickly. They also tend to perform better if labeled data is scarce (see e.g. MultiFiT ). Another advantage, compared to Transformer-based models: they can handle arbitrarily long inputs, whereas Transformers can only process chunks of the length they were designed for and trained on (with some tricks available to process longer documents part by part and aggregate the results).
    Fast.ai contains an implementation of ULMFiT , together with a pre-trained model of the English language. The MultiFiT repository contains pre-trained models of several other languages, and pre-training scripts if another language is needed.

Proposed architecture

This article describes a novel network architecture for classification models working on text documents — a modification of ULMFiT . Using that paper’s nomenclature, a different classifier head is proposed.

ULMFiT — recap

Training an ULMFiT model from scratch consists of 3 steps:

  1. language model pre-training — training a language model on a large, general text corpus (e.g. Wikipedia)
  2. language model fine-tuning — continuing training of the language model on texts from the domain of the classification task (want to classify tweets? train on a large set of tweets in the same language)
  3. classifier fine-tuning — the last layer of the language model is replaced with a classifier head (described below). The remaining, pre-trained part is then called an encoder.
    First, only the randomly-initialized parameters of the classifier head are trained. Afterwards, layers of the encoder gradually start being optimized, last-to-first.

The network architecture used in the ULMFiT paper is depicted below. It contains:

  • a trainable embedding layer,
  • 3 recurrent layers ( LSTM with many regularization methods added: AWD-LSTM ),
  • a classifier head (which aggregates the variable-length sequence into a fixed-length representation and calculates the classification decision)

Figure 4: ULMFiT classification — processing an example input sequence

The classifier head performs Concat Pooling , then passes its results through one fully connected layer and the output layer.

Concat polingis a simple concatenation of 3 elements:

  • average-pooling — element-wise average along the sequence dimension (averaging the vectors representing each word),
  • max-pooling — element-wise maximum along the sequence dimension,
  • the last output vector of the encoder.

Branching Attention — proposed classifier head

The described architecture is meant to directly address two main problems with using average-pooling, and/or max-pooling to aggregate text representations generated by recurrent language models.

  1. Position-wise statistics (average- and max-pooling) equally weight all tokens in the sequence — even though often only part of it is relevant to the classification task. In case of sentiment analysis in movie reviews — a reviewer might write about a sad, tragic story told in a great movie. While most of the text might have negative sentiment, we only want to focus on the part describing the movie, not the content of its plot.
  2. Average- and max-pooling only make semantic sense if the text features of interest align with axes of the coordinate system of the representation space. For sentiment analysis, we would like to see one neuron (one dimension of the representation space) express exactly the positive or negative sentiment in the text, and nothing else. While for simple concepts like sentiment it can be approximately true , it is not likely to hold, even approximately, for more complex features. If the condition does not hold, averaging across a long sequence (hundreds or thousands of tokens) is very likely to corrupt stored information.

Therefore, the classifier head contains two “branches”, each answering one question:

  • ATT branch: “which parts are relevant”, what should we pay att ention to?
  • AGG branch: “what are the features we want to agg regate?”

Both are implemented as simple, fully connected neural networks. Numbers and sizes of their layers are additional hyper-parameters, the choice of which is discussed in the “Results” section. The networks are applied independently to each sequence element. Values returned by the ATT branch — a scalar (eⱼ) for each sequence element — are then passed through a Softmax function to obtain proper weights (aⱼ) for a weighted average. A weighted average of vectors ( bⱼ ) returned by AGG becomes the final representation of the sequence (C) .

Equations for the Branching Attention architecture. Non-linear functions f and g are realized by neural networks ATT and AGG. Vectors in bold.

Diagram of the “Branching Attention” classifier head:

Figure 5: Branching Attention classifier head

note: if the AGG branch is skipped, and ATT only has the output layer, the whole aggregation reduces to dot-product attention , with a single, trainable query. It does not produce particularly good results, as discussed in the next section.

Experiments and results

Most experiments were performed using a popular sentiment classification dataset: IMDB . It contains 50,000 movie reviews, written by users of the imdb.com website. They are labeled as positive if the user’s rating of the movie was 7 or above, negative for 4 or lower. The classes are balanced, there are equal numbers of positive and negative reviews for each movie. There are no more than 30 reviews per movie. Document lengths vary greatly — many are relatively short, but 2% are longer than 1000 words. It is also worth noting, that the labels are somewhat “noisy”, e.g. there are some positive reviews labeled as negative.

IMDB samples

To simulate a small dataset situation, yet make statistically significant comparisons between architectures, the experiments were conducted in the following way: models were trained on 1000-element samples of the IMDB training dataset, and evaluated on the whole test set. This was repeated 20 times for each architecture and hyper-parameter set. The training-related hyper-parameters and a single dropout multiplier (as recommended by ULMFiT authors) were tuned separately for Branching Attention and Concat Pooling heads.

Below, the results of the proposed architecture (Branching Attention) are compared with the baseline (Concat Pooling). Two additional variants are provided, as a minimal ablation study:

  • the AGG branch is removed (and encoder output is averaged with weights from the ATT branch),
  • the ATT branch is removed, and AGG-extracted features are weighted uniformly.
Table 1: Classification performance with different classifier heads, averaged from training on 20 1000-element datasets. Optimizing parameters of the whole network (with gradual unfreezing)

While by no means ground-breaking, the results above seem encouraging. Without changing the encoder — the most important part of the model — the number of incorrectly classified samples was reduced by over 10%. The proposed architecture seems to make better use of the encoder’s representation of text.

In the best configuration presented above, the Branching Attention head had 30% fewer parameterscompared to the original ULMFiT architecture. In other configurations, performing only slightly worse, the number of parameters decreased by up to 85%.

It is worth noting that removing either branch of the Branching Attention head causes a significant drop in performance — even below the baseline levels. While not shown here for brevity, reducing the depth of either branch (to a single layer) also results in a small, but consistent, drop in performance. The results support using the complete network architecture, with both branches, as previously described.

The distributions of accuracy scores from individual experiment runs are shown below, as a box plot. It is meant to visualize the variance of results for each configuration, resulting from different sampling of the training dataset and different random initialization.

Worth noting: the single, negative outlier present for all four architectures, denoted as a dot, comes from training on the same sample of the training dataset — presumably of inferior quality. Likely, if training was repeated with the same dataset each time (only changing the random seed), the variance of accuracy for each configuration would be much lower. However, the use of different dataset samples was meant to ensure that any conclusions are more general — not specific to a particular dataset. Overall, the improvement of mean accuracy cannot be reasonably explained as random noise.

Head-only training

What other benefits can we get from using a classifier head with a trainable aggregation, instead of a fixed one? We might not need to modify the parameters of the encoder at all. According to ULMFiT , we should first optimize the classifier head, and then gradually add layers of the encoder to the optimizer’s scope — this approach was described in the previous section. However, training exclusively the classifier’s head has important practical benefits, as discussed in the next section. The table below summarizes results obtained this way — only optimizing parameters of the classifier’s head.

Table 2: Classification performance with different classifier heads, averaged from training on 20 1000-element datasets. Optimizing only parameters of the classifier head

Somewhat surprisingly, for the full Branching Attention head, the results did not deteriorate at all. Training classifiers this way can be a viable, practical approach , especially for relatively small datasets . As expected, the gap between the proposed architecture and Concat Pooling has increased.

The variant without the AGG branch performed above expectations — better than when the whole network was optimized (in the previous section). It might indicate that the training process when optimizing the full network could be improved upon — but such attempts were not successful.

Benefits of head-only training

Why is it important that the proposed architecture performs much better than the baseline when only training the classifier head? There are several practical advantages to building systems this way. Training is faster and requires less memory. A less obvious benefit lies in the low number of parameters that are unique to the model being trained.

New use cases open in scenarios where many classifiers operate on texts from the same domain, e.g. “tweets” or “news articles” for a given language. It becomes very cheap to train, store and run a new model:

  • training takes less than a minute (on a modern GPU)
  • a stored classifier head takes ~100kB of disk space (ballpark figure, less for some configurations)
  • at inference time, documents for all models can be batched together, and efficiently encoded on the same GPU/TPU; the model-specific part of inference is small, can be quickly run on a CPU.

Where can it be useful? There are many possible scenarios, examples might include:

  • AutoML platforms — where a non-ML-savvy user provides a labeled dataset and expects an API to a model performing classification.
  • recommender systems — the “product” half of the cold start problem . For example, when a new article is posted on Medium, nobody has interacted with it. Traditional (collaborative filtering) algorithms recommend things based on the opinions of people with similar tastes. This approach does not work for new content. We could train “clap prediction” (rating prediction) models for a group of active users, based on their past activity. We could then generate artificial ratings for new articles before anyone had a chance to read them.
  • personalization , where enough data is available. Recommending social media posts based on their content, rather than other people’s interaction with it, hashtags, etc. Taking Twitter as an example, while it is infeasible to run all new tweets against all users’ preference models, it would be possible to use such preference models as a final filtering layer on top of the current system.

Attention visualization

By analyzing attention weights in a neural network we can understand it better — we can see which parts of the input are relevant to the task at hand. Interesting visualizations have been demonstrated e.g. in machine translation and image caption generation .

With Branching Attention, in some circumstances, we can take it a step further. If the classification problem is binary, we can achieve quite good results by setting the last (or only) dimension of the AGG branch to 1. Effectively, for each input token, we will obtain a single, scalar value, and the final decision will be a weighted average of these values (after a simple transformation). In the case of sentiment analysis, we can show “local” sentiment after processing each token, as well as importance scores of specific parts of the document .

See the example below, and check out the interactive demo . In the demo, it is possible to visualize both weights and sentiment at the same time (as in the example), just the weights or just the sentiment. It is also possible to test the model on a different piece of text.

Figure 6: Example movie review — attention visualization

Opacityof the color behind each token means the attention weight associated with it. Hue denotes the value of a feature (sentiment) calculated for that token, considering its left context , in a red-to-green scale.

A few observations:

  • The model correctly recognizes sentiment in most cases, and “pays attention” to relevant parts (although “holes started to appear in the story” was not recognized as relevant).
  • Negations , even if placed several words away from its object, are interpreted correctly (“I do not consider it a good movie”).
  • Large weights are assigned to full stop characters. Several possible explanations of this fact were explored, as described in the “What did not work” section. The remaining, working hypothesis: when encountering the full stop, the model has already processed a full statement, and subsequent words are unlikely to change its meaning. The model learns to access the content of whole sentences through the encoder’s state after processing the full stop.

Full IMDB, other datasets

To verify if the proposed architecture performs well generally, or just for the particular dataset type and size, several other experiments were conducted.

Full IMDB dataset

Table 3: classification performance on the full IMDB dataset

When training on the full IMDB dataset, in its default train/test split, the results are approximately equal to those of a Concat Pooling-based classifier (albeit Branching Attention seems to be less sensitive to the choice of training-related hyper-parameters). Intuitively, the large training dataset might contain enough information to modify the encoder in such a way, that its output will be meaningfully aggregated by average- and max-pooling.

Full IMDB dataset, head-only training

Table 4: classification performance on the full IMDB dataset, optimizing only the classifier head

When training on the full IMDB dataset, but optimizing only the classifier head, Branching Attention performs significantly better than the baseline. As expected, it performs worse than when the whole network is optimized. Nevertheless, this approach can be useful in some cases, as discussed in the “Benefits of head-only training” section.

Automatic Cyber-bullying Detection — PolEval 2019 Task 6

Using an older version of the code , a system based on Branching Attention was entered into the PolEval 2019 contest — task 6. Its goal was to prepare a system assessing which tweets, written in Polish, might contain cyber-bullying content. A crowd-labeled training set of ~10,000 example tweets was provided. It was highly imbalanced — with an overwhelming majority of examples not containing cyber-bullying.

The system used a sub-word tokenization mechanism — SentencePiece — to deal with misspellings common on Twitter. This type of tokenization is generally beneficial for Polish — as it is a fusional language. Its abundant pre- and post-fixes dramatically increase the number of unique words — which is a problem with traditional, word-based tokenization approaches. The language model was trained on Polish Wikipedia and then fine-tuned on a large, unlabelled dataset of Polish-language tweets.

In the contest, the system came in 3rd( proceedings , page 106), behind an ULMFiT — like system with a larger, improved language model (similar to, and by co-authors of, MultiFiT ). However, it ranked just ahead of a BERT-based system.

Performance of Branching Attention and Concat Pooling was then compared in the same way as in the “IMDB sample” experiment, with hyper-parameters tuned on the IMDB dataset . Both models were trained 20 times, each time on a different sub-sample of the training dataset, and with different random initialization. To directly compare models’ performance and avoid issues with setting a detection threshold, the Average Precision metric was chosen.

Table 5: Average Precision in Cyber-bullying detection tasks, 20 repetitions, 1000-element training dataset samples

In a completely different context — different language, tokenization strategy, different (more complex) task, without further hyper-parameter optimization, Branching Attention provided a modest, but noticeable performance improvement . It can be considered a successful test of the proposed architecture.

Question categorization — SemEval-2019 Task 8A

Another verification of the method’s results was performed using the Fact-Checking Questions dataset. The dataset was extracted from “Qatar Living” — a community question-answering forum. Each document — first post in a topic — consists of a category, a subject line, and the post’s content. The fields were concatenated together, with delimiters between them. The task: assigning a category to each question: “factual”, “opinion/advice” or “social”.

In an experimental setup identical to the one described above, with the same hyper-parameters, the obtained results were:

Table 5: Accuracy in question categorization tasks, 20 repetitions, 1000-element training dataset samples

Once again, the network using Branching Attention performed significantly better— despite both networks sharing the same pre-trained encoder.

What did not work, possible improvements

Several attempts to improve the network’s performance were made:

  1. Bi-directional encoder . The output of the encoder was concatenated with the output of a second one, trained to process text backwards. It provided a richer representation of the document, at each point conditioned on its whole content. GPU memory requirements increased, approximately by a factor of 2, but performance did not improve noticeably. However, due to memory constraints, other parameters had to be changed — possibly degrading performance. Further experiments, utilizing GPUs with more memory (ideally 16GB+), could change this conclusion. Joint training of the forward and backward language models could improve results, as well as save some memory — by weight sharing in the embeddings/softmax layers.
    A similar result was reported by the authors of ULMFiT . However, it is worth noting, that Branching Attention aggregation should be able to make better use of the additional information — by, for each sequence element, calculating relevant features, and only then averaging them. Overall, while the initial results were negative, this idea still seems to be worth further research.
  2. Using earlier layers of the encoder. Following ELMO , outputs of the earlier layers of the encoder were used as additional inputs to the classifier head (LSTM-1 and LSTM-2 in). It did not improve the results , but — according to the original authors — optimal layer or combination of layers to use might vary between tasks.
  3. Using the representation of the last token. One of the hypotheses, why the model learns to assign high weights to the full stop characters, was: it looks for the end of the document. Representation of the last token is conditioned on the whole document, so it could hypothetically contain all the relevant information. However, directly incorporating the last token’s representation in various ways did not improve classification results.
  4. Using LSTM’s “cell” values. For each element of the input sequence, an LSTM layer produces two vectors: “cell” state values c and output/hidden values h . While c values are meant to store information for longer, output values are produced as a gated version of c , and the output gate is parameterized by the local context. Normally, only h values are passed to the next layer (see).
    Another idea why representations of full stop characters might be weighted heavily: in a language model setting they are used to predict the first word of a new sentence. To do that, they have to access more of the document’s context — possibly causing the output gates to be more open.
    This hypothesis inspired an idea to directly use the non-gated c values in the classifier head. However, the classification results did not improve (and the efficiency of computations was reduced because the optimized cuDNN implementation could not be used).
    Upon closer inspection of corresponding c and h vectors in several examples, h values did not resemble gated versions of c (in a sense of approximating hard, binary gating). It appeared as if the gating operation simply provided an additional, trainable, non-linear transformation — somewhat similar to a size 1 convolution. Further analysis of what happens in the output gates of LSTM networks, and how it corresponds to the syntax of the text being processed, seems to be an interesting topic in its own right, but beyond the scope of this article.


  • Across several tasks, the proposed solution outperformed the original ULMFiT architecture or performed equally well.
  • Classifier heads based on “Branching Attention” have an advantage in case of relatively small datasets, and if just the classifier head is being optimized.
  • Optimizing only a classifier head, on top of a pre-trained language model, can be a viable strategy if the labeled training data is scarce and/or many classifiers operating on similar documents are needed. In the latter scenario, if just the parameters of classifier heads are different, a much more efficient deployment is possible, and far less disk space is needed.
  • All variants of the Branching Attention architecture add some degree of explainability to the classifier. Not only can we find out which parts were deemed important, but also — to some degree — how they were understood (see demo ).
  • In its recommended configuration, the proposed architecture has fewer parameters and requires fewer computations than the baseline. Since its performance seems to be mostly better or equal to that of the baseline, it can be considered as a drop-in replacement for classifiers using the original ULMFiT configuration.

Code and the fine-tuned language model, needed to reproduce this article’s results, are provided.