diff --git a/.gitignore b/.gitignore index 833eee1a0774..368764941cec 100644 --- a/.gitignore +++ b/.gitignore @@ -91,10 +91,8 @@ ENV/ *~ *.pyc *~ -build config.mk config.cmake -build_* Win32 *.dir perf @@ -187,7 +185,6 @@ tvm_u.* tvm_t.* # Mac OS X .DS_Store -build* # Jetbrain .idea diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 000000000000..230ab66104df --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1,3 @@ +Cargo.lock +target/ +**/*.rs.bk diff --git a/rust/.rustfmt.toml b/rust/.rustfmt.toml new file mode 100644 index 000000000000..df9a65dacfaa --- /dev/null +++ b/rust/.rustfmt.toml @@ -0,0 +1,59 @@ +max_width = 100 +hard_tabs = false +tab_spaces = 2 +newline_style = "Auto" +use_small_heuristics = "Default" +indent_style = "Block" +wrap_comments = false +comment_width = 80 +normalize_comments = false +format_strings = false +format_macro_matchers = false +format_macro_bodies = true +empty_item_single_line = true +struct_lit_single_line = true +fn_single_line = false +where_single_line = false +imports_indent = "Block" +imports_layout = "Mixed" +merge_imports = true +reorder_imports = true +reorder_modules = true +reorder_impl_items = false +type_punctuation_density = "Wide" +space_before_colon = false +space_after_colon = true +spaces_around_ranges = false +binop_separator = "Front" +remove_nested_parens = true +combine_control_expr = true +struct_field_align_threshold = 0 +match_arm_blocks = true +force_multiline_blocks = false +fn_args_density = "Tall" +brace_style = "SameLineWhere" +control_brace_style = "AlwaysSameLine" +trailing_semicolon = true +trailing_comma = "Vertical" +match_block_trailing_comma = false +blank_lines_upper_bound = 1 +blank_lines_lower_bound = 0 +edition = "Edition2015" +merge_derives = true +use_try_shorthand = true +use_field_init_shorthand = false +force_explicit_abi = true +condense_wildcard_suffixes = false +color = "Auto" +required_version = "0.99.4" +unstable_features = false +disable_all_formatting = false +skip_children = false +hide_parse_errors = false +error_on_line_overflow = false +error_on_unformatted = false +report_todo = "Never" +report_fixme = "Never" +ignore = [] +emit_mode = "Files" +make_backup = false diff --git a/rust/.travis.yml b/rust/.travis.yml new file mode 100644 index 000000000000..63a3d0277c1b --- /dev/null +++ b/rust/.travis.yml @@ -0,0 +1,5 @@ +language: rust +rust: + - nightly +matrix: + fast_finish: true diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 000000000000..0819e0c70023 --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "tvm" +version = "0.1.0" +license = "Apache-2.0" +description = "TVM Rust runtime" +repository = "https://github.com/dmlc/tvm" +readme = "README.md" +keywords = ["tvm", "nnvm"] +categories = ["api-bindings", "science"] +authors = ["Nick Hynes "] + +[features] +default = ["nom/std"] +sgx = ["nom/alloc"] + +[dependencies] +bounded-spsc-queue = "0.4.0" +error-chain = { version = "0.12.0", default-features = false } +itertools = "0.7.8" +lazy_static = "1.1.0" +ndarray = "0.11.2" +nom = {version = "4.0.0", default-features = false } +serde = "1.0.59" +serde_derive = "1.0.79" +serde_json = "1.0.17" + +[target.'cfg(not(target_env = "sgx"))'.dependencies] +num_cpus = "1.8.0" diff --git a/rust/src/errors.rs b/rust/src/errors.rs new file mode 100644 index 000000000000..f9da7180b8cc --- /dev/null +++ b/rust/src/errors.rs @@ -0,0 +1,39 @@ +#[cfg(target_env = "sgx")] +use alloc::alloc; +#[cfg(not(target_env = "sgx"))] +use std::alloc; +use std::num; + +use ndarray; +use serde_json; + +error_chain! { + errors { + TryFromTVMRetValueError(expected: String, actual: i64) { + description("mismatched types while downcasting TVMRetValue") + display("invalid downcast: expected `{}` but was `{}`", expected, actual) + } + + GraphFormatError(msg: String) { + description("unable to load graph") + display("could not load graph json: {}", msg) + } + + LoadGraphParamsError(msg: String) { + description("unable to load graph params") + display("could not load graph params: {}", msg) + } + } + foreign_links { + Alloc(alloc::AllocErr); + GraphDeserialize(serde_json::Error); + ParseInt(num::ParseIntError); + ShapeError(ndarray::ShapeError); + } +} + +impl From for Error { + fn from(_err: alloc::LayoutErr) -> Error { + Error::from_kind(ErrorKind::Msg("Layout error".to_string())) + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs new file mode 100644 index 000000000000..4a70e428d37a --- /dev/null +++ b/rust/src/lib.rs @@ -0,0 +1,68 @@ +//! This crate is an implementation of the TVM runtime for modules compiled with `--system-lib`. +//! It's mainly useful for compiling to WebAssembly and SGX, +//! but also native if you prefer Rust to C++. +//! +//! For TVM graphs, the entrypoint to this crate is `runtime::GraphExecutor`. +//! Single-function modules are used via the `packed_func!` macro after obtaining +//! the function from `runtime::SystemLibModule` +//! +//! The main entrypoints to this crate are `GraphExecutor` +//! For examples of use, please refer to the multi-file tests in the `tests` directory. + +#![feature( + alloc, + allocator_api, + box_syntax, + extern_prelude, + fn_traits, + try_from, + unboxed_closures, + vec_remove_item +)] + +#[cfg(target_env = "sgx")] +extern crate alloc; +extern crate bounded_spsc_queue; +#[cfg(target_env = "sgx")] +extern crate core; +#[macro_use] +extern crate error_chain; +#[macro_use] +extern crate itertools; +#[macro_use] +extern crate lazy_static; +extern crate ndarray; +#[macro_use] +extern crate nom; +#[cfg(not(target_env = "sgx"))] +extern crate num_cpus; +extern crate serde; +#[macro_use] +extern crate serde_derive; +extern crate serde_json; + +pub mod ffi { + #![allow( + non_camel_case_types, + non_snake_case, + non_upper_case_globals, + unused + )] + + pub mod runtime { + use std::os::raw::{c_char, c_int, c_void}; + + include!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/runtime/c_runtime_api.rs" + )); + + pub type BackendPackedCFunc = + extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; + } +} + +pub mod errors; +pub mod runtime; + +pub use errors::*; diff --git a/rust/src/runtime/allocator.rs b/rust/src/runtime/allocator.rs new file mode 100644 index 000000000000..d704336bff1f --- /dev/null +++ b/rust/src/runtime/allocator.rs @@ -0,0 +1,52 @@ +#[cfg(target_env = "sgx")] +use alloc::alloc::{self, Layout}; +#[cfg(not(target_env = "sgx"))] +use std::alloc::{self, Layout}; + +use errors::*; + +const DEFAULT_ALIGN_BYTES: usize = 4; + +#[derive(PartialEq, Eq)] +pub struct Allocation { + layout: Layout, + ptr: *mut u8, +} + +impl Allocation { + /// Allocates a chunk of memory of `size` bytes with optional alignment. + pub fn new(size: usize, align: Option) -> Result { + let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); + let layout = Layout::from_size_align(size, alignment)?; + let ptr = unsafe { alloc::alloc(layout.clone()) }; + if ptr.is_null() { + alloc::handle_alloc_error(layout); + } + Ok(Self { + ptr: ptr, + layout: layout, + }) + } + + pub fn as_mut_ptr(&self) -> *mut u8 { + self.ptr + } + + /// Returns the size of the Allocation in bytes. + pub fn size(&self) -> usize { + self.layout.size() + } + + /// Returns the byte alignment of the Allocation. + pub fn align(&self) -> usize { + self.layout.align() + } +} + +impl Drop for Allocation { + fn drop(&mut self) { + unsafe { + alloc::dealloc(self.ptr, self.layout.clone()); + } + } +} diff --git a/rust/src/runtime/array.rs b/rust/src/runtime/array.rs new file mode 100644 index 000000000000..79d22e400cff --- /dev/null +++ b/rust/src/runtime/array.rs @@ -0,0 +1,461 @@ +use std::{ + any::TypeId, + convert::TryFrom, + mem, + os::raw::{c_int, c_void}, + ptr, slice, +}; + +use ndarray; + +use super::allocator::Allocation; +use errors::*; +use ffi::runtime::{ + DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, + DLDeviceType_kDLCPU, DLTensor, +}; + +/// A `Storage` is a container which holds `Tensor` data. +#[derive(PartialEq)] +pub enum Storage<'a> { + /// A `Storage` which owns its contained bytes. + Owned(Allocation), + + /// A view of an existing `Storage`. + View(&'a mut [u8], usize), // ptr, align +} + +impl<'a> Storage<'a> { + pub fn new(size: usize, align: Option) -> Result> { + Ok(Storage::Owned(Allocation::new(size, align)?)) + } + + pub fn as_mut_ptr(&self) -> *mut u8 { + match self { + Storage::Owned(alloc) => alloc.as_mut_ptr(), + Storage::View(slice, _) => slice.as_ptr() as *mut u8, + } + } + + pub fn size(&self) -> usize { + match self { + Storage::Owned(alloc) => alloc.size(), + Storage::View(slice, _) => slice.len(), + } + } + + pub fn align(&self) -> usize { + match self { + Storage::Owned(alloc) => alloc.align(), + Storage::View(_, align) => *align, + } + } + + pub fn as_ptr(&self) -> *const u8 { + self.as_mut_ptr() as *const _ + } + + /// Returns a `Storage::View` which points to an owned `Storage::Owned`. + pub fn view(&self) -> Storage<'a> { + match self { + Storage::Owned(alloc) => Storage::View( + unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) }, + self.align(), + ), + Storage::View(slice, _) => Storage::View( + unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) }, + self.align(), + ), + } + } + + pub fn is_owned(&self) -> bool { + match self { + Storage::Owned(_) => true, + _ => false, + } + } + + /// Returns an owned version of this storage via cloning. + pub fn to_owned(&self) -> Storage<'static> { + let s = Storage::new(self.size(), Some(self.align())).unwrap(); + unsafe { + s.as_mut_ptr() + .copy_from_nonoverlapping(self.as_ptr(), self.size()) + } + s + } +} + +impl<'a, T> From<&'a [T]> for Storage<'a> { + fn from(data: &'a [T]) -> Self { + let data = unsafe { + slice::from_raw_parts_mut( + data.as_ptr() as *const u8 as *mut u8, + data.len() * mem::size_of::() as usize, + ) + }; + Storage::View(data, mem::align_of::()) + } +} + +/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`. +/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or +/// converted to `ndarray::Array` for non-TVM processing. +/// +/// # Examples +/// +/// ``` +/// extern crate ndarray; +/// +/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]); +/// let mut a: Tensor = a_nd.into(); +/// let mut a_dl: DLTensor = (&mut t).into(); +/// call_packed!(tvm_fn, &mut a_dl); +/// +/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs. +/// let mut a_nd = ndarray::Array::try_from(&a).unwrap(); +/// ``` +#[derive(PartialEq)] +pub struct Tensor<'a> { + /// The bytes which contain the data this `Tensor` represents. + pub(super) data: Storage<'a>, + pub(super) ctx: TVMContext, + pub(super) dtype: DataType, + pub(super) shape: Vec, // not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h + /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous. + pub(super) strides: Option>, + pub(super) byte_offset: isize, + pub(super) size: usize, +} + +unsafe impl<'a> Send for Tensor<'a> {} + +impl<'a> Tensor<'a> { + pub fn shape(&self) -> Vec { + self.shape.clone() + } + + /// Returns the data of this `Tensor` as a `Vec`. + /// + /// # Panics + /// + /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. + pub fn to_vec(&self) -> Vec { + assert!(self.is_contiguous()); + assert!(self.dtype.is_type::()); + let mut vec: Vec = Vec::with_capacity(self.size * self.dtype.itemsize()); + unsafe { + vec.as_mut_ptr().copy_from_nonoverlapping( + self.data.as_ptr().offset(self.byte_offset) as *const T, + self.size, + ); + vec.set_len(self.size); + } + vec + } + + /// Returns `true` iff this `Tensor` is represented by a contiguous region of memory. + pub fn is_contiguous(&self) -> bool { + match self.strides { + None => true, + Some(ref strides) => { + // check that stride for each dimension is the product of all trailing dimensons' shapes + self + .shape + .iter() + .zip(strides) + .rfold( + (true, 1), + |(is_contig, expected_stride), (shape, stride)| { + ( + is_contig && *stride == expected_stride, + expected_stride * (*shape as usize), + ) + }, + ).0 + } + } + } + + /// Returns a clone of this `Tensor`. + /// + /// # Panics + /// + /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. + pub fn copy(&mut self, other: &Tensor) { + assert!( + self.dtype == other.dtype && self.size == other.size, + "Tensor shape/dtype mismatch." + ); + assert!( + self.is_contiguous() && other.is_contiguous(), + "copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`", + self.strides, + other.strides + ); + unsafe { + self + .data + .as_mut_ptr() + .offset(self.byte_offset as isize) + .copy_from_nonoverlapping( + other.data.as_mut_ptr().offset(other.byte_offset), + other.size * other.dtype.itemsize(), + ); + } + } + + /// Returns an owned version of this `Tensor` via cloning. + pub fn to_owned(&self) -> Tensor<'static> { + let t = Tensor { + data: self.data.to_owned(), + ctx: self.ctx.clone(), + dtype: self.dtype.clone(), + size: self.size.clone(), + shape: self.shape.clone(), + strides: None, + byte_offset: 0, + }; + unsafe { mem::transmute::, Tensor<'static>>(t) } + } + + fn from_array_storage<'s, T, D: ndarray::Dimension>( + arr: &ndarray::Array, + storage: Storage<'s>, + type_code: usize, + ) -> Tensor<'s> { + let type_width = mem::size_of::() as usize; + Tensor { + data: storage, + ctx: TVMContext::default(), + dtype: DataType { + code: type_code, + bits: 8 * type_width, + lanes: 1, + }, + size: arr.len(), + shape: arr.shape().iter().map(|&v| v as i64).collect(), + strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()), + byte_offset: 0, + } + } +} + +/// Conversions to `ndarray::Array` from `Tensor`, if the types match. +macro_rules! impl_ndarray_try_from_tensor { + ($type:ty, $dtype:expr) => { + impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> { + type Error = Error; + fn try_from(tensor: &'a Tensor) -> Result> { + ensure!( + tensor.dtype == $dtype, + "Cannot convert Tensor with dtype {:?} to ndarray", + tensor.dtype + ); + Ok(ndarray::Array::from_shape_vec( + tensor + .shape + .iter() + .map(|s| *s as usize) + .collect::>(), + tensor.to_vec::<$type>(), + )?) + } + } + }; +} + +impl_ndarray_try_from_tensor!(i32, DTYPE_INT32); +impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32); +impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32); +impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64); + +impl DLTensor { + pub(super) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self { + assert!(!flatten || tensor.is_contiguous()); + Self { + data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void, + ctx: DLContext::from(&tensor.ctx), + ndim: if flatten { 1 } else { tensor.shape.len() } as i32, + dtype: DLDataType::from(&tensor.dtype), + shape: if flatten { + &tensor.size as *const _ as *mut i64 + } else { + tensor.shape.as_ptr() + } as *mut i64, + strides: if flatten || tensor.is_contiguous() { + ptr::null_mut() + } else { + tensor.strides.as_ref().unwrap().as_ptr() + } as *mut i64, + byte_offset: 0, + } + } +} + +impl<'a, 't> From<&'a Tensor<'t>> for DLTensor { + fn from(tensor: &'a Tensor<'t>) -> Self { + DLTensor::from_tensor(tensor, false /* flatten */) + } +} + +impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor { + fn from(tensor: &'a mut Tensor<'t>) -> Self { + DLTensor::from_tensor(tensor, false /* flatten */) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct DataType { + pub(super) code: usize, + pub(super) bits: usize, + pub(super) lanes: usize, +} + +impl DataType { + /// Returns the number of bytes occupied by an element of this `DataType`. + fn itemsize(&self) -> usize { + (self.bits * self.lanes) >> 3 + } + + /// Returns whether this `DataType` represents primitive type `T`. + fn is_type(&self) -> bool { + if self.lanes != 1 { + return false; + } + let typ = TypeId::of::(); + (typ == TypeId::of::() && self.code == 0 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 0 && self.bits == 64) + || (typ == TypeId::of::() && self.code == 1 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 1 && self.bits == 64) + || (typ == TypeId::of::() && self.code == 2 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 2 && self.bits == 64) + } +} + +impl<'a> From<&'a DataType> for DLDataType { + fn from(dtype: &'a DataType) -> Self { + Self { + code: dtype.code as u8, + bits: dtype.bits as u8, + lanes: dtype.lanes as u16, + } + } +} + +macro_rules! make_dtype_const { + ($name: ident, $code: ident, $bits: expr, $lanes: expr) => { + const $name: DataType = DataType { + code: $code as usize, + bits: $bits, + lanes: $lanes, + }; + }; +} + +make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1); +make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1); +// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1); +make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1); +make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1); + +impl Default for DLContext { + fn default() -> Self { + DLContext { + device_type: DLDeviceType_kDLCPU, + device_id: 0, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TVMContext { + pub(super) device_type: usize, + pub(super) device_id: usize, +} + +impl<'a> From<&'a TVMContext> for DLContext { + fn from(ctx: &'a TVMContext) -> Self { + Self { + device_type: ctx.device_type as u32, + device_id: ctx.device_id as i32, + } + } +} + +impl Default for TVMContext { + fn default() -> Self { + Self { + device_type: DLDeviceType_kDLCPU as usize, + device_id: 0, + } + } +} + +/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`. +/// +/// # Panics +/// +/// Panics if the ndarray is not contiguous. +macro_rules! impl_tensor_from_ndarray { + ($type:ty, $typecode:expr) => { + impl From> for Tensor<'static> { + fn from(arr: ndarray::Array<$type, D>) -> Self { + assert!(arr.is_standard_layout(), "Array must be contiguous."); + let size = arr.len() * mem::size_of::<$type>() as usize; + let storage = + Storage::from(unsafe { slice::from_raw_parts(arr.as_ptr() as *const u8, size) }); + Tensor::from_array_storage(&arr, storage, $typecode as usize) + } + } + impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> { + fn from(arr: &'a ndarray::Array<$type, D>) -> Self { + assert!(arr.is_standard_layout(), "Array must be contiguous."); + Tensor::from_array_storage( + arr, + Storage::from(arr.as_slice().unwrap()), + $typecode as usize, + ) + } + } + }; +} + +/// `From` conversions to `DLTensor` for `ndarray::Array`. +/// Takes a reference to the `ndarray` since `DLTensor` is not owned. +macro_rules! impl_dltensor_from_ndarray { + ($type:ty, $typecode:expr) => { + impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { + fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { + DLTensor { + data: arr.as_mut_ptr() as *mut c_void, + ctx: DLContext::default(), + ndim: arr.ndim() as c_int, + dtype: DLDataType { + code: $typecode as u8, + bits: 8 * mem::size_of::<$type>() as u8, + lanes: 1, + }, + shape: arr.shape().as_ptr() as *const i64 as *mut i64, + strides: arr.strides().as_ptr() as *const isize as *mut i64, + byte_offset: 0, + } + } + } + }; +} + +impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); +impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); + +impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); +impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); +impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); +impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); +impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); +impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); diff --git a/rust/src/runtime/c_runtime_api.rs b/rust/src/runtime/c_runtime_api.rs new file mode 100644 index 000000000000..62cfa0d15451 --- /dev/null +++ b/rust/src/runtime/c_runtime_api.rs @@ -0,0 +1,770 @@ +/* automatically generated by rust-bindgen for TVM revision 6292c78 */ + +pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0"; +pub const DLPACK_VERSION: u32 = 8; +pub const _STDINT_H: u32 = 1; +pub const _FEATURES_H: u32 = 1; +pub const _DEFAULT_SOURCE: u32 = 1; +pub const __USE_ISOC11: u32 = 1; +pub const __USE_ISOC99: u32 = 1; +pub const __USE_ISOC95: u32 = 1; +pub const __USE_POSIX_IMPLICITLY: u32 = 1; +pub const _POSIX_SOURCE: u32 = 1; +pub const _POSIX_C_SOURCE: u32 = 200809; +pub const __USE_POSIX: u32 = 1; +pub const __USE_POSIX2: u32 = 1; +pub const __USE_POSIX199309: u32 = 1; +pub const __USE_POSIX199506: u32 = 1; +pub const __USE_XOPEN2K: u32 = 1; +pub const __USE_XOPEN2K8: u32 = 1; +pub const _ATFILE_SOURCE: u32 = 1; +pub const __USE_MISC: u32 = 1; +pub const __USE_ATFILE: u32 = 1; +pub const __USE_FORTIFY_LEVEL: u32 = 0; +pub const _STDC_PREDEF_H: u32 = 1; +pub const __STDC_IEC_559__: u32 = 1; +pub const __STDC_IEC_559_COMPLEX__: u32 = 1; +pub const __STDC_ISO_10646__: u32 = 201505; +pub const __STDC_NO_THREADS__: u32 = 1; +pub const __GNU_LIBRARY__: u32 = 6; +pub const __GLIBC__: u32 = 2; +pub const __GLIBC_MINOR__: u32 = 23; +pub const _SYS_CDEFS_H: u32 = 1; +pub const __WORDSIZE: u32 = 64; +pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1; +pub const __SYSCALL_WORDSIZE: u32 = 64; +pub const _BITS_WCHAR_H: u32 = 1; +pub const INT8_MIN: i32 = -128; +pub const INT16_MIN: i32 = -32768; +pub const INT32_MIN: i32 = -2147483648; +pub const INT8_MAX: u32 = 127; +pub const INT16_MAX: u32 = 32767; +pub const INT32_MAX: u32 = 2147483647; +pub const UINT8_MAX: u32 = 255; +pub const UINT16_MAX: u32 = 65535; +pub const UINT32_MAX: u32 = 4294967295; +pub const INT_LEAST8_MIN: i32 = -128; +pub const INT_LEAST16_MIN: i32 = -32768; +pub const INT_LEAST32_MIN: i32 = -2147483648; +pub const INT_LEAST8_MAX: u32 = 127; +pub const INT_LEAST16_MAX: u32 = 32767; +pub const INT_LEAST32_MAX: u32 = 2147483647; +pub const UINT_LEAST8_MAX: u32 = 255; +pub const UINT_LEAST16_MAX: u32 = 65535; +pub const UINT_LEAST32_MAX: u32 = 4294967295; +pub const INT_FAST8_MIN: i32 = -128; +pub const INT_FAST16_MIN: i64 = -9223372036854775808; +pub const INT_FAST32_MIN: i64 = -9223372036854775808; +pub const INT_FAST8_MAX: u32 = 127; +pub const INT_FAST16_MAX: u64 = 9223372036854775807; +pub const INT_FAST32_MAX: u64 = 9223372036854775807; +pub const UINT_FAST8_MAX: u32 = 255; +pub const UINT_FAST16_MAX: i32 = -1; +pub const UINT_FAST32_MAX: i32 = -1; +pub const INTPTR_MIN: i64 = -9223372036854775808; +pub const INTPTR_MAX: u64 = 9223372036854775807; +pub const UINTPTR_MAX: i32 = -1; +pub const PTRDIFF_MIN: i64 = -9223372036854775808; +pub const PTRDIFF_MAX: u64 = 9223372036854775807; +pub const SIG_ATOMIC_MIN: i32 = -2147483648; +pub const SIG_ATOMIC_MAX: u32 = 2147483647; +pub const SIZE_MAX: i32 = -1; +pub const WINT_MIN: u32 = 0; +pub const WINT_MAX: u32 = 4294967295; +pub type int_least8_t = ::std::os::raw::c_schar; +pub type int_least16_t = ::std::os::raw::c_short; +pub type int_least32_t = ::std::os::raw::c_int; +pub type int_least64_t = ::std::os::raw::c_long; +pub type uint_least8_t = ::std::os::raw::c_uchar; +pub type uint_least16_t = ::std::os::raw::c_ushort; +pub type uint_least32_t = ::std::os::raw::c_uint; +pub type uint_least64_t = ::std::os::raw::c_ulong; +pub type int_fast8_t = ::std::os::raw::c_schar; +pub type int_fast16_t = ::std::os::raw::c_long; +pub type int_fast32_t = ::std::os::raw::c_long; +pub type int_fast64_t = ::std::os::raw::c_long; +pub type uint_fast8_t = ::std::os::raw::c_uchar; +pub type uint_fast16_t = ::std::os::raw::c_ulong; +pub type uint_fast32_t = ::std::os::raw::c_ulong; +pub type uint_fast64_t = ::std::os::raw::c_ulong; +pub type intmax_t = ::std::os::raw::c_long; +pub type uintmax_t = ::std::os::raw::c_ulong; +pub type wchar_t = ::std::os::raw::c_int; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct max_align_t { + pub __clang_max_align_nonce1: ::std::os::raw::c_longlong, + pub __bindgen_padding_0: u64, + pub __clang_max_align_nonce2: f64, +} +pub const DLDeviceType_kDLCPU: DLDeviceType = 1; +pub const DLDeviceType_kDLGPU: DLDeviceType = 2; +pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3; +pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4; +pub const DLDeviceType_kDLMetal: DLDeviceType = 8; +pub const DLDeviceType_kDLVPI: DLDeviceType = 9; +pub const DLDeviceType_kDLROCM: DLDeviceType = 10; +/// \brief The device type in DLContext. +pub type DLDeviceType = u32; +/// \brief A Device context for Tensor and operator. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLContext { + /// \brief The device type used in the device. + pub device_type: DLDeviceType, + /// \brief The device index + pub device_id: ::std::os::raw::c_int, +} +pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0; +pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1; +pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2; +/// \brief The type code options DLDataType. +pub type DLDataTypeCode = u32; +/// \brief The data type the tensor can hold. +/// +/// Examples +/// - float: type_code = 2, bits = 32, lanes=1 +/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 +/// - int8: type_code = 0, bits = 8, lanes=1 +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLDataType { + /// \brief Type code of base types. + /// We keep it uint8_t instead of DLDataTypeCode for minimal memory + /// footprint, but the value should be one of DLDataTypeCode enum values. + /// + pub code: u8, + /// \brief Number of bits, common choices are 8, 16, 32. + pub bits: u8, + /// \brief Number of lanes in the type, used for vector types. + pub lanes: u16, +} +/// \brief Plain C Tensor object, does not manage memory. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLTensor { + /// \brief The opaque data pointer points to the allocated data. + /// This will be CUDA device pointer or cl_mem handle in OpenCL. + /// This pointer is always aligns to 256 bytes as in CUDA. + pub data: *mut ::std::os::raw::c_void, + /// \brief The device context of the tensor + pub ctx: DLContext, + /// \brief Number of dimensions + pub ndim: ::std::os::raw::c_int, + /// \brief The data type of the pointer + pub dtype: DLDataType, + /// \brief The shape of the tensor + pub shape: *mut i64, + /// \brief strides of the tensor, + /// can be NULL, indicating tensor is compact. + pub strides: *mut i64, + /// \brief The offset in bytes to the beginning pointer to data + pub byte_offset: u64, +} +/// \brief C Tensor object, manage memory of DLTensor. This data structure is +/// intended to faciliate the borrowing of DLTensor by another framework. It is +/// not meant to transfer the tensor. When the borrowing framework doesn't need +/// the tensor, it should call the deleter to notify the host that the resource +/// is no longer needed. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLManagedTensor { + /// \brief DLTensor which is being memory managed + pub dl_tensor: DLTensor, + /// \brief the context of the original host framework of DLManagedTensor in + /// which DLManagedTensor is used in the framework. It can also be NULL. + pub manager_ctx: *mut ::std::os::raw::c_void, + /// \brief Destructor signature void (*)(void*) - this should be called + /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL + /// if there is no way for the caller to provide a reasonable destructor. + pub deleter: ::std::option::Option, +} +/// \brief type of array index. +pub type tvm_index_t = i64; +pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5; +pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6; +pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7; +pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11; +pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12; +/// \brief Extension device types in TVM +pub type TVMDeviceExtType = u32; +pub const TVMTypeCode_kHandle: TVMTypeCode = 3; +pub const TVMTypeCode_kNull: TVMTypeCode = 4; +pub const TVMTypeCode_kTVMType: TVMTypeCode = 5; +pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6; +pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7; +pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8; +pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9; +pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10; +pub const TVMTypeCode_kStr: TVMTypeCode = 11; +pub const TVMTypeCode_kBytes: TVMTypeCode = 12; +pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13; +pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15; +pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16; +pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20; +pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64; +pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128; +/// \brief The type code in TVMType +/// \note TVMType is used in two places. +pub type TVMTypeCode = u32; +/// \brief The data type used in TVM Runtime. +/// +/// Examples +/// - float: type_code = 2, bits = 32, lanes=1 +/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 +/// - int8: type_code = 0, bits = 8, lanes=1 +/// +/// \note Arguments TVM API function always takes bits=64 and lanes=1 +pub type TVMType = DLDataType; +/// \brief The Device information, abstract away common device types. +pub type TVMContext = DLContext; +/// \brief The tensor array stucture to TVM API. +pub type TVMArray = DLTensor; +/// \brief the array handle +pub type TVMArrayHandle = *mut TVMArray; +/// \brief Union type of values +/// being passed through API and function calls. +#[repr(C)] +#[derive(Copy, Clone)] +pub union TVMValue { + pub v_int64: i64, + pub v_float64: f64, + pub v_handle: *mut ::std::os::raw::c_void, + pub v_str: *const ::std::os::raw::c_char, + pub v_type: TVMType, + pub v_ctx: TVMContext, + _bindgen_union_align: u64, +} +/// \brief Byte array type used to pass in byte array +/// When kBytes is used as data type. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct TVMByteArray { + pub data: *const ::std::os::raw::c_char, + pub size: usize, +} +/// \brief Handle to TVM runtime modules. +pub type TVMModuleHandle = *mut ::std::os::raw::c_void; +/// \brief Handle to packed function handle. +pub type TVMFunctionHandle = *mut ::std::os::raw::c_void; +/// \brief Handle to hold return value. +pub type TVMRetValueHandle = *mut ::std::os::raw::c_void; +/// \brief The stream that is specific to device +/// can be NULL, which indicates the default one. +pub type TVMStreamHandle = *mut ::std::os::raw::c_void; +extern "C" { + /// \brief Used for implementing C API function. + /// Set last error message before return. + /// \param msg The error message to be set. + pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char); +} +extern "C" { + /// \brief return str message of the last error + /// all function in this file will return 0 when success + /// and -1 when an error occured, + /// TVMGetLastError can be called to retrieve the error + /// + /// this function is threadsafe and can be called by different thread + /// \return error info + pub fn TVMGetLastError() -> *const ::std::os::raw::c_char; +} +extern "C" { + /// \brief Load module from file. + /// \param file_name The file name to load the module from. + /// \param format The format of the module. + /// \param out The result module + /// + /// \return 0 when success, -1 when failure happens + /// \note The resulting module do not contain import relation. + /// It can be reconstructed by TVMModImport. + pub fn TVMModLoadFromFile( + file_name: *const ::std::os::raw::c_char, + format: *const ::std::os::raw::c_char, + out: *mut TVMModuleHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Add dep to mod's dependency. + /// This allows functions in this module to use modules. + /// + /// \param mod The module handle. + /// \param dep The dependent module to be imported. + /// \return 0 when success, -1 when failure happens + pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Get function from the module. + /// \param mod The module handle. + /// \param func_name The name of the function. + /// \param query_imports Whether to query imported modules + /// \param out The result function, can be NULL if it is not available. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMModGetFunction( + mod_: TVMModuleHandle, + func_name: *const ::std::os::raw::c_char, + query_imports: ::std::os::raw::c_int, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free front-end extension type resource. + /// \param handle The extension handle. + /// \param type_code The type of of the extension type. + /// \return 0 when success, -1 when failure happens + pub fn TVMExtTypeFree( + handle: *mut ::std::os::raw::c_void, + type_code: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free the Module + /// \param mod The module to be freed. + /// + /// \note This may not free up the module's resources. + /// If there is active TVMFunctionHandle uses the module + /// Or if this module is imported by another active module. + /// + /// The all functions remains valid until TVMFuncFree is called. + /// \return 0 when success, -1 when failure happens + pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free the function when it is no longer needed. + /// \param func The function handle + /// \return 0 when success, -1 when failure happens + pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Call a Packed TVM Function. + /// + /// \param func node handle of the function. + /// \param arg_values The arguments + /// \param type_codes The type codes of the arguments + /// \param num_args Number of arguments. + /// + /// \param ret_val The return value. + /// \param ret_type_code the type code of return value. + /// + /// \return 0 when success, -1 when failure happens + /// \note TVM calls always exchanges with type bits=64, lanes=1 + /// + /// \note API calls always exchanges with type bits=64, lanes=1 + /// If API call returns container handles (e.g. FunctionHandle) + /// these handles should be managed by the front-end. + /// The front-end need to call free function (e.g. TVMFuncFree) + /// to free these handles. + pub fn TVMFuncCall( + func: TVMFunctionHandle, + arg_values: *mut TVMValue, + type_codes: *mut ::std::os::raw::c_int, + num_args: ::std::os::raw::c_int, + ret_val: *mut TVMValue, + ret_type_code: *mut ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Set the return value of TVMPackedCFunc. + /// + /// This function is called by TVMPackedCFunc to set the return value. + /// When this function is not called, the function returns null by default. + /// + /// \param ret The return value handle, pass by ret in TVMPackedCFunc + /// \param value The value to be returned. + /// \param type_code The type of the value to be returned. + /// \param num_ret Number of return values, for now only 1 is supported. + pub fn TVMCFuncSetReturn( + ret: TVMRetValueHandle, + value: *mut TVMValue, + type_code: *mut ::std::os::raw::c_int, + num_ret: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Inplace translate callback argument value to return value. + /// This is only needed for non-POD arguments. + /// + /// \param value The value to be translated. + /// \param code The type code to be translated. + /// \note This function will do a shallow copy when necessary. + /// + /// \return 0 when success, -1 when failure happens. + pub fn TVMCbArgToReturn( + value: *mut TVMValue, + code: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +/// \brief C type of packed function. +/// +/// \param args The arguments +/// \param type_codes The type codes of the arguments +/// \param num_args Number of arguments. +/// \param ret The return value handle. +/// \param resource_handle The handle additional resouce handle from fron-end. +/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. +/// \sa TVMCFuncSetReturn +pub type TVMPackedCFunc = ::std::option::Option< + unsafe extern "C" fn( + args: *mut TVMValue, + type_codes: *mut ::std::os::raw::c_int, + num_args: ::std::os::raw::c_int, + ret: TVMRetValueHandle, + resource_handle: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int, +>; +/// \brief C callback to free the resource handle in C packed function. +/// \param resource_handle The handle additional resouce handle from fron-end. +pub type TVMPackedCFuncFinalizer = + ::std::option::Option; +/// \brief Signature for extension function declarer. +/// +/// TVM call this function to get the extension functions +/// The declarer will call register_func to register function and their name. +/// +/// \param register_func_handle The register function +/// \return 0 if success, -1 if failure happens +pub type TVMExtensionFuncDeclarer = ::std::option::Option< + unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int, +>; +extern "C" { + /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle. + /// + /// The resource_handle will be managed by TVM API, until the function is no longer used. + /// + /// \param func The packed C function. + /// \param resource_handle The resource handle from front-end, can be NULL. + /// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL + /// \param out the result function handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMFuncCreateFromCFunc( + func: TVMPackedCFunc, + resource_handle: *mut ::std::os::raw::c_void, + fin: TVMPackedCFuncFinalizer, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Register the function to runtime's global table. + /// + /// The registered function then can be pulled by the backend by the name. + /// + /// \param name The name of the function. + /// \param f The function to be registered. + /// \param override Whether allow override already registered function. + pub fn TVMFuncRegisterGlobal( + name: *const ::std::os::raw::c_char, + f: TVMFunctionHandle, + override_: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Get a global function. + /// + /// \param name The name of the function. + /// \param out the result function pointer, NULL if it does not exist. + /// + /// \note The function handle of global function is managed by TVM runtime, + /// So TVMFuncFree is should not be called when it get deleted. + pub fn TVMFuncGetGlobal( + name: *const ::std::os::raw::c_char, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief List all the globally registered function name + /// \param out_size The number of functions + /// \param out_array The array of function names. + /// \return 0 when success, -1 when failure happens + pub fn TVMFuncListGlobalNames( + out_size: *mut ::std::os::raw::c_int, + out_array: *mut *mut *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Allocate a nd-array's memory, + /// including space of shape, of given spec. + /// + /// \param shape The shape of the array, the data content will be copied to out + /// \param ndim The number of dimension of the array. + /// \param dtype_code The type code of the dtype + /// \param dtype_bits The number of bits of dtype + /// \param dtype_lanes The number of lanes in the dtype. + /// \param device_type The device type of context + /// \param device_id The device id of context. + /// \param out The output handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayAlloc( + shape: *const tvm_index_t, + ndim: ::std::os::raw::c_int, + dtype_code: ::std::os::raw::c_int, + dtype_bits: ::std::os::raw::c_int, + dtype_lanes: ::std::os::raw::c_int, + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + out: *mut TVMArrayHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free the TVM Array. + /// \param handle The array handle to be freed. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Copy array data from CPU byte array. + /// \param handle The array handle. + /// \param data the data pointer + /// \param nbytes The number of bytes to copy. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyFromBytes( + handle: TVMArrayHandle, + data: *mut ::std::os::raw::c_void, + nbytes: usize, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Copy array data to CPU byte array. + /// \param handle The array handle. + /// \param data the data pointer + /// \param nbytes The number of bytes to copy. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyToBytes( + handle: TVMArrayHandle, + data: *mut ::std::os::raw::c_void, + nbytes: usize, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Copy the array, both from and to must be valid during the copy. + /// \param from The array to be copied from. + /// \param to The target space. + /// \param stream The stream where the copy happens, can be NULL. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyFromTo( + from: TVMArrayHandle, + to: TVMArrayHandle, + stream: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Produce an array from the DLManagedTensor that shares data memory + /// with the DLManagedTensor. + /// \param from The source DLManagedTensor. + /// \param out The output array handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayFromDLPack( + from: *mut DLManagedTensor, + out: *mut TVMArrayHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Produce a DLMangedTensor from the array that shares data memory with + /// the array. + /// \param from The source array. + /// \param out The DLManagedTensor handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayToDLPack( + from: TVMArrayHandle, + out: *mut *mut DLManagedTensor, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Delete (free) a DLManagedTensor's data. + /// \param dltensor Pointer to the DLManagedTensor. + pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor); +} +extern "C" { + /// \brief Create a new runtime stream. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context + /// \param out The new stream handle + /// \return 0 when success, -1 when failure happens + pub fn TVMStreamCreate( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + out: *mut TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free a created stream handle. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context + /// \param stream The stream to be freed + /// \return 0 when success, -1 when failure happens + pub fn TVMStreamFree( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + stream: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Set the runtime stream of current thread to be stream. + /// The subsequent calls to the same device_type + /// will use the setted stream handle. + /// The specific type of stream is runtime device dependent. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context. + /// \param handle The stream handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMSetStream( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + handle: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Wait until all computations on stream completes. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context. + /// \param stream The stream to be synchronized. + /// \return 0 when success, -1 when failure happens + pub fn TVMSynchronize( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + stream: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Synchronize two streams of execution. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context + /// \param src The source stream to synchronize. + /// \param dst The destination stream to synchronize. + /// \return 0 when success, -1 when failure happens + pub fn TVMStreamStreamSynchronize( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + src: TVMStreamHandle, + dst: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Backend function for modules to get function + /// from its environment mod_node (its imports and global function). + /// The user do should not call TVMFuncFree on func. + /// + /// \param mod_node The module handle. + /// \param func_name The name of the function. + /// \param out The result function. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendGetFuncFromEnv( + mod_node: *mut ::std::os::raw::c_void, + func_name: *const ::std::os::raw::c_char, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Backend function to register system-wide library symbol. + /// + /// \param name The name of the symbol + /// \param ptr The symbol address. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendRegisterSystemLibSymbol( + name: *const ::std::os::raw::c_char, + ptr: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Backend function to allocate temporal workspace. + /// + /// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment. + /// + /// \param nbytes The size of the space requested. + /// \param device_type The device type which the space will be allocated. + /// \param device_id The device id which the space will be allocated. + /// \param dtype_code_hint The type code of the array elements. Only used in + /// certain backends such as OpenGL. + /// \param dtype_bits_hint The type bits of the array elements. Only used in + /// certain backends such as OpenGL. + /// \return nullptr when error is thrown, a valid ptr if success + pub fn TVMBackendAllocWorkspace( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + nbytes: u64, + dtype_code_hint: ::std::os::raw::c_int, + dtype_bits_hint: ::std::os::raw::c_int, + ) -> *mut ::std::os::raw::c_void; +} +extern "C" { + /// \brief Backend function to free temporal workspace. + /// + /// \param ptr The result allocated space pointer. + /// \param device_type The device type which the space will be allocated. + /// \param device_id The device id which the space will be allocated. + /// \return 0 when no error is thrown, -1 when failure happens + /// + /// \sa TVMBackendAllocWorkspace + pub fn TVMBackendFreeWorkspace( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + ptr: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int; +} +/// \brief Environment for TVM parallel task. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct TVMParallelGroupEnv { + /// \brief Auxiliary used for synchronization + pub sync_handle: *mut ::std::os::raw::c_void, + /// \brief total amount of task + pub num_task: i32, +} +/// \brief The callback function to execute a parallel lambda +/// \param task_id the task id of the function. +/// \param penv The parallel environment backs the execution. +/// \param cdata The supporting closure data. +pub type FTVMParallelLambda = ::std::option::Option< + unsafe extern "C" fn( + task_id: ::std::os::raw::c_int, + penv: *mut TVMParallelGroupEnv, + cdata: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int, +>; +extern "C" { + /// \brief Backend function for running parallel jobs. + /// + /// \param flambda The parallel function to be launched. + /// \param cdata The closure data. + /// \param num_task Number of tasks to launch, can be 0, means launch + /// with all available threads. + /// + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendParallelLaunch( + flambda: FTVMParallelLambda, + cdata: *mut ::std::os::raw::c_void, + num_task: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief BSP barrrier between parallel threads + /// \param task_id the task id of the function. + /// \param penv The parallel environment backs the execution. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendParallelBarrier( + task_id: ::std::os::raw::c_int, + penv: *mut TVMParallelGroupEnv, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Simple static initialization fucntion. + /// Run f once and set handle to be not null. + /// This function is mainly used for test purpose. + /// + /// \param handle An global address to indicate f + /// \param f The function to be ran + /// \param cdata The closure data to pass to the function. + /// \param nbytes Number of bytes in the closure data. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendRunOnce( + handle: *mut *mut ::std::os::raw::c_void, + f: ::std::option::Option< + unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int, + >, + cdata: *mut ::std::os::raw::c_void, + nbytes: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} diff --git a/rust/src/runtime/graph.rs b/rust/src/runtime/graph.rs new file mode 100644 index 000000000000..6c53aeb9f6e9 --- /dev/null +++ b/rust/src/runtime/graph.rs @@ -0,0 +1,466 @@ +use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str}; + +use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr}; +use serde; +use serde_json; + +use super::{DataType, Module, Storage, TVMArgValue, TVMContext, Tensor}; +use errors::{Error, ErrorKind, Result}; +use ffi::runtime::{ + DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor, +}; + +// Magic number for NDArray file. @see `kTVMNDArrayMagic` in `ndarray.h` +const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F; +// Magic number for NDArray list file. @see `kTVMNDArrayListMagic` in `graph_runtime.h` +const _NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7; + +/// A TVM computation graph. +/// +/// # Examples +/// +/// ``` +/// let graph_json = fs::read_to_string("graph.json")).unwrap(); +/// let graph = Graph::try_from(&graph_json).unwrap(); +/// ``` +#[derive(Serialize, Deserialize, Debug)] +pub struct Graph { + pub nodes: Vec, + pub arg_nodes: Vec, + pub heads: Vec, + pub node_row_ptr: Option>, + pub attrs: Option>, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Entry { + pub id: usize, + pub index: usize, + pub version: usize, +} + +impl Graph { + fn entry_index(&self, entry: &Entry) -> Result { + self + .node_row_ptr + .as_ref() + .map(|nrp| nrp[entry.id] + entry.index) + .ok_or("Missing node_row_ptr.".into()) + } + + /// Attempt to deserialize a JSON attribute to a type `T`. + fn get_attr(&self, attr: &str) -> Result { + Ok(serde_json::from_value::( + self + .attrs + .as_ref() + .ok_or(ErrorKind::GraphFormatError( + "Missing graph attrs".to_string(), + ))?.get(attr) + .ok_or(ErrorKind::GraphFormatError(format!( + "Missing {} attr", + attr + )))?.to_owned(), + )?) + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Node { + pub op: String, + pub name: String, + pub inputs: Vec, + pub attrs: Option>, + pub control_deps: Option>, +} + +struct NodeAttrs { + func_name: String, + num_outputs: usize, + flatten_data: bool, +} + +impl Node { + fn parse_attrs(&self) -> Result { + let attrs = self + .attrs + .as_ref() + .ok_or(format!("Missing node.attrs for `{}`", self.name))?; + let func_name = attrs + .get("func_name") + .ok_or(format!("Node `{}` is missing attrs.func_name", self.name))? + .to_string(); + let num_outputs = attrs + .get("num_outputs") + .ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))? + .parse::()?; + let flatten_data = attrs + .get("flatten_data") + .ok_or(format!( + "Node `{}` is missing attrs.flatten_data", + self.name + ))?.parse::()? + == 1; + Ok(NodeAttrs { + func_name, + num_outputs, + flatten_data, + }) + } +} + +impl<'a> TryFrom<&'a String> for Graph { + type Error = Error; + fn try_from(graph_json: &String) -> Result { + let graph = serde_json::from_str(graph_json)?; + Ok(graph) + } +} + +impl<'a> TryFrom<&'a str> for Graph { + type Error = Error; + fn try_from(graph_json: &'a str) -> Result { + let graph = serde_json::from_str(graph_json)?; + Ok(graph) + } +} + +/// A executor for a TVM computation graph. +/// +/// # Examples +/// +/// ``` +/// use ndarray::Array; +/// +/// let syslib = SystemLibModule::default(); // a provider of TVM functions +/// +/// let mut params_bytes = Vec::new(); +/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap(); +/// let params = tvm::runtime::load_param_dict(¶ms_bytes).unwrap(); +/// +/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap(); +/// +/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); +/// exec.load_params(params); +/// +/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]); +/// exec.set_input("data", x.into()); +/// exec.run(); +/// let output = exec.get_output(0).unwrap(); +/// +/// println!("{:#?}", Array::try_from(output).unwrap()); +/// ``` +pub struct GraphExecutor<'m, 't> { + graph: Graph, + op_execs: Vec>, + tensors: Vec>, +} + +unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {} + +impl<'m, 't> GraphExecutor<'m, 't> { + pub fn new(graph: Graph, lib: &'m M) -> Result { + let tensors = Self::setup_storages(&graph)?; + Ok(GraphExecutor { + op_execs: Self::setup_op_execs(&graph, lib, &tensors)?, + tensors: tensors, + graph: graph, + }) + } + + /// Runs the computation graph. + pub fn run(&self) { + self.op_execs.iter().for_each(|op_exec| { + op_exec(); + }); + } + + /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output. + fn setup_storages<'a>(graph: &'a Graph) -> Result>> { + let storage_ids = graph.get_attr::<(String, Vec)>("storage_id")?.1; + let shapes = graph.get_attr::<(String, Vec>)>("shape")?.1; + let dtypes = graph + .get_attr::<(String, Vec)>("dltype")? + .1 + .iter() + .map(|dltype| { + if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) { + Ok(dtype) + } else { + Err(ErrorKind::GraphFormatError(format!("Invalid dltype: {}", dltype).to_string()).into()) + } + }).collect::>>()?; + + let align = dtypes.iter().map(|dtype| dtype.bits as usize).max(); + let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1]; + for (i, &storage_id) in storage_ids.iter().enumerate() { + let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3; + let nbytes = dtype_size * shapes[i].iter().product::() as usize; + storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]); + } + + let mut storages: Vec = storage_num_bytes + .into_iter() + .map(|nbytes| Storage::new(nbytes, align)) + .collect::>>()?; + + let tensors = izip!(storage_ids, shapes, dtypes) + .map(|(storage_id, shape, dtype)| { + let storage = storages[storage_id].view(); + Tensor { + data: mem::replace(&mut storages[storage_id], storage), + ctx: TVMContext::default(), + dtype: dtype, + size: shape.iter().product::() as usize, + shape: shape, + strides: None, + byte_offset: 0, + } + }).collect(); + + Ok(tensors) + } + + /// Creates closures which represent the computation performed by this graph. + fn setup_op_execs( + graph: &Graph, + lib: &'m M, + tensors: &Vec>, + ) -> Result>> { + ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr."); + let node_row_ptr = graph.node_row_ptr.as_ref().unwrap(); + + let mut op_execs = Vec::new(); + for (i, node) in graph.nodes.iter().enumerate() { + if node.op == "null" { + continue; + } + ensure!(node.op == "tvm_op", "Only TVM ops are supported."); + ensure!(node.attrs.is_some(), "Missing node attrs."); + + let attrs = node.parse_attrs()?; + + if attrs.func_name == "__nop" { + continue; + } + + let func = lib + .get_function(&attrs.func_name) + .ok_or(format!("Missing function {}", attrs.func_name))?; + let arg_indices = node + .inputs + .iter() + .map(|entry| graph.entry_index(entry)) + .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi))); + + let dl_tensors = arg_indices + .map(|idx| { + let tensor = &tensors[idx?]; + Ok(if attrs.flatten_data { + DLTensor::from_tensor(tensor, true /* flatten */) + } else { + DLTensor::from(tensor) + }) + }).collect::>>() + .unwrap(); + let op: Box = box move || { + let args = dl_tensors + .iter() + .map(|t| t.into()) + .collect::>(); + func(args.as_slice()); + }; + op_execs.push(op); + } + Ok(op_execs) + } + + pub fn load_params(&mut self, params: HashMap>) { + params.into_iter().for_each(|(name, param)| { + self.set_input(name, param); + }) + } + + pub fn set_input>(&mut self, name: S, value: Tensor<'t>) { + if let Some(idx) = self.get_input_index(name.as_ref()) { + // TODO: consider `new_with_params` to avoid ever allocating + let ptr = self.tensors[idx].data.as_ptr(); + let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr); + let mut owner = to_replace.nth(0).unwrap(); + if value.data.is_owned() { + // FIXME: for no-copy, need setup_op_execs to not capture tensor ptr + // mem::replace(&mut (*owner), value); + // to_replace.for_each(|t| { + // panic!("replacing"); + // t.data = owner.data.view(); + // }); + owner.copy(&value); + } else { + owner.copy(&value); + } + } else { + println!("Unexpected input `{}`", name.as_ref()); + } + } + + /// Returns the graph input with name `name`, if it exists. + pub fn get_input>(&mut self, name: S) -> Option<&Tensor> { + self + .get_input_index(name.as_ref()) + .and_then(move |idx| Some(&self.tensors[idx])) + } + + /// Returns the graph output with index `index`, if it exists. + pub fn get_output(&self, idx: usize) -> Option<&Tensor> { + let graph = &self.graph; + graph.heads.get(idx).and_then(|entry| { + graph + .entry_index(entry) + .map(|idx| self.tensors.get(idx)) + .unwrap_or(None) + }) + } + + /// Returns the index for graph input with name `name`, if it exists. + pub fn get_input_index>(&self, name: S) -> Option { + let graph = &self.graph; + (0..graph.nodes.len()) + .skip_while(|&i| graph.nodes[i].name != name.as_ref()) + .nth(0) + .and_then(|i| { + if graph.arg_nodes.iter().any(|&id| id == i) { + graph.node_row_ptr.as_ref().map(|nrp| nrp[i]) + } else { + None + } + }) + } +} + +/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h +named!( + tvm_str_to_type, + do_parse!( + type_name: alpha1 >> + bits: digit1 >> + lanes: opt!(tuple!(tag!("x"), digit1)) >> + (DataType { + code: match type_name { + CompleteStr("int") => DLDataTypeCode_kDLInt, + CompleteStr("uint") => DLDataTypeCode_kDLUInt, + CompleteStr("float") => DLDataTypeCode_kDLFloat, + _ => DLDataTypeCode_kDLFloat, + } as usize, + bits: bits.parse::().unwrap() as usize, + lanes: match lanes { + Some(lanes) => lanes.1.parse::().unwrap() as usize, + None => 1, + }, + }) + ) +); + +/// Converts a bytes to String. +named!( + name, + map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8( + b.to_vec() + )) +); + +/// Parses a TVMContext +named!( + tvm_ctx<&[u8], TVMContext>, + do_parse!( + device_type: le_u32 >> + device_id: le_i32 >> + (TVMContext { device_type: device_type as usize, device_id: device_id as usize }) + ) +); + +/// Parses a DataType +named!( + data_type<&[u8], DataType>, + do_parse!( + code: le_u8 >> + bits: le_u8 >> + lanes: le_u16 >> + (DataType { code: code as usize, bits: bits as usize, lanes: lanes as usize }) + ) +); + +/// Parses a Tensor from a TVM array file. +named!( + tensor, + do_parse!( + take!(8) + >> bits!(tag_bits!(u64, 64, 0)) + >> ctx: tvm_ctx + >> ndim: le_u32 + >> dtype: data_type + >> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize) + >> length: le_i64 + >> data: take!(length) + >> (Tensor { + data: Storage::from(data), + ctx: ctx, + dtype: dtype, + size: shape.iter().product::() as usize, + shape: shape, + strides: None, + byte_offset: 0, + }) + ) +); + +/// Parses a graph params dict from a params binary file. +named!( + parse_param_dict>, + do_parse!( + take!(8) + >> bits!(tag_bits!(u64, 64, 0)) + >> names: length_count!(le_u64, name) + >> tensors: length_count!(le_u64, tensor) + >> (HashMap::from_iter(names.into_iter().zip(tensors.into_iter()))) + ) +); + +/// Loads a param dict saved using `nnvm.compiler.save_param_dict`. +pub fn load_param_dict(bytes: &[u8]) -> Result> { + if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) { + if remaining_bytes.len() > 0 { + bail!(ErrorKind::LoadGraphParamsError("extra input".to_string())) + } else { + Ok(param_dict) + } + } else { + bail!(ErrorKind::LoadGraphParamsError( + "invalid parameters file".to_string() + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_str_to_type() { + assert_eq!( + tvm_str_to_type(CompleteStr("float24")).unwrap().1, + DataType { + code: DLDataTypeCode_kDLFloat as usize, + bits: 24, + lanes: 1 + } + ); + assert_eq!( + tvm_str_to_type(CompleteStr("uint111x44")).unwrap().1, + DataType { + code: DLDataTypeCode_kDLUInt as usize, + bits: 111, + lanes: 44 + } + ); + } +} diff --git a/rust/src/runtime/mod.rs b/rust/src/runtime/mod.rs new file mode 100644 index 000000000000..bdf7094113d8 --- /dev/null +++ b/rust/src/runtime/mod.rs @@ -0,0 +1,25 @@ +mod allocator; +mod array; +mod module; +#[macro_use] +mod packed_func; +mod graph; +#[cfg(target_env = "sgx")] +#[macro_use] +pub mod sgx; +mod threading; +mod workspace; + +use std::os::raw::c_char; + +pub use self::{array::*, graph::*, module::*, packed_func::*, threading::*, workspace::*}; + +#[no_mangle] +pub extern "C" fn TVMAPISetLastError(cmsg: *const c_char) { + #[cfg(not(target_env = "sgx"))] + unsafe { + panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap()); + } + #[cfg(target_env = "sgx")] + ocall_packed!("__sgx_set_last_error__", cmsg); +} diff --git a/rust/src/runtime/module.rs b/rust/src/runtime/module.rs new file mode 100644 index 000000000000..2594756d9885 --- /dev/null +++ b/rust/src/runtime/module.rs @@ -0,0 +1,46 @@ +use std::{ + collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, +}; + +use ffi::runtime::BackendPackedCFunc; +use runtime::packed_func::{wrap_backend_packed_func, PackedFunc}; + +pub trait Module { + fn get_function>(&self, name: S) -> Option; +} + +pub struct SystemLibModule; + +lazy_static! { + static ref SYSTEM_LIB_FUNCTIONS: Mutex> = + Mutex::new(HashMap::new()); +} + +impl Module for SystemLibModule { + fn get_function>(&self, name: S) -> Option { + SYSTEM_LIB_FUNCTIONS + .lock() + .unwrap() + .get(name.as_ref()) + .map(|func| wrap_backend_packed_func(func.to_owned())) + } +} + +impl Default for SystemLibModule { + fn default() -> Self { + SystemLibModule {} + } +} + +#[no_mangle] +pub extern "C" fn TVMBackendRegisterSystemLibSymbol( + cname: *const c_char, + func: BackendPackedCFunc, +) -> i32 { + let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() }; + SYSTEM_LIB_FUNCTIONS + .lock() + .unwrap() + .insert(name.to_string(), func); + return 0; +} diff --git a/rust/src/runtime/packed_func.rs b/rust/src/runtime/packed_func.rs new file mode 100644 index 000000000000..030d677329c0 --- /dev/null +++ b/rust/src/runtime/packed_func.rs @@ -0,0 +1,286 @@ +use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void}; + +use ffi::runtime::{ + BackendPackedCFunc, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLTensor, + TVMTypeCode_kArrayHandle, TVMTypeCode_kHandle, TVMValue, +}; + +use errors::*; + +pub type PackedFunc = Box TVMRetValue + Send + Sync>; + +/// Calls a packed function and returns a `TVMRetValue`. +/// +/// # Example +/// +/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` +#[macro_export] +macro_rules! call_packed { + ($fn:expr, $($args:expr),+) => { + $fn(&[$($args.into(),)+]) + }; + ($fn:expr) => { + $fn(&Vec::new()) + }; +} + +/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way +/// to obtain a `TVMArgValue` is automatically via `call_packed!`. +#[derive(Clone, Copy)] +pub struct TVMArgValue<'a> { + _lifetime: PhantomData<&'a ()>, + pub(crate) value: TVMValue, + pub(crate) type_code: i64, +} + +impl<'a> TVMArgValue<'a> { + pub fn new(value: TVMValue, type_code: i64) -> Self { + TVMArgValue { + _lifetime: PhantomData, + value: value, + type_code: type_code, + } + } +} + +/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode. +macro_rules! impl_prim_tvm_arg { + ($type:ty, $field:ident, $code:expr, $as:ty) => { + impl<'a> From<$type> for TVMArgValue<'a> { + fn from(val: $type) -> Self { + TVMArgValue { + value: TVMValue { $field: val as $as }, + type_code: $code as i64, + _lifetime: PhantomData, + } + } + } + }; + ($type:ty, $field:ident, $code:expr) => { + impl_prim_tvm_arg!($type, $field, $code, $type); + }; + ($type:ty,v_int64) => { + impl_prim_tvm_arg!($type, v_int64, DLDataTypeCode_kDLInt, i64); + }; + ($type:ty,v_float64) => { + impl_prim_tvm_arg!($type, v_float64, DLDataTypeCode_kDLFloat, f64); + }; +} + +impl_prim_tvm_arg!(f32, v_float64); +impl_prim_tvm_arg!(f64, v_float64); +impl_prim_tvm_arg!(i8, v_int64); +impl_prim_tvm_arg!(u8, v_int64); +impl_prim_tvm_arg!(i32, v_int64); +impl_prim_tvm_arg!(u32, v_int64); +impl_prim_tvm_arg!(i64, v_int64); +impl_prim_tvm_arg!(u64, v_int64); +impl_prim_tvm_arg!(bool, v_int64); + +/// Creates a conversion to a `TVMArgValue` for an object handle. +impl<'a, T> From<*const T> for TVMArgValue<'a> { + fn from(ptr: *const T) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: ptr as *mut T as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +/// Creates a conversion to a `TVMArgValue` for a mutable object handle. +impl<'a, T> From<*mut T> for TVMArgValue<'a> { + fn from(ptr: *mut T) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: ptr as *mut c_void, + }, + type_code: TVMTypeCode_kHandle as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> { + fn from(arr: &'a mut DLTensor) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: arr as *mut _ as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a> From<&'a DLTensor> for TVMArgValue<'a> { + fn from(arr: &'a DLTensor) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: arr as *const _ as *mut DLTensor as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +/// An owned TVMPODValue. Can be converted from a variety of primitive and object types. +/// Can be downcasted using `try_from` if it contains the desired type. +/// +/// # Example +/// +/// ``` +/// let a = 42u32; +/// let b: i64 = TVMRetValue::from(a).try_into().unwrap(); +/// +/// let s = "hello, world!"; +/// let t: TVMRetValue = s.into(); +/// assert_eq!(String::try_from(t).unwrap(), s); +/// ``` +pub struct TVMRetValue { + /// A primitive return value, if any. + prim_value: u64, + /// An object return value, if any. + box_value: Box, + /// The DLDataTypeCode which determines whether `prim_value` or `box_value` is in use. + type_code: i64, +} + +#[cfg(target_env = "sgx")] +impl TVMRetValue { + pub(crate) fn from_tvm_value(value: TVMValue, type_code: i64) -> Self { + unsafe { + Self { + prim_value: match type_code { + 0 | 1 => value.v_int64 as u64, + 2 => value.v_float64 as u64, + 3 | 7 | 8 | 9 | 10 => value.v_handle as u64, + 11 | 12 => value.v_str as u64, + _ => 0, + } as u64, + box_value: box (), + type_code: type_code, + } + } + } + + pub fn into_tvm_value(self) -> (TVMValue, i64) { + let val = match self.type_code { + 0 | 1 => TVMValue { + v_int64: self.prim_value.clone() as i64, + }, + 2 => TVMValue { + v_float64: self.prim_value.clone() as f64, + }, + 3 | 7 | 8 | 9 | 10 => TVMValue { + v_handle: Box::into_raw(self.box_value) as *mut c_void, + }, + 11 | 12 => TVMValue { + v_str: Box::into_raw(self.box_value) as *const _, + }, + _ => unreachable!(), + }; + (val, self.type_code) + } +} + +impl Default for TVMRetValue { + fn default() -> Self { + TVMRetValue { + prim_value: 0, + box_value: box (), + type_code: 0, + } + } +} + +macro_rules! impl_prim_ret_value { + ($type:ty, $code:expr) => { + impl From<$type> for TVMRetValue { + fn from(val: $type) -> Self { + TVMRetValue { + prim_value: val as u64, + box_value: box (), + type_code: $code, + } + } + } + impl TryFrom for $type { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$type> { + if ret.type_code == $code { + Ok(ret.prim_value as $type) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!($type).to_string(), + ret.type_code + )) + } + } + } + }; +} + +macro_rules! impl_boxed_ret_value { + ($type:ty, $code:expr) => { + impl From<$type> for TVMRetValue { + fn from(val: $type) -> Self { + TVMRetValue { + prim_value: 0, + box_value: box val, + type_code: $code, + } + } + } + impl TryFrom for $type { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$type> { + if let Ok(val) = ret.box_value.downcast::<$type>() { + Ok(*val) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!($type).to_string(), + ret.type_code + )) + } + } + } + }; +} + +impl_prim_ret_value!(i8, 0); +impl_prim_ret_value!(u8, 1); +impl_prim_ret_value!(i16, 0); +impl_prim_ret_value!(u16, 1); +impl_prim_ret_value!(i32, 0); +impl_prim_ret_value!(u32, 1); +impl_prim_ret_value!(f32, 2); +impl_prim_ret_value!(i64, 0); +impl_prim_ret_value!(u64, 1); +impl_prim_ret_value!(f64, 2); +impl_prim_ret_value!(isize, 0); +impl_prim_ret_value!(usize, 1); +impl_boxed_ret_value!(String, 11); + +// @see `WrapPackedFunc` in `llvm_module.cc`. +pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { + box move |args: &[TVMArgValue]| { + func( + args + .iter() + .map(|ref arg| arg.value) + .collect::>() + .as_ptr(), + args + .iter() + .map(|ref arg| arg.type_code as i32) + .collect::>() + .as_ptr() as *const i32, + args.len() as i32, + ); + TVMRetValue::default() + } +} diff --git a/rust/src/runtime/sgx.rs b/rust/src/runtime/sgx.rs new file mode 100644 index 000000000000..bf9d54a4af65 --- /dev/null +++ b/rust/src/runtime/sgx.rs @@ -0,0 +1,82 @@ +use std::{ + ffi::CString, + os::raw::{c_char, c_int}, +}; + +use errors::Result; +use ffi::runtime::TVMValue; +use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue}; + +pub use runtime::threading::tvm_run_worker as run_worker; + +#[macro_export] +macro_rules! tvm_ocall { + ($func: expr) => { + match $func { + 0 => Ok(()), + err => Err(format!("SGX error: {}", err)), + } + }; +} + +pub type SgxStatus = u32; + +#[cfg(target_env = "sgx")] +extern "C" { + fn tvm_ocall_packed_func( + name: *const c_char, + arg_values: *const TVMValue, + type_codes: *const c_int, + num_args: c_int, + ret_val: *mut TVMValue, + ret_type_code: *mut c_int, + ) -> SgxStatus; +} + +pub fn ocall_packed_func>(fn_name: S, args: &[TVMArgValue]) -> Result { + let mut ret_val = TVMValue { v_int64: 0 }; + let ret_type_code = 0i64; + unsafe { + tvm_ocall!(tvm_ocall_packed_func( + CString::new(fn_name.as_ref()).unwrap().as_ptr(), + args + .iter() + .map(|ref arg| arg.value) + .collect::>() + .as_ptr(), + args + .iter() + .map(|ref arg| arg.type_code as i32) + .collect::>() + .as_ptr() as *const i32, + args.len() as i32, + &mut ret_val as *mut TVMValue, + &mut (ret_type_code as i32) as *mut c_int, + ))?; + } + Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64)) +} + +#[macro_export] +macro_rules! ocall_packed { + ($fn_name:expr, $($args:expr),+) => { + ::runtime::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+]) + .expect(concat!("Error calling `", $fn_name, "`")) + }; + ($fn_name:expr) => { + ::runtime::sgx::ocall_packed_func($fn_name, &Vec::new()) + .expect(concat!("Error calling `", $fn_name, "`")) + } +} + +pub fn shutdown() { + if env!("TVM_NUM_THREADS") != "0" { + sgx_join_threads() + } +} + +impl Drop for SystemLibModule { + fn drop(&mut self) { + shutdown() + } +} diff --git a/rust/src/runtime/threading.rs b/rust/src/runtime/threading.rs new file mode 100644 index 000000000000..c0d6221c91b7 --- /dev/null +++ b/rust/src/runtime/threading.rs @@ -0,0 +1,334 @@ +use std::{ + os::raw::{c_int, c_void}, + sync::{ + atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT}, + Arc, Barrier, + }, +}; + +#[cfg(not(target_env = "sgx"))] +use num_cpus; +#[cfg(not(target_env = "sgx"))] +use std::{ + env, + thread::{self, JoinHandle}, +}; + +#[cfg(target_env = "sgx")] +use std::{collections::VecDeque, ptr, sync::Mutex}; + +use bounded_spsc_queue::{self, Producer}; + +use super::super::errors::*; +use ffi::runtime::TVMParallelGroupEnv; + +#[cfg(target_env = "sgx")] +use super::{TVMArgValue, TVMRetValue}; + +type FTVMParallelLambda = + extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; + +/// Holds a parallel job request made by a TVM library function. +struct Job { + cb: FTVMParallelLambda, + cdata: *const c_void, + req_num_tasks: usize, + pending: Arc, +} + +impl Job { + /// Splits this job into a number of `Task`s which can be scheduled. + fn tasks(&self, num_workers: usize) -> Vec { + let num_tasks = if self.req_num_tasks == 0 { + num_workers + } else { + self.req_num_tasks.min(num_workers) + }; + self.pending.store(num_tasks, Ordering::SeqCst); + + let barrier = Arc::new(Barrier::new(num_tasks)); + + (0..num_tasks) + .map(move |i| Task { + id: i, + flambda: self.cb, + penv: TVMParallelGroupEnv { + sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void, + num_task: num_tasks as i32, + }, + cdata: self.cdata, + pending: Arc::clone(&self.pending), + }).collect() + } + + /// Waits for all tasks in this `Job` to be completed. + fn wait(&self) -> Result<()> { + while self.pending.load(Ordering::Acquire) > 0 { + #[cfg(not(target_env = "sgx"))] + thread::yield_now(); + } + Ok(()) + } +} + +/// A chunk of work requested by a TVM function. +struct Task { + id: usize, + flambda: FTVMParallelLambda, + penv: TVMParallelGroupEnv, + cdata: *const c_void, + pending: Arc, +} +unsafe impl Send for Task {} +unsafe impl Sync for Task {} + +impl FnOnce<()> for Task { + type Output = i32; + extern "rust-call" fn call_once(self, _args: ()) -> Self::Output { + let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata); + self.pending.fetch_sub(1, Ordering::AcqRel); + status + } +} + +#[derive(Default)] +struct Threads { + #[allow(unused)] + #[cfg(not(target_env = "sgx"))] + handles: Vec>, + queues: Vec>, +} + +impl<'a> Threads { + #[cfg(not(target_env = "sgx"))] + fn launch) + 'static + Copy>( + num_threads: usize, + cb: F, + ) -> Self { + let (handles, queues) = (0..num_threads) + .map(|_| { + let (p, c) = bounded_spsc_queue::make(2); + let handle = thread::spawn(move || cb(c.into())); + (handle, p) + }).unzip(); + Threads { + handles: handles, + queues: queues, + } + } + + #[cfg(target_env = "sgx")] + fn launch) + 'static + Copy>( + num_threads: usize, + _cb: F, + ) -> Self { + let mut consumer_queues = SGX_QUEUES.lock().unwrap(); + let queues = (0..num_threads) + .map(|_| { + let (p, c) = bounded_spsc_queue::make(2); + consumer_queues.push_back(c.into()); + p + }).collect(); + ocall_packed!("__sgx_thread_group_launch__", num_threads as u64); + Threads { queues: queues } + } +} + +struct ThreadPool { + num_workers: usize, + #[allow(unused)] + threads: Threads, +} + +thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new()); + +impl ThreadPool { + fn new() -> Self { + let num_workers = max_concurrency(); + ThreadPool { + num_workers: num_workers, + threads: Threads::launch(num_workers, ThreadPool::run_worker), + } + } + + fn launch(&self, job: Job) { + let mut tasks = job.tasks(self.num_workers + 1); + + for (i, task) in tasks.split_off(1).into_iter().enumerate() { + self.threads.queues[i].push(task); + } + + tasks.pop().unwrap()(); + job.wait().unwrap(); + } + + fn run_worker(queue: Consumer) { + loop { + let task = queue.pop(); + let result = task(); + if result == ::min_value() { + break; + } else if result != 0 { + panic!("Error running task."); + } + } + } +} + +// Send + Sync wrapper for bounded_spsc_queue::Consumer +struct Consumer { + consumer: bounded_spsc_queue::Consumer, +} +impl From> for Consumer { + fn from(c: bounded_spsc_queue::Consumer) -> Self { + Consumer { consumer: c } + } +} +impl Consumer { + fn pop(&self) -> T { + self.consumer.pop() + } +} +unsafe impl Send for Consumer {} +unsafe impl Sync for Consumer {} + +#[cfg(target_env = "sgx")] +lazy_static! { + /// Holds tasks for untrusted threads which re-enter the enclave to execute. + static ref SGX_QUEUES: Mutex>> = Mutex::new(VecDeque::new()); +} + +#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))] +fn max_concurrency() -> usize { + if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) { + if let Ok(threads) = usize::from_str_radix(&threads_str, 10) { + return threads; + } + } + num_cpus::get_physical() +} + +#[cfg(target_env = "sgx")] +fn max_concurrency() -> usize { + usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1) +} + +#[cfg(target_arch = "wasm32")] +fn max_concurrency() -> usize { + 0 // wasm doesn't support threads yet +} + +#[cfg(target_env = "sgx")] +pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue { + let q = { + let mut qs = SGX_QUEUES.lock().unwrap(); + qs.pop_front() + // `qs: MutexGuard` needs to be dropped here since `run_worker` won't return + }; + if let Some(q) = q { + ThreadPool::run_worker(q); + } + TVMRetValue::default() +} + +#[no_mangle] +pub extern "C" fn TVMBackendParallelLaunch( + cb: FTVMParallelLambda, + cdata: *const c_void, + num_task: usize, +) -> c_int { + if max_concurrency() == 0 { + let penv = TVMParallelGroupEnv { + sync_handle: 0 as *mut c_void, + num_task: 1, + }; + cb(0, &penv as *const _, cdata); + } else { + THREAD_POOL.with(|pool| { + pool.launch(Job { + cb: cb, + cdata: cdata, + req_num_tasks: num_task, + pending: Arc::new(ATOMIC_USIZE_INIT), + }); + }); + } + return 0; +} + +#[cfg(target_env = "sgx")] +pub(crate) fn sgx_join_threads() { + extern "C" fn poison_pill( + _task_id: usize, + _penv: *const TVMParallelGroupEnv, + _cdata: *const c_void, + ) -> i32 { + ::min_value() + } + + THREAD_POOL.with(|pool| { + pool.launch(Job { + cb: poison_pill, + cdata: ptr::null(), + req_num_tasks: 0, + pending: Arc::new(ATOMIC_USIZE_INIT), + }); + }); + ocall_packed!("__sgx_thread_group_join__", 0); +} + +// @see https://github.com/dmlc/tvm/issues/988 for information on why this function is used. +#[no_mangle] +pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) { + let barrier: &Arc = unsafe { &*((*penv).sync_handle as *const Arc) }; + barrier.wait(); +} + +#[cfg(test)] +mod tests { + use std::{ptr, thread, time::Duration}; + + use super::*; + + #[test] + fn test_max_concurrency() { + env::set_var("TVM_NUM_THREADS", "42"); + env::set_var("OMP_NUM_THREADS", "24"); + assert_eq!(max_concurrency(), 42); + env::remove_var("TVM_NUM_THREADS"); + assert_eq!(max_concurrency(), 24); + } + + extern "C" fn flambda( + task_id: usize, + penv: *const TVMParallelGroupEnv, + cdata: *const c_void, + ) -> i32 { + if cdata == ptr::null() { + return 0; + } + unsafe { + let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize)); + thread::sleep(Duration::from_millis(50 * task_id as u64)); + counter.fetch_add(1, Ordering::SeqCst); + task_ids_sum.fetch_add(task_id, Ordering::SeqCst); + assert_eq!((*penv).num_task, 3); + } + 0 + } + + #[test] + fn test_parallel_launch() { + TVMBackendParallelLaunch(flambda, ptr::null(), 6); + let counter = ATOMIC_USIZE_INIT; + let task_ids_sum = ATOMIC_USIZE_INIT; + let cdata = (counter, task_ids_sum); + let num_tasks = 3; + TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks); + assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks); + assert_eq!( + cdata.1.load(Ordering::SeqCst), + (0..num_tasks).sum::() + ); + } +} diff --git a/rust/src/runtime/workspace.rs b/rust/src/runtime/workspace.rs new file mode 100644 index 000000000000..d0e6d8c89255 --- /dev/null +++ b/rust/src/runtime/workspace.rs @@ -0,0 +1,119 @@ +use std::{ + cell::RefCell, + os::raw::{c_int, c_void}, + ptr, +}; + +use super::allocator::Allocation; +use errors::*; + +const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h` + +struct WorkspacePool { + workspaces: Vec, + free: Vec, + in_use: Vec, +} + +impl WorkspacePool { + fn new() -> Self { + WorkspacePool { + workspaces: Vec::new(), + free: Vec::new(), + in_use: Vec::new(), + } + } + + fn alloc_new(&mut self, size: usize) -> Result<*mut u8> { + self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?); + self.in_use.push(self.workspaces.len() - 1); + Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr()) + } + + fn alloc(&mut self, size: usize) -> Result<*mut u8> { + if self.free.len() == 0 { + return self.alloc_new(size); + } + let idx = self + .free + .iter() + .fold(None, |cur_ws_idx: Option, &idx| { + let ws_size = self.workspaces[idx].size(); + if !ws_size >= size { + return cur_ws_idx; + } + cur_ws_idx.or(Some(idx)).and_then(|cur_idx| { + let cur_size = self.workspaces[cur_idx].size(); + Some(match ws_size <= cur_size { + true => idx, + false => cur_idx, + }) + }) + }); + match idx { + Some(idx) => { + self.free.remove_item(&idx).unwrap(); + self.in_use.push(idx); + Ok(self.workspaces[idx].as_mut_ptr()) + } + None => self.alloc_new(size), + } + } + + fn free(&mut self, ptr: *mut u8) -> Result<()> { + let mut ws_idx = None; + for i in 0..self.in_use.len() { + let idx = self.in_use[i]; + if self.workspaces[idx].as_mut_ptr() == ptr { + self.in_use.remove(i); + ws_idx = Some(idx); + break; + } + } + Ok( + self + .free + .push(ws_idx.ok_or("Tried to free nonexistent workspace.")?), + ) + } +} + +thread_local!(static WORKSPACE_POOL: RefCell = RefCell::new(WorkspacePool::new())); + +const WORKSPACE_PAGE_SIZE: usize = 4 << 10; + +#[no_mangle] +pub extern "C" fn TVMBackendAllocWorkspace( + _device_type: c_int, + _device_id: c_int, + size: u64, + _dtype_code_hint: c_int, + _dtype_bits_hint: c_int, +) -> *mut c_void { + let nbytes = if size == 0 { + WORKSPACE_PAGE_SIZE + } else { + size as usize + }; + WORKSPACE_POOL.with(|pool_cell| { + pool_cell + .borrow_mut() + .alloc(nbytes as usize) + .unwrap_or(ptr::null_mut()) as *mut c_void + }) +} + +#[no_mangle] +pub extern "C" fn TVMBackendFreeWorkspace( + _device_type: c_int, + _device_id: c_int, + ptr: *mut c_void, +) -> c_int { + WORKSPACE_POOL.with(|pool_cell| { + (match pool_cell.borrow_mut().free(ptr as *mut u8) { + Ok(()) => 0, + Err(_) => -1, + }) as c_int + }); + return 0; +} diff --git a/rust/tests/.gitignore b/rust/tests/.gitignore new file mode 100644 index 000000000000..811076739bfa --- /dev/null +++ b/rust/tests/.gitignore @@ -0,0 +1,3 @@ +*.json +*.params +*.o diff --git a/rust/tests/build_model.py b/rust/tests/build_model.py new file mode 100644 index 000000000000..e0b90495159f --- /dev/null +++ b/rust/tests/build_model.py @@ -0,0 +1,53 @@ +"""Builds a simple NNVM graph for testing.""" + +from os import path as osp + +import nnvm +from nnvm import sym +from nnvm.compiler import graph_util +from nnvm.testing import init +import numpy as np +import tvm + +CWD = osp.dirname(osp.abspath(osp.expanduser(__file__))) + + +def _get_model(dshape): + data = sym.Variable('data', shape=dshape) + fc1 = sym.dense(data, units=dshape[-1]*2, use_bias=True) + left, right = sym.split(fc1, indices_or_sections=2, axis=1) + return sym.Group(((left + 1), (right - 1))) + + +def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10): + if isinstance(graph, sym.Symbol): + graph = nnvm.graph.create(graph) + ishapes, _ = graph_util.infer_shape(graph, **input_shapes) + param_shapes = dict(zip(graph.index.input_names, ishapes)) + np.random.seed(seed) + params = {} + for param, shape in param_shapes.items(): + if param in {'data', 'label'} or not shape: + continue + init_value = np.empty(shape).astype('float32') + initializer(param, init_value) + params[param] = tvm.nd.array(init_value) + return params + +def main(): + dshape = (32, 16) + net = _get_model(dshape) + ishape_dict = {'data': dshape} + params = _init_params(net, ishape_dict) + graph, lib, params = nnvm.compiler.build(net, 'llvm', + shape=ishape_dict, + params=params, + dtype='float32') + + with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet: + f_resnet.write(graph.json()) + with open(osp.join(CWD, 'graph.params'), 'wb') as f_params: + f_params.write(nnvm.compiler.save_param_dict(params)) + +if __name__ == '__main__': + main() diff --git a/rust/tests/test_graph_serde.rs b/rust/tests/test_graph_serde.rs new file mode 100644 index 000000000000..a596544212ca --- /dev/null +++ b/rust/tests/test_graph_serde.rs @@ -0,0 +1,38 @@ +#![feature(try_from)] + +extern crate serde; +extern crate serde_json; + +extern crate tvm; + +use std::{convert::TryFrom, fs, io::Read}; + +use tvm::runtime::Graph; + +#[test] +fn test_load_graph() { + let mut params_bytes = Vec::new(); + fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params")) + .expect("Could not find TVM graph. Did you run `tests/build_model.py`?") + .read_to_end(&mut params_bytes) + .unwrap(); + let _params = tvm::runtime::load_param_dict(¶ms_bytes); + + let graph = Graph::try_from( + &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(), + ).unwrap(); + + assert_eq!(graph.nodes[3].op, "tvm_op"); + assert_eq!( + graph.nodes[3] + .attrs + .as_ref() + .unwrap() + .get("func_name") + .unwrap(), + "fuse_dense" + ); + assert_eq!(graph.nodes[5].inputs[0].index, 0); + assert_eq!(graph.nodes[6].inputs[0].index, 1); + assert_eq!(graph.heads.len(), 2); +} diff --git a/rust/tests/test_nnvm/Cargo.toml b/rust/tests/test_nnvm/Cargo.toml new file mode 100644 index 000000000000..7e6ce5fb729c --- /dev/null +++ b/rust/tests/test_nnvm/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "test-nnvm" +version = "0.0.0" +license = "Apache-2.0" +authors = ["Nick Hynes "] + +[dependencies] +ndarray = "0.11.2" +tvm = { path = "../../" } +serde = "1.0.59" +serde_json = "1.0.17" + +[build-dependencies] +ar = "0.6.0" diff --git a/rust/tests/test_nnvm/build.rs b/rust/tests/test_nnvm/build.rs new file mode 100644 index 000000000000..cb3a4e0d574d --- /dev/null +++ b/rust/tests/test_nnvm/build.rs @@ -0,0 +1,28 @@ +extern crate ar; + +use std::{env, path::PathBuf, process::Command}; + +use ar::Builder; +use std::fs::File; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_graph.py" + )).arg(&out_dir) + .output() + .expect("Failed to execute command"); + if output.stderr.len() > 0 { + panic!(String::from_utf8(output.stderr).unwrap()); + } + + let in_path: PathBuf = [&out_dir, "graph.o"].iter().collect(); + let out_path: PathBuf = [&out_dir, "libgraph.a"].iter().collect(); + let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap()); + builder.append_path(in_path.to_str().unwrap()).unwrap(); + + println!("cargo:rustc-link-lib=static=graph"); + println!("cargo:rustc-link-search=native={}", out_dir); +} diff --git a/rust/tests/test_nnvm/src/build_test_graph.py b/rust/tests/test_nnvm/src/build_test_graph.py new file mode 100755 index 000000000000..429cc2128931 --- /dev/null +++ b/rust/tests/test_nnvm/src/build_test_graph.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 + +"""Builds a simple NNVM graph for testing.""" + +from os import path as osp +import sys + +import nnvm +from nnvm import sym +from nnvm.compiler import graph_util +from nnvm.testing import init +import numpy as np +import tvm + + +def _get_model(dshape): + data = sym.Variable('data', shape=dshape) + fc = sym.dense(data, units=dshape[-1]*2, use_bias=True) + left, right = sym.split(fc, indices_or_sections=2, axis=1) + return sym.Group(((left + 1), (right - 1), fc)) + + +def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10): + if isinstance(graph, sym.Symbol): + graph = nnvm.graph.create(graph) + ishapes, _ = graph_util.infer_shape(graph, **input_shapes) + param_shapes = dict(zip(graph.index.input_names, ishapes)) + np.random.seed(seed) + params = {} + for param, shape in param_shapes.items(): + if param in {'data', 'label'} or not shape: + continue + + init_value = np.arange(np.product(shape), 0, -1).reshape(*shape).astype('float32') + if param.endswith('_bias'): + params[param] = tvm.nd.array(init_value) + continue + + init_value = np.empty(shape).astype('float32') + initializer(param, init_value) + # init_value /= init_value.sum() + 1e-10 + params[param] = tvm.nd.array(init_value) + return params + +def main(): + dshape = (4, 8) + net = _get_model(dshape) + ishape_dict = {'data': dshape} + params = _init_params(net, ishape_dict) + graph, lib, params = nnvm.compiler.build(net, 'llvm --system-lib', + shape=ishape_dict, + params=params, + dtype='float32') + + out_dir = sys.argv[1] + lib.save(osp.join(sys.argv[1], 'graph.o')) + with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet: + f_resnet.write(graph.json()) + with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params: + f_params.write(nnvm.compiler.save_param_dict(params)) + +if __name__ == '__main__': + main() diff --git a/rust/tests/test_nnvm/src/main.rs b/rust/tests/test_nnvm/src/main.rs new file mode 100644 index 000000000000..0953ce2a2603 --- /dev/null +++ b/rust/tests/test_nnvm/src/main.rs @@ -0,0 +1,80 @@ +#![feature(try_from)] + +#[macro_use] +extern crate ndarray; +extern crate serde; +extern crate serde_json; + +extern crate tvm; +use std::{collections::HashMap, convert::TryFrom, fs, io::Read}; + +use ndarray::Array; +use tvm::runtime::{Graph, GraphExecutor, SystemLibModule, Tensor}; + +const BATCH_SIZE: usize = 4; +const IN_DIM: usize = 8; + +macro_rules! check_sum { + ($e:expr, $a:ident, $b:ident) => { + let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap(); + check_sum!(a, $b); + }; + ($e:expr, $a:expr, $b:ident) => { + let a = Array::try_from($e.get_output($a).unwrap()).unwrap(); + check_sum!(a, $b); + }; + ($a:ident, $b:ident) => { + let a_sum: f32 = $a.scalar_sum(); + let b_sum: f32 = $b.scalar_sum(); + assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum); + }; +} + +fn main() { + let syslib = SystemLibModule::default(); + + let mut params_bytes = Vec::new(); + fs::File::open(concat!(env!("OUT_DIR"), "/graph.params")) + .unwrap() + .read_to_end(&mut params_bytes) + .unwrap(); + let params = tvm::runtime::load_param_dict(¶ms_bytes) + .unwrap() + .into_iter() + .map(|(k, v)| (k, v.to_owned())) + .collect::>>(); + + let graph = + Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap()).unwrap(); + let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); + + let x = Array::from_shape_vec( + (BATCH_SIZE, IN_DIM), + (0..BATCH_SIZE * IN_DIM) + .map(|x| x as f32) + .collect::>(), + ).unwrap(); + let w = Array::try_from(params.get("dense0_weight").unwrap()) + .unwrap() + .into_shape((IN_DIM * 2, IN_DIM)) + .unwrap(); + let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap(); + let dense = x.dot(&w.t()) + &b; + let left = dense.slice(s![.., 0..IN_DIM]); + let right = dense.slice(s![.., IN_DIM..]); + let expected_o0 = &left + 1f32; + let expected_o1 = &right - 1f32; + + exec.load_params(params); + exec.set_input("data", x.clone().into()); + + check_sum!(exec, data, x); + check_sum!(exec, dense0_weight, w); + check_sum!(exec, dense0_bias, b); + + exec.run(); + + check_sum!(exec, 0, expected_o0); + check_sum!(exec, 1, expected_o1); + check_sum!(exec, 2, dense); +} diff --git a/rust/tests/test_tvm_basic/Cargo.toml b/rust/tests/test_tvm_basic/Cargo.toml new file mode 100644 index 000000000000..bd4193bcb8fb --- /dev/null +++ b/rust/tests/test_tvm_basic/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "test-tvm-basic" +version = "0.0.0" +license = "Apache-2.0" +authors = ["Nick Hynes "] + +[dependencies] +ndarray = "0.11.2" +tvm = { path = "../../" } + +[build-dependencies] +ar = "0.6.0" diff --git a/rust/tests/test_tvm_basic/build.rs b/rust/tests/test_tvm_basic/build.rs new file mode 100644 index 000000000000..778dd1cab1ca --- /dev/null +++ b/rust/tests/test_tvm_basic/build.rs @@ -0,0 +1,28 @@ +extern crate ar; + +use std::{env, path::PathBuf, process::Command}; + +use ar::Builder; +use std::fs::File; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_lib.py" + )).arg(&out_dir) + .output() + .expect("Failed to execute command"); + if output.stderr.len() > 0 { + panic!(String::from_utf8(output.stderr).unwrap()); + } + + let in_path: PathBuf = [&out_dir, "test.o"].iter().collect(); + let out_path: PathBuf = [&out_dir, "libtest.a"].iter().collect(); + let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap()); + builder.append_path(in_path.to_str().unwrap()).unwrap(); + + println!("cargo:rustc-link-lib=static=test"); + println!("cargo:rustc-link-search=native={}", out_dir); +} diff --git a/rust/tests/test_tvm_basic/src/build_test_lib.py b/rust/tests/test_tvm_basic/src/build_test_lib.py new file mode 100755 index 000000000000..7289a778fcec --- /dev/null +++ b/rust/tests/test_tvm_basic/src/build_test_lib.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +"""Prepares a simple TVM library for testing.""" + +from os import path as osp +import sys + +import tvm + +def main(): + n = tvm.var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.create_schedule(C.op) + s[C].parallel(s[C].op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o')) + +if __name__ == '__main__': + main() diff --git a/rust/tests/test_tvm_basic/src/main.rs b/rust/tests/test_tvm_basic/src/main.rs new file mode 100644 index 000000000000..b6c11451d12a --- /dev/null +++ b/rust/tests/test_tvm_basic/src/main.rs @@ -0,0 +1,25 @@ +extern crate ndarray; +#[macro_use] +extern crate tvm; + +use ndarray::Array; +use tvm::{ + ffi::runtime::DLTensor, + runtime::{Module, SystemLibModule}, +}; + +fn main() { + let syslib = SystemLibModule::default(); + let add = syslib + .get_function("default_function") + .expect("main function not found"); + let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); + let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); + let mut c = Array::from_vec(vec![0f32; 4]); + let e = Array::from_vec(vec![2f32, 2., 4., 4.]); + let mut a_dl: DLTensor = (&mut a).into(); + let mut b_dl: DLTensor = (&mut b).into(); + let mut c_dl: DLTensor = (&mut c).into(); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl); + assert!(c.all_close(&e, 1e-8f32)); +}