TLDR: In Active Learning we use a “human in the loop” approach to data labelling, reducing the amount of data that needs to be labelled drastically, and making machine learning applicable when labelling costs would be too high otherwise. In our paper we present BatchBALD: a new practical method for choosing batches of informative points in Deep Active Learning which avoids labelling redundancies that plague existing methods. Our approach is based on information theory and expands on useful intuitions. We have also made our implementation available on GitHub at https://github.com/BlackHC/BatchBALD.

What’s Active Learning?

Using deep learning and a large labelled dataset, we are able to obtain excellent performance on a range of important tasks. Often, however, we only have access to a large unlabelled dataset. For example, it is easy to acquire lots of stock photos, but labelling these images is time consuming and expensive. This excludes many applications from benefiting from recent advances in deep learning.

In Active Learning we only ask experts to label the most informative data points instead of labelling the whole dataset upfront. The model is then retrained using these newly acquired data points and all previously labelled data points. This process is repeated until we are happy with the accuracy of our model.

Figure 1: Active learning loop. The active learning steps of retraining, scoring, labelling, and acquisition are repeated until the model has sufficient accuracy.

To perform Active Learning, we need to define some measure of informativeness, which is often done in the form of an acquisition function. This measure is called an “acquisition function” because the score it computes determines which data points we want to acquire. We send unlabelled data points which maximise the acquisition function to an expert and ask for labels.

The problem is…

Usually, the informativeness of unlabelled points is assessed individually, with one popular acquisition function being BALD . However, assessing informativeness individually can lead to extreme waste because a single informative point can have lots of (near-identical) copies. This means that if we naively acquire the top-K most informative points, we might end up asking an expert to label K near-identical points!

Figure 2: BALD scores (informativeness) for 1000 randomly-chosen points from the MNIST dataset (hand-written digits). The points are colour-coded by digit label and sorted by score. The model used for scoring has been trained to 90% accuracy first. If we were to pick the top scoring points (e.g. scores above 0.6), most of them would be 8s ( █ ), even though we can assume that after acquiring the first couple of them our model would consider them less informative than other available data. Points are slightly shifted on the x-axis by digit label to avoid overlaps.

Our contribution

In our work, we efficiently expand the notion of acquisition functions to batches (sets) of data points and develop a new acquisition function that takes into account similarities between data points when acquiring a batch. For this, we take the commonly-used BALD acquisition function and extend it to BatchBALD in a grounded way, which we will explain below.

Figure 3: Idealised acquisitions of BALD and BatchBALD. If a dataset were to contain many (near) replicas for each data point, then BALD would select all replicas of a single informative data point at the expense of other informative data points, wasting data efficiency.

However, knowing how to score batches of points is not sufficient! We still have the challenge of finding the batch with the highest score. The naive solution would be to try all subsets of data points, but that wouldn’t work because there are exponentially many possibilities.

For our acquisition function, we found that it satisfies a very useful property called submodularity which allows us to follow a greedy approach: selecting points one by one, and conditioning each new point on all points previously added to the batch. Using the submodularity property, we can show that this greedy approach finds a subset that is “good enough” (i.e. 1 − 1 / e 1-1/e 1−1/e-approximate).

Overall, this leads our acquisition function BatchBALD to outperform BALD: it needs fewer iterations and fewer data points to reach high accuracy for similar batch sizes, significantly reducing redundant model retraining and expert labelling, hence cost and time.

Moreover, it is empirically as good as, but much faster than, the optimal choice of acquiring individual points sequentially, where we retrain the model after every single point acquisition.

(a) Performance on MNIST. BatchBALD outperforms BALD with acquisition size 10 and performs close to the optimum of acquisition size 1 (b) Relative total time on MNIST. Normalized to training BatchBALD with acquisition size 10 to 95% accuracy. The stars mark when 95% accuracy is reached for each method. Figure 4: Performance and training duration of BALD and BatchBALD on MNIST. BatchBALD with acquisition size 10 performs no different than BALD with acquisition size 1, but it only requires a fraction of the time because it needs to retrain the model fewer times. Compared to BALD with acquisition size 10, BatchBALD also requires fewer acquisitions to reach 95% accuracy.

Before we explain our acquisition function, however, we need to understand what the BALD acquisition function does.

What’s BALD?

BALD stands for “Bayesian Active Learning by Disagreement” .

