Despite doing some work and research in the ai ecosystem for some time, until recently I didn't stop to think about backpropagation and gradient updates within neural networks. This article seeks to rectify this and will hopefully provide a thorough but easy-to-follow dive into the topic by implementing a simple (but somewhat powerful) neural network framework from scratch.
Basically, a neural network is just a mathematical function from our input space to our desired output space. In fact, we can effectively “unwrap” any neural network into a function. Consider, for example, the following simple neural network with two layers and one input:
Now we can build an equivalent function going layer by layer, starting from the input. Let's follow our final function layer by layer:
- At the entry, we start with the identity function. pred(x) = x
- In the first linear layer, we get pred(x) = w₁x+b₁
- ReLU connects us pred(x) = max(0, w₁x+b₁)
- In the final layer, we get pred(x) = w₂(max(0, w₁x+b₁)) + b₂
With more complicated networks, these functions of course become unwieldy, but the point is that we can construct such neural network representations.
However, we can go one step further: functions of this form are not extremely convenient for calculation, but we can parse them into a more useful form, that is, a syntax tree. For our simple network, the tree would look like this:
In this form of tree, our leaves are parameters, constants and inputs, and the other nodes are elementary operations whose arguments are his children. Of course, these elementary operations do not have to be binary: the sigmoid operation, for example, is unary (and so is ReLU if we do not represent it as a maximum of 0 and x), and we can choose to support multiplication and addition of more than an entry.
By thinking of our network as a tree of these elementary operations, we can now do many things very easily with recursion, which will form the basis of our back-propagation and forward-propagation algorithms. In code, we can define a recursive neural network class similar to this:
from dataclasses import dataclass, field
from typing import List@dataclass
class NeuralNetNode:
"""A node in our neural network tree"""
children: List('NeuralNetNode') = field(default_factory=list)
def op(self, x: List(float)) -> float:
"""The operation that this node performs"""
raise NotImplementedError
def forward(self) -> float:
"""Evaluate this node on the given input"""
return self.op((child.forward() for child in self.children))
# This is just for convenience
def __call__(self) -> List(float):
return self.forward()
def __repr__(self):
return f'{self.__class__.__name__}({self.children})'
Now suppose we have a differentiable loss function for our neural network, say MSE. Remember that MSE (for a sample) is defined as follows:
Now we want to update our parameters (the green circles in our tree representation) given the value of our loss. To do this, we need the derivative of our loss function with respect to each parameter. However, calculating this directly from the loss is extremely difficult; After all, our MSE is calculated in terms of the value predicted by our neural network, which can be an extraordinarily complicated function.
This is where a very useful piece of mathematics comes into play: the chain rule. Instead of being forced to compute our highly complex derivatives from scratch, we can compute a series of simpler derivatives.
It turns out that the chain rule fits very well with our recursive tree structure. The idea basically works as follows: assuming we have sufficiently simple elementary operations, each elementary operation knows its derivative with respect to all of its arguments. Given the derivative of the parent operation, we can calculate the derivative of each child operation with respect to the loss function by simple multiplication. For a simple linear regression model using MSE, we can diagram it as follows:
Of course, some of our nodes don't do anything with their derivatives, that is, only our leaf nodes care. But now each node can obtain the derivative of its output with respect to the loss function through this recursive process. So we can add the following methods to our NeuralNetNode class:
def grad(self) -> List(float):
"""The gradient of this node with respect to its inputs"""
raise NotImplementedErrordef backward(self, derivative_from_parent: float):
"""Propagate the derivative from the parent to the children"""
self.on_backward(derivative_from_parent)
deriv_wrt_children = self.grad()
for child, derivative_wrt_child in zip(self.children, deriv_wrt_children):
child.backward(derivative_from_parent * derivative_wrt_child)
def on_backward(self, derivative_from_parent: float):
"""Hook for subclasses to override. Things like updating parameters"""
pass
Exercise 1: Try creating one of these trees for a simple linear regression model and do the recursive gradient updates by hand for a couple of steps.
Note: For the sake of simplicity, we require that our nodes have only one parent (or none). If each node is allowed to have multiple parents, our backoff() algorithm becomes somewhat more complicated since each child needs to add the derivative of its parent to compute its own. We can do this iteratively with a topological sort (e.g. see here) or even recursively, that is, with reverse accumulation (although in this case we would need to make a second pass to update all the parameters). This isn't extraordinarily difficult, so I'll leave it as an exercise for the reader (and I'll talk more about it in part 2, stay tuned).
Building models
The rest of our code really just involves implementing parameters, inputs and operations and of course running our training. Parameters and inputs are fairly simple constructs:
import random@dataclass
class Input(NeuralNetNode):
"""A leaf node that represents an input to the network"""
value: float=0.0
def op(self, x):
return self.value
def grad(self) -> List(float):
return (1.0)
def __repr__(self):
return f'{self.__class__.__name__}({self.value})'
@dataclass
class Parameter(NeuralNetNode):
"""A leaf node that represents a parameter to the network"""
value: float=field(default_factory=lambda: random.uniform(-1, 1))
learning_rate: float=0.01
def op(self, x):
return self.value
def grad(self):
return (1.0)
def on_backward(self, derivative_from_parent: float):
self.value -= derivative_from_parent * self.learning_rate
def __repr__(self):
return f'{self.__class__.__name__}({self.value})'
The operations are a little more complicated, although not too much: we just need to calculate their gradients correctly. Below are implementations of some useful operations:
import math@dataclass
class Operation(NeuralNetNode):
"""A node that performs an operation on its inputs"""
pass
@dataclass
class Add(Operation):
"""A node that adds its inputs"""
def op(self, x):
return sum(x)
def grad(self):
return (1.0) * len(self.children)
@dataclass
class Multiply(Operation):
"""A node that multiplies its inputs"""
def op(self, x):
return math.prod(x)
def grad(self):
grads = ()
for i in range(len(self.children)):
cur_grad = 1
for j in range(len(self.children)):
if i == j:
continue
cur_grad *= self.children(j).forward()
grads.append(cur_grad)
return grads
@dataclass
class ReLU(Operation):
"""
A node that applies the ReLU function to its input.
Note that this should only have one child.
"""
def op(self, x):
return max(0, x(0))
def grad(self):
return (1.0 if self.children(0).forward() > 0 else 0.0)
@dataclass
class Sigmoid(Operation):
"""
A node that applies the sigmoid function to its input.
Note that this should only have one child.
"""
def op(self, x):
return 1 / (1 + math.exp(-x(0)))
def grad(self):
return (self.forward() * (1 - self.forward()))
The operation superclass here is not useful yet, although we will need it to more easily find the inputs to our model later.
Notice how often function gradients require the values of their children, so we need to call the child's forward() method. We'll talk more about this in a moment.
Defining a neural network in our framework is a bit detailed but is very similar to building a tree. Here, for example, is code for a simple linear classifier in our framework:
linear_classifier = Add((
Multiply((
Parameter(),
Input()
)),
Parameter()
))
Using our models
To run a prediction with our model, we must first complete the entries in our tree and then call forward() on the parent. However, to complete the entries, we must first find them, so we add the following method to our Operation class (we don't add this to our NeuralNetNode class since the input type is not defined there yet):
def find_input_nodes(self) -> List(Input):
"""Find all of the input nodes in the subtree rooted at this node"""
input_nodes = ()
for child in self.children:
if isinstance(child, Input):
input_nodes.append(child)
elif isinstance(child, Operation):
input_nodes.extend(child.find_input_nodes())
return input_nodes
Now we can add the predict() method to the Operation class:
def predict(self, inputs: List(float)) -> float:
"""Evaluate the network on the given inputs"""
input_nodes = self.find_input_nodes()
assert len(input_nodes) == len(inputs)
for input_node, value in zip(input_nodes, inputs):
input_node.value = value
return self.forward()
Exercise 2: The current way we implement predict() is somewhat inefficient since we need to traverse the tree to find all the entries every time we run predict(). Write a compile() method that caches the inputs of the operation when it runs.
Training our models is now very simple:
from typing import Callable, Tupledef train_model(
model: Operation,
loss_fn: Callable((float, float), float),
loss_grad_fn: Callable((float, float), float),
data: List(Tuple(List(float), float)),
epochs: int=1000,
print_every: int=100
):
"""Train the given model on the given data"""
for epoch in range(epochs):
total_loss = 0.0
for x, y in data:
prediction = model.predict(x)
total_loss += loss_fn(y, prediction)
model.backward(loss_grad_fn(y, prediction))
if epoch % print_every == 0:
print(f'Epoch {epoch}: loss={total_loss/len(data)}')
Here is, for example, how we would train a linear classifier from Fahrenheit to Celsius using our framework:
def mse_loss(y_true: float, y_pred: float) -> float:
return (y_true - y_pred) ** 2def mse_loss_grad(y_true: float, y_pred: float) -> float:
return -2 * (y_true - y_pred)
def fahrenheit_to_celsius(x: float) -> float:
return (x - 32) * 5 / 9
def generate_f_to_c_data() -> List(List(float)):
data = ()
for _ in range(1000):
f = random.uniform(-1, 1)
data.append(((f), fahrenheit_to_celsius(f)))
return data
linear_classifier = Add((
Multiply((
Parameter(),
Input()
)),
Parameter()
))
train_model(linear_classifier, mse_loss, mse_loss_grad, generate_f_to_c_data())
After running this, we get
print(linear_classifier)
print(linear_classifier.predict((32)))>> Add(children=(Multiply(children=(Parameter(0.5555555555555556), Input(0.8930639016107234))), Parameter(-17.777777777777782)))
>> -1.7763568394002505e-14
Which correctly corresponds to a linear classifier with weight 0.56, bias -17.78 (which is the Fahrenheit to Celsius formula)
Of course, we can also train much more complex models, for example, here is one to predict whether a point (x, y) is above or below the line y = x:
def bce_loss(y_true: float, y_pred: float, eps: float=0.00000001) -> float:
y_pred = min(max(y_pred, eps), 1 - eps)
return -y_true * math.log(y_pred) - (1 - y_true) * math.log(1 - y_pred)def bce_loss_grad(y_true: float, y_pred: float, eps: float=0.00000001) -> float:
y_pred = min(max(y_pred, eps), 1 - eps)
return (y_pred - y_true) / (y_pred * (1 - y_pred))
def generate_binary_data():
data = ()
for _ in range(1000):
x = random.uniform(-1, 1)
y = random.uniform(-1, 1)
data.append(((x, y), 1 if y > x else 0))
return data
model_binary = Sigmoid(
(
Add(
(
Multiply(
(
Parameter(),
ReLU(
(
Add(
(
Multiply(
(
Parameter(),
Input()
)
),
Multiply(
(
Parameter(),
Input()
)
),
Parameter()
)
)
)
)
)
),
Parameter()
)
)
)
)
train_model(model_binary, bce_loss, bce_loss_grad, generate_binary_data())
Then we reasonably get
print(model_binary.predict((1, 0)))
print(model_binary.predict((0, 1)))
print(model_binary.predict((0, 1000)))
print(model_binary.predict((-5, 3)))
print(model_binary.predict((0, 0)))>> 3.7310797619230176e-66
>> 0.9997781079343139
>> 0.9997781079343139
>> 0.9997781079343139
>> 0.23791579184662365
Although it has a reasonable execution time, it is somewhat slower than we would expect. This is because we have to call forward() and recalculate the model inputs. a lot in the reverse() call. As such, perform the following exercise:
Exercise 3: Add caching to our network. That is, in the call to forward(), the model must return the cached value from the previous call to forward() if and only if the entries have not changed since the last call. Be sure to run forward() again if the inputs have changed.
And that's it! We now have a functional neural network framework in which we can train many interesting models (although not networks with nodes feeding many other nodes. This is not too difficult to add; see the note in the chain discussion). rule), although it is a bit detailed. If you want to improve it, try some of the following:
Exercise 4: If you think about it, the more “complex” nodes in our network (e.g. linear layers) are actually just “macros” in some sense, that is, if we had a neural network tree that had, for example, the following aspect:
what you are actually doing is this:
In other words, Linear(input) is actually just a macro for a tree that contains |entry| + 1 parameters, the first of which are weights in the multiplication and the last is a bias. Every time we see Linear(input)we can replace it with an equivalent tree composed only of elementary operations.
For this exercise, your job is to implement the Macro class. The class must be a Operation which is recursively replaced with elementary operations
Note: This step can be done at any time, although it is probably easier to add a compile() method to the Operation class that you call before training (or add it to your existing method from Exercise 2). Of course, we can also implement more complex nodes in other (perhaps more efficient) ways, but it's still a good exercise.
Exercise 5: Although we never actually need internal nodes to produce anything more than a number as output, sometimes it's nice for the root of our tree (i.e., our output layer) to produce something else (for example, a list of numbers in the case of a Softmax). Implement the Production class and allow it to produce a Listof(float) instead of just a float. As a bonus, try implementing SoftMax output.
Note: There are a few ways to do this. You can make the output extend the operation and then modify the op() method of the NeuralNetNode class to return a List (float) instead of just a float. Alternatively, you can create a new Node superclass that extends both Output and Operation. This is probably easier.
Note also that although these outputs can produce lists, they will only get a derivative of the loss function: the loss function will simply take a list of floats instead of a float (e.g. categorical cross entropy loss).
Exercise 6: Remember that earlier in the article we said that neural networks are just mathematical functions composed of elementary operations? Add the funcify() method to the NeuralNetNode class that converts it to a function written in human-readable notation (add parentheses as you wish). For example, the neural network Add((Parameter(0.1), Parameter(0.2))) should collapse to “0.1 + 0.2” (or “(0.1 + 0.2)”).
Note: For this to work, the entries must have a name. If you did exercise 2, name your inputs in the compile() function. Otherwise, you'll have to find a way to name your entries; writing a compile() function is still probably the easiest way.
Exercise 7: Modify our framework to allow nodes to have multiple parents. I will solve this in part 2.
That is all for now! If you want to check the code, you can check this google collaboration that has it all (except the solutions for all exercises except 6, although I may add them in part 2).
Contact me at [email protected] for any questions.
Unless otherwise specified, all images are the author's.