-
-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nn api hard to use with different devices #388
Comments
Sketch of possible approach 1 trait LinearConfig {
type Input: Dim;
type Output: Dim;
}
struct Linear<const I: usize, const O: usize>;
impl<const I: usize, const O: usize> LinearConfig for Linear<I, O> {
type Input = Const<I>;
type Output = Const<O>;
}
struct DeviceLinear<Cfg: LinearConfig, D: Device> {
weight: Tensor<(Cfg::Input, Cfg::Output), f32, D>,
bias: Tensor<(Cfg::Output,), f32, D>,
} |
Another potentially related issue is using modules with different dtypes (once that is supported). For example, should this work? This is another area where separating structure from device/dtype could be helpful: let m32 = dev.build_module::<MLP, f32>();
let m16 = dev.build_module::<MLP, f16>(); |
Another option is to define a trait which generalizes the device parameter for modules. This could look like: pub trait OnDeviceTrait<D> {
type Output;
}
pub type OnDevice<M, D> = <M as OnDeviceTrait<D>>::Output; With this, the MLP sequential module above could be defined as type MLP<D> = OnDevice<(Linear<5, 10>, ReLU, Linear<10, 1>), D>; This would even allow type MLP = (Linear<5, 10>, ReLU, Linear<10, 1>);
type CudaMLP = OnDevice<MLP, Cuda>; Here's a working implementation of this trait for Linear: impl<const I: usize, const O: usize, D1: Device<f32>, D2: Device<f32>> OnDeviceTrait<D2> for Linear<I, O, D1> {
type Output = Linear<I, O, D2>;
} 'OnDeviceTrait' may not be the best name for the trait, but this solution should be fairly simple to implement for us and for end users. |
Oh interesting idea with the double device generics in the OnDeviceTrait impl, that's a big positive that it wouldn't require a separate struct definition. I don't see any issues with tuples either, I'll try sketching that out! |
I've implemented this for tuples already, I'll open a pr with my progress. |
You can do let dev: D = Default::default();
let m: MLP<_> = dev.build_module(); to not have to specify the device twice. |
You can, but if you forget to specify it or specify it as the wrong thing and forget to update it, you get really opaque errors (infinite type recursion atm). I'm liking the idea of separating device specification from nn structure. Adding the OnDeviceTrait (or whatever it will be called) hopefully makes the error messages cleaner in these cases |
I'm also testing out using your approach with the two device generics in the ResetParams trait: pub trait BuildModule<D: Device<E>, E: Dtype>: Sized {
type Built: ResetParams<D, E>;
/// Construct it on the device
fn build(device: &D) -> Self::Built {
Self::try_build(device).unwrap()
}
/// Fallible version of [ResetParams::build]
fn try_build(device: &D) -> Result<Self::Built, D::Err>;
}
/// Something that can reset it's parameters.
pub trait ResetParams<D: Device<E>, E: Dtype>: Sized {
/// Mutates parameters. Each implementor
/// of this trait decides how the parameters are initialized. In
/// fact, some impls may not even use randomness.
fn reset_params(&mut self) {
self.try_reset_params().unwrap();
}
/// Fallible version of [ResetParams::reset_params].
fn try_reset_params(&mut self) -> Result<(), D::Err>;
} However, it seems like you can't use |
@nkoppel i'm currently thinking of this approach below. I'll take this issue over since it'll involve a lot of documentation/test updates as well. trait BuildModule<D: Device> {
fn build_module(dev: &D) -> Self;
}
trait BuildOnDevice<D: Device> {
type Built: BuildModule<D>;
fn build_on_device(dev: &D) -> Self::Built {
BuildModule::build_module(dev)
}
} where each module would implement both of these impl<const I: usize, const O: usize, D: Device> BuildModule<D> for Linear<I, O, D> {
fn build_module(dev: &D) -> Self {
Self {
weight: Default::default(),
bias: Default::default(),
}
}
}
impl<const I: usize, const O: usize, Src: Device, Dst: Device> BuildOnDevice<Dst>
for Linear<I, O, Src>
{
type Built = Linear<I, O, Dst>;
} Which gives us the ability to do: fn main() {
type Dev = Cuda;
type Model = (Linear<3, 5>, ReLU, Linear<5, 3>);
type DeviceModel<D> = (Linear<3, 5, D>, ReLU, Linear<5, 3, D>);
let dev: Dev = Default::default();
let q = Model::build_on_device(&dev);
let q: DeviceModel<Dev> = Model::build_on_device(&dev);
let q: DeviceModel<_> = Model::build_on_device(&dev);
let q: (Linear<3, 5, _>, ReLU, Linear<5, 3, _>) = Model::build_on_device(&dev);
let q: DeviceModel<Dev> = BuildModule::build_module(&dev);
} I like this because people can use the trait approach with BuildModule if they want to create their NN structure with device generic built in, OR they can go with the build_on_device call if they don't. These will be fairly straightforward to document the differences, and I think they are named fairly clearly. Should be easy to implement as well (BuildOnDevice is basically your ToDevice, with the default method depending on BuildModule added in). |
* add OnDevice type alias and implement for Linear and tuples * document OnDevice and OnDeviceTrait * Implement OnDevice for activations, AddInto, and Tensor * fix documentation * Implement OnDevice for more nn modules * Implement OnDevice for the rest of the nn modules * run cargo fmt * allow doctests to compile without cuda * improve documentation * rename OnDeviceTrait to ToDevice; add to_device method to ToDevice * have to_device take &self * run cargo fmt * apply suggested changes * allow dev.build to work like `dev.build::<OnDevice<X, _>>` * revert previous changes * update documentation * documentation wording
Summary
Currently, a
Device
generic was added to all structs innn
that require storing tensors. E.g.Linear<5, 10, Cpu>
orLinear<5, 10, Cuda>
.For an example of the current api for specifying a simple structure, you can specify:
However this currently defaults to the Cpu Device. In order to use this structure on a Cuda device, you'd have to go through and modify everything to set the new device generic:
Additionally, when using with
dev.build_module()
, you have to specify the device twice here:Even though the device should be identical to whatever device is building the module.
Further, if you happen to change the type of
dev
, you can get really hard to understand compile errors if you don't update the type you specified on the MLP:Ways forward
I see a few ways forward, each with their own advantages/disadvantages. There are probably more ways forward as well.
Add a 2nd layer of nn configuration that would be used to build modules on devices
This would look something like this:
Specifically, the existing structures would be repurposes to concrete types for already build modules. The types
Linear<5, 10>
orConv2D<3, 2, 1>
would turn into almost configuration for device building.Building device would tie these configuration structs together with the concrete build types:
Pros
Cons
Add a
.to(device)
method of modulesThis would mean you'd aways construct modules on Cpu and then send them to whatever device you wanted:
Pros
.to(device)
already exists in pytorch so people would be familiar with itCons
The text was updated successfully, but these errors were encountered: