Javed Qadrud-Din was an Insight Fellow in Fall 2017. He is currently a machine learning engineer at Casetext where he works on natural language processing for the legal industry. Prior to Insight, he was at IBM Watson.

In late 2018, Google open-sourced BERT, a powerful deep learning algorithm for natural language processing. BERT can be pre-trained on a massive corpus of unlabeled data, and then fine-tuned to a task for which you have a limited amount of data. This allows BERT to provide significantly higher performance than models that are only able to leverage a small task-specific dataset. In the first problem that we applied BERT at Casetext (the company where I work), we obtained a 66% improvement in accuracy over the best model we had tried up until that point. This post will show you how to fine-tune BERT for a simple text classification task of your own.

How BERT works: a brief overview

BERT learns useful text representations by being pre-trained on two different tasks:

1. In a sentence with two words removed, BERT is trained to predict what those two words are, and

2. Given two sentences, BERT is trained to determine whether one of these sentences comes after the other in a piece of text, or whether they are just two unrelated sentences.

The beauty of using these two tasks to do the pre-training, is that the training sets can be obtained programmatically, rather than through costly human annotation efforts. As a result, BERT can be pre-trained on a truly massive corpus of text, in the process learning rich representations of language that are impossible to learn with small labeled datasets.

BERT’s final layers can then be fine-tuned on a task of your choosing that will benefit from the rich representations of language it learned during pre-training.

Using BERT for text classification

Google’s documentation on BERT is generally good, but how to use BERT on a simple text classification task isn’t immediately obvious. By “simple text classification task,” we mean a task in which you want to classify/categorize portions of text that are roughly one sentence to a paragraph in length. The BERT documentation shows you how to classify the relationships between pairs of sentences, but it doesn’t detail how to use BERT to label single portions of text.

After digging into the BERT code to figure out how to do it, I’ve come up with this workflow, which will let you easily classify single passages of text:

Clone the BERT github repo onto your own machine. Just open up your terminal and type the following:

2. Download the BERT model files. These are the weights and other necessary files to represent the information BERT learned in pre-training. You’ll need to pick which BERT pre-trained weights you want. If you don’t have access to a Google TPU, you’ll want to pick one of the “base” models. You should pick a “cased” model or an “uncased” model depending on whether you think letter casing will be helpful for the task you’re trying to solve. Save this into the directory where you cloned the git repository and unzip it. Here are links to the files for English:

Files for other languages can be found on the BERT project github page.

3. Put your data into the format BERT expects. Create a folder in the directory where you cloned BERT. You’ll be adding three separate files there called train.tsv dev.tsv and test.tsv (tsv, for tab separated values). In train.tsv and dev.tsv you should have four columns with no headers as follows:

Column 1: An ID for the row (can be just a count, or even just the same number or letter for every row, if you don’t care to keep track of each individual example).

Column 2: A label for the row as an int. These are the classification labels that your classifier aims to predict.

Column 3: A column of all the same letter — this is a throw-away column that you need to include because the BERT model expects it.

Column 4: The text examples you want to classify.

Here is an example of what the data in train.tsv and dev.tsv should look like:

1 0 a an example of text that should fit in class 0

2 1 a an example of text that should fit in class 1

3 0 a another class 0 example

4 2 a a class 2 example

test.tsv should have a slightly different format.

Column 1: an ID for each example, similar to column 1 in the train and dev files, and

Column 2: the text you want to classify. Also, test.tsv should have a header line (whereas train and dev should not). Here is an example of what test.tsv should look like:

id sentence

1 my first test example

2 another test example. Yay this is fun!

3 yet another test example

If you’re looking for an easy way to get data into this format, I recommend making it into a CSV file and then using the pandas Python package to convert it into a TSV. If you don’t already have a CSV file containing your data, you can make one by using a tool like Google Sheets and exporting as a CSV. If you do this, make sure to put your columns in the right order before you export it. Here’s code to use pandas to convert CSV to TSV:

import pandas as pd

df = pd.read_csv('path/to/your/csv/here.csv')

df.to_csv('path/of/your/choice.tsv', sep='\t', index=False, header=False)

# if you are creating test.tsv, set header=True instead of False

You can make your own choice for much of your data you want in train, test and dev sets, but a good rule of thumb is 80% in train, and 10% each in dev and test.

4. Run training. Navigate to the directory you cloned BERT into, and type the following commands (or put them in a shell script and run the script).

export BERT_BASE_DIR=./path/to/weights/downloaded/in/step2 python bert/run_classifier.py \

--task_name=cola \

--do_train=true \

--do_eval=true \

--data_dir=./data \

--vocab_file=$BERT_BASE_DIR/vocab.txt \

--bert_config_file=$BERT_BASE_DIR/bert_config.json \

--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \

--max_seq_length=128 \

--train_batch_size=32 \

--learning_rate=2e-5 \

--num_train_epochs=3.0 \

--output_dir=./bert_output/

If you get an out of memory error, you may need to run this on a machine with a GPU that has more on-board RAM or a TPU (see instructions for TPUs in the BERT github repo). You can try to fix this issue by reducing the training_batch_size, though the training will run slower as a result. If your typical text is longer than 128 words, you can increase max_seq_length up to a max of 512, though the model will run slower if you do this and you may get an out of memory error.

It can take a long time to train, so this step may take a while. You should see output regarding progress as it runs.

Once it’s finished running, you’ll get reports on how the model did in the bert_output directory.

Using BERT to predict on new data

If you want to run inference on new data, you can put that data into test.tsv in the same format as we did in step 3 above. Then go into the bert_output directory and note the number of the highest-number model.ckpt file you see there. This set of files contains the weights for the model you trained. Once you’ve determined the highest checkpoint number, run the following commands in the terminal or through a shell script:

export BERT_BASE_DIR=./path/to/weights/downloaded/in/step2

export TRAINED_CLASSIFIER=./bert_output/model.ckpt-[highest checkpoint number you saw] python bert/run_classifier.py \

--task_name=cola \

--do_predict=true \

--data_dir=./data \

--vocab_file=$BERT_BASE_DIR/vocab.txt \

--bert_config_file=$BERT_BASE_DIR/bert_config.json \

--init_checkpoint=$TRAINED_CLASSIFIER \

--max_seq_length=128 \

--output_dir=./bert_output/

Make sure the max_seq_length parameter is the same as you set it to during training. You should now get a file in bert_output called test_results.tsv . This file will have a number of columns equal to the number of classes you were aiming to classify, with the probability of each class for each example in each row. The rows are in the same order as the rows of data you had in test.tsv .

The power of BERT in practice

BERT has greatly increased our capacity to do transfer learning in NLP — an important step forward for the field. As I previously mentioned, for the first problem that we applied BERT to at Casetext, we obtained a 66% improvement in accuracy over the best model we had tried up until that point.

Casetext is now using BERT as part of our system to determine whether judicial opinions in our database have been overturned. To do this, humans hand-annotated approximately 10,000 examples of sentences in legal text where a court has overturned a prior court’s decision. Leveraging BERT’s unsupervised pre-training allowed us to obtain excellent results, even with this relatively small number of hand-labeled examples. In our production system, we set a threshold score on our machine learning model to the point where we will catch all (or very nearly all) sentences that contain overturning language. In other words, we set the threshold to weight very strongly for recall, at the expense of precision. We then have human experts review all of these potentially-overturning sentences and select the ones that are truly overturning. Before using BERT, we needed experts to read 9.8% of all sentences in every new judicial opinion that came out. Now, they need to read only 3.32% — a massive reduction in a costly process.

Interestingly, we achieved this substantial improvement without doing our own BERT pre-training on a corpus of legal text. We used the pre-training from Google as is, which was trained on web data and Wikipedia. Once we have time to pre-train on our own legal corpus, we expect to obtain even better results.

If you have questions, feel free to post them to me. If you found this helpful, please do give a clap!