As the “Bayesian” in the name tells us, this assumes a Bayesian setting which allows us to capture uncertainties in the predictions of our model. In a Bayesian model, the parameters are not just numbers (point estimates) that get updated during training but probability distributions.

This allows the model to quantify its beliefs: a wide distributions for a parameter means that the model is uncertain about its true value, whereas a narrow one quantifies high certainty.

BALD scores a data point x x x based on how well the model’s predictions y y y inform us about the model parameters ω \boldsymbol{\omega} ω. For this, it computes the mutual information I ( y , ω ) \mathbb{I}(y, \boldsymbol{\omega}) I(y,ω). Mutual information is well-known in information theory and captures the information overlap between quantities.

When using the BALD acquisition function to select a batch of b b b points, we select the top- b b b points with highest BALD scores, which is standard practice in the field. This is the same as maximising the following batch acquisition function a B A L D ( { x 1 , … , x b } , p ( ω ∣ D t r a i n ) ) : = ∑ i = 1 b I ( y i ; ω ∣ x i , D t r a i n ) a_{\mathrm{BALD}}\left(\left\{x_{1}, \ldots, x_{b}\right\}, \mathrm{p}\left(\boldsymbol{\omega} | \mathcal{D}_{\mathrm{train}}\right)\right) :=\sum_{i=1}^{b} \mathbb{I}\left(y_{i} ; \boldsymbol{\omega} | x_{i}, \mathcal{D}_{\mathrm{train}}\right) aBALD​({x1​,…,xb​},p(ω∣Dtrain​)):=i=1∑b​I(yi​;ω∣xi​,Dtrain​) with { x 1 ∗ , … , x b ∗ } : = arg ⁡ max ⁡ { x 1 , … , x b } ⊆ D pool a B A L D ( { x 1 , … , x b } , p ( ω ∣ D train ) ) . \left\{x_{1}^{*}, \ldots, x_{b}^{*}\right\} := \underset{\left\{x_{1}, \ldots, x_{b}\right\} \subseteq \mathcal{D}_{\text { pool }}}{\arg \max } a_{\mathrm{BALD}}\left(\left\{\boldsymbol{x}_{1}, \ldots, \boldsymbol{x}_{b}\right\}, \mathrm{p}\left(\boldsymbol{\omega} | \mathcal{D}_{\text { train }}\right)\right). {x1∗​,…,xb∗​}:={x1​,…,xb​}⊆D pool ​argmax​aBALD​({x1​,…,xb​},p(ω∣D train ​)). Intuitively, if we imagine the information content of the predictions given some data points and the model parameters as sets in the batch case, the mutual information can be seen as intersection of these sets, which captures the notion that mutual information measures the information overlap.

Figure 5: Intuition behind BALD. Areas in grey contribute to the BALD score. Areas in dark grey are double-counted.

In fact, Yeung shows that this intuition is well-grounded, and we can define an information measure μ ∗ \mu^* μ∗ that allows us to express information-theoretic quantities using set operations: H ( x , y ) = μ ∗ ( x ∪ y ) I ( x , y ) = μ ∗ ( x ∩ y ) E p ( y ) H ( x ∣ y ) = μ ∗ ( x ∖ y ) \begin{aligned} \mathbb{H}(x,y) &= \mu^*(x \cup y) \\ \mathbb{I}(x,y) &= \mu^*(x \cap y) \\ \mathbb{E}_{p(y)} \mathbb{H}(x | y) &= \mu^*(x \setminus y) \end{aligned} H(x,y)I(x,y)Ep(y)​H(x∣y)​=μ∗(x∪y)=μ∗(x∩y)=μ∗(x∖y)​ Figure 5 visualizes the scores that BALD computes as area of the intersection of these sets when acquiring a batch of 3 points. Because BALD is a simple sum, mutual information between data points is double-counted, and BALD overestimates the true mutual information. This is why naively using BALD in a dataset with lots of (near-identical) copies of the same point will lead us to select all the copies: we double count the mutual information intersection between all!

Figure 6: Intuition behind BatchBALD. BatchBALD takes into account similarities between the data points.

BatchBALD

