Skip to content

nitish1295/IITGN

Repository files navigation

IITGN

This project is part of the screening process to work on AI/ML projects under guidance from professors at IIT, Gandhinagar. The process had the folllowing 4 questions which had to be completed over a course of roughly 5 days:

  1. Bivariate Normal Distribution Plot ✅ - Plot the bivariate normal as shown in the image below:

    .

    The plot must include intractive sliders using which users can vary the parameters for the distribution. Kindly run the notebook at(run all cells) Binder

  2. Sampling from MVN ✅ - Write a sampling method from scratch, the sampling method must produce samples from a Multivariate Normal Distribution based on the randomly generated mean and covariance matrix. The method must work for any number of dimensions and you are not allowed to use any default libraries to generate a normal distribution. Kindly run the notebook at Open In Colab

  3. Implement a Neural Network from Scratch ❎ - Write a neural network from scratch which works with MNIST data set. The network must optimize itself using gradient descent which is written from scratch. Finally plot graphs for train and test and evaluate the model based on different classification metrics

  4. Implement Bayesian Regression from Scratch ❎ - Write Bayesian Linear Regression from Scratch and plot the learned predictive mean and 2 standard deviations around it. Use your own 1-D dataset with noise.


Constraints:

  • You are only allowed to use JAX unless mentioned otherwise.
  • Your code must adhere to PEP8

Learning Outcomes and Challenges:

  1. Bivariate Normal Distribution Plot : I had never used JAX before so starting with it was bit of a challenge initially, but once I learned that JAX worked pretty much like numpy in some aspects, it was smooth sailing ⛵. Although I did get caught up with immutable arrays in JAX and ipywidgets functionalities at a later point. From a mathematical standpoint I had to uderstand positive definite matrices and how to generate them so that the JAX multivariate normal function can return valid values for a given covariance matrix(FYI - I understood the complete idea behind this once I completed the second question below)

  2. Sampling from MVN : In this task I did understand that you can use Central Limit Theorem to generate standard normals but did not understand how that will work with a random mean and covariance matrix in a Mutlivariate Setting. After reading about for a while I discovered a method which uses Cholesky decompositon to create samples from mutivariate normal.

  3. Implement a Neural Network from Scratch : I am still working on implementing this neural network, but till now I have only implemented forward and backward pass for the dense layers, it is pretty rudimentry hence I havent included it here. This task is slightly challenging but exciting at the same time and I plan to continue working on this, since I haven't implemented a Neural Net from scratch before.


Technologies/ Libraries:

  • Python Libraries
    • JAX
    • Plotly
    • matplotlib

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published