An experimental deep learning library written in pure rust. Breakage expected on each release in the short term. See mnist.rs in examples or Rusty_SR for usage samples.
The key types are Node
and Ops
which are Rc
-like references to components of a shared mutable Graph
, which is extended gradually with new tensors and operations via construction functions. Facilities for reverse-mode automatic differentiation are included in operations, extending the graph as necessary.
Typical graph construction and differentiation shown below:
// 1. Build a MLP neural net graph - 98% @ 10 epochs
let input = Node::new(&[-1, 28, 28, 1]).set_name("input");
let labels = Node::new(&[-1, 10]).set_name("labels");
let layer1 = elu(affine(&input, 256, msra(1.0))).set_name("layer1");
let layer2 = elu(affine(&layer1, 256, msra(1.0))).set_name("layer2");
let logits = linear(&layer2, 10, msra(1.0)).set_name("logits");
let training_loss = add(
reduce_sum(softmax_cross_entropy(&logits, &labels, -1), &[], false).set_name("loss"),
scale(l2(logits.graph().nodes_tagged(NodeTag::Parameter)), 1e-3).set_name("regularisation"),
)
.set_name("training_loss");
let accuracy = equal(argmax(&logits, -1), argmax(&labels, -1)).set_name("accuracy");
let parameters = accuracy.graph().nodes_tagged(NodeTag::Parameter);
let grads = Grad::of(training_loss).wrt(parameters).build()?;
Current work is focused on improving the high level graph construction API, and better supporting dynamic/define-by-run graphs.
Issues are a great place for discussion, problems, requests.
Patchy until the library API experimentation ends, particularly until the graph construction API finalised.
- Computation hypergraph
- NN
- Dense Connection and Bias operations
- N-dimensional Convolution
- Arbitrary padding
- Strides
- Reflection padding
- Categorical Cross Entropy
- Binary Cross Entropy
- Boolean
- Equal
- Greater_Equal
- Greater_Than
- Less_Equal
- Less_Than
- Not
- Elementwise
- Abs
- Ceil
- Cos
- Div
- Elu
- Exp
- Floor
- Identity
- Leaky_relu
- Ln
- Logistic
- Max
- Min
- Mul
- Negative
- Offset
- Reciprocal
- Relu
- Robust
- Round
- Scale
- Sign
- Sin
- SoftPlus
- SoftSign
- Sqr
- Sqrt
- Srgb
- Subtract
- Tanh
- Grad
- Stop_grad
- Manip
- Concat
- Slice
- Permute_axes
- Expand_dims
- Remove_dims
- Math
- Argmax
- Broadcast
- Pooling
- N-dimensional Avg_Pool
- Max pool
- N-dimensional spaxel shuffling for "Sub-pixel Convolution"
- N-dimensional Linear-Interpolation
- Global Pooling
- Reduce
- Reduce_Prod
- Reduce_Sum
- Regularisation
- L1
- L2
- Hoyer_squared
- Robust
- Shapes
- Shape inference and constraint propagation
- Data Loading
- Mnist
- Cifar
- Image Folders
- Imagenet (ILSVRC)
- SGD
- RMSProp
- ADAM
- Basic numerical tests
- Limit Optimiser evaluation batch size to stay within memory limits
- Selectively disable calculation of forward values, node derivatives and parameter derivatives
- Builder patterns for operation contruction
- Split Graph struct into mutable GraphBuilder and immutable Sub-Graphs
- Replace 'accidentally quadratic' graph algorithms
- Replace up-front allocation with Sub-Graph optimised allocation/deallocation patterns based on liveness analysis of nodes
- Overhaul data ingestion, particularly buffering input processing/reads.
- Move tensor format to bluss' ndarray
- Improve naming inter/intra-library consistancy
- Operator overloading for simple ops
- Complete documentation
- Reduce ability to express illegal states in API
- Move from panics to error-chain
- Move from error-chain to thiserror
- Guard unsafe code rigourously
- Comprehensive tests
- Optionally typed tensors
- Arrayfire as an option for sgemm on APUs
- Graph optimisation passes and inplace operations
- Support for both dynamic and static graphs
- RNNs
MIT