You may know it’s impossible to define the best text classifier. In fields such as computer vision, there’s a strong consensus about a general way of designing models − deep networks with lots of residual connections. Unlike that, text classification is still far from convergence on some narrow area.

In this article, we’ll focus on the few main generalized approaches of text classifier algorithms and their use cases. Along with the high-level discussion, we offer a collection of hands-on tutorials and tools that can help with building your own models.

Text Classification Benchmarks

The toolbox of a modern machine learning practitioner who focuses on text mining spans from TF-IDF features and Linear SVMs, to word embeddings (word2vec) and attention-based neural architectures.

It’s important to distinguish two cases when the effectiveness of a certain method is demonstrated: research and competition.

When researchers compare the text classification algorithms, they use them as they are, probably augmented with a few tricks, on well-known datasets that allow them to compare their results with many other attempts on the same problem.

Some well-known text classification benchmarks:

AG’s news articles

Sogou news corpora

Amazon Review Full

Amazon Review Polarity

DBPedia

Yahoo Answers

Yelp Review Full

Yelp Review Polarity

We’ve made a special folder on google drive so you could download them right away.

Deep vs. Shallow Learning

The really remarkable thing about the datasets widely adopted in NLP research is that both simple and very complex models work on them very well. To showcase this, let’s discuss two papers:

The datasets in both cases are the same, and the results in terms of precision are roughly the same across all the experiments. But the training and inference time varies greatly between the two.

The first model takes literally seconds to train, while the second needs several hours, which would be a game changer when it comes to choosing the hyperparameters.

What makes this approach interesting is that their model doesn’t make any assumptions about the data. At the lowest level they treat the text as a sequence of characters, allowing the convolutional layers to build the features in a completely content-agnostic way.

The second paper features a much lighter model that’s designed to work fast on a CPU and consists of a joint embedding layer and a softmax classifier.

On the other hand, if you take a look at some of the winning solutions on Kaggle, you’ll see they are dominated by highly customized complex ensembles.

A good example would be the recent Quora Question Pairs competition and ongoing DeepHack.Turing, where top-ranking solutions consist of several different models: gradient boosting machines, RNNs, and CNNs.

The practical lesson we can learn here is that despite the results of certain methods published in research, getting the best performance from the particular tasks in vivo is closer to art than to science, requiring careful tuning of complicated pipelines.

The striking contrast with the research here can be seen in a writeup for a winning solution on Kaggle.

Neural network-based text classifiers typically follow the same linear meta architecture:

Embedding

Deep representation

Fully connected part

Embedding

Embedding layers take a sequence of word ids as an input and produce a sequence of corresponding vectors as an output. Their functionality is really straightforward, and since the actual semantics of those vectors are not interesting for our problem, the only remaining question is “What is the best way to initialize the weights?”

Depending on the problem, the answers may be as counterintuitive as the advice “generate your own synthetic labels, train word2vec on them, and init the embedding layer with them.”

But for all practical purposes you can use a pre-trained set of embeddings and jointly fine-tune it for your particular model. It’s likely that resulting word vectors will cease to demonstrate the same properties as they do in a vanilla word2vec model:

But it doesn’t matter in this case.

The go-to solution here is to use pretrained word2vec embeddings and try to use lower learning rates for the embedding layer (multiply general learning rate by 0.1).

Deep representation

The main purpose of the deep representation part is to condense all relevant information in its output while suppressing the parts that could lead to identifying a single sample from it. This is highly desirable because the network with high capacity is likely to overfit on particular examples and perform poorly on the test set.

Recurrent neural network (RNN)

When the problem consists of obtaining a single prediction for a given document (spam/not spam), the most straightforward and reliable architecture is a multilayer fully connected text classifier applied to the hidden state of a recurrent network. Semantics of this state are considered irrelevant, and the entire vector is treated as a compressed description of the text.

Here are several useful sources:

• A great starting point for understanding how to use LSTMs for text classification (in this case — sentiment analysis).

• LSTM Tutorial for PyTorch

• Almost classic tutorial by Chris Olah

Since the main work is being done in the recurrent layer, it’s important to make sure that it captures only the relevant information. It’s a frequent challenge for natural language applications and an open scientific problem.

On a high level, there are two things that can be done here:

Use Bidirectional LSTMs. This is almost always a good idea, because it essentially captures the context around each word, instead of sequential “reading.” Use a transitional layer for embeddings. LSTMs learn to distinguish important and unimportant parts of the sequence by themselves, but we can’t be sure that the representation from the embedding layer is the best input, especially if we don’t finetune the embeddings. Adding a layer that’s applied to each word embedding independently can improve your results, acting as a simple attention layer.

Convolutional neural network (CNN)

An alternative way to train a deep text classifier is to use convolutional networks. Typically, given a large enough receptive field, you can achieve the same results as with a dedicated attention layer. There’s no single trick here, but keeping a lot of feature maps in the beginning and reducing their number exponentially later helps to avoid learning irrelevant patterns.

Take a look at this simple implementation of CNN classifier in PyTorch. It shows how to train and evaluate a convolutional classifier with its own embedding layer.

Dense Classifier

A fully-connected part performs a series of transformations on the deep representation and finally outputs the scores for each class. The best practice here is to apply the transformations as follows: