I set aside some time to figure out how to build neural networks from scratch in Clojure, without external libraries.
After a couple dependency-free versions, I ended up adding the neaderthal library to do faster matrix math. The different versions I wrote along the way are on github, in case they’re helpful for anybody else who wants to do this in Clojure.
Neural networks are surprisingly easy to get started with. There’s significantly more “magic” inside a good concurrent queue implementation, for example, than inside a basic neural network to recognize handwritten digits.
For example, here’s the “hello world” of neural networks, a widget to recognize a hand-drawn digit:
Loading JavaScript...
See this widget at github.com/matthewdowney/clojure-neural-networks-from-scratch/tree/main/mnist-scittle
And here’s the code for the pixel array -> digit
computation^{1}:
(defn sigmoid [n] (/ 1.0 (+ 1.0 (Math/exp (- n)))))
(defn feedforward [inputs weights biases]
(for [[b ws :as _neuron] (map vector biases weights)]
(let [weighted-input (reduce + (map * inputs ws))]
(sigmoid (+ b weighted-input)))))
(defn argmax [numbers]
(let [idx+val (map-indexed vector numbers)]
(first (apply max-key second idx+val))))
(defn digit [pixels]
(-> pixels (feedforward w0 b0) (feedforward w1 b1) argmax))
It’s striking that such a complicated task works without intricate code or underlying black-box libraries.^{2} I felt kind of dumb for not having known this already!
The three most helpful resources for me were:
3Blue1Brown’s video series on neural networks, with visualizations and intuitive explanations. Good initial context.
Michael Nielsen’s neural networks and deep learning tutorial, which uses Python and numpy.
Andrej Karpathy’s intro to neural networks and backpropagation, which is pure Python (no numpy), and was kind of a lifesaver for understanding backpropagation.
In retrospect, to get started, I’d recommend reading the first part of Nielsen’s tutorial, skipping to the Andrej Karpathy video, and then solving MNIST from scratch using those two things as references, before coming back to the rest of Nielsen’s material.
I also went through Dragan Djuric’s impressive and erudite deep learning from scratch to GPU tutorial series, but I can’t say I’d recommend it as an introduction to neural networks.^{3}
I’m glad I decided to start from scratch without any external libraries, including ones for matrix math.
I do, however, wish I’d watched Andrej Karpathy’s video before getting so deep into Nielsen’s tutorial, especially because of the backprop calculus^{4}, which I struggled with for a while. Karpathy’s REPL-based, algorithmic explanation was much more intuitive for me than the formal mathematical version.
My approach was to:
The training time for one epoch of MNIST was 400 seconds in the first two versions, 5 seconds in the third (on par with the Python sample), and down to 1 second in final version.
I’m glad I broke it down like this. Would do again.
Before implementing the backprop algorithm, I built some unit tests for calculating the weight and bias gradients given starting weights and biases and some training data, and this turned out to be enormously helpful. I used Nielsen’s sample Python code to generate the test vectors.
Finally, invoking numpy via libpython-clj at the REPL was useful for figuring out the equivalent neanderthal expressions.
A neuron in a neural network is just a function [inputs] -> scalar output
, where the output is a linear combination of the inputs and the neuron’s weights, summed together with the neuron’s bias, and passed to an activation function.
Much of the magic inside of neural network libraries has less to do with cleverer algorithms and more to do with vectorized SIMD instructions and/or being parsimonious with GPU memory usage and communication back and forth with main memory.
Neural networks can, theoretically, compute any function. And a more readily believable fact: with linear activation functions, no matter how many layers you add to a neural network, it simplifies to a linear transformation.
But, the activation function is not necessarily all that squiggly — ReLU is just max(0, x)
and it’s widely used.
Since I used Scittle to embed the Clojure code in this page, you can browse the source file directly. ↩
And sure, this is a rudimentary network architecture, and there’s a sense in which “the real program” is the magic weights and biases numbers in the w0
, w1
, b0
, and b1
vectors, but it turns out that you can also write the training code from scratch to find those vectors without too much trouble. ↩
It is definitely an introduction to memory reuse tricks and GPU programming, for someone who already has a strong grasp of linear algebra, and wants to reinforce or deepen existing understanding of neural networks and relevant performance optimization. Which is crucial for deep learning in practice, but is a lot to take in at first. ↩
Also, on the indexes in Nielsen’s neural network backpropagation algorithm — the style in the Python sample starting on line 101 was hard for me to parse, with negative indexes and iterations using three indexes each. I found it helpful to rewrite like this:
# compute the difference between the output and the expected output
# this is the last layer's error
error = self.cost_derivative(activations[-1], y) * sigmoid_prime(zs[-1])
# weight and bias gradient vectors, same shape as the network layers
nabla_w = []
nabla_b = [error]
# the activations list has inputs prepended, so it's longer by 1
activations_for_layer = lambda layer_idx: activations[layer_idx+1]
# iterate backwards through the layers
for layer_idx in xrange(len(self.weights), 0, -1):
# compute a change in weights using the previous layer's activation
prev_activation = activations_for_layer(layer_idx-1)
nabla_w.insert(0, np.dot(error, prev_activation.transpose()))
# if there is a previous layer, compute its error
if layer_idx > 0:
this_layer_weights = self.weights[layer_idx]
prev_layer_weighted_inputs = zs[layer_idx-1]
sp = sigmoid_prime(prev_layer_weighted_inputs)
error = np.dot(this_layer_weights.transpose(), error) * sp
nabla_b.insert(0, error)
return (nabla_b, nabla_w)