HALP: High-Accuracy Low-Precision Training

by Chris De Sa, Megan Leszczynski, Jian Zhang, Alana Marzoev, Chris Aberger, Kunle Olukotun, and Chris Ré

Using fewer bits of precision to train machine learning models limits training accuracy—or does it? This post describes cases in which we can get high-accuracy solutions using low-precision computation via a technique called bit recentering, and our theory to explain what's going on.

Low-precision computation has been gaining a lot of traction in machine learning. Companies have even started developing new hardware architectures that natively support and accelerate low-precision operations including Microsoft's Project Brainwave and Google's TPU. Even though using low precision can have a lot of systems benefits, low-precision methods have been used primarily for inference—not for training. Previous low-precision training algorithms suffered from a fundamental tradeoff: when calculations use fewer bits, more round-off error is added, which limits training accuracy. According to conventional wisdom, this tradeoff limits practitioners' ability to deploy low-precision training algorithms in their systems. But is this tradeoff really fundamental? Is it possible to design algorithms that use low precision without it limiting their accuracy?

It turns out that yes, it is sometimes possible to get high-accuracy solutions from low-precision training—and here we'll describe a new variant of stochastic gradient descent (SGD) called high-accuracy low precision (HALP) that can do it. HALP can do better than previous algorithms because it reduces the two sources of noise that limit the accuracy of low-precision SGD: gradient variance and round-off error.

To reduce noise from gradient variance, HALP uses a known technique called stochastic variance-reduced gradient (SVRG). SVRG periodically uses full gradients to decrease the variance of the gradient samples used in SGD.

To reduce noise from quantizing numbers into a low-precision representation, HALP uses a new technique we call bit centering. The intuition behind bit centering is that as we get closer to the optimum, the gradient gets smaller in magnitude and in some sense carries less information, so we should be able to compress it. By dynamically re-centering and re-scaling our low-precision numbers, we can lower the quantization noise as the algorithm converges.

Why was low-precision SGD limited?

First, to set the stage: we want to solve training problems of the form \[ \text{maximize } f(w) = \frac{1}{N} \sum_{i=1}^N f_i(w) \text{ over } w \in \mathbb{R}^d. \] This is the classic empirical risk minimization problem used to train many machine learning models, including deep neural networks. One standard way of solving this is with stochastic gradient descent, which is an iterative algorithm that approaches the optimum by running \[ w_{t+1} = w_t - \alpha

abla f_{i_t}(w_t) \] where \( i_t \) is an index randomly chosen from \( \{1, \ldots, N\} \) at each iteration. We want to run an algorithm like this, but to make the iterates \( w_t \) low-precision. That is, we want them to use fixed-point arithmetic with a small number of bits, typically 8 or 16 bits (this is small compared with the 32-bit or 64-bit floating point numbers that are standard for these algorithms). But when this is done directly to the SGD update rule, we run into a representation problem: the solution to the problem \( w^* \) may not be representable in the chosen fixed-point representation. For example, if we use an 8-bit fixed-point representation that can store the integers \( \{ -128, -127, \ldots, 127 \} \), and the true solution is \( w^* = 100.5 \) then we can't get any closer than a distance of \( 0.5 \) to the solution since we can't even represent non-integers. Beyond this, the round-off error that results from converting the gradients to fixed-point can slow down convergence. These effects together limit the accuracy of low-precision SGD.

Bit Centering

When we are running SGD, in some sense what we are actually doing is averaging (or summing up) a bunch of gradient samples. The key idea behind bit centering is as the gradients become smaller, we can average them with less error using the same number of bits. To see why, think about averaging a bunch of numbers in \([-100, 100]\) and compare this to averaging a bunch of numbers in \([-1, 1]\). In the former case, we'd need to choose a fixed-point representation that can cover the entire range \([-100, 100]\) (for example, \( \{ -128, -127, \ldots, 126, 127 \} \)), while in the latter case, we can choose one that covers \([-1, 1]\) (for example, \( \{ -\frac{128}{127}, -\frac{127}{127}, \ldots, \frac{126}{127}, \frac{127}{127} \} \)). This means that with a fixed number of bits, the delta, the difference between adjacent representable numbers, is smaller in the latter case than in the former: as a consequence, the round-off error will also be lower.

This key idea gives us a key insight. To average the numbers in range \([-1, 1]\) with less error than the ones in \([-100, 100]\), we needed to use a different fixed-point representation. This insight suggests that we should dynamically update the low-precision representation: as the gradients get smaller, we should use fixed-point numbers that have a smaller delta and cover a smaller range.

But how do we know how to update our representation? What range do we need to cover? Well, if our objective is strongly convex with parameter \( \mu \), then whenever we take a full gradient at some point \( w \), we can bound the location of the optimum with \[ \| w - w^* \| \le \frac{1}{\mu} \|

abla f(w) \|. \] This inequality gives us a range of values in which the solution can be located, and so whenever we compute a full gradient, we can re-center and re-scale the low-precision representation to cover this range. This process is illustrated in the following figure.

We call this operation bit centering. Note that even if our objective is not strongly convex, we can still perform bit-centering: now the parameter \( \mu \) becomes a hyperparameter of the algorithm. With periodic bit centering, as an algorithm converges, the quantization error decreases—and it turns out that this can let it converge to arbitrarily accurate solutions.

HALP

HALP is our algorithm which runs SVRG and uses bit centering with a full gradient at every epoch to update the low-precision representation. The full details and algorithm statement are in the paper; here, we'll just present an overview of those results. First, we showed that for strongly convex, Lipschitz smooth functions (this is the standard setting under which the convergence rate of SVRG was originally analyzed), as long as the number of bits \( b \) we use satisfies \[ 2^b > O\left(\kappa \sqrt{d} \right) \] where \( \kappa \) is the condition number of the problem, then for an appropriate setting of the step size and epoch length (details for how to set these are in the paper), HALP will converge at a linear rate to arbitrarily accurate solutions. More explicitly, for some \( 0 < \gamma < 1 \), \[ \mathbf{E}\left[ f(\tilde w_{K+1}) - f(w^*) \right] \le \gamma^K \left( f(\tilde w_1) - f(w^*) \right) \] where \( \tilde w_{K+1} \) denotes the value of the iterate after the K-th epoch. We can see this happening in the following figure.

This figure evaluates HALP on linear regression on a synthetic dataset with 100 features and 1000 examples. It compares it with base full-precision SGD and SVRG, low-precision SGD (LP-SGD), and a low-precision version of SVRG without bit centering (LP-SVRG). Notice that HALP converges to very high-accuracy solutions even with only 8-bits (although it is eventually limited by floating-point error). In this case HALP converges to an even higher-accuracy solution than full-precision SVRG because HALP uses less floating-point arithmetic and therefore is less sensitive to floating-point inaccuracy.

...and there's more!

This was only a selection of results: there's a lot more in the paper.

We showed that HALP matches SVRG's convergence trajectory--even for Deep learning models.

We implemented HALP efficiently, and showed that it can run up to \( 4 \times \) faster than full-precision SVRG on the CPU .

. We also implemented HALP in TensorQuant, a deep learning library, and showed that it can exceed the validation performance of plain low-precision SGD on some deep learning tasks.

The obvious but exciting next step is to implement HALP efficiently on low-precision hardware, following up on our work for the next generation of compute architectures (at ISCA 2017).