I’m still developing intuition for backpropagation myself. One of the simplest explanations I’ve seen is NanoNeuron by Oleksii Trekhleb, so I remixed it. Let’s go!
Data
The linear function we’re going to learn is f(x)=1.8x+32 which converts temperatures from Celsius to Fahrenheit.
float celsiusToFahrenheit(float x) {
float w = 1.8;
float b = 32;
float z = w * x + b;
return z;
}
The training data set will consist of two arrays for x and y values.
float[] xTrain = new float[100];
float[] yTrain = new float[100];
void generateDataset() {
for (int i = 0; i < xTrain.length; i++) {
xTrain[i] = i;
yTrain[i] = celsiusToFahrenheit(xTrain[i]);
}
}
Prediction
The neurons in the network are simple linear units. During forward propagation, each Neuron
object multiplies its input by a weight and adds a bias.
class Neuron {
float w;
float b;
Neuron() {
w = random(-1, 1);
b = 0;
}
float forwardProp(float a) {
float z = w * a + b;
return z;
}
}
The first network we’re going to train consists of a single Neuron
we’ll call Layer1
. I’m not going to apply an activation function like sigmoid, tanh, or ReLU, but I’m still going to call the output of Layer1
its activation a1
.
class Network {
Neuron Layer1;
float a1 = 0;
Network() {
Layer1 = new Neuron();
}
float forwardProp(float x) {
a1 = Layer1.forwardProp(x);
return a1;
}
}
Cost
OK, so we can call forwardProp
on a Network
object and make a prediction. But Layer1
’s weight was randomly initialized and its bias is 0 , so any prediction is probably trash. We need information on how badly a network is performing in order to nudge its weights and biases towards more optimal values. The measure of “badness” is called cost and we want to minimize it.
There are a handful of standard ways to calculate cost–we’ll use a riff on mean squared error starting with C=\frac{1}{2}(a_{1}-y)^{2} for each training example. Since we’re training on a data set with m=100 examples, the average cost after each epoch of training would be:
averageCost=\frac{1}{m} \displaystyle\sum_{i=1}^m{C_{i}}
Data
Prediction
Cost
Now how do we figure out which way to nudge our neurons? Cost is a function of the weights and biases in the network–and we want to minimize cost–so we need to move toward the minimum of the cost function.
Let’s pretend the cost function looks like the blue parabola below. If we were at the point (3, -3) , then we’d need to move left towards the minimum.
A tangent line Courtesy OpenStax
Imagine zooming into the point (3,-3) . No matter how closely you look, the orange tangent line y=2x-9 only intersects the blue line at precisely the point (3,-3) .
The slope of this particular tangent line is +2 ; this is our old friend rise over run \frac{\Delta{y}}{\Delta{x}} . At this particular point on the graph of f the slope is positive, so increasing the input x increases the output f(x) . If we want to minimize cost, then we’d better decrease x by moving to the left.
Calculus!
It’s time for a little bit of calculus. The slope of the line tangent to a function f at any given point is known as the derivative. The derivative of f with respect to a variable x is sometimes written \frac{df}{dx} or f’(x) .
What is the derivative (slope) of a constant function like f(x)=5 ?
Quadratics
We know the derivative of a linear function f(x)=mx+b is \frac{df}{dx}=m . The first result from calculus I’ll use without any explanation whatsoever is this: the derivative of a quadratic function f(x)=ax^{2} is \frac{df}{dx}=2ax . The exponent 2 drops down and multiplies the base x along with whatever coefficient a happens to be out front.
If the function also had linear and constant terms f(x)=ax^{2}+bx+c then we’d be looking at f’(x)=2ax+b . The derivative of a sum is the sum of the derivatives.
Just for kicks, find the derivative of f(x)=x^{2}-4x , then input 3 .
The Chain Rule
The second result from calculus I’m just going to use is called the chain rule. Some functions f can be written as the composition of two other functions g and h , as in f(x)=g(h(x)) . The original input x is first passed into h , then that function’s output h(x) becomes the input to g .
x -> h -> h(x) -> g -> g(h(x))
Given such a composition, the chain rule says:
f’(x)=g’(h(x))h’(x)
Yikes… let’s see that in action. I’ll rewrite
f(x)=x^{2}-4x
as
f(x)=(x-2)^{2}-4
Now if we set g(x)=x^{2}-4 and h(x)=x-2 , we have f(x)=g(h(x))=(x-2)^{2}-4 .
Alright, moment of truth.
g’(h(x))=2(x-2)
h’(x)=1
So what is g’(h(x))h’(x) ? And does it equal the derivative you found moment ago?
Backpropagation
Back to the Network
, cost is a function of both w and b , so we need to figure out how cost changes with respect to each variable. More formally, we need to find the partial derivatives \frac{\partial{C}}{\partial{w}} and \frac{\partial{C}}{\partial{b}} . Let’s follow one training example through the Network
.
a_{1}=wx+b
C=\frac{1}{2}(a_{1}-y)^{2}
It looks like C could easily be rewritten as C=\frac{1}{2}(wx+b-y)^{2} . Remember, x and y are known training examples, not variables. The quantities we’re actually changing are w and b , and we’ll consider one at a time.
\frac{\partial{C}}{\partial{w}}=2\times{\frac{1}{2}}\times{(wx+b-y)}\times{x}
\frac{\partial{C}}{\partial{b}}=2\times{\frac{1}{2}}\times{(wx+b-y)}\times{1}
Let’s simplify and streamline the notation a bit.
dw=(a_{1}-y)x
db=a_{1}-y
This is the heart of backpropagation. Here is a tiny implementation we can add to the Network
class that uses a FloatDict for convenience.
FloatDict backProp(float x, float y) {
FloatDict grad = new FloatDict();
grad.set("dw", (a1 - y) * x);
grad.set("db", a1 - y);
return grad;
}
We’ll calculate both partial derivatives for each training example (x_{i},y_{i}) and sum them up, then we’ll use their averages to update w and b like so.
w:=w-\alpha{dw}
b:=b-\alpha{db}
The partial derivatives form the Network
’s gradient which points “uphill” for any given function. Since we’re trying to minimize cost, we nudge our neurons in the opposite direction by subtracting the partial derivatives.
And that \alpha term? It’s the learning rate–a multiplier that determines how fast we go downhill toward the minimum. Turns out this hyperparameter is a big deal.
void train(float[] x, float[] y) {
int m = x.length;
averageCost = 0;
float dw = 0;
float db = 0;
for (int i = 0; i < m; i++) {
forwardProp(x[i]);
float predictionCost = 0.5 * pow(a1 - y[i], 2);
FloatDict grad = backProp(x[i], y[i]);
averageCost += predictionCost;
dw += grad.get("dw");
db += grad.get("db");
}
averageCost /= m;
dw /= m;
db /= m;
Layer1.w -= alpha * dw;
Layer1.b -= alpha * db;
}
All together now
Neuron
class Neuron {
float w;
float b;
Neuron() {
w = random(-1, 1);
b = 0;
}
float forwardProp(float a) {
float z = w * a + b;
return z;
}
}
Network
class Network {
Neuron Layer1;
float a1 = 0;
float averageCost = 0;
Network() {
Layer1 = new Neuron();
}
float forwardProp(float x) {
a1 = Layer1.forwardProp(x);
return a1;
}
FloatDict backProp(float x, float y) {
FloatDict grads = new FloatDict();
grads.set("dw", (a1 - y) * x);
grads.set("db", a1 - y);
return grads;
}
void train(float[] x, float[] y) {
int m = x.length;
averageCost = 0;
float dw = 0;
float db = 0;
for (int i = 0; i < m; i++) {
forwardProp(x[i]);
float predictionCost = 0.5 * pow(a1 - y[i], 2);
FloatDict grads = backProp(x[i], y[i]);
averageCost += predictionCost;
dw += grads.get("dw");
db += grads.get("db");
}
averageCost /= m;
dw /= m;
db /= m;
Layer1.w -= alpha * dw;
Layer1.b -= alpha * db;
}
}
Sketch
// Based on NanoNeuron by Oleksii Trekhleb
// MIT License
float celsiusToFahrenheit(float x) {
float w = 1.8;
float b = 32;
float z = w * x + b;
return z;
}
float[] xTrain = new float[100];
float[] yTrain = new float[100];
void generateDataset() {
for (int i = 0; i < xTrain.length; i++) {
xTrain[i] = i;
yTrain[i] = celsiusToFahrenheit(xTrain[i]);
}
}
Network n = new Network();
float alpha = 0.0005;
int epoch = 0;
FloatList cost = new FloatList();
void setup() {
size(100, 100);
generateDataset();
noLoop();
}
void draw() {
background(0);
for (int i = 0; i < 1000; i++) {
epoch++;
n.train(xTrain, yTrain);
cost.append(n.averageCost); // icanhaz visualization?
}
println("Epoch: " + epoch);
println("f(x) = " + n.Layer1.w + " * x + " + n.Layer1.b);
float x = 10;
float y = celsiusToFahrenheit(x);
println("Guess: " + n.forwardProp(x));
println("Expected: " + y);
println();
}
void mousePressed() {
redraw();
}