Neural Arithmetic Logic Units – getting backpropagation nets to extrapolate

Backpropagation nets have a problem doing math. You can get them to learn a multiplication table, but when you try to use the net on problems where the answers are higher or lower than the ones used in training, they fail. In theory, they should be able to extrapolate, but in practice, they memorize, instead of learning the principles behind addition, multiplication, division, etc.

A group at Google DeepMind in England solved this problem.
They did this by modifying the typical backprop neuron as follows:

  1. They removed the bias input
  2. They removed the nonlinear activation function
  3. Instead of just using one weight on each incoming connection to the neuron, they use two. Both weights are learned by gradient descent, but a sigmoid function is applied to one, a hypertangent function is applied to the other, and then they are multiplied together. In standard nets, a sigmoid or hypertangent function is not used on weights at all, instead these types of functions are used on activation.  The opposite is true here.

Here is the equation for computing the weight matrix.  W is the final weight, and the variables M and W with the hat symbols are values that are combined to create that final composite weight:

nalu2b

So what is the rationale behind all this?

First lets look at what a sigmoid function looks like:

sigmoid2

And now a hypertangent function (also known as ‘tanh’):

hypertangent2

We see that the sigmoid function ranges (on the Y axis) between 0 and 1. The hypertangent ranges from -1 to 1. Both functions have a high rate of change when their x-values are fairly close to zero, but that rate of change flattens out the farther they get from that point.

So if you multiply these two functions together, the most the product can be is 1, the least is -1, and there is a bias to the composite weight result – its less likely to be fractional, and more likely to be -1, 1, or zero.
Why the bias?
The reason is that near x = zero, the derivative being large actually indicates that the neuron would be biased to learn numbers other than that point (because it will take the biggest step sizes when the derivative is highest). Thus, tanh is biased to learn its saturation points (-1 and 1) and sigmoid is biased to learn its saturation points (0 and 1). The elementwise product of them thus has saturation points at -1, 1, and 0.

So why have a bias? As they explain:

Our first model is the neural accumulator (NAC), which is a special case of a linear (affine) layer whose transformation matrix W consists just of -1’s, 0’s, and 1’s; that is, its outputs are additions or subtractions (rather than arbitrary rescalings) of rows in the input vector. This prevents the layer from changing the scale of the representations of the numbers when mapping the input to the output, meaning that they are consistent throughout the model, no matter how many operations are chained together.

As an example, if you want the neuron to realize it has to add 5 and -7, you don’t want those numbers multiplied by fractions, rather in this case, you prefer 1 and -1. Likewise, the result of this neuron’s addition could be fed into another neuron, and again, you don’t want it multiplied by a fraction before it is combined with that neuron’s other inputs.

This isn’t always true though, one of their experiments was learning to calculate the square root, which required a weight training to the value of 0.5.

On my first read of the paper, I was sure of why the net worked, and so I asked one author: Andrew Trask, who replied that it works because:

 

  1. because it encodes numbers as real values (instead of as distributed representations)
  2. because the functions it learns over numbers extrapolate inherently (aka… addition/multiplication/division/subtraction) – so learning an attention mechanism over these functions leads to neural nets which extrapolate

 

The first point is important because many models assume that any particular number is coded by many neurons, each with different weights. In this model, one neuron, without any nonlinear function applied to its result, does math such as addition and subtraction.

It is true that real neurons are limited in the values they can represent. In fact, neurons fire at a constant, fixed amplitude and its just the frequency of pulses that increase when they get a higher input.

But ignoring that point, the units they have can extrapolate, because they do simple addition and subtraction (point #2).

But wait a minute – what about multiplication and division?

For those operations they make use of a mathematical property of logarithms. The log of (X * Y) is equal to log(X) + log(Y). So if you take logarithms of values before you feed them into an addition neuron, and then the inverse of the log of the result, you have the equivalent of multiplication.

The log is differentiable, so the net can still learn by gradient descent.

So they now need to combine the addition/subtraction neurons with the multiplication/division neurons, and this diagram shows their method:

nalu1

nalu2c

This fairly simple but clever idea is a breakthrough:

Experiments show that NALU-enhanced neural networks can learn to track time, perform arithmetic over images of numbers, translate numerical language into real-valued scalars, execute computer code, and count objects in images. In contrast to conventional architectures, we obtain substantially better generalization both inside and outside of the range of numerical values encountered during training, often extrapolating orders of magnitude beyond trained numerical ranges.

Source:
Neural Arithmetic Logic Units – Andrew Trask, Felix Hill, Scott Reed, Jack Rae, Chris Dyer, Phil Blunsom – Google DeepMind

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s