By Michal Jagodzinski - April 6th, 2023
simple_nn
is a simple neural network framework, not much more to it. The functionality of simple_nn
was inspired by PyTorch and tinygrad. I wanted the ability for low-level control of the training process, allowing users to specify exactly how their model is trained.
I did the majority of the work for this project a couple of months ago, but I did not feel that it was complete enough to showcase on my Substack. But I do want to showcase it on here, and similar to SAT, document the further work I do on it.
using LinearAlgebra, Distributions, Plots, Random, Zygote, MLDatasets
Random.seed!(69420)
Defining data:
data = [
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0]
]
targets = [
[0.0, 1.0],
[1.0, 0.0],
[1.0, 0.0],
[0.0, 1.0]
]
Setting up the network and optimizer:
setup = [
DenseLayer((2, 2), sigmoid_activation),
DenseLayer((2, 2), sigmoid_activation)
]
xor_net = CreateNetwork(setup, datatype=Float32, init_distribution=Normal());
OptimizerSetup!(xor_net, GradientDescentOptimizer!, learning_rate=0.01)
Dict{String, Any} with 3 entries:
"optimizer" => GradientDescentOptimizer!
"optimizer_name" => "GradientDescentOptimizer!"
"learning_rate" => 0.01
Training loop:
epochs = 250
xor_loss = 0.0
xor_losses = []
cross_entropy_loss(x_in, y_val) = -sum(y_val .* log.(Forward(xor_net, x_in)))
for epoch in 1:epochs
for (i, input) in enumerate(data)
x_in = convert(Vector{Float32}, input)
y_val = targets[i]
pred = Forward(xor_net, x_in)
global xor_loss = cross_entropy_loss(x_in, y_val)
Backward!(xor_net, cross_entropy_loss, x_in, y_val)
if epoch % 50 == 0 && i == length(targets)
println("Epoch $epoch\tPred: $(round(maximum(pred), digits=4))\tTarget: $(maximum(y_val))\tloss: $(round(xor_loss, digits=4))")
end
end
push!(xor_losses, xor_loss)
end
preds = round.(maximum.([Forward(xor_net, convert(Vector{Float32}, x)) for x in data]))
passed = preds .== convert(Vector{Float32}, maximum.(targets))
println("\n$(sum(passed))/4 tests passed | Accuracy $(100 * sum(passed) / length(targets))%")
UndefVarError: `Params` not defined
plot(
xor_losses,
xlabel="Epoch", ylabel="Cross-Entropy Loss", label="",
size=(800,500), dpi=300
)
simple_nn
has a couple of built-in optimizers. To compare them, we'll revisit the XOR problem. Defining the networks for each optimizer:
optimizers = [
GradientDescentOptimizer!,
MomentumOptimizer!,
RMSpropOptimizer!,
AdamOptimizer!
]
xor_networks = [
CreateNetwork(setup, datatype=Float32, init_distribution=Normal()) for _ in 1:length(optimizers)
]
for (i, opt) in enumerate(optimizers)
OptimizerSetup!(xor_networks[i], opt)
end
Training loop:
opt_loss = []
opt_losses = []
loss_funcs = [
(x_in, y_val) -> -sum(y_val .* log.(Forward(net, x_in))) for net in xor_networks
]
for epoch in 1:epochs
for (i, input) in enumerate(data)
x_in = convert(Vector{Float32}, input)
y_val = targets[i]
global opt_loss = []
for i in 1:length(optimizers)
push!(opt_loss, loss_funcs[i](x_in, y_val))
Backward!(xor_networks[i], loss_funcs[i], x_in, y_val)
end
end
push!(opt_losses, opt_loss)
end
Plotting the losses of each optimizer:
loss_gd = [loss[1] for loss in opt_losses]
loss_mgd = [loss[2] for loss in opt_losses]
loss_rms = [loss[3] for loss in opt_losses]
loss_adam = [loss[4] for loss in opt_losses]
plot(loss_gd, xlabel="Epoch", ylabel="Cross-Entropy Loss", label="Gradient Descent", size=(800,500), dpi=300)
plot!(loss_mgd, label="Momentum")
plot!(loss_rms, label="RMSprop")
plot!(loss_adam, label="ADAM")
Importing data:
train_x, train_y = MNIST(split=:train)[:]
train_x = Float32.(train_x)
test_x, test_y = MNIST(split=:test)[:]
test_x = Float32.(test_x)
Defining some helper functions:
flatten(matrix) = vcat(matrix...)
function one_hot_encoding(target)
return Float32.(target .== collect(0:9))
end
Defining the network:
mnist_network = CreateNetwork([
DenseLayer((784, 128), sigmoid_activation),
DenseLayer((128, 64), sigmoid_activation),
DenseLayer((64, 10), softmax_activation)
], datatype=Float32, init_distribution=Normal())
OptimizerSetup!(mnist_network, AdamOptimizer!);
Training loop:
epochs = 1000
mnist_loss = 0.0
batch_size = 32
batch_losses = []
validation_accuracies = []
cross_entropy_loss(x_in, y_val) = -sum(y_val .* log.(Forward(mnist_network, x_in)))
for epoch in 1:epochs
batch_idx = rand((1:size(train_x, 3)), batch_size)
batch_x = train_x[1:end, 1:end, batch_idx]
batch_y = train_y[batch_idx]
batch_loss = []
for i in 1:batch_size
x_in = convert(Vector{Float32}, flatten(batch_x[1:end, 1:end, i]))
y_val = one_hot_encoding(batch_y[i])
pred = Forward(mnist_network, x_in)
global mnist_loss = cross_entropy_loss(x_in, y_val)
Backward!(mnist_network, cross_entropy_loss, x_in, y_val)
push!(batch_loss, mnist_loss)
end
push!(batch_losses, sum(batch_loss) / length(batch_loss))
if epoch % 5 == 0
val_batch = 32
val_batch_idx = rand((1:size(test_x, 3)), val_batch)
test_predictions = argmax.([
Forward(mnist_network, convert(Vector{Float32}, flatten(test_x[1:end, 1:end, idx]))) for idx in val_batch_idx
]) .- 1
test_correct = test_predictions .== test_y[val_batch_idx]
val_accuracy = 100 * sum(test_correct) / val_batch
push!(validation_accuracies, val_accuracy)
end
end
p1 = plot(batch_losses, yaxis=:log, label="Batch Loss")
p2 = plot(validation_accuracies, label="Validation Accuracy")
plot(p1, p2, layout=(1,2), size=(800,400), dpi=300)
Classifying a small sample of images:
samps = rand(1:size(test_x, 3), 4)
test_preds = [
Forward(mnist_network, convert(Vector{Float32}, flatten(test_x[1:end, 1:end, samp]))) for samp in samps
]
preds = [findmax(pred)[2] - 1 for pred in test_preds]
plots = []
for (i, pred) in enumerate(preds)
temp = heatmap(
test_x[1:end, 1:end, samps[i]]',
yflip=true,
title="Target = $(test_y[samps[i]]) | Prediction = $pred",
)
push!(plots, temp)
end
plot(plots..., layout=(2,2), size=(700,600), dpi=300)
Overall accuracy:
test_predictions = argmax.([
Forward(mnist_network, convert(Vector{Float32}, flatten(test_x[1:end, 1:end, i]))) for i in 1:size(test_x, 3)
]) .- 1
test_correct = test_predictions .== test_y
println("Accuracy: $(round(100 * sum(test_correct) / length(test_correct), digits=2))%")
UndefVarError: `test_x` not defined
Hope you enjoyed this small showcase of simple_nn
. I am planning on working on it some more. I already have a somewhat functional implementation of convolutional neural networks, but it does not work that well with Zygote.jl
, the autodifferentiation library I use. I created this project for my own learning, I do not expect anyone to actually use this. But it was fun to work on, and it definitely helped me understand the neural networks a lot more.
Thanks for reading! Until next time.