Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing abs/exp/div/sum_to cuda kernels #331

Merged
merged 18 commits into from
Jan 7, 2023
Merged

Implementing abs/exp/div/sum_to cuda kernels #331

merged 18 commits into from
Jan 7, 2023

Conversation

coreylowman
Copy link
Owner

@coreylowman coreylowman commented Jan 6, 2023

Also adds some necessary scaffolding for building/implementing them.

Resolves #184

@coreylowman coreylowman added the gpu Related to GPU support label Jan 6, 2023
auto tmp = inp[inp_strided_i];

unsigned int out_strided_i = get_strided_index(i, num_dims, dims, out_strides);
atomicAdd(out + out_strided_i, tmp);
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this is okay to start with, this will be really bad performance when reducing large tensors to 1 element. Need to add an issue to create a special kernel for that case (and there are standard ways to write high performant ones of those)

Comment on lines +34 to +35
let out_strides: Src::Concrete =
BroadcastStridesTo::<Src, Ax>::broadcast_strides(&dst, dst.strides());
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only difference between normal impl of forward and this is this call where the strides are broadcast. Maybe there is some way aroudn that?

let mut storage = self.dev.alloc_zeros_async::<f32>(numel)?;

let fwd_fn = self.dev.get_func(K::MODULE_NAME, K::FWD_FN_NAME).unwrap();
let cfg = LaunchConfig::for_num_elems(numel as u32);
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably make a helper method for computing a good version of this launch config - need to take advantage of threads & blocks.

#[cfg(feature = "cuda")]
mod cuda {
pub fn build_ptx() {
// TODO build ptx file in source tree and don't call nvcc if so
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be done later, I'm not even sure if it's necessary. Once we have all the kernels in place we can see, but no need to complicate something that's pretty simple atm

Comment on lines +150 to +153
let dims: CudaSlice<usize> = self.dev.take_async(lhs.shape.concrete().into())?;
let lhs_strides: CudaSlice<usize> = self.dev.take_async(lhs.strides.into())?;
let rhs_strides: CudaSlice<usize> = self.dev.take_async(rhs.strides.into())?;
let out_strides: CudaSlice<usize> = self.dev.take_async(grad_out.strides.into())?;
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These same values were also allocated in the forward call - a potential improvement for the future is pre-allocating them. Though these are only used in binary ops - if a tensor is only ever used in a unary op then it doesn't need to allocate these

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
gpu Related to GPU support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

CUDA kernels JIT vs compile time compilation
1 participant