From 2c7117c82bfd20e6457024b9eec03f1c3b5e9ce6 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Mon, 20 Aug 2018 00:01:58 +0000 Subject: [PATCH] Address code review comments --- rust/src/runtime/array.rs | 104 +++++++++++++++++++------------- rust/src/runtime/module.rs | 2 +- rust/src/runtime/packed_func.rs | 9 +++ 3 files changed, 72 insertions(+), 43 deletions(-) diff --git a/rust/src/runtime/array.rs b/rust/src/runtime/array.rs index 722687eebbfcd..9f1cc5d7e4843 100644 --- a/rust/src/runtime/array.rs +++ b/rust/src/runtime/array.rs @@ -219,23 +219,54 @@ impl<'a> Tensor<'a> { }; unsafe { mem::transmute::, Tensor<'static>>(t) } } -} -impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD { - type Error = Error; - fn try_from(tensor: &'a Tensor) -> Result> { - ensure!( - tensor.dtype == DTYPE_FLOAT32, - "Cannot convert Tensor with dtype {:?} to ndarray", - tensor.dtype - ); - Ok(ndarray::Array::from_shape_vec( - tensor.shape.iter().map(|s| *s as usize).collect::>(), - tensor.to_vec::(), - )?) + fn from_array_storage<'s, T, D: ndarray::Dimension>( + arr: &ndarray::Array, + storage: Storage<'s>, + type_code: usize, + ) -> Tensor<'s> { + let type_width = mem::size_of::() as usize; + Tensor { + data: storage, + ctx: TVMContext::default(), + dtype: DataType { + code: type_code, + bits: 8 * type_width, + lanes: 1, + }, + size: arr.len(), + shape: arr.shape().iter().map(|&v| v as i64).collect(), + strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()), + byte_offset: 0, + } } } +/// Conversions to `ndarray::Array` from `Tensor`, if the types match. +macro_rules! impl_ndarray_try_from_tensor { + ($type:ty, $dtype:expr) => { + impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> { + type Error = Error; + fn try_from(tensor: &'a Tensor) -> Result> { + ensure!( + tensor.dtype == $dtype, + "Cannot convert Tensor with dtype {:?} to ndarray", + tensor.dtype + ); + Ok(ndarray::Array::from_shape_vec( + tensor.shape.iter().map(|s| *s as usize).collect::>(), + tensor.to_vec::<$type>(), + )?) + } + } + }; +} + +impl_ndarray_try_from_tensor!(i32, DTYPE_INT32); +impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32); +impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32); +impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64); + impl DLTensor { pub(super) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self { assert!(!flatten || tensor.is_contiguous()); @@ -299,12 +330,6 @@ impl DataType { } } -const DTYPE_FLOAT32: DataType = DataType { - code: DLDataTypeCode_kDLFloat as usize, - bits: 32, - lanes: 1, -}; - impl<'a> From<&'a DataType> for DLDataType { fn from(dtype: &'a DataType) -> Self { Self { @@ -315,6 +340,22 @@ impl<'a> From<&'a DataType> for DLDataType { } } +macro_rules! make_dtype_const { + ($name: ident, $code: ident, $bits: expr, $lanes: expr) => { + const $name: DataType = DataType { + code: $code as usize, + bits: $bits, + lanes: $lanes, + }; + } +} + +make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1); +make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1); +make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1); +make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1); +make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1); + impl Default for DLContext { fn default() -> Self { DLContext { @@ -348,27 +389,6 @@ impl Default for TVMContext { } } -fn tensor_from_array_storage<'a, 's, T, D: ndarray::Dimension>( - arr: &ndarray::Array, - storage: Storage<'s>, - type_code: usize, -) -> Tensor<'s> { - let type_width = mem::size_of::() as usize; - Tensor { - data: storage, - ctx: TVMContext::default(), - dtype: DataType { - code: type_code, - bits: 8 * type_width, - lanes: 1, - }, - size: arr.len(), - shape: arr.shape().iter().map(|&v| v as i64).collect(), - strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()), - byte_offset: 0, - } -} - /// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`. /// /// # Panics @@ -382,13 +402,13 @@ macro_rules! impl_tensor_from_ndarray { let size = arr.len() * mem::size_of::<$type>() as usize; let storage = Storage::from(unsafe { slice::from_raw_parts(arr.as_ptr() as *const u8, size) }); - tensor_from_array_storage(&arr, storage, $typecode as usize) + Tensor::from_array_storage(&arr, storage, $typecode as usize) } } impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> { fn from(arr: &'a ndarray::Array<$type, D>) -> Self { assert!(arr.is_standard_layout(), "Array must be contiguous."); - tensor_from_array_storage( + Tensor::from_array_storage( arr, Storage::from(arr.as_slice().unwrap()), $typecode as usize, diff --git a/rust/src/runtime/module.rs b/rust/src/runtime/module.rs index e23c21fcd6116..2594756d98854 100644 --- a/rust/src/runtime/module.rs +++ b/rust/src/runtime/module.rs @@ -9,7 +9,7 @@ pub trait Module { fn get_function>(&self, name: S) -> Option; } -pub struct SystemLibModule {} +pub struct SystemLibModule; lazy_static! { static ref SYSTEM_LIB_FUNCTIONS: Mutex> = diff --git a/rust/src/runtime/packed_func.rs b/rust/src/runtime/packed_func.rs index 2586f67643f13..ff5fd667af126 100644 --- a/rust/src/runtime/packed_func.rs +++ b/rust/src/runtime/packed_func.rs @@ -200,9 +200,18 @@ macro_rules! impl_boxed_ret_value { }; } +impl_prim_ret_value!(i8, 0); +impl_prim_ret_value!(u8, 1); +impl_prim_ret_value!(i16, 0); +impl_prim_ret_value!(u16, 1); impl_prim_ret_value!(i32, 0); impl_prim_ret_value!(u32, 1); impl_prim_ret_value!(f32, 2); +impl_prim_ret_value!(i64, 0); +impl_prim_ret_value!(u64, 1); +impl_prim_ret_value!(f64, 2); +impl_prim_ret_value!(isize, 0); +impl_prim_ret_value!(usize, 1); impl_boxed_ret_value!(String, 11); // @see `WrapPackedFunc` in `llvm_module.cc`.