Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn api hard to use with different devices #388

Closed
Tracked by #278
coreylowman opened this issue Jan 22, 2023 · 9 comments · Fixed by #405
Closed
Tracked by #278

nn api hard to use with different devices #388

coreylowman opened this issue Jan 22, 2023 · 9 comments · Fixed by #405
Assignees

Comments

@coreylowman
Copy link
Owner

Summary

Currently, a Device generic was added to all structs in nn that require storing tensors. E.g. Linear<5, 10, Cpu> or Linear<5, 10, Cuda>.

For an example of the current api for specifying a simple structure, you can specify:

type MLP = (Linear<5, 10>, ReLU, Linear<10, 1>);

However this currently defaults to the Cpu Device. In order to use this structure on a Cuda device, you'd have to go through and modify everything to set the new device generic:

type MLP<D: DeviceStorage> = (Linear<5, 10, D>, ReLU, Linear<10, 1, D>);

Additionally, when using with dev.build_module(), you have to specify the device twice here:

let dev: D = Default::default();
let m: MLP<D> = dev.build_module();

Even though the device should be identical to whatever device is building the module.

Further, if you happen to change the type of dev, you can get really hard to understand compile errors if you don't update the type you specified on the MLP:

let dev: D1 = Default::default();
// really weird errors because D1 can't build modules for device type D2
let m: MLP<D2> = dev.build_module();

Ways forward

I see a few ways forward, each with their own advantages/disadvantages. There are probably more ways forward as well.

Add a 2nd layer of nn configuration that would be used to build modules on devices

This would look something like this:

struct Linear<const I: usize, const O: usize>;

struct DeviceLinear<const I: usize, const O: usize, D: DeviceStorage> {
    weight: Tensor<Rank2<I, O>, f32, D>,
    bias: Tensor<Rank1<B>, f32, D>,
}

type MLP = (Linear<5, 10>, ReLU, Linear<10, 1>);

fn main() {
    let cpu: Cpu = Default::default();
    let cuda: Cuda = Default::default();

    let mlp_on_cpu: (DeviceLinear<5, 10, Cpu>, ReLU, DeviceLinear<5, 10, Cpu>) = dev.build_module<MLP>();
    let mlp_on_gpu: (DeviceLinear<5, 10, Cuda>, ReLU, DeviceLinear<5, 10, Cuda>) = dev.build_module<MLP>();
}

Specifically, the existing structures would be repurposes to concrete types for already build modules. The types Linear<5, 10> or Conv2D<3, 2, 1> would turn into almost configuration for device building.

Building device would tie these configuration structs together with the concrete build types:

trait BuildOn<D: DeviceStorage> {
    type Built;
}

impl BuildOn<D: DeviceStorage> for Linear<I, O> {
    type Built= DeviceLinear<I, O, D>;
}

Pros

  1. This makes specifying structure simple, actually it would be the exact same as you specify them now
  2. You can re-use the same structure for different devices without changing the definition

Cons

  1. The constructed type is different from the type you use to specify the structure. So notably you have to specify the structure as a function generic instead of the output generic.
  2. Having both DeviceLinear and Linear would be confusing, and we'd need documentation on what the difference is and when to use it
  3. It is not obvious why this was done

Add a .to(device) method of modules

This would mean you'd aways construct modules on Cpu and then send them to whatever device you wanted:

type MLP = (Linear<5, 10>, ReLU, Linear<10, 1>);

fn main() {
    let cpu: Cpu = Default::default();
    let cuda: Cuda = Default::default();

    let mlp_on_cpu: MLP = cpu.build_module();
    let mlp_on_gpu: (Linear<5, 10, Cuda>, ReLU, Linear<5, 10, Cuda>) = mlp_on_cpu.to(cuda);
}

Pros

  1. Existing structure would remain the same
  2. .to(device) already exists in pytorch so people would be familiar with it
  3. Very little documentation would need to be added, and this approach is probably easy to understand
  4. Easy to implement with minimal changes

Cons

  1. This doesn't address the issues of having an additional device generic on the structures. For things like Conv2D or Transformer, you will have to specify all the parameters, even those with default values that you want to keep.
  2. This still has the same issue of if you try to build with a different device than the type does.
@coreylowman coreylowman mentioned this issue Jan 22, 2023
47 tasks
@coreylowman
Copy link
Owner Author

Sketch of possible approach 1

trait LinearConfig {
    type Input: Dim;
    type Output: Dim;
}

struct Linear<const I: usize, const O: usize>;

impl<const I: usize, const O: usize> LinearConfig for Linear<I, O> {
    type Input = Const<I>;
    type Output = Const<O>;
}

struct DeviceLinear<Cfg: LinearConfig, D: Device> {
    weight: Tensor<(Cfg::Input, Cfg::Output), f32, D>,
    bias: Tensor<(Cfg::Output,), f32, D>,
}

@coreylowman
Copy link
Owner Author

Another potentially related issue is using modules with different dtypes (once that is supported).

For example, should this work? (Linear<5, 10, f32>, Linear<5, 10, f64>)?

This is another area where separating structure from device/dtype could be helpful:

let m32 = dev.build_module::<MLP, f32>();
let m16 = dev.build_module::<MLP, f16>();

@nkoppel
Copy link
Contributor

nkoppel commented Jan 23, 2023

Another option is to define a trait which generalizes the device parameter for modules. This could look like:

pub trait OnDeviceTrait<D> {
    type Output;
}

pub type OnDevice<M, D> = <M as OnDeviceTrait<D>>::Output;

With this, the MLP sequential module above could be defined as

type MLP<D> = OnDevice<(Linear<5, 10>, ReLU, Linear<10, 1>), D>;

This would even allow

type MLP = (Linear<5, 10>, ReLU, Linear<10, 1>);
type CudaMLP = OnDevice<MLP, Cuda>;

Here's a working implementation of this trait for Linear:

impl<const I: usize, const O: usize, D1: Device<f32>, D2: Device<f32>> OnDeviceTrait<D2> for Linear<I, O, D1> {
    type Output = Linear<I, O, D2>;
}

'OnDeviceTrait' may not be the best name for the trait, but this solution should be fairly simple to implement for us and for end users.

@coreylowman
Copy link
Owner Author

Oh interesting idea with the double device generics in the OnDeviceTrait impl, that's a big positive that it wouldn't require a separate struct definition. I don't see any issues with tuples either, I'll try sketching that out!

@nkoppel
Copy link
Contributor

nkoppel commented Jan 24, 2023

I've implemented this for tuples already, I'll open a pr with my progress.

@nkoppel
Copy link
Contributor

nkoppel commented Jan 24, 2023

Additionally, when using with dev.build_module(), you have to specify the device twice here:

let dev: D = Default::default();
let m: MLP<D> = dev.build_module();

You can do

let dev: D = Default::default();
let m: MLP<_> = dev.build_module();

to not have to specify the device twice.

@coreylowman
Copy link
Owner Author

You can, but if you forget to specify it or specify it as the wrong thing and forget to update it, you get really opaque errors (infinite type recursion atm).

I'm liking the idea of separating device specification from nn structure. Adding the OnDeviceTrait (or whatever it will be called) hopefully makes the error messages cleaner in these cases

@coreylowman
Copy link
Owner Author

I'm also testing out using your approach with the two device generics in the ResetParams trait:

pub trait BuildModule<D: Device<E>, E: Dtype>: Sized {
    type Built: ResetParams<D, E>;

    /// Construct it on the device
    fn build(device: &D) -> Self::Built {
        Self::try_build(device).unwrap()
    }
    /// Fallible version of [ResetParams::build]
    fn try_build(device: &D) -> Result<Self::Built, D::Err>;
}

/// Something that can reset it's parameters.
pub trait ResetParams<D: Device<E>, E: Dtype>: Sized {
    /// Mutates parameters. Each implementor
    /// of this trait decides how the parameters are initialized. In
    /// fact, some impls may not even use randomness.
    fn reset_params(&mut self) {
        self.try_reset_params().unwrap();
    }

    /// Fallible version of [ResetParams::reset_params].
    fn try_reset_params(&mut self) -> Result<(), D::Err>;
}

However, it seems like you can't use Linear<1, 2, _> with this approach because rust can't infer the type of the device for some reason?

@coreylowman
Copy link
Owner Author

coreylowman commented Jan 26, 2023

@nkoppel i'm currently thinking of this approach below. I'll take this issue over since it'll involve a lot of documentation/test updates as well.

trait BuildModule<D: Device> {
    fn build_module(dev: &D) -> Self;
}

trait BuildOnDevice<D: Device> {
    type Built: BuildModule<D>;
    fn build_on_device(dev: &D) -> Self::Built {
        BuildModule::build_module(dev)
    }
}

where each module would implement both of these

impl<const I: usize, const O: usize, D: Device> BuildModule<D> for Linear<I, O, D> {
    fn build_module(dev: &D) -> Self {
        Self {
            weight: Default::default(),
            bias: Default::default(),
        }
    }
}
impl<const I: usize, const O: usize, Src: Device, Dst: Device> BuildOnDevice<Dst>
    for Linear<I, O, Src>
{
    type Built = Linear<I, O, Dst>;
}

Which gives us the ability to do:

fn main() {
    type Dev = Cuda;
    type Model = (Linear<3, 5>, ReLU, Linear<5, 3>);
    type DeviceModel<D> = (Linear<3, 5, D>, ReLU, Linear<5, 3, D>);

    let dev: Dev = Default::default();

    let q = Model::build_on_device(&dev);
    let q: DeviceModel<Dev> = Model::build_on_device(&dev);
    let q: DeviceModel<_> = Model::build_on_device(&dev);
    let q: (Linear<3, 5, _>, ReLU, Linear<5, 3, _>) = Model::build_on_device(&dev);
    let q: DeviceModel<Dev> = BuildModule::build_module(&dev);
}

I like this because people can use the trait approach with BuildModule if they want to create their NN structure with device generic built in, OR they can go with the build_on_device call if they don't. These will be fairly straightforward to document the differences, and I think they are named fairly clearly. Should be easy to implement as well (BuildOnDevice is basically your ToDevice, with the default method depending on BuildModule added in).

@coreylowman coreylowman self-assigned this Jan 26, 2023
coreylowman pushed a commit that referenced this issue Jan 26, 2023
* add OnDevice type alias and implement for Linear and tuples

* document OnDevice and OnDeviceTrait

* Implement OnDevice for activations, AddInto, and Tensor

* fix documentation

* Implement OnDevice for more nn modules

* Implement OnDevice for the rest of the nn modules

* run cargo fmt

* allow doctests to compile without cuda

* improve documentation

* rename OnDeviceTrait to ToDevice; add to_device method to ToDevice

* have to_device take &self

* run cargo fmt

* apply suggested changes

* allow dev.build to work like `dev.build::<OnDevice<X, _>>`

* revert previous changes

* update documentation

* documentation wording
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants