Skip to content

A Multivariate Gaussian Bayes classifier written using JAX

License

Notifications You must be signed in to change notification settings

SalamanderXing/jax-gaussian-bayes

Repository files navigation

Multivariate Gaussian Bayes classsifier in JAX

This is a simple implementation of a multivariate Gaussian Bayes classifier in JAX. The classifier is trained and tested on the CIFAR-10 dataset.

Run by:

python evaluate.py

It will print the accuracy of a random classifier vs a Navie Bayes classifier vs a Multivariate Gaussian Bayes classifier.

If the program fais, probably you need to download the dataset first. You can do it by running:

wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

Dependencies

  • python >= 3.9
  • JAX