Skip to content

Commit

Permalink
Address code review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nhynes committed Aug 20, 2018
1 parent 9387ddc commit 2c7117c
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 43 deletions.
104 changes: 62 additions & 42 deletions rust/src/runtime/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,23 +219,54 @@ impl<'a> Tensor<'a> {
};
unsafe { mem::transmute::<Tensor<'a>, Tensor<'static>>(t) }
}
}

impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<f32> {
type Error = Error;
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<f32>> {
ensure!(
tensor.dtype == DTYPE_FLOAT32,
"Cannot convert Tensor with dtype {:?} to ndarray",
tensor.dtype
);
Ok(ndarray::Array::from_shape_vec(
tensor.shape.iter().map(|s| *s as usize).collect::<Vec<usize>>(),
tensor.to_vec::<f32>(),
)?)
fn from_array_storage<'s, T, D: ndarray::Dimension>(
arr: &ndarray::Array<T, D>,
storage: Storage<'s>,
type_code: usize,
) -> Tensor<'s> {
let type_width = mem::size_of::<T>() as usize;
Tensor {
data: storage,
ctx: TVMContext::default(),
dtype: DataType {
code: type_code,
bits: 8 * type_width,
lanes: 1,
},
size: arr.len(),
shape: arr.shape().iter().map(|&v| v as i64).collect(),
strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()),
byte_offset: 0,
}
}
}

/// Conversions to `ndarray::Array` from `Tensor`, if the types match.
macro_rules! impl_ndarray_try_from_tensor {
($type:ty, $dtype:expr) => {
impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> {
type Error = Error;
fn try_from(tensor: &'a Tensor) -> Result<ndarray::ArrayD<$type>> {
ensure!(
tensor.dtype == $dtype,
"Cannot convert Tensor with dtype {:?} to ndarray",
tensor.dtype
);
Ok(ndarray::Array::from_shape_vec(
tensor.shape.iter().map(|s| *s as usize).collect::<Vec<usize>>(),
tensor.to_vec::<$type>(),
)?)
}
}
};
}

impl_ndarray_try_from_tensor!(i32, DTYPE_INT32);
impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32);
impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32);
impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64);

impl DLTensor {
pub(super) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self {
assert!(!flatten || tensor.is_contiguous());
Expand Down Expand Up @@ -299,12 +330,6 @@ impl DataType {
}
}

const DTYPE_FLOAT32: DataType = DataType {
code: DLDataTypeCode_kDLFloat as usize,
bits: 32,
lanes: 1,
};

impl<'a> From<&'a DataType> for DLDataType {
fn from(dtype: &'a DataType) -> Self {
Self {
Expand All @@ -315,6 +340,22 @@ impl<'a> From<&'a DataType> for DLDataType {
}
}

macro_rules! make_dtype_const {
($name: ident, $code: ident, $bits: expr, $lanes: expr) => {
const $name: DataType = DataType {
code: $code as usize,
bits: $bits,
lanes: $lanes,
};
}
}

make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1);
make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1);
make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1);
make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1);
make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1);

impl Default for DLContext {
fn default() -> Self {
DLContext {
Expand Down Expand Up @@ -348,27 +389,6 @@ impl Default for TVMContext {
}
}

fn tensor_from_array_storage<'a, 's, T, D: ndarray::Dimension>(
arr: &ndarray::Array<T, D>,
storage: Storage<'s>,
type_code: usize,
) -> Tensor<'s> {
let type_width = mem::size_of::<T>() as usize;
Tensor {
data: storage,
ctx: TVMContext::default(),
dtype: DataType {
code: type_code,
bits: 8 * type_width,
lanes: 1,
},
size: arr.len(),
shape: arr.shape().iter().map(|&v| v as i64).collect(),
strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()),
byte_offset: 0,
}
}

/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`.
///
/// # Panics
Expand All @@ -382,13 +402,13 @@ macro_rules! impl_tensor_from_ndarray {
let size = arr.len() * mem::size_of::<$type>() as usize;
let storage =
Storage::from(unsafe { slice::from_raw_parts(arr.as_ptr() as *const u8, size) });
tensor_from_array_storage(&arr, storage, $typecode as usize)
Tensor::from_array_storage(&arr, storage, $typecode as usize)
}
}
impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> {
fn from(arr: &'a ndarray::Array<$type, D>) -> Self {
assert!(arr.is_standard_layout(), "Array must be contiguous.");
tensor_from_array_storage(
Tensor::from_array_storage(
arr,
Storage::from(arr.as_slice().unwrap()),
$typecode as usize,
Expand Down
2 changes: 1 addition & 1 deletion rust/src/runtime/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub trait Module {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<PackedFunc>;
}

pub struct SystemLibModule {}
pub struct SystemLibModule;

lazy_static! {
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, BackendPackedCFunc>> =
Expand Down
9 changes: 9 additions & 0 deletions rust/src/runtime/packed_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,18 @@ macro_rules! impl_boxed_ret_value {
};
}

impl_prim_ret_value!(i8, 0);
impl_prim_ret_value!(u8, 1);
impl_prim_ret_value!(i16, 0);
impl_prim_ret_value!(u16, 1);
impl_prim_ret_value!(i32, 0);
impl_prim_ret_value!(u32, 1);
impl_prim_ret_value!(f32, 2);
impl_prim_ret_value!(i64, 0);
impl_prim_ret_value!(u64, 1);
impl_prim_ret_value!(f64, 2);
impl_prim_ret_value!(isize, 0);
impl_prim_ret_value!(usize, 1);
impl_boxed_ret_value!(String, 11);

// @see `WrapPackedFunc` in `llvm_module.cc`.
Expand Down

0 comments on commit 2c7117c

Please sign in to comment.