In order to avoid double-counting, we want to compute the quantity μ ∗ ( ⋃ i y i ∩ ω ) \mu^*(\bigcup_i y_i \cap \boldsymbol{\omega}) μ∗(⋃i​yi​∩ω) , as depicted in figure 6, which corresponds to the mutual information I ( y 1 , . . . , y b ; ω ∣ x 1 , . . . , x b , D t r a i n ) \mathbb{I}(y_1,…,y_b ; \boldsymbol{\omega} | x_1,…,x_b, \mathcal{D}_\mathrm{train}) I(y1​,...,yb​;ω∣x1​,...,xb​,Dtrain​) between the joint of the y i y_i yi​ and ω \boldsymbol{\omega} ω : a B a t c h B A L D ( { x 1 , … , x b } , p ( ω ∣ D t r a i n ) ) : = I ( y 1 , … , y b ; ω ∣ x 1 , … , x b , D t r a i n ) . a_{\mathrm{BatchBALD}}\left(\left\{x_{1}, \ldots, x_{b}\right\}, \mathrm{p}\left(\boldsymbol{\omega} | \mathcal{D}_{\mathrm{train}}\right)\right) := \mathbb{I}\left(y_{1}, \ldots, y_{b} ; \boldsymbol{\omega} | x_{1}, \ldots, x_{b}, \mathcal{D}_{\mathrm{train}}\right). aBatchBALD​({x1​,…,xb​},p(ω∣Dtrain​)):=I(y1​,…,yb​;ω∣x1​,…,xb​,Dtrain​). Expanding the definition of the mutual information, we obtain the difference between the following two terms: a B a t c h B A L D ( { x 1 , … , x b } , p ( ω ∣ D t r a i n ) ) = H ( y 1 , … , y b ∣ x 1 , … , x b , D t r a i n ) − E p ( ω ∣ D t r a i n ) [ H ( y 1 , … , y b ∣ x 1 , … , x b , ω ) ] . a_{\mathrm{BatchBALD}}\left(\left\{x_{1}, \ldots, x_{b}\right\}, \mathrm{p}(\boldsymbol{\omega} | \mathcal{D}_{\mathrm{train}})\right) = \mathbb{H}\left(y_{1}, \ldots, y_{b}\right | x_{1}, \ldots, x_{b}, \mathcal{D}_{\mathrm{train}})-\mathbb{E}_{\mathrm{p}(\boldsymbol{\omega} | \mathcal{D}_{\mathrm{train}} )}\left[\mathbb{H}\left(y_{1}, \ldots, y_{b} | x_{1}, \ldots, x_{b}, \boldsymbol{\omega}\right)\right]. aBatchBALD​({x1​,…,xb​},p(ω∣Dtrain​))=H(y1​,…,yb​∣x1​,…,xb​,Dtrain​)−Ep(ω∣Dtrain​)​[H(y1​,…,yb​∣x1​,…,xb​,ω)]. The first term captures the general uncertainty of the model. The second term captures the expected uncertainty for a given draw of the model parameters.

We can see that the score is going to be large when the model has different explanations for the data point that it is confident about individually (yielding a small second term) but the predictions are disagreeing with each other (yielding a large first term), hence the “by Disagreement” in the name.

Submodularity

Now to determine which data points to acquire, we are going to use submodularity.

Given a function f : Ω → R f: \Omega \to \mathbb{R} f : Ω → R , we call f f f submodular, if: f ( A ∪ { x , y } ) − f ( A ) ≤ ( f ( A ∪ { x } ) − f ( A ) ) + ( f ( A ∪ { y } ) − f ( A ) ) , f(A \cup \{x, y\}) - f(A) \le \left ( f(A \cup \{x\}) - f(A) \right ) + \left ( f(A \cup \{y\}) - f(A) \right ), f ( A ∪ { x , y } ) − f ( A ) ≤ ( f ( A ∪ { x } ) − f ( A ) ) + ( f ( A ∪ { y } ) − f ( A ) ) , for all A ⊆ Ω A \subseteq \Omega A ⊆ Ω and elements x , y ∈ Ω x,y \in \Omega x , y ∈ Ω . We show in appendix A of the paper that our acquisition function fulfils this property. Submodularity tells us that there are diminishing returns: selecting two points increases the score more than just adding either one of them individually but less than the separate improvements together:We show in appendix A of the paper that our acquisition function fulfils this property.

Nemhauser et al. have shown that, for submodular functions, one can use a greedy algorithm to pick points with a guarantee that their score is at least 1 − 1 / e ≈ 63 % 1-1/e \approx 63\% 1−1/e≈63% as good as the optimal one. Such an algorithm is called 1 − 1 / e 1-1/e 1−1/e-approximate.

