Complete from-scratch implementation of neural networks (in ≈400-500 lines of code) with examples for training on the MNIST handwritten digits dataset and CIFAR-10 object recognition dataset. See it in action here!
A neural network consists of layers of neurons. For example, in the diagram, there are three layers of neurons with 6,
4, and 3 neurons, respectively. Each neuron has weighted connections (whose value may signify the strength and
relationship of their connection) to each neuron in the next layer. Each neuron receives a
weighted input sum (denoted
which constrains values to be between 0 and 1 (useful, e.g., when you want the network to output a probability that an input is from a classification class). You may refer to other activation functions in the corresponding section below.
The resulting activation becomes the input for the neurons in
the next layer. Notably, we denote each weight
Note that each layer tends to have a bias neuron, whose weight value is simply added to the weighted
sum
In both the learning and classification stage, an example input array/vector is fed as input to the first layer, and through the weighted sum and activation processes described previously, becomes the input vector for the succeeding layer, finally resulting in an output vector in the final layer.
For a network to learn, its adjustable parameters (the weights) are altered based on the network's classification
errors. Particularly, the goal in the learning process is to minimize error
For example, the partial derivative
represents the sensitivity of
with
Notably, the negative sign ensures the weight is updated such that the error
We have
Correspondingly, in code:
if (isOutputLayer) {
for (int j = 0; j < weights[0].length; j++) {
deltaWeights[j] = error[j] * activationFunction(weightedSumOutput[j], true);
}
} else {
for (int j = 0; j < nextLayer.weights.length; j++) {
for (int l = 0; l < nextLayer.weights[0].length; l++) {
deltaWeights[j] += nextLayer.weights[j][l] * nextLayer.deltaWeights[l];
}
deltaWeights[j] *= activationFunction(weightedSumOutput[j], true);
}
}
for (int i = 0; i < weights.length; i++) {
for (int j = 0; j < weights[i].length; j++) {
weightsAdjustments[i][j] += deltaWeights[j] * inputs[i] * -learningRate;
}
}
for (int j = 0; j < biases.length; j++) {
biasesAdjustments[j] += deltaWeights[j] * -learningRate;
}where activationFunction(weightedSum, true) finds the derivative weights is a 2D array of size [i][j] and biases is an array of size [j], for current layer with
Used when examples can only be classified as one of several possible classification classes (one-hot encoding).
Notably used with softmax activation function. Note that binary cross-entropy with two classes is simply a special
case of categorical cross-entropy
(source).
Used when examples can be classified as several of possible classification classes. Notably used with sigmoid
activation function.
1e-8.
Default learning rate for momentum and demon momentum is ~0.01, whilst for adam-based updates, is ~0.001.
Where
for
Let
Note, the implementation only considers its use with categorical cross entropy:
Gradient descent performs updates after an epoch of weight updates is averaged, whereas mini-batch gradient descent does so with smaller batches, and stochastic gradient descent updates parameters after each training example.
Simple parallelization is implemented through Java's streams API:
IntStream.range(0, weights.length).parallel().forEach(i -> {
// Perform operations
});
// Is equivalent to:
for (int i = 0; i < weights.length; i++) {
// Perform operations
}- Backpropagation
- Activation functions
- Optimizers
- More on optimizers
- Demon (Decaying momentum)
- Cross entropy and backpropagation: [1], [2]
- Based off Y8 me’s overcomplicated (and likely inaccurate) code.
Some other ideas to experiment with:
- Use of various techniques to reduce overfitting — Dropout; L1 & L2 Regularization; Data Preprocessing & Augmentation
- Different Neural Network architectures/variants — Generative Adversarial Networks (GANs); Convolutional Neural Networks (CNNs); Recurrent Neural Networks (RNNs) and Long short-term memory variant (LSTM); Residual Neural Networks; Transformers
- Different learning methods — Supervised; Unsupervised; Reinforcement
- Consideration regarding the importance of data — is it possible to reduce number of training examples yet achieve similar performance results?
- Consideration regarding complexity of models — is it possible to reduce computational resources on training/running models yet retain similar performance results?
- How can a general model by trained to achieve generalized intelligence? (think ARC-AGI benchmarks)
- Art of the Problem series on neural networks mentions accelerated learning for robots using reinforcement learning by running simulations where physical properties are altered, allowing trained models to better adapt to the natural environment, which is often much different from training examples. Could a similar technique be applied to augment image data (e.g., randomly alter image color/positioning/dilation/skewing) and improve model generalization? (Would be similar to dropout and data augmentation techniques to reduce model overfitting.)
- Dive into Deep Learning — notably, references to CNNs, RNNs, LSTMs, GANs, Transformers, Computer Vision, Reinforcement Learning, Dropout
- Object detection and recognition — forms the basis of Optical Character Recognition (OCR) systems. Apart from being
able to recognize specific characters, there also needs to be a model that is able to detect the position of
characters in an input image (i.e., a model able to find bounding boxes of characters within an image to recognize).
- Models for bounding box detection — Single-shot multi-box detection, Region-based CNNs ([1], [2]), YOLO
