Skip to content

Autodiff graph, in rust, for my final highschool programming project.

Notifications You must be signed in to change notification settings

unic0rn9k/autodiff

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

29 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

A library for automatic differentiation using GATs Rust

This library was developed as a part of a school projekt, therefore it also includes a repport written in Danish.

As a part of the projekt i also wrote an mnist classifier, which is described in the repport.

To run the mnist classifier example, clone the repository, and run cargo r --package mnist_classifier.

Example from unit test in src/mat.rs

    let x = mat(DMatrix::<f32>::new_random(3, 1)).symbol("input");

    let w1 = mat(DMatrix::<f32>::new_random(2, 3)).symbol("w1");
    let b1 = mat(DMatrix::<f32>::new_random(2, 1)).symbol("b1");

    let w2 = mat(DMatrix::<f32>::new_random(4, 2)).symbol("w2");
    let b2 = mat(DMatrix::<f32>::new_random(4, 1)).symbol("b2");

    let dl2 = mat(DMatrix::<f32>::repeat(4, 1, 1.));

    let l1 = &w1 * &x + &b1;
    let l2 = &w2 * &l1 + &b2;
    let [dw1, dw2] = l2.derivative(["w1", "w2"], &dl2);

    l2.eval();

    println!("dw1 = {dw1:?}");
    println!();
    println!("dw2 = {dw2:?}");
    println!();

    assert_eq!(dw2.eval().shape(), w2.eval().shape());
    assert_eq!(dw1.eval().shape(), w1.eval().shape());

Long compile times

The repport also details some of the issues with leveraging GATs for autodiff, which lead me to try to rewrite the projekt on a new branch (runtime-compiled-graphs), where I instead use a structure for representing expressions, which can then be compiled into a evaluatable graph, at run time. The graph can then be reused for identical functions, with different values. This also allows for optimization of the mathematical expression, caching of computed values, an pruning of identical graph segments.

One thing I wanted to work into this project, was GPU support, so runtime-compiled-graphs is build with multiple backends in mind, and is meant to have GPU support out of the box.

About

Autodiff graph, in rust, for my final highschool programming project.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages