Deconstructing BERT: Distilling 6 Patterns from 100 Million Parameters

From BERT’s tangled web of attention, some intuitive patterns emerge.

The year 2018 marked a turning point for the field of Natural Language Processing, with a series of deep-learning models achieving state-of-the-art results on NLP tasks ranging from question answering to sentiment classification. Most recently, Google’s BERT algorithm has emerged as a sort of “one model to rule them all,” based on its superior performance over a wide variety of tasks.

BERT builds on two key ideas that have been responsible for many of the recent advances in NLP: (1) the transformer architecture and (2) unsupervised pre-training. The transformer is a sequence model that forgoes the recurrent structure of RNN’s for a fully attention-based approach, as described in the classic Attention Is All You Need. BERT is also pre-trained; its weights are learned in advance through two unsupervised tasks: masked language modeling (predicting a missing word given the left and right context) and next sentence prediction (predicting whether one sentence follows another). Thus BERT doesn’t need to be trained from scratch for each new task; rather, its weights are fine-tuned. For more details about BERT, check out the The Illustrated Bert.

BERT is a (multi-headed) beast

Bert is not like traditional attention models that use a flat attention structure over the hidden states of an RNN. Instead, BERT uses multiple layers of attention (12 or 24 depending on the model), and also incorporates multiple attention “heads” in every layer (12 or 16). Since model weights are not shared between layers, a single BERT model effectively has up to 24 x 16 = 384 different attention mechanisms.

Visualizing BERT

Because of BERT’s complexity, it can be difficult to intuit the meaning of its learned weights. Deep-learning models in general are notoriously opaque, and various visualization tools have been developed to help make sense of them. However, I hadn’t found one that could shed light on the attention patterns that BERT was learning. Fortunately, Tensor2Tensor has an excellent tool for visualizing attention in encoder-decoder transformer models, so I modified this to work with BERT’s architecture, using a PyTorch implementation of BERT. The adapted interface is shown below, and you can run it yourself using the notebooks on Github.

The tool visualizes attention as lines connecting the position being updated (left) with the position being attended to (right). Colors identify the corresponding attention head(s), while line thickness reflects the attention score. At the top of the tool, the user can select the model layer, as well as one or more attention heads (by clicking on the color patches at the top, representing the 12 heads).

What does BERT actually learn?

I used the tool to explore the attention patterns of various layers / heads of the pre-trained BERT model (the BERT-Base, uncased version). I experimented with different input values, but for demonstration purposes, I just use the following inputs:

Sentence A: I went to the store. Sentence B: At the store, I bought fresh strawberries.

BERT uses WordPiece tokenization and inserts special classifier ([CLS]) and separator ([SEP]) tokens, so the actual input sequence is: [CLS] i went to the store . [SEP] at the store , i bought fresh straw ##berries . [SEP]

I found some fairly distinctive and surprisingly intuitive attention patterns. Below I identify six key patterns and for each one I show visualizations for a particular layer / head that exhibited the pattern.

Pattern 1: Attention to next word

In this pattern, most of the attention at a particular position is directed to the next token in the sequence. Below we see an example of this for layer 2, head 0. (The selected head is indicated by the highlighted square in the color bar at the top.) The figure on the left shows the attention for all tokens, while the one on the right shows the attention for one selected token (“i”). In this example, virtually all of the attention is directed to “went,” the next token in the sequence.