Notes on neural networks from scratch in Clojure

May 31, 2023   

I set aside some time to figure out how to build neural networks from scratch in Clojure, without external libraries.

After a couple dependency-free versions, I ended up adding the neaderthal library to do faster matrix math. The different versions I wrote along the way are on github, in case they’re helpful for anybody else who wants to do this in Clojure.

First impressions and hello world

Neural networks are surprisingly easy to get started with. There’s significantly more “magic” inside a good concurrent queue implementation, for example, than inside a basic neural network to recognize handwritten digits.

For example, here’s the “hello world” of neural networks, a widget to recognize a hand-drawn digit:


Loading JavaScript...

See this widget at github.com/matthewdowney/clojure-neural-networks-from-scratch/tree/main/mnist-scittle



And here’s the code for the pixel array -> digit computation1:

(defn sigmoid [n] (/ 1.0 (+ 1.0 (Math/exp (- n)))))

(defn feedforward [inputs weights biases]
  (for [[b ws :as _neuron] (map vector biases weights)]
    (let [weighted-input (reduce + (map * inputs ws))]
      (sigmoid (+ b weighted-input)))))

(defn argmax [numbers]
  (let [idx+val (map-indexed vector numbers)]
    (first (apply max-key second idx+val))))

(defn digit [pixels]
  (-> pixels (feedforward w0 b0) (feedforward w1 b1) argmax))

It’s striking that such a complicated task works without intricate code or underlying black-box libraries.2 I felt kind of dumb for not having known this already!

Resources

The three most helpful resources for me were:

In retrospect, to get started, I’d recommend reading the first part of Nielsen’s tutorial, skipping to the Andrej Karpathy video, and then solving MNIST from scratch using those two things as references, before coming back to the rest of Nielsen’s material.

I also went through Dragan Djuric’s impressive and erudite deep learning from scratch to GPU tutorial series, but I can’t say I’d recommend it as an introduction to neural networks.3

Approach in retrospect

I’m glad I decided to start from scratch without any external libraries, including ones for matrix math.

I do, however, wish I’d watched Andrej Karpathy’s video before getting so deep into Nielsen’s tutorial, especially because of the backprop calculus4, which I struggled with for a while. Karpathy’s REPL-based, algorithmic explanation was much more intuitive for me than the formal mathematical version.

My approach was to:

  1. First, build a neural network for the MNIST problem with no matrix math (nn_01.clj),
  2. Then, create a version with handwritten matrix math,
  3. Eventually, add the neanderthal library for matrix math in a third version,
  4. Finally, enhance performance with batch training in the fourth version.

The training time for one epoch of MNIST was 400 seconds in the first two versions, 5 seconds in the third (on par with the Python sample), and down to 1 second in final version.

I’m glad I broke it down like this. Would do again.

Before implementing the backprop algorithm, I built some unit tests for calculating the weight and bias gradients given starting weights and biases and some training data, and this turned out to be enormously helpful. I used Nielsen’s sample Python code to generate the test vectors.

Finally, invoking numpy via libpython-clj at the REPL was useful for figuring out the equivalent neanderthal expressions.

Basic things that I should have already known but didn’t

  • A neuron in a neural network is just a function [inputs] -> scalar output, where the output is a linear combination of the inputs and the neuron’s weights, summed together with the neuron’s bias, and passed to an activation function.

  • Much of the magic inside of neural network libraries has less to do with cleverer algorithms and more to do with vectorized SIMD instructions and/or being parsimonious with GPU memory usage and communication back and forth with main memory.

  • Neural networks can, theoretically, compute any function. And a more readily believable fact: with linear activation functions, no matter how many layers you add to a neural network, it simplifies to a linear transformation.

  • But, the activation function is not necessarily all that squiggly — ReLU is just max(0, x) and it’s widely used.

  1. Since I used Scittle to embed the Clojure code in this page, you can browse the source file directly. 

  2. And sure, this is a rudimentary network architecture, and there’s a sense in which “the real program” is the magic weights and biases numbers in the w0, w1, b0, and b1 vectors, but it turns out that you can also write the training code from scratch to find those vectors without too much trouble. 

  3. It is definitely an introduction to memory reuse tricks and GPU programming, for someone who already has a strong grasp of linear algebra, and wants to reinforce or deepen existing understanding of neural networks and relevant performance optimization. Which is crucial for deep learning in practice, but is a lot to take in at first. 

  4. Also, on the indexes in Nielsen’s neural network backpropagation algorithm — the style in the Python sample starting on line 101 was hard for me to parse, with negative indexes and iterations using three indexes each. I found it helpful to rewrite like this:

    # compute the difference between the output and the expected output
    # this is the last layer's error
    error = self.cost_derivative(activations[-1], y) * sigmoid_prime(zs[-1])
    
    # weight and bias gradient vectors, same shape as the network layers
    nabla_w = []
    nabla_b = [error]
    
    # the activations list has inputs prepended, so it's longer by 1
    activations_for_layer = lambda layer_idx: activations[layer_idx+1]
    
    # iterate backwards through the layers
    for layer_idx in xrange(len(self.weights), 0, -1):
    
        # compute a change in weights using the previous layer's activation
        prev_activation = activations_for_layer(layer_idx-1)
        nabla_w.insert(0, np.dot(error, prev_activation.transpose()))
    
        # if there is a previous layer, compute its error
        if layer_idx > 0:
            this_layer_weights = self.weights[layer_idx]
            prev_layer_weighted_inputs = zs[layer_idx-1]
    
            sp = sigmoid_prime(prev_layer_weighted_inputs)
            error = np.dot(this_layer_weights.transpose(), error) * sp
            nabla_b.insert(0, error)
    
    return (nabla_b, nabla_w)