-
-
Notifications
You must be signed in to change notification settings - Fork 99
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding Cuda device and skeleton cuda kernel impls (#322)
* WIP commit for cuda device * Updating to latest cudarc, using take_async * Rework Cuda allocation to use cpu * Use TestDevice in tensor * Using SampleTensor * Clean up cuda device * Adding cuda kernel to all ops * Adding cuda kernels to optims and fixing nn TestDeviceUsage * Remove cuda from default features * Update cudarc version * Adding Unpin to Unit * Adding std feature for cudarc dependency * Updating cudarc to 0.5.0
- Loading branch information
1 parent
bb36005
commit cb2e687
Showing
86 changed files
with
1,565 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
use crate::{shapes::Shape, tensor::Cuda}; | ||
|
||
impl super::AdamKernel<f32> for Cuda { | ||
fn update<S: Shape>( | ||
t: i32, | ||
cfg: &super::AdamConfig<f32>, | ||
param: &mut Self::Storage<S, f32>, | ||
moment1: &mut Self::Storage<S, f32>, | ||
moment2: &mut Self::Storage<S, f32>, | ||
grad: Self::Storage<S, f32>, | ||
) { | ||
todo!() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
mod cpu_kernel; | ||
|
||
#[cfg(feature = "cuda")] | ||
mod cuda_kernel; | ||
|
||
use std::marker::PhantomData; | ||
|
||
use crate::{ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
use crate::tensor::Cuda; | ||
|
||
impl super::RMSpropKernel<f32> for Cuda { | ||
fn update<S: crate::shapes::Shape>( | ||
cfg: &super::RMSpropConfig<f32>, | ||
param: &mut Self::Storage<S, f32>, | ||
momentum: &mut Self::Storage<S, f32>, | ||
square_avg: &mut Self::Storage<S, f32>, | ||
grad_avg: &mut Self::Storage<S, f32>, | ||
grad: Self::Storage<S, f32>, | ||
) { | ||
todo!() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
mod cpu_kernel; | ||
|
||
#[cfg(feature = "cuda")] | ||
mod cuda_kernel; | ||
|
||
use std::marker::PhantomData; | ||
|
||
use crate::{ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
use crate::{shapes::*, tensor::Cuda}; | ||
|
||
impl<E: Dtype> super::SgdKernel<E> for Cuda { | ||
fn update<S: Shape>( | ||
cfg: &super::SgdConfig<E>, | ||
param: &mut Self::Storage<S, E>, | ||
velocity: &mut Self::Storage<S, E>, | ||
grad: Self::Storage<S, E>, | ||
) { | ||
todo!() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.