The greedy algorithm starts with an empty batch A = { } A = \{\} A={} and computes a B a t c h B A L D ( A ∪ { x } ) a_{\mathrm{BatchBALD}}(A \cup \{x\}) aBatchBALD​(A∪{x}) for all unlabelled data points, adds the highest-scoring x x x to A A A and repeats this process until A A A is of acquisition size.

This is explained in more detail in the paper.

Consistent MC Dropout

We implement Bayesian neural networks using MC dropout . However, as an important difference to other implementations, we require consistent MC dropout: to be able to compute the joint entropies between data points, we need to compute a B a t c h B A L D a_{\mathrm{BatchBALD}} aBatchBALD​ using the same sampled model parameters.

To see why, we have investigated how the scores change with different sets of sampled model parameters being used in MC dropout inference in figure 7.

Without consistent MC dropout, scores would be sampled using different sets of sampled model parameters, losing function correlations between the y i y_i yi​’s for near-by x i x_i xi​’s, and would essentially be no different than random acquisition given the spread of their scores.

Figure 7: BatchBALD scores for different sets of 100 sampled model parameters. This shows the BatchBALD scores for a 1000 randomly picked points out of the pool set while selecting the 10th point in a batch for an MNIST model that has already reached 90% accuracy. The scores for a single set of 100 model parameters is shown in blue. The BatchBALD estimates show strong banding with the score differences between different sets of sampled parameters being larger than the differences between different data points for a given set within a single band “trajectory”).

Experiments on MNIST, Repeated MNIST and EMNIST

We have run experiments on classifying EMNIST, which is a dataset of handwritten letters and digits consisting of 47 classes and 120000 data points.

Figure 8: Examples of all 47 classes of EMNIST.

We can show improvement over BALD which performs worse (even compared to random acquisition!) when acquiring large batches:

Figure 9: Performance on EMNIST. BatchBALD consistently outperforms both random acquisition and BALD while BALD is unable to beat random acquisition.

This is because compared to BatchBALD and random, BALD actively selects redundant points. To understand this better, we can look at the acquired class labels and compute the entropy of their distribution. The higher the entropy, the more diverse the acquired labels are:

Figure 10: Entropy of acquired class labels over acquisition steps on EMNIST. BatchBALD steadily acquires a more diverse set of data points.

We can also look at the actual distribution of acquired classes at the end of training, and see that BALD undersamples some classes while BatchBALD manages to pick data points from different classes more uniformly (without knowing the classes, of course). Random acquisition also picks classes more uniformly than BALD, but not as well as BatchBALD. Figure 14: Histogram of acquired class labels on EMNIST. BatchBALD left, random acquisition center, and BALD right. Classes are sorted by number of acquisitions. Several EMNIST classes are underrepresented in BALD and random acquisition while BatchBALD acquires classes more uniformly. The histograms were created from all acquired points.

Figure 11: Histogram of acquired class labels on EMNIST. BatchBALD left and BALD right. Classes are sorted by number of acquisitions, and only the lower half is shown for clarity. Several EMNIST classes are underrepresented in BALD while BatchBALD acquires classes more uniformly. The histograms were created from all acquired points.

To see how much better BatchBALD copes with pathological cases, we also experimented with a version of MNIST that we call Repeated MNIST. It is simply MNIST repeated 3 time with some added Gaussian noise and shows how BALD falls into a trap where picking the top b b b individual points is detrimental because there are too many similar points. But BALD is not the only acquisition function to fail in this regime. Figure 15: Performance on Repeated MNIST. BALD, BatchBALD, Var Ratios, Mean STD and random acquisition: acquisition size 10 with 10 MC dropout samples.

Figure 12: Performance on Repeated MNIST with acquisition size 10. BatchBALD outperforms BALD while BALD performs worse than random acquisition due to the replications in the dataset.

We also played around with different acquisition sizes and found that on MNIST, BatchBALD can even acquire 40 points at a time with little loss of data efficiency while BALD deteriorates quickly.

(BALD) (BatchBALD) Figure 13: Performance on MNIST for increasing acquisition sizes. BALD’s performance drops drastically as the acquisition size increases. BatchBALD maintains strong performance even with increasing acquisition size.

Final thoughts

We found it quite surprising that a standard acquisition function, used widely in active learning, performed worse even compared to a random baseline, when evaluated on batches of data. We enjoyed digging into the core of the problem, trying to understand why it failed, which led to some new insights about the way we use information theory tools in the field. In many ways, the true lesson here is that when something fails — pause and think.