Skip to content

Commit

Permalink
Adds Tensor::to_device method (#741)
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman authored Apr 26, 2023
1 parent 4fac0f1 commit e085c11
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion examples/02-ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ fn main() {
// these operations are equal across devices
#[cfg(feature = "cuda")]
{
use dfdx::{nn::ToDevice, tensor::Cpu};
use dfdx::tensor::Cpu;

let cpu = Cpu::default();

Expand Down
16 changes: 16 additions & 0 deletions src/tensor/storage_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,22 @@ pub trait TensorFromVec<E: Unit>: DeviceStorage {
) -> Result<Tensor<S, E, Self>, Self::Err>;
}

impl<S: Shape, E: Unit, D: DeviceStorage, T> Tensor<S, E, D, T> {
/// Clones the tensor onto a different device.
pub fn to_device<Dst: TensorFromVec<E>>(&self, device: &Dst) -> Tensor<S, E, Dst> {
self.try_to_device(device).unwrap()
}

/// Fallibly clones the tensor onto a different device.
pub fn try_to_device<Dst: TensorFromVec<E>>(
&self,
device: &Dst,
) -> Result<Tensor<S, E, Dst>, Dst::Err> {
let buf = self.as_vec();
device.try_tensor_from_vec(buf, self.shape)
}
}

/// Construct tensors from rust data
pub trait TensorFrom<Src, S: Shape, E: Unit>: DeviceStorage {
/// Create a tensor from rust data
Expand Down

0 comments on commit e085c11

Please sign in to comment.