The recent Deepmind’s Neural Arithmetic Logic Unit(NALU) is a very neat idea. It is a simple module that enables numeracy for neural nets. Contrary to popular belief, neural nets are not very good at arithmetic and counting(if at all). If you train an adder network between 0 and 10, it will do okay if you give it 3 + 5 but won’t be able to extrapolate and will fail miserably for 1000 + 3000. Similarly, if the net is trained to count up to 10, it won’t be able to count to 20. The NALU is able to track time, perform arithmetic, translate numerical language into scalars, execute computer code, and count objects in images.

The central idea behind Neural ALU is a differentiable function that outputs 0, -1 or 1, rendering the concept of addition and subtraction trainable. The beauty also lies in the simplicity of the formula: tanh(m) * sigmoid(w) made of fundamental building blocks. If you think about it: tanh is -1 or 1, sigmoid is 0 or 1 so the product of two would be one of 0, 1, -1.

Here is the plot of the function:

This image shows what the NAC looks like:

NALU is just NAC cast into log space and back with a learned gate:

The cost function is

0.5 * (y_hat - y) ** 2

so the partial derivative dJ/dm_0 is

dJ/dm_0 = (y_hat - y) * dy_hat/dm_0 = (y_hat - y) * d(x0 * tanh(m_0) * sigmoid(w_0))/dm_0 = (y_hat - y) * x0 * dtanh(m_0) * sigmoid(w_0)

Here is a toy NALU implemented in x86(with SSE!) that uses real Intel FPU ALUs.

; Neural ALU implementation in x86_64 ; ; nasm -felf64 nalu.s ; gcc -no-pie nalu.o -o nalu -g ; %define USE_SUB 1 %define EPOCH 1_000_000 global main extern printf section .data first_fmt: db "first weight: %f, " , 0 second_fmt: db " second weight : % f " , 0xA , 0 rand_seed: dd 1 rand_max: dd - 2147483648 ; -2^31 section .bss result: resq 2 ; reserve 2 floats PRN: resq 2 w_hats: resq 2 m_hats: resq 2 xs: resd 2 tanhs: resd 2 sigms: resd 2 tmp1: resq 2 tmp2: resq 2 weights: resq 1 err: resq 2 section .text main: mov ebx , EPOCH .calc: cmp ebx , 0 je .exit dec ebx .init_rand: call rand fstp dword [ xs ] call rand fstp dword [ xs + 4 ] .tanhs_and_sigmoids: ;; first calculate tanhs and put those in tanhs finit fld dword [ m_hats ] call tanh fstp dword [ tanhs ] finit fld dword [ m_hats + 4 ] call tanh fstp dword [ tanhs + 4 ] ;; calculate sigmoids finit fld dword [ w_hats ] call si gmoid fstp dword [ si gms ] finit fld dword [ w_hats + 4 ] call si gmoid fstp dword [ si gms + 4 ] .forward_pass: movdqu xmm0 , [ tanhs ] ; move 128 bits movdqu xmm1 , [ si gms ] movq xmm2 , [ xs ] ; move 64 bits mulps xmm0 , xmm1 ; tanh * sigmoid movdqu [ weights ], xmm0 mulps xmm0 , xmm2 ; tanh * sigmoid * xs haddps xmm0 , xmm0 ; y_hat haddps xmm0 , xmm0 ; horizontal add (sum) %if USE_SUB hsubps xmm2 , xmm2 ; y = x0 - x1 hsubps xmm2 , xmm2 %else haddps xmm2 , xmm2 ; y = x0 + x1 haddps xmm2 , xmm2 %endif .calc_error: subps xmm0 , xmm2 ; xmm0 <- y_hat - y extractps eax , xmm0 , 1 mov [ err ], eax .backpropagate: finit ;; m[0] -= err * x0 * sigm0 * dtanh(m[0]); fld dword [ m_hats ] ; dtanh(m0) call dtanh fld dword [ xs ] ; x0 fmul fld dword [ err ] ; err fmul fld dword [ si gms ] ; sigm0 fmul fld dword [ m_hats ] ; dtanh(m0) fsubr fstp dword [ m_hats ] finit ;; m[1] -= err * x1 * sigm1 * dtanh(m[1]); fld dword [ m_hats + 4 ] ; dtanh(m1) call dtanh fld dword [ xs + 4 ] ; x1 fmul fld dword [ err ] ; err fmul fld dword [ si gms + 4 ] ; sigm1 fmul fld dword [ m_hats + 4 ] ; dtanh(m1) fsubr fstp dword [ m_hats + 4 ] finit ;; w[0] -= err * x0 * dsigmoid(w[0]) * tanh0; fld dword [ w_hats ] call ds igmoid fld dword [ xs ] fmul fld dword [ err ] fmul fld dword [ tanhs ] fmul fld dword [ w_hats ] fsubr fstp dword [ w_hats ] finit ;; w[1] -= err * x1 * dsigmoid(w[1]) * tanh1; fld dword [ w_hats + 4 ] call ds igmoid fld dword [ xs + 4 ] fmul fld dword [ err ] fmul fld dword [ tanhs + 4 ] fmul fld dword [ w_hats + 4 ] fsubr fstp dword [ w_hats + 4 ] .print: sub rsp , 8 ; reserve stack pointer movd xmm0 , [ weights ] ; pass result to printf via xmm0 cvtps2pd xmm0 , xmm0 ; convert float to double mov rdi , first_fmt ; printf format string mov rax , 1 ; number of varargs call printf ; call printf add rsp , 8 ; add stack pointer back sub rsp , 8 ; reserve stack pointer movd xmm0 , [ weights + 4 ] ; pass result to printf via xmm0 cvtps2pd xmm0 , xmm0 ; convert float to double mov rdi , second_fmt ; printf format string mov rax , 1 ; number of varargs call printf ; call printf add rsp , 8 ; add stack pointer back jmp .calc .exit: mov eax , 0x60 xor edi , edi syscall tanh: ; (exp(x) - exp(-1)) / (exp(x) + exp(-x)) fst dword [ tmp1 ] ; tmp1 <- x call exp ; ; exp(x) fst dword [ tmp2 ] ; tmp2 <- exp(x) fld dword [ tmp1 ] fchs call exp fst dword [ tmp1 ] ; tmp1 <- exp(-x) fld dword [ tmp2 ] fsubr fld dword [ tmp2 ] ; load exp(x) and exp(-x) fld dword [ tmp1 ] fadd fdiv ret dtanh: ; 1. - pow(tanh(x), 2.) call tanh fst dword [ tmp1 ] ; duplicate tanh on the stack fld dword [ tmp1 ] fmul ; tanh(x) * tanh(x) fld1 ; load 1 fsubr ; 1 - tanh(x) ** 2 ret sigmoid: ; 1 / (1 + exp(-x)) fchs ; -x call exp ; exp(-x) fld1 ; load 1 fadd fld1 ; load 1 fdivr ; 1 / ST(0) ret dsigmoid: ; sigmoid(x) * (1. - sigmoid(x)) call si gmoid fst dword [ tmp1 ] ; tmp <- sigmoid(x) fchs fld1 fadd fld dword [ tmp1 ] ; st(0) <- sigmoid(x) fmul ret exp: fldl2e fmulp st1 , st0 ; st0 = x*log2(e) = tmp1 fld1 fscale ; st0 = 2^int(tmp1), st1=tmp1 fxch fld1 fxch ; st0 = tmp1, st1=1, st2=2^int(tmp1) fprem ; st0 = fract(tmp1) = tmp2 f2xm1 ; st0 = 2^(tmp2) - 1 = tmp3 faddp st1 , st0 ; st0 = tmp3+1, st1 = 2^int(tmp1) fmulp st1 , st0 ; st0 = 2^int(tmp1) + 2^fract(tmp1) = 2^(x*log2(e)) ret rand: imul eax , dword [ rand_seed ], 16807 ; RandSeed *= 16807 mov dword [ rand_seed ], eax fild dword [ rand_seed ] ; load RandSeed as an integer fidiv dword [ rand_max ] ; div by max int value (absolute) = eax / (-2^31) ret

If you run this, the first tanh * sigmoid goes to 1 and second one go to -1.

Epoch l0 l1 0 0.0 0.0 50000 0.987506901824 -0.987548950867 100000 0.991264033674 -0.991189817923 150000 0.992845113954 -0.992861588357 200000 0.993821244128 -0.993813140853 250000 0.994479531604 -0.994470005826 300000 0.994956870738 -0.994965214447 350000 0.995335580972 -0.995335751094 400000 0.995641550629 -0.995639510579 450000 0.99588903762 -0.995888041575 500000 0.996102719885 -0.996098271471 550000 0.996282859485 -0.996286010814 600000 0.996444518075 -0.996441767134 650000 0.996583070776 -0.996582158171 700000 0.996711963875 -0.99670336452 750000 0.996820796932 -0.996818826574 800000 0.996921023282 -0.9969240341 850000 0.997012684359 -0.997014549213 900000 0.997100144072 -0.997097107772 950000 0.997177851616 -0.99717492668

Here is a runnable NAC toy example implemented in python:

from random import random import math def tanh ( x ): return math . tanh ( x ) def dtanh ( x ): return 1. - math . tanh ( x ) ** 2 def sigmoid ( x ): return 1 / ( 1 + math . exp ( - x )) def dsigmoid ( x ): return sigmoid ( x ) * ( 1 - sigmoid ( x )) m0 = m1 = w0 = w1 = 0.0 for i in range ( 1000000 ): x0 = random () x1 = random () y = x0 - x1 # forward pass l0 = tanh ( m0 ) * sigmoid ( w0 ) l1 = tanh ( m1 ) * sigmoid ( w1 ) y_h = l0 * x0 + l1 * x1 # calculate error e = y_h - y # backpropagation m0 -= e * x0 * sigmoid ( w0 ) * dtanh ( m0 ) m1 -= e * x1 * sigmoid ( w1 ) * dtanh ( m1 ) w0 -= e * x0 * dsigmoid ( w0 ) * tanh ( m0 ) w1 -= e * x1 * dsigmoid ( w1 ) * tanh ( m1 ) if not i % 50000 : print i , l0 , l1

You should see the neural net converge immediately.

Source