Star Coffee Writing about stuff, working in public

Simple NN Showcase

By Michal Jagodzinski - April 6th, 2023

simple_nn is a simple neural network framework, not much more to it. The functionality of simple_nn was inspired by PyTorch and tinygrad. I wanted the ability for low-level control of the training process, allowing users to specify exactly how their model is trained.

I did the majority of the work for this project a couple of months ago, but I did not feel that it was complete enough to showcase on my Substack. But I do want to showcase it on here, and similar to SAT, document the further work I do on it.

Simple Example: The XOR Problem

using LinearAlgebra, Distributions, Plots, Random, Zygote, MLDatasets
Random.seed!(69420)

Defining data:

data = [
    [0.0, 0.0],
    [0.0, 1.0],
    [1.0, 0.0],
    [1.0, 1.0]
]

targets = [
    [0.0, 1.0],
    [1.0, 0.0],
    [1.0, 0.0],
    [0.0, 1.0]
]

Setting up the network and optimizer:

setup = [
    DenseLayer((2, 2), sigmoid_activation),
    DenseLayer((2, 2), sigmoid_activation)
]

xor_net = CreateNetwork(setup, datatype=Float32, init_distribution=Normal());

OptimizerSetup!(xor_net, GradientDescentOptimizer!, learning_rate=0.01)
Dict{String, Any} with 3 entries:
  "optimizer" => GradientDescentOptimizer!
  "optimizer_name" => "GradientDescentOptimizer!"
  "learning_rate" => 0.01

Training loop:

epochs = 250

xor_loss = 0.0
xor_losses = []

cross_entropy_loss(x_in, y_val) = -sum(y_val .* log.(Forward(xor_net, x_in)))

for epoch in 1:epochs
    for (i, input) in enumerate(data)
        x_in = convert(Vector{Float32}, input)
        y_val = targets[i]

        pred = Forward(xor_net, x_in)
        global xor_loss = cross_entropy_loss(x_in, y_val)

        Backward!(xor_net, cross_entropy_loss, x_in, y_val)

        if epoch % 50 == 0 && i == length(targets)
            println("Epoch $epoch\tPred: $(round(maximum(pred), digits=4))\tTarget: $(maximum(y_val))\tloss: $(round(xor_loss, digits=4))")
        end
    end
    push!(xor_losses, xor_loss)
end

preds = round.(maximum.([Forward(xor_net, convert(Vector{Float32}, x)) for x in data]))
passed = preds .== convert(Vector{Float32}, maximum.(targets))

println("\n$(sum(passed))/4 tests passed | Accuracy $(100 * sum(passed) / length(targets))%")
Epoch 50	Pred: 0.7142	Target: 1.0	loss: 0.8478
Epoch 100	Pred: 0.7681	Target: 1.0	loss: 0.5518
Epoch 150	Pred: 0.8103	Target: 1.0	loss: 0.3888
Epoch 200	Pred: 0.8428	Target: 1.0	loss: 0.2921
Epoch 250	Pred: 0.8681	Target: 1.0	loss: 0.2302

4/4 tests passed | Accuracy 100.0%
plot(
    xor_losses,
    xlabel="Epoch", ylabel="Cross-Entropy Loss", label="",
    size=(800,500), dpi=300
)

Comparing Optimizers

simple_nn has a couple of built-in optimizers. To compare them, we'll revisit the XOR problem. Defining the networks for each optimizer:

optimizers = [
    GradientDescentOptimizer!,
    MomentumOptimizer!,
    RMSpropOptimizer!,
    AdamOptimizer!
]

xor_networks = [
    CreateNetwork(setup, datatype=Float32, init_distribution=Normal()) for _ in 1:length(optimizers)
]

for (i, opt) in enumerate(optimizers)
    OptimizerSetup!(xor_networks[i], opt)
end

Training loop:

opt_loss = []
opt_losses = []

loss_funcs = [
    (x_in, y_val) -> -sum(y_val .* log.(Forward(net, x_in))) for net in xor_networks
]

for epoch in 1:epochs
    for (i, input) in enumerate(data)
        x_in = convert(Vector{Float32}, input)
        y_val = targets[i]

        global opt_loss = []

        for i in 1:length(optimizers)
            push!(opt_loss, loss_funcs[i](x_in, y_val))
            Backward!(xor_networks[i], loss_funcs[i], x_in, y_val)
        end
    end
    push!(opt_losses, opt_loss)
end

Plotting the losses of each optimizer:

loss_gd = [loss[1] for loss in opt_losses]
loss_mgd = [loss[2] for loss in opt_losses]
loss_rms = [loss[3] for loss in opt_losses]
loss_adam = [loss[4] for loss in opt_losses]

plot(loss_gd, xlabel="Epoch", ylabel="Cross-Entropy Loss", label="Gradient Descent", size=(800,500), dpi=300)
plot!(loss_mgd, label="Momentum")
plot!(loss_rms, label="RMSprop")
plot!(loss_adam, label="ADAM")

MNIST

Importing data:

train_x, train_y = MNIST(split=:train)[:]
train_x = Float32.(train_x)

test_x, test_y = MNIST(split=:test)[:]
test_x = Float32.(test_x)

Defining some helper functions:

flatten(matrix) = vcat(matrix...)

function one_hot_encoding(target)
    return Float32.(target .== collect(0:9))
end

Defining the network:

mnist_network = CreateNetwork([
    DenseLayer((784, 128), sigmoid_activation),
    DenseLayer((128, 64), sigmoid_activation),
    DenseLayer((64, 10), softmax_activation)
], datatype=Float32, init_distribution=Normal())

OptimizerSetup!(mnist_network, AdamOptimizer!);

Training loop:

epochs = 1000

mnist_loss = 0.0

batch_size = 32
batch_losses = []
validation_accuracies = []

cross_entropy_loss(x_in, y_val) = -sum(y_val .* log.(Forward(mnist_network, x_in)))

for epoch in 1:epochs
    batch_idx = rand((1:size(train_x, 3)), batch_size)

    batch_x = train_x[1:end, 1:end, batch_idx]
    batch_y = train_y[batch_idx]

    batch_loss = []

    for i in 1:batch_size
        x_in = convert(Vector{Float32}, flatten(batch_x[1:end, 1:end, i]))
        y_val = one_hot_encoding(batch_y[i])

        pred = Forward(mnist_network, x_in)
        global mnist_loss = cross_entropy_loss(x_in, y_val)

        Backward!(mnist_network, cross_entropy_loss, x_in, y_val)

        push!(batch_loss, mnist_loss)
    end

    push!(batch_losses, sum(batch_loss) / length(batch_loss))

    if epoch % 5 == 0
        val_batch = 32
        val_batch_idx = rand((1:size(test_x, 3)), val_batch)

        test_predictions = argmax.([
            Forward(mnist_network, convert(Vector{Float32}, flatten(test_x[1:end, 1:end, idx]))) for idx in val_batch_idx
        ]) .- 1

        test_correct = test_predictions .== test_y[val_batch_idx]
        val_accuracy = 100 * sum(test_correct) / val_batch

        push!(validation_accuracies, val_accuracy)
    end
end
p1 = plot(batch_losses, yaxis=:log, label="Batch Loss")
p2 = plot(validation_accuracies, label="Validation Accuracy")

plot(p1, p2, layout=(1,2), size=(800,400), dpi=300)

Classifying a small sample of images:

samps = rand(1:size(test_x, 3), 4)

test_preds = [
    Forward(mnist_network, convert(Vector{Float32}, flatten(test_x[1:end, 1:end, samp]))) for samp in samps
]

preds = [findmax(pred)[2] - 1 for pred in test_preds]

plots = []
for (i, pred) in enumerate(preds)
    temp = heatmap(
        test_x[1:end, 1:end, samps[i]]',
        yflip=true,
        title="Target = $(test_y[samps[i]]) | Prediction = $pred",
    )

    push!(plots, temp)
end

plot(plots..., layout=(2,2), size=(700,600), dpi=300)
// Image matching '/assets/posts/simple-nn-showcase/code/mnist-samp-plot' not found. //

Overall accuracy:

test_predictions = argmax.([
    Forward(mnist_network, convert(Vector{Float32}, flatten(test_x[1:end, 1:end, i]))) for i in 1:size(test_x, 3)
]) .- 1

test_correct = test_predictions .== test_y

println("Accuracy: $(round(100 * sum(test_correct) / length(test_correct), digits=2))%")
UndefVarError: `test_x` not defined

Wrapping Up

Hope you enjoyed this small showcase of simple_nn. I am planning on working on it some more. I already have a somewhat functional implementation of convolutional neural networks, but it does not work that well with Zygote.jl, the autodifferentiation library I use. I created this project for my own learning, I do not expect anyone to actually use this. But it was fun to work on, and it definitely helped me understand the neural networks a lot more.

Thanks for reading! Until next time.

CC BY-SA 4.0 Michal Jagodzinski. Last modified: September 17, 2023.
Website built with Franklin.jl and the Julia programming language.