Tag

Showing blog posts sorted under the tag: data science

Interactive Digit Classification Using Neural Network Trained on MNIST Data

Several years ago, I created a fully connected neural network from scratch in C as a learning exercise. I followed the first few chapters of Michael Nielson's book 'Neural Networks and Deep Learning' that I highly recommend.

The network was designed to train on the MNIST number dataset, which is a well-known dataset used in many machine learning examples. The goal is to identify hand-written digits as any number between 0 and 9. The final network performed quite well and achieved 97.14% accuracy on the test dataset. Not bad for a bit of matrix algebra wrapped up in some C code.

Anyways, ever since then I've had the idea to create a little browser widget to let people use the model I trained in an interactive way. Of course, I was beaten to the punch once, twice, and many more times I'm sure. But even still I wanted to see how well my model would perform at this task.

Homemode Fully Connected Neural Network

Before starting to work on the widget, I beefed up my neural network a little bit and was able to train one with a 98.2% accuracy on the test MNIST data. I then used a web framework called Svelte to create a drawing and predicting widget. Since my model is all simple linear algebra, exporting the weights from C and hard-coding them into Javascript was not too much work. Libraries like Math.js made it pretty easy to recreate everything. The final product is the widget you see below. It runs entirely client side in the browser using my trained neural network.

* does not run on iOS Safari, possible macOS Safari as well *

If you tried a few numbers, you probably noticed that the predictions can often be rather poor. I found that it has a really difficult time with '1's , '0's, and '9's. It was a bit disappointing, even with a 98.2% accuracy on test data, it still has a lot of trouble with new numbers. My guess is that, due to the fully-connectedness of the network, it has a difficult time generalizing new data. Like, if a '1' is off to the side or at a wrong angle that isn't present in the training data then it will predict incorrectly.

Keras/Tensorflow Convolutional Neural Network

Another type of network often used on the MNIST data is a convolutional network. I won't go into the topic here but this explanation was pretty helpful for my understanding. Convolutional networks work so well on MNIST that it's actually one of the 'getting started' examples for Keras.

I wanted to see how much the widget would improve with a convolutional network instead of my fully-connected version. So, I followed the Keras example and trained one in Python that reached an accuracy of 99.3% on the test data. Crucially though, I believe that it generalizes much better and is therefore more tolerant to digits that may not be presented in exactly the same way as in the training data. And the results definitely show, in my testing it seems to predict the correct digit much more often then my homemade model.

Again, the widget below is running entirely in the browser using the the Tensorflow.js library. Tenforflow.js allowed me to export the model from Python and import it directly into the Svelte widget.

* does not run on iOS Safari, possible macOS Safari as well *

Embedding Widgets

Because the widgets run entirely client-side, feel free to embed them anywhere on your own site using the code snippets below. They are web components that use a shadow DOM so should always look the same no matter where they are embedded. Kind of like frame, but for the modern age.

<script src="https://www.cluoma.com/js/mnist_widget.js"></script>
<div><mnist-checker-widget /></div>
<script src="https://www.cluoma.com/js/mnist_convolution_widget.js"></script>
<div><mnist-convolution-checker-widget /></div>

Full source code for this project is posted on my GitHub.


Tags: