the last question

MNIST from Scratch

Dec 28, 2023

So I had this idea that I should really understand what PyTorch is doing under the hood. Like actually understand it, not just know that .backward() computes gradients somehow.

The plan was to write the exact same neural network twice. Once with PyTorch like a normal person, once with pure NumPy where I compute every single gradient by hand. See if I could match the performance.

Spoiler: I beat PyTorch.


Flatten the 28x28 images to 784 pixels, split into train/val, shuffle everything.

!wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

X = np.frombuffer(gzip.decompress(train_images), dtype=np.uint8)[16:].reshape(-1, 784)
X_test = np.frombuffer(gzip.decompress(test_images), dtype=np.uint8)[16:].reshape(-1, 784)
Y = np.frombuffer(gzip.decompress(train_labels), dtype=np.uint8)[8:]
Y_test = np.frombuffer(gzip.decompress(test_labels), dtype=np.uint8)[8:]

The Easy Way

Built the most basic 2-layer network you can imagine:

class BroNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.block = nn.Sequential(*[
            nn.Linear(28 * 28, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, 10),
        ])

784 inputs → 128 hidden → 10 outputs. ReLU in the middle. That’s it.

for iter in (t:= trange(max_iters)):
    xb, yb = get_batch('train')
    xb, yb = torch.tensor(xb).float(), torch.tensor(yb)
    pred, loss = model(xb, yb)
    cat = torch.argmax(pred, dim=1)
    accuracy = (cat == yb).float().mean()
    optim.zero_grad(set_to_none=True)
    loss.backward()
    optim.step()

10,000 iterations with Adam. Got 96.4% test accuracy. Pretty standard result, nothing exciting.


The Hard Way

Now for the fun part. No PyTorch, no automatic differentiation. Just numpy and the chain rule.

First, initialize the weights.

w1 = np.random.randn(784,128).astype(np.float32) * (2/784)**2
b1 = np.random.randn(128).astype(np.float32) * 0.01
w2 = np.random.randn(128, 10).astype(np.float32) * 0.01
b2 = np.random.randn(10).astype(np.float32) * 0

Forward pass is straight matmul.

xw1 = xb.dot(w1) + b1
act = np.maximum(0,xw1)  # ReLU
logits = act.dot(w2) + b2

Write softmax and cross-entropy to compute the loss.

counts = np.exp(logits) 
counts_sum = counts.sum(axis=1, keepdims=True)
counts_sum_inv = 1 / counts_sum
probs = counts * counts_sum_inv
llh = np.log(probs)
loss = -llh[np.arange(probs.shape[0]),yb].mean()

This is where usually you would just call .backward() and move on with your life.


The backward pass is where you really learn what’s happening. Starting from the loss and working backwards.

Cross-entropy gradients are the most annoying.

dllh = np.zeros_like(llh)
dllh[np.arange(probs.shape[0]),yb] = - 1 / llh[np.arange(probs.shape[0]),yb].shape[0]
dprobs = 1 / probs * dllh
dcounts_sum_inv = (counts * dprobs).sum(1, keepdims=True)
dcounts = counts_sum_inv * dprobs
dcounts_sum = (-1/counts_sum**2) * dcounts_sum_inv
dcounts += dcounts_sum.copy()
dlogits = counts * dcounts

Then backprop through the second layer.

dact = dlogits.dot(w2.T)
dw2 = act.T.dot(dlogits)
db2 = dlogits.sum(0)

The ReLU derivative is the simplest. Just zero out gradients where the activation was negative.

dxw1 = (act>0).astype(np.float32) * dact  # ReLU derivative

Finally, gradients for the first layer:

dw1 = xb.T.dot(dxw1)
db1 = dxw1.sum(0)

Update the weights with basic SGD:

w1 += -lr * dw1
w2 += -lr * dw2
b1 += -lr * db1
b2 += -lr * db2

Results

After 10,000 iterations, my hand-coded implementation got 97.7% test accuracy:

logits = np.maximum(0,X_test.dot(w1) + b1).dot(w2) + b2
(np.argmax(logits, axis=1) == Y_test).mean()
## loss.backward()     :  0.964 
## backprop by hand    :  0.977 #swole_doge

Why?

Look, I’m not saying one should implement backprop by hand for production models. That would be insane. The point was to understand exactly what’s happening at each step. What PyTorch actually does when you call .backward(). It isn’t magic, just math. No black boxes. You see that “deep learning” is really just careful bookkeeping of matmuls and derivatives.