-
-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds TensorCollection #469
Conversation
@nkoppel thoughts? was able to impl optimizers with this approach as well |
One thing I don't like about the current implementation is the repeated code in pub trait VisitTensors<E: Dtype, D: DeviceStorage> {
type Argument<'a, M>: VisitTensorArgument<M> where M: 'a;
type Err;
fn visit<S: Shape>(
&mut self,
full_path: String,
opts: TensorOptions<S, E, D>,
t: Self::Argument<'_, Tensor<S, E, D>>,
) -> Result<(), Self::Err>;
}
pub trait VisitTensorArgument<M> {
type WithModule<'a, Mod>: VisitTensorArgument<Mod> where Self: 'a;
fn get_field<Field, GetRef, GetMut>(
&mut self,
get_ref: GetRef,
get_mut: GetMut,
) -> Self::WithModule<'_, Field>
where
GetRef: FnMut(&M) -> &Field,
GetMut: FnMut(&mut M) -> &mut Field,
Field: TensorCollection<E, D>,
} This would remove the need to have three different |
I totally agree, and nice sketch. Want to open a PR into this branch with that implemented? I think as long as we can keep that isolated to the implementation for RecursiveWalker (and not infect the user facing TensorCollection/TensorVisitor) it sounds like a great add |
@nkoppel I'm pretty happy with where this is, any other organization/naming stuff you think we should do before merge? |
Co-authored-by: nkoppel <nathankoppel0@gmail.com>
Resolves #435
Related to #460
TODOs