-
-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TensorContainer trait to allow more argument types for TensorVisitors in #469 #472
Changes from all commits
a5ff5f2
f2721e0
bbdf197
d2fa72d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
use super::visitors::TensorContainer; | ||
|
||
impl TensorContainer for &'static () { | ||
type WithModule<'a, Mod: 'a> = &'a Mod; | ||
|
||
fn get_field<'a, Mod, Field, GetRef, GetMut>( | ||
module: &'a mut Self::WithModule<'_, Mod>, | ||
get_ref: &mut GetRef, | ||
_get_mut: &mut GetMut, | ||
) -> Self::WithModule<'a, Field> | ||
where | ||
GetRef: FnMut(&Mod) -> &Field, | ||
GetMut: FnMut(&mut Mod) -> &mut Field, | ||
{ | ||
get_ref(*module) | ||
} | ||
} | ||
|
||
impl TensorContainer for &'static mut () { | ||
type WithModule<'a, Mod: 'a> = &'a mut Mod; | ||
|
||
fn get_field<'a, Mod, Field, GetRef, GetMut>( | ||
module: &'a mut Self::WithModule<'_, Mod>, | ||
_get_ref: &mut GetRef, | ||
get_mut: &mut GetMut, | ||
) -> Self::WithModule<'a, Field> | ||
where | ||
GetRef: FnMut(&Mod) -> &Field, | ||
GetMut: FnMut(&mut Mod) -> &mut Field, | ||
{ | ||
get_mut(*module) | ||
} | ||
} | ||
|
||
macro_rules! tuple_impls { | ||
([$($name:ident),+] [$($idx:tt),+]) => { | ||
impl<$($name: TensorContainer),+> TensorContainer for ($($name,)+) { | ||
type WithModule<'a, Mod: 'a> = ($($name::WithModule<'a, Mod>,)+); | ||
|
||
fn get_field<'a, Mod, Field, GetRef, GetMut>( | ||
module: &'a mut Self::WithModule<'_, Mod>, | ||
get_ref: &mut GetRef, | ||
get_mut: &mut GetMut, | ||
) -> Self::WithModule<'a, Field> | ||
where | ||
GetRef: FnMut(&Mod) -> &Field, | ||
GetMut: FnMut(&mut Mod) -> &mut Field, | ||
{ | ||
($($name::get_field(&mut module.$idx, get_ref, get_mut),)+) | ||
} | ||
} | ||
} | ||
} | ||
|
||
tuple_impls!([M1][0]); | ||
tuple_impls!([M1, M2] [0, 1]); | ||
tuple_impls!([M1, M2, M3] [0, 1, 2]); | ||
tuple_impls!([M1, M2, M3, M4] [0, 1, 2, 3]); | ||
tuple_impls!([M1, M2, M3, M4, M5] [0, 1, 2, 3, 4]); | ||
tuple_impls!([M1, M2, M3, M4, M5, M6] [0, 1, 2, 3, 4, 5]); | ||
|
||
impl<T: TensorContainer> TensorContainer for std::vec::Vec<T> { | ||
type WithModule<'a, Mod: 'a> = std::vec::Vec<T::WithModule<'a, Mod>>; | ||
|
||
fn get_field<'a, Mod, Field, GetRef, GetMut>( | ||
module: &'a mut Self::WithModule<'_, Mod>, | ||
get_ref: &mut GetRef, | ||
get_mut: &mut GetMut, | ||
) -> Self::WithModule<'a, Field> | ||
where | ||
GetRef: FnMut(&Mod) -> &Field, | ||
GetMut: FnMut(&mut Mod) -> &mut Field, | ||
{ | ||
module | ||
.iter_mut() | ||
.map(|x| T::get_field(x, get_ref, get_mut)) | ||
.collect() | ||
} | ||
} | ||
|
||
impl<T: TensorContainer> TensorContainer for Option<T> { | ||
type WithModule<'a, Mod: 'a> = Option<T::WithModule<'a, Mod>>; | ||
|
||
fn get_field<'a, Mod, Field, GetRef, GetMut>( | ||
module: &'a mut Self::WithModule<'_, Mod>, | ||
get_ref: &mut GetRef, | ||
get_mut: &mut GetMut, | ||
) -> Self::WithModule<'a, Field> | ||
where | ||
GetRef: FnMut(&Mod) -> &Field, | ||
GetMut: FnMut(&mut Mod) -> &mut Field, | ||
{ | ||
module.as_mut().map(|x| T::get_field(x, get_ref, get_mut)) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -112,10 +112,12 @@ pub trait LoadFromNpz<E: Dtype + NumpyDtype, D: CopySlice<E>>: TensorCollection< | |
} | ||
impl<E: Dtype + NumpyDtype, D: CopySlice<E>, T: TensorCollection<E, D>> LoadFromNpz<E, D> for T {} | ||
|
||
impl<W: Write + Seek, E: Dtype + NumpyDtype, D: CopySlice<E>> VisitTensorRef<E, D> | ||
impl<W: Write + Seek, E: Dtype + NumpyDtype, D: CopySlice<E>> VisitTensors<E, D> | ||
for zip::ZipWriter<W> | ||
{ | ||
type Container = &'static (); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah lets go with your suggestion in the PR description about using the unconstructable enums instead of these, purely for readability. I know i'd forget in a couple months what this means haha. type Container = TensorRef; makes so much sense! |
||
type Err = ZipError; | ||
|
||
fn visit<S: Shape>( | ||
&mut self, | ||
full_path: String, | ||
|
@@ -126,10 +128,12 @@ impl<W: Write + Seek, E: Dtype + NumpyDtype, D: CopySlice<E>> VisitTensorRef<E, | |
} | ||
} | ||
|
||
impl<R: Read + Seek, E: Dtype + NumpyDtype, D: CopySlice<E>> VisitTensorMut<E, D> | ||
impl<R: Read + Seek, E: Dtype + NumpyDtype, D: CopySlice<E>> VisitTensors<E, D> | ||
for zip::ZipArchive<R> | ||
{ | ||
type Container = &'static mut (); | ||
type Err = NpzError; | ||
|
||
fn visit<S: Shape>( | ||
&mut self, | ||
full_path: String, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,34 +45,33 @@ impl<S: Shape, E: Dtype, D: DeviceStorage> TensorOptions<S, E, D> { | |
} | ||
} | ||
|
||
pub trait VisitTensorRef<E: Dtype, D: DeviceStorage> { | ||
pub trait VisitTensors<E: Dtype, D: DeviceStorage> { | ||
type Container: TensorContainer; | ||
type Err; | ||
fn visit<S: Shape>( | ||
&mut self, | ||
full_path: String, | ||
opts: TensorOptions<S, E, D>, | ||
t: &Tensor<S, E, D>, | ||
) -> Result<(), Self::Err>; | ||
} | ||
|
||
pub trait VisitTensorMut<E: Dtype, D: DeviceStorage> { | ||
type Err; | ||
fn visit<S: Shape>( | ||
&mut self, | ||
full_path: String, | ||
opts: TensorOptions<S, E, D>, | ||
t: &mut Tensor<S, E, D>, | ||
t: <Self::Container as TensorContainer>::WithModule<'_, Tensor<S, E, D>>, | ||
) -> Result<(), Self::Err>; | ||
} | ||
|
||
pub trait VisitTensorMutRef<E: Dtype, D: DeviceStorage> { | ||
type Err; | ||
fn visit<S: Shape>( | ||
&mut self, | ||
full_path: String, | ||
opts: TensorOptions<S, E, D>, | ||
ts: (&mut Tensor<S, E, D>, &Tensor<S, E, D>), | ||
) -> Result<(), Self::Err>; | ||
type ContainerWithModule<'a, C, M> = <C as TensorContainer>::WithModule<'a, M>; | ||
|
||
pub trait TensorContainer: 'static { | ||
type WithModule<'a, Mod: 'a> | ||
where | ||
Self: 'a; | ||
|
||
fn get_field<'a, Mod, Field, GetRef, GetMut>( | ||
module: &'a mut Self::WithModule<'_, Mod>, | ||
get_ref: &mut GetRef, | ||
get_mut: &mut GetMut, | ||
) -> Self::WithModule<'a, Field> | ||
where | ||
GetRef: FnMut(&Mod) -> &Field, | ||
GetMut: FnMut(&mut Mod) -> &mut Field; | ||
} | ||
|
||
pub trait TensorCollection<E: Dtype, D: DeviceStorage>: Sized { | ||
|
@@ -124,131 +123,51 @@ pub(crate) struct RecursiveWalker<'a, M, F> { | |
pub(crate) path: &'a mut Vec<String>, | ||
} | ||
|
||
impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorRef<E, D>> TensorVisitor<M, E, D> | ||
for RecursiveWalker<'a, &'a M, F> | ||
{ | ||
type Err = F::Err; | ||
fn visit_module<Field, GetRef, GetMut>( | ||
&mut self, | ||
mut get_refs: GetRef, | ||
_: GetMut, | ||
name: &str, | ||
) -> Result<(), Self::Err> | ||
where | ||
GetRef: FnMut(&M) -> &Field, | ||
GetMut: FnMut(&mut M) -> &mut Field, | ||
Field: TensorCollection<E, D>, | ||
{ | ||
self.path.push(name.into()); | ||
let mut walker = RecursiveWalker { | ||
m: get_refs(self.m), | ||
f: self.f, | ||
path: self.path, | ||
}; | ||
Field::iter_tensors(&mut walker)?; | ||
self.path.pop(); | ||
Ok(()) | ||
} | ||
fn visit_tensor<S: Shape, GetRef, GetMut>( | ||
&mut self, | ||
mut get_refs: GetRef, | ||
_: GetMut, | ||
name: &str, | ||
opts: TensorOptions<S, E, D>, | ||
) -> Result<(), F::Err> | ||
where | ||
GetRef: FnMut(&M) -> &Tensor<S, E, D>, | ||
GetMut: FnMut(&mut M) -> &mut Tensor<S, E, D>, | ||
{ | ||
self.path.push(name.into()); | ||
self.f.visit(self.path.join("."), opts, get_refs(self.m))?; | ||
self.path.pop(); | ||
Ok(()) | ||
} | ||
} | ||
|
||
impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMut<E, D>> TensorVisitor<M, E, D> | ||
for RecursiveWalker<'a, &'a mut M, F> | ||
impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensors<E, D>> TensorVisitor<M, E, D> | ||
for RecursiveWalker<'a, ContainerWithModule<'a, F::Container, M>, F> | ||
{ | ||
type Err = F::Err; | ||
fn visit_module<Field, GetRef, GetMut>( | ||
&mut self, | ||
_: GetRef, | ||
mut get_muts: GetMut, | ||
name: &str, | ||
) -> Result<(), F::Err> | ||
where | ||
GetRef: FnMut(&M) -> &Field, | ||
GetMut: FnMut(&mut M) -> &mut Field, | ||
Field: TensorCollection<E, D>, | ||
{ | ||
self.path.push(name.into()); | ||
let mut walker = RecursiveWalker { | ||
m: get_muts(self.m), | ||
f: self.f, | ||
path: self.path, | ||
}; | ||
Field::iter_tensors(&mut walker)?; | ||
self.path.pop(); | ||
Ok(()) | ||
} | ||
fn visit_tensor<S: Shape, GetRef, GetMut>( | ||
&mut self, | ||
_: GetRef, | ||
mut get_muts: GetMut, | ||
name: &str, | ||
opts: TensorOptions<S, E, D>, | ||
) -> Result<(), F::Err> | ||
where | ||
GetRef: FnMut(&M) -> &Tensor<S, E, D>, | ||
GetMut: FnMut(&mut M) -> &mut Tensor<S, E, D>, | ||
{ | ||
self.path.push(name.into()); | ||
self.f.visit(self.path.join("."), opts, get_muts(self.m))?; | ||
self.path.pop(); | ||
Ok(()) | ||
} | ||
} | ||
|
||
impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMutRef<E, D>> TensorVisitor<M, E, D> | ||
for RecursiveWalker<'a, (&'a mut M, &'a M), F> | ||
{ | ||
type Err = F::Err; | ||
fn visit_module<Field, GetRef, GetMut>( | ||
&mut self, | ||
mut get_refs: GetRef, | ||
mut get_muts: GetMut, | ||
name: &str, | ||
) -> Result<(), F::Err> | ||
) -> Result<(), Self::Err> | ||
where | ||
GetRef: FnMut(&M) -> &Field, | ||
GetMut: FnMut(&mut M) -> &mut Field, | ||
Field: TensorCollection<E, D>, | ||
{ | ||
self.path.push(name.into()); | ||
let mut walker = RecursiveWalker { | ||
m: (get_muts(self.m.0), get_refs(self.m.1)), | ||
m: F::Container::get_field(&mut self.m, &mut get_refs, &mut get_muts), | ||
f: self.f, | ||
path: self.path, | ||
}; | ||
Field::iter_tensors(&mut walker)?; | ||
std::mem::drop(walker); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, if you remove it, rust will consider |
||
self.path.pop(); | ||
Ok(()) | ||
} | ||
|
||
fn visit_tensor<S: Shape, GetRef, GetMut>( | ||
&mut self, | ||
mut get_refs: GetRef, | ||
mut get_muts: GetMut, | ||
name: &str, | ||
opts: TensorOptions<S, E, D>, | ||
) -> Result<(), F::Err> | ||
) -> Result<(), Self::Err> | ||
where | ||
GetRef: FnMut(&M) -> &Tensor<S, E, D>, | ||
GetMut: FnMut(&mut M) -> &mut Tensor<S, E, D>, | ||
{ | ||
self.path.push(name.into()); | ||
let tensors = (get_muts(self.m.0), get_refs(self.m.1)); | ||
self.f.visit(self.path.join("."), opts, tensors)?; | ||
self.f.visit( | ||
self.path.join("."), | ||
opts, | ||
F::Container::get_field(&mut self.m, &mut get_refs, &mut get_muts), | ||
)?; | ||
self.path.pop(); | ||
Ok(()) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could see this being used for optional fields (e.g. bias if we go the Option route), nice!