Skip to content

Commit

Permalink
fix bug in reshape on cuda; fix 02-ops on cuda (#622)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoppel authored Mar 27, 2023
1 parent 67f4884 commit cf6e11b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
2 changes: 1 addition & 1 deletion examples/02-ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ fn main() {
// these operations are equal across devices
#[cfg(feature = "cuda")]
{
use dfdx::nn::ToDevice;
use dfdx::{nn::ToDevice, tensor::Cpu};

let cpu = Cpu::default();

Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/reshape_to/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ where
self.dev.load_ptx(PTX_SRC.into(), Self::MOD, Self::FNS)?;
}

let numel = inp.data.len();
let numel = inp.shape.num_elements();
let mut storage = unsafe { self.dev.alloc::<E>(numel) }?;

let inp_dims = self.dev.htod_copy(inp.shape.concrete().into())?;
Expand Down
14 changes: 14 additions & 0 deletions src/tensor_ops/reshape_to/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,18 @@ mod tests {
],
)
}

#[test]
fn test_reshape_broadcasted() {
let dev: TestDevice = Default::default();
let a: Tensor<Rank2<2, 3>, TestDtype, _> = dev.tensor([1., 2., 3.]).broadcast();
let b: Tensor<Rank2<3, 2>, TestDtype, _> = a.clone().reshape();

#[cfg(feature = "cuda")]
use cudarc::driver::DeviceSlice;

assert_eq!(b.data.len(), 6);
assert_eq!(a.as_vec(), b.as_vec());
assert_eq!(b.array(), [[1., 2.], [3., 1.], [2., 3.]]);
}
}

0 comments on commit cf6e11b

Please sign in to comment.