K-fold cross-validation implemented from scratch to aid in analysis of the MNIST dataset
It was determined that the left side of the numbers 3 and 8 are the important features used in differentiating between handwriting samples of the numbers 3 and 8. In the Hard2ClassifyData folders, we can see that samples with light handwriting are typically hard to classify as well as "scrunched" handwriting, where the lower part of 3 overextends to almost make it look like an 8. An example is below:
Important pixels to differentiate between a 3 and an 8 as determined by Logistic Regression (brighter pixels means more important)
Important pixels to differentiate between a 3 and an 8 as determined by Linear SVM (brighter pixels means more important)
The folders Easy2ClassifyData[MLmodel] and Hard2ClassifyData[MLmodel] contain examples of handwriting that is easy for the ML model to classify and hard for the ML model to classify, respectively.
- Unzip the MNIST.zip file to obtain the MNIST dataset. Ensure the unzipped folder remains in the same folder as the file MNIST.zip.
- To replicate the results, simply open a file and run.
- cv_builtin.py will fit a Logistic Regression and Linear SVM model to the data and print the tuned hyperparameters as determined by scikit-learn methods.
- cv_scratch.py will run the cross validation implemented from scratch. Feel free to change the number of folds K, which is defined as a global variable to experiment with the method.
- important_features.py will use the best lambda (can be changed in the code) as determined by cv_scratch.py to find the pixels in an image that are important in differentiating between a 3 and an 8.