Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TensorContainer trait to allow more argument types for TensorVisitors in #469 #472

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions src/nn/impl_tensor_container.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use super::visitors::TensorContainer;

impl TensorContainer for &'static () {
type WithModule<'a, Mod: 'a> = &'a Mod;

fn get_field<'a, Mod, Field, GetRef, GetMut>(
module: &'a mut Self::WithModule<'_, Mod>,
get_ref: &mut GetRef,
_get_mut: &mut GetMut,
) -> Self::WithModule<'a, Field>
where
GetRef: FnMut(&Mod) -> &Field,
GetMut: FnMut(&mut Mod) -> &mut Field,
{
get_ref(*module)
}
}

impl TensorContainer for &'static mut () {
type WithModule<'a, Mod: 'a> = &'a mut Mod;

fn get_field<'a, Mod, Field, GetRef, GetMut>(
module: &'a mut Self::WithModule<'_, Mod>,
_get_ref: &mut GetRef,
get_mut: &mut GetMut,
) -> Self::WithModule<'a, Field>
where
GetRef: FnMut(&Mod) -> &Field,
GetMut: FnMut(&mut Mod) -> &mut Field,
{
get_mut(*module)
}
}

macro_rules! tuple_impls {
([$($name:ident),+] [$($idx:tt),+]) => {
impl<$($name: TensorContainer),+> TensorContainer for ($($name,)+) {
type WithModule<'a, Mod: 'a> = ($($name::WithModule<'a, Mod>,)+);

fn get_field<'a, Mod, Field, GetRef, GetMut>(
module: &'a mut Self::WithModule<'_, Mod>,
get_ref: &mut GetRef,
get_mut: &mut GetMut,
) -> Self::WithModule<'a, Field>
where
GetRef: FnMut(&Mod) -> &Field,
GetMut: FnMut(&mut Mod) -> &mut Field,
{
($($name::get_field(&mut module.$idx, get_ref, get_mut),)+)
}
}
}
}

tuple_impls!([M1][0]);
tuple_impls!([M1, M2] [0, 1]);
tuple_impls!([M1, M2, M3] [0, 1, 2]);
tuple_impls!([M1, M2, M3, M4] [0, 1, 2, 3]);
tuple_impls!([M1, M2, M3, M4, M5] [0, 1, 2, 3, 4]);
tuple_impls!([M1, M2, M3, M4, M5, M6] [0, 1, 2, 3, 4, 5]);

impl<T: TensorContainer> TensorContainer for std::vec::Vec<T> {
type WithModule<'a, Mod: 'a> = std::vec::Vec<T::WithModule<'a, Mod>>;

fn get_field<'a, Mod, Field, GetRef, GetMut>(
module: &'a mut Self::WithModule<'_, Mod>,
get_ref: &mut GetRef,
get_mut: &mut GetMut,
) -> Self::WithModule<'a, Field>
where
GetRef: FnMut(&Mod) -> &Field,
GetMut: FnMut(&mut Mod) -> &mut Field,
{
module
.iter_mut()
.map(|x| T::get_field(x, get_ref, get_mut))
.collect()
}
}

impl<T: TensorContainer> TensorContainer for Option<T> {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could see this being used for optional fields (e.g. bias if we go the Option route), nice!

type WithModule<'a, Mod: 'a> = Option<T::WithModule<'a, Mod>>;

fn get_field<'a, Mod, Field, GetRef, GetMut>(
module: &'a mut Self::WithModule<'_, Mod>,
get_ref: &mut GetRef,
get_mut: &mut GetMut,
) -> Self::WithModule<'a, Field>
where
GetRef: FnMut(&Mod) -> &Field,
GetMut: FnMut(&mut Mod) -> &mut Field,
{
module.as_mut().map(|x| T::get_field(x, get_ref, get_mut))
}
}
1 change: 1 addition & 0 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ mod embedding;
mod flatten;
mod generalized_residual;
mod impl_module_for_tuples;
mod impl_tensor_container;
mod layer_norm;
mod linear;
mod module;
Expand Down
8 changes: 6 additions & 2 deletions src/nn/npz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,12 @@ pub trait LoadFromNpz<E: Dtype + NumpyDtype, D: CopySlice<E>>: TensorCollection<
}
impl<E: Dtype + NumpyDtype, D: CopySlice<E>, T: TensorCollection<E, D>> LoadFromNpz<E, D> for T {}

impl<W: Write + Seek, E: Dtype + NumpyDtype, D: CopySlice<E>> VisitTensorRef<E, D>
impl<W: Write + Seek, E: Dtype + NumpyDtype, D: CopySlice<E>> VisitTensors<E, D>
for zip::ZipWriter<W>
{
type Container = &'static ();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah lets go with your suggestion in the PR description about using the unconstructable enums instead of these, purely for readability. I know i'd forget in a couple months what this means haha.

type Container = TensorRef;

makes so much sense!

type Err = ZipError;

fn visit<S: Shape>(
&mut self,
full_path: String,
Expand All @@ -126,10 +128,12 @@ impl<W: Write + Seek, E: Dtype + NumpyDtype, D: CopySlice<E>> VisitTensorRef<E,
}
}

impl<R: Read + Seek, E: Dtype + NumpyDtype, D: CopySlice<E>> VisitTensorMut<E, D>
impl<R: Read + Seek, E: Dtype + NumpyDtype, D: CopySlice<E>> VisitTensors<E, D>
for zip::ZipArchive<R>
{
type Container = &'static mut ();
type Err = NpzError;

fn visit<S: Shape>(
&mut self,
full_path: String,
Expand Down
6 changes: 4 additions & 2 deletions src/nn/num_params.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use super::visitors::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensorRef};
use super::visitors::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensors};

use crate::{shapes::*, tensor::*};

use std::{string::String, vec::Vec};

struct Counter(usize);
impl<E: Dtype, D: DeviceStorage> VisitTensorRef<E, D> for Counter {
impl<E: Dtype, D: DeviceStorage> VisitTensors<E, D> for Counter {
type Container = &'static ();
type Err = D::Err;

fn visit<S: Shape>(
&mut self,
_: String,
Expand Down
6 changes: 4 additions & 2 deletions src/nn/reset_params.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use super::visitors::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensorMut};
use super::visitors::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensors};

use crate::{shapes::*, tensor::*};

use std::{string::String, vec::Vec};

struct Resetter;
impl<E: Dtype, D: DeviceStorage> VisitTensorMut<E, D> for Resetter {
impl<E: Dtype, D: DeviceStorage> VisitTensors<E, D> for Resetter {
type Container = &'static mut ();
type Err = D::Err;

fn visit<S: Shape>(
&mut self,
_: String,
Expand Down
141 changes: 30 additions & 111 deletions src/nn/visitors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,34 +45,33 @@ impl<S: Shape, E: Dtype, D: DeviceStorage> TensorOptions<S, E, D> {
}
}

pub trait VisitTensorRef<E: Dtype, D: DeviceStorage> {
pub trait VisitTensors<E: Dtype, D: DeviceStorage> {
type Container: TensorContainer;
type Err;
fn visit<S: Shape>(
&mut self,
full_path: String,
opts: TensorOptions<S, E, D>,
t: &Tensor<S, E, D>,
) -> Result<(), Self::Err>;
}

pub trait VisitTensorMut<E: Dtype, D: DeviceStorage> {
type Err;
fn visit<S: Shape>(
&mut self,
full_path: String,
opts: TensorOptions<S, E, D>,
t: &mut Tensor<S, E, D>,
t: <Self::Container as TensorContainer>::WithModule<'_, Tensor<S, E, D>>,
) -> Result<(), Self::Err>;
}

pub trait VisitTensorMutRef<E: Dtype, D: DeviceStorage> {
type Err;
fn visit<S: Shape>(
&mut self,
full_path: String,
opts: TensorOptions<S, E, D>,
ts: (&mut Tensor<S, E, D>, &Tensor<S, E, D>),
) -> Result<(), Self::Err>;
type ContainerWithModule<'a, C, M> = <C as TensorContainer>::WithModule<'a, M>;

pub trait TensorContainer: 'static {
type WithModule<'a, Mod: 'a>
where
Self: 'a;

fn get_field<'a, Mod, Field, GetRef, GetMut>(
module: &'a mut Self::WithModule<'_, Mod>,
get_ref: &mut GetRef,
get_mut: &mut GetMut,
) -> Self::WithModule<'a, Field>
where
GetRef: FnMut(&Mod) -> &Field,
GetMut: FnMut(&mut Mod) -> &mut Field;
}

pub trait TensorCollection<E: Dtype, D: DeviceStorage>: Sized {
Expand Down Expand Up @@ -124,131 +123,51 @@ pub(crate) struct RecursiveWalker<'a, M, F> {
pub(crate) path: &'a mut Vec<String>,
}

impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorRef<E, D>> TensorVisitor<M, E, D>
for RecursiveWalker<'a, &'a M, F>
{
type Err = F::Err;
fn visit_module<Field, GetRef, GetMut>(
&mut self,
mut get_refs: GetRef,
_: GetMut,
name: &str,
) -> Result<(), Self::Err>
where
GetRef: FnMut(&M) -> &Field,
GetMut: FnMut(&mut M) -> &mut Field,
Field: TensorCollection<E, D>,
{
self.path.push(name.into());
let mut walker = RecursiveWalker {
m: get_refs(self.m),
f: self.f,
path: self.path,
};
Field::iter_tensors(&mut walker)?;
self.path.pop();
Ok(())
}
fn visit_tensor<S: Shape, GetRef, GetMut>(
&mut self,
mut get_refs: GetRef,
_: GetMut,
name: &str,
opts: TensorOptions<S, E, D>,
) -> Result<(), F::Err>
where
GetRef: FnMut(&M) -> &Tensor<S, E, D>,
GetMut: FnMut(&mut M) -> &mut Tensor<S, E, D>,
{
self.path.push(name.into());
self.f.visit(self.path.join("."), opts, get_refs(self.m))?;
self.path.pop();
Ok(())
}
}

impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMut<E, D>> TensorVisitor<M, E, D>
for RecursiveWalker<'a, &'a mut M, F>
impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensors<E, D>> TensorVisitor<M, E, D>
for RecursiveWalker<'a, ContainerWithModule<'a, F::Container, M>, F>
{
type Err = F::Err;
fn visit_module<Field, GetRef, GetMut>(
&mut self,
_: GetRef,
mut get_muts: GetMut,
name: &str,
) -> Result<(), F::Err>
where
GetRef: FnMut(&M) -> &Field,
GetMut: FnMut(&mut M) -> &mut Field,
Field: TensorCollection<E, D>,
{
self.path.push(name.into());
let mut walker = RecursiveWalker {
m: get_muts(self.m),
f: self.f,
path: self.path,
};
Field::iter_tensors(&mut walker)?;
self.path.pop();
Ok(())
}
fn visit_tensor<S: Shape, GetRef, GetMut>(
&mut self,
_: GetRef,
mut get_muts: GetMut,
name: &str,
opts: TensorOptions<S, E, D>,
) -> Result<(), F::Err>
where
GetRef: FnMut(&M) -> &Tensor<S, E, D>,
GetMut: FnMut(&mut M) -> &mut Tensor<S, E, D>,
{
self.path.push(name.into());
self.f.visit(self.path.join("."), opts, get_muts(self.m))?;
self.path.pop();
Ok(())
}
}

impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMutRef<E, D>> TensorVisitor<M, E, D>
for RecursiveWalker<'a, (&'a mut M, &'a M), F>
{
type Err = F::Err;
fn visit_module<Field, GetRef, GetMut>(
&mut self,
mut get_refs: GetRef,
mut get_muts: GetMut,
name: &str,
) -> Result<(), F::Err>
) -> Result<(), Self::Err>
where
GetRef: FnMut(&M) -> &Field,
GetMut: FnMut(&mut M) -> &mut Field,
Field: TensorCollection<E, D>,
{
self.path.push(name.into());
let mut walker = RecursiveWalker {
m: (get_muts(self.m.0), get_refs(self.m.1)),
m: F::Container::get_field(&mut self.m, &mut get_refs, &mut get_muts),
f: self.f,
path: self.path,
};
Field::iter_tensors(&mut walker)?;
std::mem::drop(walker);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if you remove it, rust will consider self.path to be mutably borrowed on the next line.

self.path.pop();
Ok(())
}

fn visit_tensor<S: Shape, GetRef, GetMut>(
&mut self,
mut get_refs: GetRef,
mut get_muts: GetMut,
name: &str,
opts: TensorOptions<S, E, D>,
) -> Result<(), F::Err>
) -> Result<(), Self::Err>
where
GetRef: FnMut(&M) -> &Tensor<S, E, D>,
GetMut: FnMut(&mut M) -> &mut Tensor<S, E, D>,
{
self.path.push(name.into());
let tensors = (get_muts(self.m.0), get_refs(self.m.1));
self.f.visit(self.path.join("."), opts, tensors)?;
self.f.visit(
self.path.join("."),
opts,
F::Container::get_field(&mut self.m, &mut get_refs, &mut get_muts),
)?;
self.path.pop();
Ok(())
}
Expand Down
4 changes: 3 additions & 1 deletion src/optim/adam/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ pub(super) trait AdamKernel<E: Dtype>: DeviceStorage {
) -> Result<(), Self::Err>;
}

impl<M, D: AdamKernel<E>, E: Dtype> VisitTensorMut<E, D> for Adam<M, E> {
impl<M, D: AdamKernel<E>, E: Dtype> VisitTensors<E, D> for Adam<M, E> {
type Container = &'static mut ();
type Err = D::Err;

fn visit<S: Shape>(
&mut self,
_: alloc::string::String,
Expand Down
4 changes: 3 additions & 1 deletion src/optim/rmsprop/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ pub(super) trait RMSpropKernel<E: Dtype>: DeviceStorage {
) -> Result<(), Self::Err>;
}

impl<M, E: Dtype, D: RMSpropKernel<E> + OneFillStorage<E>> VisitTensorMut<E, D> for RMSprop<M, E> {
impl<M, E: Dtype, D: RMSpropKernel<E> + OneFillStorage<E>> VisitTensors<E, D> for RMSprop<M, E> {
type Container = &'static mut ();
type Err = D::Err;

fn visit<S: Shape>(
&mut self,
_: alloc::string::String,
Expand Down
4 changes: 3 additions & 1 deletion src/optim/sgd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,10 @@ pub(super) trait SgdKernel<E: Dtype>: DeviceStorage {
) -> Result<(), Self::Err>;
}

impl<E: Dtype, D: SgdKernel<E>, M> VisitTensorMut<E, D> for Sgd<M, E> {
impl<E: Dtype, D: SgdKernel<E>, M> VisitTensors<E, D> for Sgd<M, E> {
type Container = &'static mut ();
type Err = D::Err;

fn visit<S: Shape>(
&mut self,
_: alloc::string::String,
Expand Down