Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds TensorCollection #469

Merged
merged 21 commits into from
Feb 22, 2023
Merged

Adds TensorCollection #469

merged 21 commits into from
Feb 22, 2023

Conversation

coreylowman
Copy link
Owner

@coreylowman coreylowman commented Feb 21, 2023

Resolves #435
Related to #460

  • Removes trait GradientUpdate
  • Removes trait ParamUpdater
  • optimizers now implement VisitTensors and use RecursiveWalker in their Optimizer impl
  • Moves trait ResetParams to use TensorCollection, and removes all nn impls
  • Adds NumParams using TensorCollection
  • Refactor SaveToNpz/LoadFromNpz to use TensorCollection
  • Removes all missing gradients tests in nn
    • Now that all nn layers use TensorCollection, testing npz/update/reset_params is enough to cover module walking
    • Unused tensor tests are now added to optimizers to ensure they properly capture unused
  • Adds traits VisitsTensors
  • Adds trait ModuleVisitor
  • Adds trait TensorCollection

TODOs

  • add documentation
  • improve use statements
  • Move visitors somewhere else other than tensor? top level?

@coreylowman
Copy link
Owner Author

@nkoppel thoughts? was able to impl optimizers with this approach as well

src/tensor/mod.rs Outdated Show resolved Hide resolved
@nkoppel
Copy link
Contributor

nkoppel commented Feb 21, 2023

One thing I don't like about the current implementation is the repeated code in visitors/base.rs. A lot of this repetition could be removed by generalizing the behavior of accessing fields in &T, &mut T, and (&mut T, &T) with something like the following:

pub trait VisitTensors<E: Dtype, D: DeviceStorage> {
    type Argument<'a, M>: VisitTensorArgument<M> where M: 'a;
    type Err;

    fn visit<S: Shape>(
        &mut self,
        full_path: String,
        opts: TensorOptions<S, E, D>,
        t: Self::Argument<'_, Tensor<S, E, D>>,
    ) -> Result<(), Self::Err>;
}

pub trait VisitTensorArgument<M> {
    type WithModule<'a, Mod>: VisitTensorArgument<Mod> where Self: 'a;

    fn get_field<Field, GetRef, GetMut>(
        &mut self,
        get_ref: GetRef,
        get_mut: GetMut,
    ) -> Self::WithModule<'_, Field>
     where
        GetRef: FnMut(&M) -> &Field,
        GetMut: FnMut(&mut M) -> &mut Field,
        Field: TensorCollection<E, D>,
}

This would remove the need to have three different VisitTensor traits, but would require that we add extra constraints to the RecursiveWalker struct to ensure that it has F::Argument<M> as it's m field. Thoughts?

@coreylowman
Copy link
Owner Author

I totally agree, and nice sketch. Want to open a PR into this branch with that implemented? I think as long as we can keep that isolated to the implementation for RecursiveWalker (and not infect the user facing TensorCollection/TensorVisitor) it sounds like a great add

src/nn/mod.rs Outdated Show resolved Hide resolved
src/nn/visitors.rs Outdated Show resolved Hide resolved
@coreylowman
Copy link
Owner Author

@nkoppel I'm pretty happy with where this is, any other organization/naming stuff you think we should do before merge?

@nkoppel
Copy link
Contributor

nkoppel commented Feb 22, 2023

@usbalbin merging this will break parts of #437. You'll need to remove the GradientUpdate implementation and replace it with a TensorCollection implementation.

@coreylowman coreylowman merged commit db408c8 into main Feb 22, 2023
@coreylowman coreylowman deleted the tensor-collection branch February 22, 2023 17:12
usbalbin added a commit to usbalbin/dfdx that referenced this pull request Feb 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Parameter Count
2 participants