From 90d1f2ce070ae6129ca22e819968c98349160bee Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Tue, 14 Aug 2018 03:40:07 +0000 Subject: [PATCH 01/10] Add rust runtime --- src/runtime/rust/.gitignore | 3 + src/runtime/rust/.rustfmt.toml | 57 +++ src/runtime/rust/.travis.yml | 5 + src/runtime/rust/Cargo.toml | 30 ++ src/runtime/rust/src/errors.rs | 30 ++ src/runtime/rust/src/lib.rs | 36 ++ src/runtime/rust/src/runtime/allocator.rs | 46 ++ src/runtime/rust/src/runtime/array.rs | 393 +++++++++++++++++ src/runtime/rust/src/runtime/graph.rs | 408 ++++++++++++++++++ src/runtime/rust/src/runtime/mod.rs | 19 + src/runtime/rust/src/runtime/module.rs | 46 ++ src/runtime/rust/src/runtime/packed_func.rs | 199 +++++++++ src/runtime/rust/src/runtime/threading.rs | 299 +++++++++++++ src/runtime/rust/src/runtime/workspace.rs | 103 +++++ src/runtime/rust/tests/.gitignore | 3 + src/runtime/rust/tests/test_graph_serde.rs | 39 ++ src/runtime/rust/tests/test_nnvm/Cargo.toml | 17 + src/runtime/rust/tests/test_nnvm/src/main.rs | 84 ++++ .../rust/tests/test_tvm_basic/Cargo.toml | 12 + .../rust/tests/test_tvm_basic/src/main.rs | 25 ++ 20 files changed, 1854 insertions(+) create mode 100644 src/runtime/rust/.gitignore create mode 100644 src/runtime/rust/.rustfmt.toml create mode 100644 src/runtime/rust/.travis.yml create mode 100644 src/runtime/rust/Cargo.toml create mode 100644 src/runtime/rust/src/errors.rs create mode 100644 src/runtime/rust/src/lib.rs create mode 100644 src/runtime/rust/src/runtime/allocator.rs create mode 100644 src/runtime/rust/src/runtime/array.rs create mode 100644 src/runtime/rust/src/runtime/graph.rs create mode 100644 src/runtime/rust/src/runtime/mod.rs create mode 100644 src/runtime/rust/src/runtime/module.rs create mode 100644 src/runtime/rust/src/runtime/packed_func.rs create mode 100644 src/runtime/rust/src/runtime/threading.rs create mode 100644 src/runtime/rust/src/runtime/workspace.rs create mode 100644 src/runtime/rust/tests/.gitignore create mode 100644 src/runtime/rust/tests/test_graph_serde.rs create mode 100644 src/runtime/rust/tests/test_nnvm/Cargo.toml create mode 100644 src/runtime/rust/tests/test_nnvm/src/main.rs create mode 100644 src/runtime/rust/tests/test_tvm_basic/Cargo.toml create mode 100644 src/runtime/rust/tests/test_tvm_basic/src/main.rs diff --git a/src/runtime/rust/.gitignore b/src/runtime/rust/.gitignore new file mode 100644 index 000000000000..230ab66104df --- /dev/null +++ b/src/runtime/rust/.gitignore @@ -0,0 +1,3 @@ +Cargo.lock +target/ +**/*.rs.bk diff --git a/src/runtime/rust/.rustfmt.toml b/src/runtime/rust/.rustfmt.toml new file mode 100644 index 000000000000..9b2cf0e1007d --- /dev/null +++ b/src/runtime/rust/.rustfmt.toml @@ -0,0 +1,57 @@ +max_width = 100 +hard_tabs = false +tab_spaces = 2 +newline_style = "Unix" +use_small_heuristics = true +indent_style = "Block" +wrap_comments = false +comment_width = 80 +normalize_comments = false +format_strings = true +empty_item_single_line = true +struct_lit_single_line = true +fn_single_line = false +where_single_line = false +imports_indent = "Block" +imports_layout = "Mixed" +merge_imports = true +reorder_imports = true +reorder_modules = true +reorder_impl_items = false +type_punctuation_density = "Wide" +space_before_colon = false +space_after_colon = true +spaces_around_ranges = false +spaces_within_parens_and_brackets = false +binop_separator = "Front" +remove_blank_lines_at_start_or_end_of_block = true +combine_control_expr = true +struct_field_align_threshold = 0 +match_arm_blocks = true +force_multiline_blocks = false +fn_args_density = "Tall" +brace_style = "SameLineWhere" +control_brace_style = "AlwaysSameLine" +trailing_semicolon = true +trailing_comma = "Vertical" +match_block_trailing_comma = false +blank_lines_upper_bound = 1 +blank_lines_lower_bound = 0 +merge_derives = true +use_try_shorthand = true +condense_wildcard_suffixes = true +force_explicit_abi = true +use_field_init_shorthand = false +write_mode = "Overwrite" +color = "Auto" +required_version = "0.6.1" +unstable_features = false +disable_all_formatting = false +skip_children = false +hide_parse_errors = false +error_on_line_overflow = true +error_on_unformatted = true +report_todo = "Never" +report_fixme = "Never" +ignore = [] +verbose_diff = false diff --git a/src/runtime/rust/.travis.yml b/src/runtime/rust/.travis.yml new file mode 100644 index 000000000000..63a3d0277c1b --- /dev/null +++ b/src/runtime/rust/.travis.yml @@ -0,0 +1,5 @@ +language: rust +rust: + - nightly +matrix: + fast_finish: true diff --git a/src/runtime/rust/Cargo.toml b/src/runtime/rust/Cargo.toml new file mode 100644 index 000000000000..1283c24dd8f7 --- /dev/null +++ b/src/runtime/rust/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "tvm" +version = "0.1.0" +license = "Apache-2.0" +description = "TVM Rust runtime" +repository = "https://github.com/nhynes/tvm-rs" +readme = "README.md" +keywords = ["tvm", "nnvm"] +categories = ["api-bindings", "science"] +authors = ["Nick Hynes "] + +[features] +par-launch-alloc = [] + +[dependencies] +bounded-spsc-queue = "0.4.0" +error-chain = { version = "0.12.0", default-features = false } +itertools = "0.7.8" +lazy_static = "1.0.0" +ndarray = "0.11.2" +nom = "4.0.0" +serde = "1.0.59" +serde_derive = "1.0.59" +serde_json = "1.0.17" + +[target.'cfg(not(target_env = "sgx"))'.dependencies] +num_cpus = "1.8.0" + +[build-dependencies] +bindgen = "0.37" diff --git a/src/runtime/rust/src/errors.rs b/src/runtime/rust/src/errors.rs new file mode 100644 index 000000000000..df6ee8f3c4e2 --- /dev/null +++ b/src/runtime/rust/src/errors.rs @@ -0,0 +1,30 @@ +use std::{alloc, num}; + +use ndarray; +use serde_json; + +error_chain! { + errors { + TryFromTVMRetValueError(expected: String, actual: i64) { + description("mismatched types while downcasting TVMRetValue") + display("invalid downcast: expected `{}` but was `{}`", expected, actual) + } + + GraphFormatError(msg: String) { + description("unable to load graph") + display("could not load graph json: {}", msg) + } + + LoadGraphParamsError(msg: String) { + description("unable to load graph params") + display("could not load graph params: {}", msg) + } + } + foreign_links { + Alloc(alloc::AllocErr); + Layout(alloc::LayoutErr); + GraphDeserialize(serde_json::Error); + ParseInt(num::ParseIntError); + ShapeError(ndarray::ShapeError); + } +} diff --git a/src/runtime/rust/src/lib.rs b/src/runtime/rust/src/lib.rs new file mode 100644 index 000000000000..b2801e1e8cc6 --- /dev/null +++ b/src/runtime/rust/src/lib.rs @@ -0,0 +1,36 @@ +#![feature(allocator_api, box_syntax, fn_traits, try_from, unboxed_closures)] + +extern crate bounded_spsc_queue; +#[macro_use] +extern crate error_chain; +#[macro_use] +extern crate itertools; +#[macro_use] +extern crate lazy_static; +extern crate ndarray; +#[macro_use] +extern crate nom; +#[cfg(not(target_env = "sgx"))] +extern crate num_cpus; +extern crate serde; +#[macro_use] +extern crate serde_derive; +extern crate serde_json; + +pub mod ffi { + #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)] + + pub mod runtime { + use std::os::raw::{c_char, c_int, c_void}; + + include!(concat!(env!("OUT_DIR"), "/c_runtime_api.rs")); + + pub type BackendPackedCFunc = + extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; + } +} + +pub mod errors; +pub mod runtime; + +pub use errors::*; diff --git a/src/runtime/rust/src/runtime/allocator.rs b/src/runtime/rust/src/runtime/allocator.rs new file mode 100644 index 000000000000..d3dc772cd4cb --- /dev/null +++ b/src/runtime/rust/src/runtime/allocator.rs @@ -0,0 +1,46 @@ +use std::alloc::{self, Layout}; + +use errors::*; + +const DEFAULT_ALIGN_BYTES: usize = 4; + +#[derive(PartialEq, Eq)] +pub struct Allocation { + layout: Layout, + ptr: *mut u8, +} + +impl Allocation { + pub fn new(size: usize, align: Option) -> Result { + let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); + let layout = Layout::from_size_align(size, alignment)?; + let ptr = unsafe { alloc::alloc(layout.clone()) }; + if ptr.is_null() { + alloc::handle_alloc_error(layout); + } + Ok(Self { + ptr: ptr, + layout: layout, + }) + } + + pub fn as_mut_ptr(&self) -> *mut u8 { + self.ptr + } + + pub fn size(&self) -> usize { + self.layout.size() + } + + pub fn align(&self) -> usize { + self.layout.align() + } +} + +impl Drop for Allocation { + fn drop(&mut self) { + unsafe { + alloc::dealloc(self.ptr, self.layout.clone()); + } + } +} diff --git a/src/runtime/rust/src/runtime/array.rs b/src/runtime/rust/src/runtime/array.rs new file mode 100644 index 000000000000..dff63fa92bb1 --- /dev/null +++ b/src/runtime/rust/src/runtime/array.rs @@ -0,0 +1,393 @@ +use std::{ + any::TypeId, + convert::TryFrom, + mem, + os::raw::{c_int, c_void}, + ptr, + slice, +}; + +use ndarray; + +use super::allocator::Allocation; +use errors::*; +use ffi::runtime::{ + DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, + DLDeviceType_kDLCPU, DLTensor, +}; + +#[derive(PartialEq)] +pub enum Storage<'a> { + Owned(Allocation), + View(&'a mut [u8], usize), // ptr, align +} + +impl<'a> Storage<'a> { + pub fn new(size: usize, align: Option) -> Result> { + Ok(Storage::Owned(Allocation::new(size, align)?)) + } + + pub fn as_mut_ptr(&self) -> *mut u8 { + match self { + Storage::Owned(alloc) => alloc.as_mut_ptr(), + Storage::View(slice, _) => slice.as_ptr() as *mut u8, + } + } + + pub fn size(&self) -> usize { + match self { + Storage::Owned(alloc) => alloc.size(), + Storage::View(slice, _) => slice.len(), + } + } + + pub fn align(&self) -> usize { + match self { + Storage::Owned(alloc) => alloc.align(), + Storage::View(_, align) => *align, + } + } + + pub fn as_ptr(&self) -> *const u8 { + self.as_mut_ptr() as *const _ + } + + pub fn view(&self) -> Storage<'a> { + match self { + Storage::Owned(alloc) => Storage::View( + unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) }, + self.align(), + ), + Storage::View(slice, _) => Storage::View( + unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) }, + self.align(), + ), + } + } + + pub fn is_owned(&self) -> bool { + match self { + Storage::Owned(_) => true, + _ => false, + } + } + + pub fn to_owned(&self) -> Storage<'static> { + let s = Storage::new(self.size(), Some(self.align())).unwrap(); + unsafe { + s.as_mut_ptr() + .copy_from_nonoverlapping(self.as_ptr(), self.size()) + } + s + } +} + +impl<'a, T> From<&'a [T]> for Storage<'a> { + fn from(data: &'a [T]) -> Self { + let data = unsafe { + slice::from_raw_parts_mut( + data.as_ptr() as *const u8 as *mut u8, + data.len() * mem::size_of::() as usize, + ) + }; + Storage::View(data, mem::align_of::()) + } +} + +#[derive(PartialEq)] +pub struct Tensor<'a> { + pub(super) data: Storage<'a>, + pub(super) ctx: TVMContext, + pub(super) dtype: DataType, + pub(super) shape: Vec, + pub(super) strides: Option>, + pub(super) byte_offset: isize, + pub(super) numel: usize, + pub(super) dshape: Vec, +} + +impl<'a> Tensor<'a> { + pub fn shape(&self) -> Vec { + self.shape.clone() + } + + pub fn to_vec(&self) -> Vec { + assert!(self.dtype.is_type::()); + let mut vec: Vec = Vec::with_capacity(self.numel * self.dtype.itemsize()); + unsafe { + vec.as_mut_ptr().copy_from_nonoverlapping( + self.data.as_ptr().offset(self.byte_offset) as *const T, + self.numel, + ); + vec.set_len(self.numel); + } + vec + } + + pub fn is_contiguous(&self) -> bool { + match self.strides { + None => true, + Some(ref strides) => { + self + .shape + .iter() + .zip(strides) + .rfold( + (true, 1), + |(is_contig, expected_stride), (shape, stride)| { + ( + is_contig && *stride == expected_stride, + expected_stride * shape, + ) + }, + ) + .0 + } + } + } + + pub fn copy(&mut self, other: &Tensor) { + assert!( + self.dtype == other.dtype && self.numel == other.numel, + "Tensor shape/dtype mismatch." + ); + assert!( + self.is_contiguous() && other.is_contiguous(), + "copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`", + self.strides, + other.strides + ); + unsafe { + self + .data + .as_mut_ptr() + .offset(self.byte_offset as isize) + .copy_from_nonoverlapping( + other.data.as_mut_ptr().offset(other.byte_offset), + other.numel * other.dtype.itemsize(), + ); + } + } + + pub fn to_owned(&self) -> Tensor<'static> { + let t = Tensor { + data: self.data.to_owned(), + ctx: self.ctx.clone(), + dtype: self.dtype.clone(), + numel: self.numel.clone(), + shape: self.shape.clone(), + strides: None, + byte_offset: 0, + dshape: self.dshape.clone(), + }; + unsafe { mem::transmute::, Tensor<'static>>(t) } + } +} + +impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD { + type Error = Error; + fn try_from(tensor: &'a Tensor) -> Result> { + ensure!( + tensor.dtype == DTYPE_FLOAT32, + "Cannot convert Tensor with dtype {:?} to ndarray", + tensor.dtype + ); + Ok(ndarray::Array::from_shape_vec( + tensor.shape.clone(), + tensor.to_vec::(), + )?) + } +} + +impl DLTensor { + pub(super) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self { + assert!(!flatten || tensor.is_contiguous()); + Self { + data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void, + ctx: DLContext::from(&tensor.ctx), + ndim: if flatten { 1 } else { tensor.shape.len() } as i32, + dtype: DLDataType::from(&tensor.dtype), + shape: if flatten { + &tensor.numel as *const _ as *mut i64 + } else { + // tensor.shape.as_ptr() + tensor.dshape.as_ptr() as *mut i64 + } as *mut i64, + strides: if flatten || tensor.is_contiguous() { + ptr::null_mut() + } else { + tensor.strides.as_ref().unwrap().as_ptr() + } as *mut i64, + byte_offset: 0, + } + } +} + +impl<'a, 't> From<&'a Tensor<'t>> for DLTensor { + fn from(tensor: &'a Tensor<'t>) -> Self { + DLTensor::from_tensor(tensor, false /* flatten */) + } +} + +impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor { + fn from(tensor: &'a mut Tensor<'t>) -> Self { + DLTensor::from_tensor(tensor, false /* flatten */) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct DataType { + pub(super) code: usize, + pub(super) bits: usize, + pub(super) lanes: usize, +} + +impl DataType { + fn itemsize(&self) -> usize { + (self.bits * self.lanes) >> 3 + } + + fn is_type(&self) -> bool { + if self.lanes != 1 { + return false; + } + let typ = TypeId::of::(); + (typ == TypeId::of::() && self.code == 0 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 0 && self.bits == 64) + || (typ == TypeId::of::() && self.code == 1 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 1 && self.bits == 64) + || (typ == TypeId::of::() && self.code == 2 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 2 && self.bits == 64) + } +} + +const DTYPE_FLOAT32: DataType = DataType { + code: DLDataTypeCode_kDLFloat as usize, + bits: 32, + lanes: 1, +}; + +impl<'a> From<&'a DataType> for DLDataType { + fn from(dtype: &'a DataType) -> Self { + Self { + code: dtype.code as u8, + bits: dtype.bits as u8, + lanes: dtype.lanes as u16, + } + } +} + +impl Default for DLContext { + fn default() -> Self { + DLContext { + device_type: DLDeviceType_kDLCPU, + device_id: 0, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TVMContext { + pub(super) device_type: usize, + pub(super) device_id: usize, +} + +impl<'a> From<&'a TVMContext> for DLContext { + fn from(ctx: &'a TVMContext) -> Self { + Self { + device_type: ctx.device_type as u32, + device_id: ctx.device_id as i32, + } + } +} + +impl Default for TVMContext { + fn default() -> Self { + Self { + device_type: DLDeviceType_kDLCPU as usize, + device_id: 0, + } + } +} + +fn tensor_from_array_storage<'a, 's, T, D: ndarray::Dimension>( + arr: &ndarray::Array, + storage: Storage<'s>, + type_code: usize, +) -> Tensor<'s> { + let type_width = mem::size_of::() as usize; + Tensor { + data: storage, + ctx: TVMContext::default(), + dtype: DataType { + code: type_code, + bits: 8 * type_width, + lanes: 1, + }, + numel: arr.len(), + shape: arr.shape().iter().map(|&v| v as usize).collect(), + strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()), + byte_offset: 0, + dshape: arr.shape().iter().map(|&v| v as i64).collect(), + } +} + +macro_rules! impl_tensor_from_ndarray { + ($type:ty, $typecode:expr) => { + impl From> for Tensor<'static> { + fn from(arr: ndarray::Array<$type, D>) -> Self { + assert!(arr.is_standard_layout(), "Array must be contiguous."); + let numel = arr.len() * mem::size_of::<$type>() as usize; + let storage = + Storage::from(unsafe { slice::from_raw_parts(arr.as_ptr() as *const u8, numel) }); + tensor_from_array_storage(&arr, storage, $typecode as usize) + } + } + impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> { + fn from(arr: &'a ndarray::Array<$type, D>) -> Self { + assert!(arr.is_standard_layout(), "Array must be contiguous."); + tensor_from_array_storage( + arr, + Storage::from(arr.as_slice().unwrap()), + $typecode as usize, + ) + } + } + }; +} + +macro_rules! impl_dltensor_from_ndarray { + ($type:ty, $typecode:expr) => { + impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { + fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { + DLTensor { + data: arr.as_mut_ptr() as *mut c_void, + ctx: DLContext::default(), + ndim: arr.ndim() as c_int, + dtype: DLDataType { + code: $typecode as u8, + bits: 8 * mem::size_of::<$type>() as u8, + lanes: 1, + }, + shape: arr.shape().as_ptr() as *const i64 as *mut i64, + strides: arr.strides().as_ptr() as *const isize as *mut i64, + byte_offset: 0, + } + } + } + }; +} + +impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); +impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); + +impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); +impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); +impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); +impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); +impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); +impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); diff --git a/src/runtime/rust/src/runtime/graph.rs b/src/runtime/rust/src/runtime/graph.rs new file mode 100644 index 000000000000..0289318d6574 --- /dev/null +++ b/src/runtime/rust/src/runtime/graph.rs @@ -0,0 +1,408 @@ +use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str}; + +use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr}; +use serde; +use serde_json; + +use super::{DataType, Module, Storage, TVMArgValue, TVMContext, Tensor}; +use errors::{Error, ErrorKind, Result}; +use ffi::runtime::{ + DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor, +}; + +const NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F; // Magic number for NDArray file +const NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7; // Magic number for NDArray list file + +#[derive(Serialize, Deserialize, Debug)] +pub struct Graph { + pub nodes: Vec, + pub arg_nodes: Vec, + pub heads: Vec, + pub node_row_ptr: Option>, + pub attrs: Option>, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Entry { + pub id: usize, + pub index: usize, + pub version: usize, +} + +impl Graph { + fn entry_index(&self, entry: &Entry) -> Result { + self + .node_row_ptr + .as_ref() + .map(|nrp| nrp[entry.id] + entry.index) + .ok_or("Missing node_row_ptr.".into()) + } + + fn get_attr(&self, attr: &str) -> Result { + Ok(serde_json::from_value::( + self + .attrs + .as_ref() + .ok_or(ErrorKind::GraphFormatError( + "Missing graph attrs".to_string(), + ))? + .get(attr) + .ok_or(ErrorKind::GraphFormatError(format!( + "Missing {} attr", + attr + )))? + .to_owned(), + )?) + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Node { + pub op: String, + pub name: String, + pub inputs: Vec, + pub attrs: Option>, + pub control_deps: Option>, +} + +struct NodeAttrs { + func_name: String, + num_outputs: usize, + flatten_data: bool, +} + +impl Node { + fn parse_attrs(&self) -> Result { + let attrs = self + .attrs + .as_ref() + .ok_or(format!("Missing node.attrs for `{}`", self.name))?; + let func_name = attrs + .get("func_name") + .ok_or(format!("Node `{}` is missing attrs.func_name", self.name))? + .to_string(); + let num_outputs = attrs + .get("num_outputs") + .ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))? + .parse::()?; + let flatten_data = attrs + .get("flatten_data") + .ok_or(format!( + "Node `{}` is missing attrs.flatten_data", + self.name + ))? + .parse::()? == 1; + Ok(NodeAttrs { + func_name, + num_outputs, + flatten_data, + }) + } +} + +impl<'a> TryFrom<&'a String> for Graph { + type Error = Error; + fn try_from(graph_json: &String) -> Result { + let graph = serde_json::from_str(graph_json)?; + Ok(graph) + } +} + +impl<'a> TryFrom<&'a str> for Graph { + type Error = Error; + fn try_from(graph_json: &'a str) -> Result { + let graph = serde_json::from_str(graph_json)?; + Ok(graph) + } +} + +pub struct GraphExecutor<'m, 't> { + graph: Graph, + op_execs: Vec>, + tensors: Vec>, +} + +impl<'m, 't> GraphExecutor<'m, 't> { + pub fn new(graph: Graph, lib: &'m M) -> Result { + let tensors = Self::setup_storages(&graph)?; + Ok(GraphExecutor { + op_execs: Self::setup_op_execs(&graph, lib, &tensors)?, + tensors: tensors, + graph: graph, + }) + } + + pub fn run(&self) { + self.op_execs.iter().for_each(|op_exec| { + op_exec(); + }); + } + + fn setup_storages<'a>(graph: &'a Graph) -> Result>> { + let storage_ids = graph.get_attr::<(String, Vec)>("storage_id")?.1; + let shapes = graph.get_attr::<(String, Vec>)>("shape")?.1; + let dtypes = graph + .get_attr::<(String, Vec)>("dltype")? + .1 + .iter() + .map(|dltype| { + if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) { + Ok(dtype) + } else { + Err(ErrorKind::GraphFormatError(format!("Invalid dltype: {}", dltype).to_string()).into()) + } + }) + .collect::>>()?; + + let align = dtypes.iter().map(|dtype| dtype.bits as usize >> 3).max(); + let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1]; + for (i, &storage_id) in storage_ids.iter().enumerate() { + let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3; + let nbytes = dtype_size * shapes[i].iter().product::(); + storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]); + } + + let mut storages: Vec = storage_num_bytes + .into_iter() + .map(|nbytes| Storage::new(nbytes, align)) + .collect::>>()?; + + let tensors = izip!(storage_ids, shapes, dtypes) + .map(|(storage_id, shape, dtype)| { + let storage = storages[storage_id].view(); + Tensor { + data: mem::replace(&mut storages[storage_id], storage), + ctx: TVMContext::default(), + dtype: dtype, + numel: shape.iter().product(), + dshape: shape.iter().map(|&v| v as i64).collect(), + shape: shape, + strides: None, + byte_offset: 0, + } + }) + .collect(); + + Ok(tensors) + } + + fn setup_op_execs( + graph: &Graph, + lib: &'m M, + tensors: &Vec>, + ) -> Result>> { + ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr."); + let node_row_ptr = graph.node_row_ptr.as_ref().unwrap(); + graph + .nodes + .iter() + .enumerate() + .filter(|(_i, node)| node.op != "null") + .map(|(i, node)| { + ensure!(node.op == "tvm_op", "Only TVM ops are supported."); + ensure!(node.attrs.is_some(), "Missing node_row_ptr."); + let attrs = node.parse_attrs()?; + let func = lib + .get_function(&attrs.func_name) + .ok_or(format!("Missing function {}", attrs.func_name))?; + let arg_indices = node + .inputs + .iter() + .map(|entry| graph.entry_index(entry)) + .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi))); + + let dl_tensors = arg_indices + .map(|idx| { + let tensor = &tensors[idx?]; + Ok(if attrs.flatten_data { + DLTensor::from_tensor(tensor, true /* flatten */) + } else { + DLTensor::from(tensor) + }) + }) + .collect::>>() + .unwrap(); + let op: Box = box move || { + let args = dl_tensors + .iter() + .map(|t| t.into()) + .collect::>(); + func(args.as_slice()); + }; + Ok(op) + }) + .collect() + } + + pub fn load_params(&mut self, params: HashMap>) { + params.into_iter().for_each(|(name, param)| { + self.set_input(name, param); + }) + } + + pub fn set_input>(&mut self, name: S, value: Tensor<'t>) { + if let Some(idx) = self.get_input_index(name.as_ref()) { + // TODO: consider `new_with_params` to avoid ever allocating + let ptr = self.tensors[idx].data.as_ptr(); + let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr); + let mut owner = to_replace.nth(0).unwrap(); + if value.data.is_owned() { + // FIXME: for no-copy, need setup_op_execs to not capture tensor ptr + // mem::replace(&mut (*owner), value); + // to_replace.for_each(|t| { + // panic!("replacing"); + // t.data = owner.data.view(); + // }); + owner.copy(&value); + } else { + owner.copy(&value); + } + } else { + println!("Unexpected input `{}`", name.as_ref()); + } + } + + pub fn get_input>(&mut self, name: S) -> Option<&Tensor> { + self + .get_input_index(name.as_ref()) + .and_then(move |idx| Some(&self.tensors[idx])) + } + + pub fn get_output(&self, idx: usize) -> Option<&Tensor> { + let graph = &self.graph; + graph.heads.get(idx).and_then(|entry| { + graph + .entry_index(entry) + .map(|idx| self.tensors.get(idx)) + .unwrap_or(None) + }) + } + + pub fn get_input_index>(&self, name: S) -> Option { + let graph = &self.graph; + (0..graph.nodes.len()) + .skip_while(|&i| graph.nodes[i].name != name.as_ref()) + .nth(0) + .and_then(|i| { + if graph.arg_nodes.iter().any(|&id| id == i) { + graph.node_row_ptr.as_ref().map(|nrp| nrp[i]) + } else { + None + } + }) + } +} + +named!( + tvm_str_to_type, + do_parse!( + type_name: alpha1 >> + bits: digit1 >> + lanes: opt!(tuple!(tag!("x"), digit1)) >> + (DataType { + code: match type_name { + CompleteStr("int") => DLDataTypeCode_kDLInt, + CompleteStr("uint") => DLDataTypeCode_kDLUInt, + CompleteStr("float") => DLDataTypeCode_kDLFloat, + _ => DLDataTypeCode_kDLFloat, + } as usize, + bits: bits.parse::().unwrap() as usize, + lanes: match lanes { + Some(lanes) => lanes.1.parse::().unwrap() as usize, + None => 1, + }, + }) + ) +); + +named!( + name, + map_res!(length_bytes!(le_u64), |b: &[u8]| { + String::from_utf8(b.to_vec()) + }) +); + +named!( + tvm_ctx<&[u8], TVMContext>, + do_parse!( + device_type: le_u32 >> + device_id: le_i32 >> + (TVMContext { device_type: device_type as usize, device_id: device_id as usize }) + ) +); + +named!( + data_type<&[u8], DataType>, + do_parse!( + code: le_u8 >> + bits: le_u8 >> + lanes: le_u16 >> + (DataType { code: code as usize, bits: bits as usize, lanes: lanes as usize }) + ) +); + +named!( + tensor, + do_parse!( + take!(8) >> bits!(tag_bits!(u64, 64, 0)) >> ctx: tvm_ctx >> ndim: le_u32 >> dtype: data_type + >> shape: count!(map!(le_i64, |sz| sz as usize), ndim as usize) >> length: le_i64 + >> data: take!(length) >> (Tensor { + data: Storage::from(data), + ctx: ctx, + dtype: dtype, + numel: shape.iter().product(), + dshape: shape.iter().map(|&v| v as i64).collect(), + shape: shape, + strides: None, + byte_offset: 0, + }) + ) +); + +named!( + parse_param_dict>, + do_parse!( + take!(8) >> bits!(tag_bits!(u64, 64, 0)) >> names: length_count!(le_u64, name) + >> tensors: length_count!(le_u64, tensor) + >> (HashMap::from_iter(names.into_iter().zip(tensors.into_iter()))) + ) +); + +pub fn load_param_dict(bytes: &[u8]) -> Result> { + if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) { + if remaining_bytes.len() > 0 { + bail!(ErrorKind::LoadGraphParamsError("extra input".to_string())) + } else { + Ok(param_dict) + } + } else { + bail!(ErrorKind::LoadGraphParamsError( + "invalid parameters file".to_string() + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_str_to_type() { + assert_eq!( + tvm_str_to_type(CompleteStr("float24")).unwrap().1, + DataType { + code: DLDataTypeCode_kDLFloat as usize, + bits: 24, + lanes: 1 + } + ); + assert_eq!( + tvm_str_to_type(CompleteStr("uint111x44")).unwrap().1, + DataType { + code: DLDataTypeCode_kDLUInt as usize, + bits: 111, + lanes: 44 + } + ); + } +} diff --git a/src/runtime/rust/src/runtime/mod.rs b/src/runtime/rust/src/runtime/mod.rs new file mode 100644 index 000000000000..871d7aff58d7 --- /dev/null +++ b/src/runtime/rust/src/runtime/mod.rs @@ -0,0 +1,19 @@ +mod allocator; +mod array; +mod module; +#[macro_use] +mod packed_func; +mod graph; +mod threading; +mod workspace; + +use std::{ffi::CStr, os::raw::c_char}; + +pub use self::{array::*, graph::*, module::*, packed_func::*, threading::*, workspace::*}; + +#[no_mangle] +pub extern "C" fn TVMAPISetLastError(cmsg: *const c_char) { + unsafe { + panic!(CStr::from_ptr(cmsg).to_str().unwrap()); + } +} diff --git a/src/runtime/rust/src/runtime/module.rs b/src/runtime/rust/src/runtime/module.rs new file mode 100644 index 000000000000..e23c21fcd611 --- /dev/null +++ b/src/runtime/rust/src/runtime/module.rs @@ -0,0 +1,46 @@ +use std::{ + collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, +}; + +use ffi::runtime::BackendPackedCFunc; +use runtime::packed_func::{wrap_backend_packed_func, PackedFunc}; + +pub trait Module { + fn get_function>(&self, name: S) -> Option; +} + +pub struct SystemLibModule {} + +lazy_static! { + static ref SYSTEM_LIB_FUNCTIONS: Mutex> = + Mutex::new(HashMap::new()); +} + +impl Module for SystemLibModule { + fn get_function>(&self, name: S) -> Option { + SYSTEM_LIB_FUNCTIONS + .lock() + .unwrap() + .get(name.as_ref()) + .map(|func| wrap_backend_packed_func(func.to_owned())) + } +} + +impl Default for SystemLibModule { + fn default() -> Self { + SystemLibModule {} + } +} + +#[no_mangle] +pub extern "C" fn TVMBackendRegisterSystemLibSymbol( + cname: *const c_char, + func: BackendPackedCFunc, +) -> i32 { + let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() }; + SYSTEM_LIB_FUNCTIONS + .lock() + .unwrap() + .insert(name.to_string(), func); + return 0; +} diff --git a/src/runtime/rust/src/runtime/packed_func.rs b/src/runtime/rust/src/runtime/packed_func.rs new file mode 100644 index 000000000000..8d28a20863f7 --- /dev/null +++ b/src/runtime/rust/src/runtime/packed_func.rs @@ -0,0 +1,199 @@ +use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void}; + +use ffi::runtime::{ + BackendPackedCFunc, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLTensor, + TVMTypeCode_kArrayHandle, TVMTypeCode_kHandle, TVMValue, +}; + +use errors::*; + +pub type PackedFunc = Box TVMRetValue>; + +#[macro_export] +macro_rules! call_packed { + ($fn:expr, $($args:expr),+) => { + $fn(&[$($args.into(),)+]) + }; +} + +#[derive(Clone, Copy)] +pub struct TVMArgValue<'a> { + _lifetime: PhantomData<&'a ()>, + value: TVMValue, + type_code: i64, +} + +macro_rules! impl_prim_tvm_arg { + ($type:ty, $field:ident, $code:expr, $as:ty) => { + impl<'a> From<$type> for TVMArgValue<'a> { + fn from(val: $type) -> Self { + TVMArgValue { + value: TVMValue { $field: val as $as }, + type_code: $code as i64, + _lifetime: PhantomData, + } + } + } + }; + ($type:ty, $field:ident, $code:expr) => { + impl_prim_tvm_arg!($type, $field, $code, $type); + }; + ($type:ty,v_int64) => { + impl_prim_tvm_arg!($type, v_int64, DLDataTypeCode_kDLInt, i64); + }; + ($type:ty,v_float64) => { + impl_prim_tvm_arg!($type, v_float64, DLDataTypeCode_kDLFloat, f64); + }; +} + +impl_prim_tvm_arg!(f32, v_float64); +impl_prim_tvm_arg!(f64, v_float64); +impl_prim_tvm_arg!(i8, v_int64); +impl_prim_tvm_arg!(u8, v_int64); +impl_prim_tvm_arg!(i32, v_int64); +impl_prim_tvm_arg!(u32, v_int64); +impl_prim_tvm_arg!(i64, v_int64); +impl_prim_tvm_arg!(u64, v_int64); +impl_prim_tvm_arg!(bool, v_int64); + +impl<'a, T> From<*const T> for TVMArgValue<'a> { + fn from(ptr: *const T) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: ptr as *mut T as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a, T> From<*mut T> for TVMArgValue<'a> { + fn from(ptr: *mut T) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: ptr as *mut c_void, + }, + type_code: TVMTypeCode_kHandle as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> { + fn from(arr: &'a mut DLTensor) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: arr as *mut _ as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a> From<&'a DLTensor> for TVMArgValue<'a> { + fn from(arr: &'a DLTensor) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: arr as *const _ as *mut DLTensor as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +pub struct TVMRetValue { + prim_value: u64, + box_value: Box, + type_code: i64, +} + +impl Default for TVMRetValue { + fn default() -> Self { + TVMRetValue { + prim_value: 0, + box_value: box (), + type_code: 0, + } + } +} + +macro_rules! impl_prim_ret_value { + ($type:ty, $code:expr) => { + impl From<$type> for TVMRetValue { + fn from(val: $type) -> Self { + TVMRetValue { + prim_value: val as u64, + box_value: box (), + type_code: $code, + } + } + } + impl TryFrom for $type { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$type> { + if ret.type_code == $code { + Ok(ret.prim_value as $type) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!($type).to_string(), + ret.type_code + )) + } + } + } + }; +} + +macro_rules! impl_boxed_ret_value { + ($type:ty, $code:expr) => { + impl From<$type> for TVMRetValue { + fn from(val: $type) -> Self { + TVMRetValue { + prim_value: 0, + box_value: box val, + type_code: $code, + } + } + } + impl TryFrom for $type { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$type> { + if let Ok(val) = ret.box_value.downcast::<$type>() { + Ok(*val) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!($type).to_string(), + ret.type_code + )) + } + } + } + }; +} + +impl_prim_ret_value!(i32, 0); +impl_prim_ret_value!(u32, 1); +impl_prim_ret_value!(f32, 2); +impl_boxed_ret_value!(String, 11); + +pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { + box move |args: &[TVMArgValue]| { + func( + args + .iter() + .map(|ref arg| arg.value) + .collect::>() + .as_ptr(), + args + .iter() + .map(|ref arg| arg.type_code as i32) + .collect::>() + .as_ptr() as *const i32, + args.len() as i32, + ); + TVMRetValue::default() + } +} diff --git a/src/runtime/rust/src/runtime/threading.rs b/src/runtime/rust/src/runtime/threading.rs new file mode 100644 index 000000000000..84b409aa6b3e --- /dev/null +++ b/src/runtime/rust/src/runtime/threading.rs @@ -0,0 +1,299 @@ +use std::{ + os::raw::{c_int, c_void}, + sync::{ + atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT}, + Arc, + Barrier, + }, +}; + +#[cfg(not(target_env = "sgx"))] +use num_cpus; +#[cfg(not(target_env = "sgx"))] +use std::{ + env, + thread::{self, JoinHandle}, +}; + +#[cfg(target_env = "sgx")] +use std::{collections::VecDeque, sync::Mutex}; + +use bounded_spsc_queue::{self, Producer}; + +use super::super::errors::*; +use ffi::runtime::TVMParallelGroupEnv; + +type FTVMParallelLambda = + extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; + +struct Job { + cb: FTVMParallelLambda, + cdata: *const c_void, + req_num_tasks: usize, + pending: Arc, +} + +impl Job { + fn tasks(&self, num_workers: usize) -> Vec { + let num_tasks = if self.req_num_tasks == 0 { + num_workers + } else { + self.req_num_tasks.min(num_workers) + }; + self.pending.store(num_tasks, Ordering::SeqCst); + + let barrier = Arc::new(Barrier::new(num_tasks)); + + (0..num_tasks) + .map(move |i| Task { + id: i, + flambda: self.cb, + penv: TVMParallelGroupEnv { + sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void, + num_task: num_tasks as i32, + }, + cdata: self.cdata, + pending: Arc::clone(&self.pending), + }) + .collect() + } + + fn wait(&self) -> Result<()> { + while self.pending.load(Ordering::Acquire) > 0 { + #[cfg(not(target_env = "sgx"))] + thread::yield_now(); + } + Ok(()) + } +} + +struct Task { + id: usize, + flambda: FTVMParallelLambda, + penv: TVMParallelGroupEnv, + cdata: *const c_void, + pending: Arc, +} +unsafe impl Send for Task {} +unsafe impl Sync for Task {} + +impl FnOnce<()> for Task { + type Output = i32; + extern "rust-call" fn call_once(self, _args: ()) -> Self::Output { + let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata); + self.pending.fetch_sub(1, Ordering::AcqRel); + status + } +} + +#[derive(Default)] +struct Threads { + #[allow(unused)] + #[cfg(not(target_env = "sgx"))] + handles: Vec>, + queues: Vec>, +} + +impl<'a> Threads { + #[cfg(not(target_env = "sgx"))] + fn launch) + 'static + Copy>( + num_threads: usize, + cb: F, + ) -> Self { + let (handles, queues) = (0..num_threads) + .map(|_| { + let (p, c) = bounded_spsc_queue::make(2); + let handle = thread::spawn(move || cb(c.into())); + (handle, p) + }) + .unzip(); + Threads { + handles: handles, + queues: queues, + } + } + + #[cfg(target_env = "sgx")] + fn launch) + 'static + Copy>(num: usize, _cb: F) -> Self { + let mut consumer_queues = SGX_QUEUES.lock().unwrap(); + let queues = (0..num) + .map(|_| { + let (p, c) = bounded_spsc_queue::make(2); + consumer_queues.push_back(c.into()); + p + }) + .collect(); + Threads { queues: queues } + } +} + +struct ThreadPool { + num_workers: usize, + #[allow(unused)] + threads: Threads, +} + +thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new()); + +impl ThreadPool { + fn new() -> Self { + let num_workers = max_concurrency(); + ThreadPool { + num_workers: num_workers, + threads: Threads::launch(num_workers, ThreadPool::run_worker), + } + } + + fn launch(&self, job: Job) { + let tasks = job.tasks(self.num_workers); + + let _: Vec<()> = tasks + .into_iter() + .zip(self.threads.queues.iter()) + .map(|(task, q)| q.push(task)) + .collect(); + + job.wait().unwrap(); + } + + fn run_worker(queue: Consumer) { + loop { + let task = queue.pop(); + if task() != 0 { + panic!("Error running task."); + } + } + } +} + +struct Consumer { + consumer: bounded_spsc_queue::Consumer, +} +impl From> for Consumer { + fn from(c: bounded_spsc_queue::Consumer) -> Self { + Consumer { consumer: c } + } +} +impl Consumer { + fn pop(&self) -> T { + self.consumer.pop() + } +} +unsafe impl Send for Consumer {} +unsafe impl Sync for Consumer {} + +#[cfg(target_env = "sgx")] +lazy_static! { + static ref SGX_QUEUES: Mutex>> = Mutex::new(VecDeque::new()); +} + +#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))] +fn max_concurrency() -> usize { + if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) { + if let Ok(threads) = usize::from_str_radix(&threads_str, 10) { + return threads; + } + } + num_cpus::get_physical() +} + +#[cfg(target_env = "sgx")] +fn max_concurrency() -> usize { + usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1) +} + +#[cfg(target_arch = "wasm32")] +fn max_concurrency() -> usize { + 0 +} + +#[cfg(target_env = "sgx")] +#[no_mangle] +pub extern "C" fn tvm_ecall_run_worker() { + if let Some(q) = SGX_QUEUES.lock().unwrap().pop_front() { + ThreadPool::run_worker(q); + } +} + +#[no_mangle] +pub extern "C" fn TVMBackendParallelLaunch( + cb: FTVMParallelLambda, + cdata: *const c_void, + num_task: usize, +) -> c_int { + if max_concurrency() == 0 { + let penv = TVMParallelGroupEnv { + sync_handle: 0 as *mut c_void, + num_task: 1, + }; + cb(0, &penv as *const _, cdata); + + #[cfg(feature = "par-launch-alloc")] + let break_the_heap: Vec = Vec::new(); // TODO: why does allocating break? + } else { + THREAD_POOL.with(|pool| { + pool.launch(Job { + cb: cb, + cdata: cdata, + req_num_tasks: num_task, + pending: Arc::new(ATOMIC_USIZE_INIT), + }); + }); + } + return 0; +} + +#[no_mangle] +pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) { + let barrier: &Arc = unsafe { &*((*penv).sync_handle as *const Arc) }; + barrier.wait(); +} + +#[cfg(test)] +mod tests { + use std::{ptr, thread, time::Duration}; + + use super::*; + + #[test] + fn test_max_concurrency() { + env::set_var("TVM_NUM_THREADS", "42"); + env::set_var("OMP_NUM_THREADS", "24"); + assert_eq!(max_concurrency(), 42); + env::remove_var("TVM_NUM_THREADS"); + assert_eq!(max_concurrency(), 24); + } + + extern "C" fn flambda( + task_id: usize, + penv: *const TVMParallelGroupEnv, + cdata: *const c_void, + ) -> i32 { + if cdata == ptr::null() { + return 0; + } + unsafe { + let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize)); + thread::sleep(Duration::from_millis(50 * task_id as u64)); + counter.fetch_add(1, Ordering::SeqCst); + task_ids_sum.fetch_add(task_id, Ordering::SeqCst); + assert_eq!((*penv).num_task, 3); + } + 0 + } + + #[test] + fn test_parallel_launch() { + TVMBackendParallelLaunch(flambda, ptr::null(), 6); + let counter = ATOMIC_USIZE_INIT; + let task_ids_sum = ATOMIC_USIZE_INIT; + let cdata = (counter, task_ids_sum); + let num_tasks = 3; + TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks); + assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks); + assert_eq!( + cdata.1.load(Ordering::SeqCst), + (0..num_tasks).sum::() + ); + } +} diff --git a/src/runtime/rust/src/runtime/workspace.rs b/src/runtime/rust/src/runtime/workspace.rs new file mode 100644 index 000000000000..fe9d8550a32c --- /dev/null +++ b/src/runtime/rust/src/runtime/workspace.rs @@ -0,0 +1,103 @@ +use std::{ + cell::RefCell, + os::raw::{c_int, c_void}, + ptr, +}; + +use super::allocator::Allocation; +use errors::*; + +struct WorkspacePool { + workspaces: Vec, + free: Vec, + in_use: Vec, +} + +impl WorkspacePool { + fn new() -> Self { + WorkspacePool { + workspaces: Vec::new(), + free: Vec::new(), + in_use: Vec::new(), + } + } + + fn alloc(&mut self, size: usize) -> Result<*mut u8> { + if self.free.len() == 0 { + self.workspaces.push(Allocation::new(size, None)?); + self.free.push(self.workspaces.len() - 1); + Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr()) + } else { + let i = self.free.iter().fold(0, |cur_ws_idx, &idx| { + let cur_size = self.workspaces[cur_ws_idx].size(); + let ws_size = self.workspaces[idx].size(); + if ws_size < size || ws_size > cur_size { + cur_ws_idx + } else { + idx + } + }); + let idx = self.free.remove(i); + self.in_use.push(idx.clone()); + Ok(self.workspaces[idx].as_mut_ptr()) + } + } + + fn free(&mut self, ptr: *mut u8) -> Result<()> { + let mut ws_idx = None; + for i in 0..self.in_use.len() { + let idx = self.in_use[i]; + if self.workspaces[idx].as_mut_ptr() == ptr { + self.in_use.remove(i); + ws_idx = Some(idx); + break; + } + } + Ok( + self + .free + .push(ws_idx.ok_or("Tried to free nonexistent workspace.")?), + ) + } +} + +thread_local!(static WORKSPACE_POOL: RefCell = RefCell::new(WorkspacePool::new())); + +const WORKSPACE_PAGE_SIZE: usize = 4 << 10; + +#[no_mangle] +pub extern "C" fn TVMBackendAllocWorkspace( + _device_type: c_int, + _device_id: c_int, + size: u64, + _dtype_code_hint: c_int, + _dtype_bits_hint: c_int, +) -> *mut c_void { + let nbytes = if size == 0 { + WORKSPACE_PAGE_SIZE + } else { + size as usize + }; + WORKSPACE_POOL.with(|pool_cell| { + (match pool_cell.borrow_mut().alloc(nbytes as usize) { + Ok(ptr) => ptr, + Err(_) => ptr::null_mut(), + }) as *mut c_void + }); + return ptr::null_mut(); +} + +#[no_mangle] +pub extern "C" fn TVMBackendFreeWorkspace( + _device_type: c_int, + _device_id: c_int, + ptr: *mut c_void, +) -> c_int { + WORKSPACE_POOL.with(|pool_cell| { + (match pool_cell.borrow_mut().free(ptr as *mut u8) { + Ok(()) => 0, + Err(_) => -1, + }) as c_int + }); + return 0; +} diff --git a/src/runtime/rust/tests/.gitignore b/src/runtime/rust/tests/.gitignore new file mode 100644 index 000000000000..811076739bfa --- /dev/null +++ b/src/runtime/rust/tests/.gitignore @@ -0,0 +1,3 @@ +*.json +*.params +*.o diff --git a/src/runtime/rust/tests/test_graph_serde.rs b/src/runtime/rust/tests/test_graph_serde.rs new file mode 100644 index 000000000000..a3679812e3eb --- /dev/null +++ b/src/runtime/rust/tests/test_graph_serde.rs @@ -0,0 +1,39 @@ +#![feature(fs_read_write, try_from)] + +extern crate serde; +extern crate serde_json; + +extern crate tvm; + +use std::{collections::HashMap, convert::TryFrom, fs, io::Read}; + +use tvm::runtime::Graph; + +#[test] +fn test_load_graph() { + let mut params_bytes = Vec::new(); + fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params")) + .unwrap() + .read_to_end(&mut params_bytes) + .unwrap(); + let params = tvm::runtime::load_param_dict(¶ms_bytes); + + let graph = + Graph::try_from( + &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(), + ).unwrap(); + + assert_eq!(graph.nodes[3].op, "tvm_op"); + assert_eq!( + graph.nodes[3] + .attrs + .as_ref() + .unwrap() + .get("func_name") + .unwrap(), + "fuse_dense" + ); + assert_eq!(graph.nodes[5].inputs[0].index, 0); + assert_eq!(graph.nodes[6].inputs[0].index, 1); + assert_eq!(graph.heads.len(), 2); +} diff --git a/src/runtime/rust/tests/test_nnvm/Cargo.toml b/src/runtime/rust/tests/test_nnvm/Cargo.toml new file mode 100644 index 000000000000..978bdf9a428f --- /dev/null +++ b/src/runtime/rust/tests/test_nnvm/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "test-nnvm" +version = "0.0.0" +license = "Apache-2.0" +authors = ["Nick Hynes "] + +[features] +par-launch-alloc = ["tvm/par-launch-alloc"] + +[dependencies] +ndarray = "0.11.2" +tvm = { path = "../../" } +serde = "1.0.59" +serde_json = "1.0.17" + +[build-dependencies] +ar = "0.6.0" diff --git a/src/runtime/rust/tests/test_nnvm/src/main.rs b/src/runtime/rust/tests/test_nnvm/src/main.rs new file mode 100644 index 000000000000..9fc5ba5d3537 --- /dev/null +++ b/src/runtime/rust/tests/test_nnvm/src/main.rs @@ -0,0 +1,84 @@ +#![feature(fs_read_write, try_from)] + +#[macro_use] +extern crate ndarray; +extern crate serde; +extern crate serde_json; + +extern crate tvm; +use std::{collections::HashMap, convert::TryFrom, fs, io::Read}; + +use ndarray::Array; +use tvm::runtime::{Graph, GraphExecutor, SystemLibModule, Tensor}; + +const BATCH_SIZE: usize = 4; +const IN_DIM: usize = 8; + +macro_rules! assert_sum_eq { + ($a:expr, $b:expr) => { + let a_sum = $a.scalar_sum(); + let b_sum = $b.scalar_sum(); + assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum); + }; +} + +fn main() { + let syslib = SystemLibModule::default(); + + let mut params_bytes = Vec::new(); + fs::File::open(concat!(env!("OUT_DIR"), "/graph.params")) + .unwrap() + .read_to_end(&mut params_bytes) + .unwrap(); + let params = tvm::runtime::load_param_dict(¶ms_bytes) + .unwrap() + .into_iter() + .map(|(k, v)| (k, v.to_owned())) + .collect::>>(); + + let graph = + Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap()).unwrap(); + let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); + + let x = Array::from_shape_vec( + (BATCH_SIZE, IN_DIM), + (0..BATCH_SIZE * IN_DIM) + .map(|x| x as f32) + .collect::>(), + ).unwrap(); + let w = Array::try_from(params.get("dense0_weight").unwrap()) + .unwrap() + .into_shape((IN_DIM * 2, IN_DIM)) + .unwrap(); + let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap(); + let dense = x.dot(&w.t()) + &b; + let left = dense.slice(s![.., 0..IN_DIM]); + let right = dense.slice(s![.., IN_DIM..]); + let expected_o0 = &left + 1f32; + let expected_o1 = &right - 1f32; + + exec.load_params(params); + exec.set_input("data", x.clone().into()); + + assert_sum_eq!(Array::try_from(exec.get_input("data").unwrap()).unwrap(), x); + assert_sum_eq!( + Array::try_from(exec.get_input("dense0_weight").unwrap()).unwrap(), + w + ); + assert_sum_eq!( + Array::try_from(exec.get_input("dense0_bias").unwrap()).unwrap(), + b + ); + + exec.run(); + + assert_sum_eq!( + Array::try_from(exec.get_output(0).unwrap()).unwrap(), + expected_o0 + ); + assert_sum_eq!( + Array::try_from(exec.get_output(1).unwrap()).unwrap(), + expected_o1 + ); + assert_sum_eq!(Array::try_from(exec.get_output(2).unwrap()).unwrap(), dense); +} diff --git a/src/runtime/rust/tests/test_tvm_basic/Cargo.toml b/src/runtime/rust/tests/test_tvm_basic/Cargo.toml new file mode 100644 index 000000000000..bd4193bcb8fb --- /dev/null +++ b/src/runtime/rust/tests/test_tvm_basic/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "test-tvm-basic" +version = "0.0.0" +license = "Apache-2.0" +authors = ["Nick Hynes "] + +[dependencies] +ndarray = "0.11.2" +tvm = { path = "../../" } + +[build-dependencies] +ar = "0.6.0" diff --git a/src/runtime/rust/tests/test_tvm_basic/src/main.rs b/src/runtime/rust/tests/test_tvm_basic/src/main.rs new file mode 100644 index 000000000000..b6c11451d12a --- /dev/null +++ b/src/runtime/rust/tests/test_tvm_basic/src/main.rs @@ -0,0 +1,25 @@ +extern crate ndarray; +#[macro_use] +extern crate tvm; + +use ndarray::Array; +use tvm::{ + ffi::runtime::DLTensor, + runtime::{Module, SystemLibModule}, +}; + +fn main() { + let syslib = SystemLibModule::default(); + let add = syslib + .get_function("default_function") + .expect("main function not found"); + let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); + let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); + let mut c = Array::from_vec(vec![0f32; 4]); + let e = Array::from_vec(vec![2f32, 2., 4., 4.]); + let mut a_dl: DLTensor = (&mut a).into(); + let mut b_dl: DLTensor = (&mut b).into(); + let mut c_dl: DLTensor = (&mut c).into(); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl); + assert!(c.all_close(&e, 1e-8f32)); +} From 81bcb88a6643e503ae2cad58cd960fee5a1d74b8 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Tue, 14 Aug 2018 04:07:20 +0000 Subject: [PATCH 02/10] Move files --- {src/runtime/rust => rust}/.gitignore | 0 {src/runtime/rust => rust}/.rustfmt.toml | 0 {src/runtime/rust => rust}/.travis.yml | 0 {src/runtime/rust => rust}/Cargo.toml | 0 {src/runtime/rust => rust}/src/errors.rs | 0 {src/runtime/rust => rust}/src/lib.rs | 0 {src/runtime/rust => rust}/src/runtime/allocator.rs | 0 {src/runtime/rust => rust}/src/runtime/array.rs | 0 {src/runtime/rust => rust}/src/runtime/graph.rs | 0 {src/runtime/rust => rust}/src/runtime/mod.rs | 0 {src/runtime/rust => rust}/src/runtime/module.rs | 0 {src/runtime/rust => rust}/src/runtime/packed_func.rs | 0 {src/runtime/rust => rust}/src/runtime/threading.rs | 0 {src/runtime/rust => rust}/src/runtime/workspace.rs | 0 {src/runtime/rust => rust}/tests/.gitignore | 0 {src/runtime/rust => rust}/tests/test_graph_serde.rs | 0 {src/runtime/rust => rust}/tests/test_nnvm/Cargo.toml | 0 {src/runtime/rust => rust}/tests/test_nnvm/src/main.rs | 0 {src/runtime/rust => rust}/tests/test_tvm_basic/Cargo.toml | 0 {src/runtime/rust => rust}/tests/test_tvm_basic/src/main.rs | 0 20 files changed, 0 insertions(+), 0 deletions(-) rename {src/runtime/rust => rust}/.gitignore (100%) rename {src/runtime/rust => rust}/.rustfmt.toml (100%) rename {src/runtime/rust => rust}/.travis.yml (100%) rename {src/runtime/rust => rust}/Cargo.toml (100%) rename {src/runtime/rust => rust}/src/errors.rs (100%) rename {src/runtime/rust => rust}/src/lib.rs (100%) rename {src/runtime/rust => rust}/src/runtime/allocator.rs (100%) rename {src/runtime/rust => rust}/src/runtime/array.rs (100%) rename {src/runtime/rust => rust}/src/runtime/graph.rs (100%) rename {src/runtime/rust => rust}/src/runtime/mod.rs (100%) rename {src/runtime/rust => rust}/src/runtime/module.rs (100%) rename {src/runtime/rust => rust}/src/runtime/packed_func.rs (100%) rename {src/runtime/rust => rust}/src/runtime/threading.rs (100%) rename {src/runtime/rust => rust}/src/runtime/workspace.rs (100%) rename {src/runtime/rust => rust}/tests/.gitignore (100%) rename {src/runtime/rust => rust}/tests/test_graph_serde.rs (100%) rename {src/runtime/rust => rust}/tests/test_nnvm/Cargo.toml (100%) rename {src/runtime/rust => rust}/tests/test_nnvm/src/main.rs (100%) rename {src/runtime/rust => rust}/tests/test_tvm_basic/Cargo.toml (100%) rename {src/runtime/rust => rust}/tests/test_tvm_basic/src/main.rs (100%) diff --git a/src/runtime/rust/.gitignore b/rust/.gitignore similarity index 100% rename from src/runtime/rust/.gitignore rename to rust/.gitignore diff --git a/src/runtime/rust/.rustfmt.toml b/rust/.rustfmt.toml similarity index 100% rename from src/runtime/rust/.rustfmt.toml rename to rust/.rustfmt.toml diff --git a/src/runtime/rust/.travis.yml b/rust/.travis.yml similarity index 100% rename from src/runtime/rust/.travis.yml rename to rust/.travis.yml diff --git a/src/runtime/rust/Cargo.toml b/rust/Cargo.toml similarity index 100% rename from src/runtime/rust/Cargo.toml rename to rust/Cargo.toml diff --git a/src/runtime/rust/src/errors.rs b/rust/src/errors.rs similarity index 100% rename from src/runtime/rust/src/errors.rs rename to rust/src/errors.rs diff --git a/src/runtime/rust/src/lib.rs b/rust/src/lib.rs similarity index 100% rename from src/runtime/rust/src/lib.rs rename to rust/src/lib.rs diff --git a/src/runtime/rust/src/runtime/allocator.rs b/rust/src/runtime/allocator.rs similarity index 100% rename from src/runtime/rust/src/runtime/allocator.rs rename to rust/src/runtime/allocator.rs diff --git a/src/runtime/rust/src/runtime/array.rs b/rust/src/runtime/array.rs similarity index 100% rename from src/runtime/rust/src/runtime/array.rs rename to rust/src/runtime/array.rs diff --git a/src/runtime/rust/src/runtime/graph.rs b/rust/src/runtime/graph.rs similarity index 100% rename from src/runtime/rust/src/runtime/graph.rs rename to rust/src/runtime/graph.rs diff --git a/src/runtime/rust/src/runtime/mod.rs b/rust/src/runtime/mod.rs similarity index 100% rename from src/runtime/rust/src/runtime/mod.rs rename to rust/src/runtime/mod.rs diff --git a/src/runtime/rust/src/runtime/module.rs b/rust/src/runtime/module.rs similarity index 100% rename from src/runtime/rust/src/runtime/module.rs rename to rust/src/runtime/module.rs diff --git a/src/runtime/rust/src/runtime/packed_func.rs b/rust/src/runtime/packed_func.rs similarity index 100% rename from src/runtime/rust/src/runtime/packed_func.rs rename to rust/src/runtime/packed_func.rs diff --git a/src/runtime/rust/src/runtime/threading.rs b/rust/src/runtime/threading.rs similarity index 100% rename from src/runtime/rust/src/runtime/threading.rs rename to rust/src/runtime/threading.rs diff --git a/src/runtime/rust/src/runtime/workspace.rs b/rust/src/runtime/workspace.rs similarity index 100% rename from src/runtime/rust/src/runtime/workspace.rs rename to rust/src/runtime/workspace.rs diff --git a/src/runtime/rust/tests/.gitignore b/rust/tests/.gitignore similarity index 100% rename from src/runtime/rust/tests/.gitignore rename to rust/tests/.gitignore diff --git a/src/runtime/rust/tests/test_graph_serde.rs b/rust/tests/test_graph_serde.rs similarity index 100% rename from src/runtime/rust/tests/test_graph_serde.rs rename to rust/tests/test_graph_serde.rs diff --git a/src/runtime/rust/tests/test_nnvm/Cargo.toml b/rust/tests/test_nnvm/Cargo.toml similarity index 100% rename from src/runtime/rust/tests/test_nnvm/Cargo.toml rename to rust/tests/test_nnvm/Cargo.toml diff --git a/src/runtime/rust/tests/test_nnvm/src/main.rs b/rust/tests/test_nnvm/src/main.rs similarity index 100% rename from src/runtime/rust/tests/test_nnvm/src/main.rs rename to rust/tests/test_nnvm/src/main.rs diff --git a/src/runtime/rust/tests/test_tvm_basic/Cargo.toml b/rust/tests/test_tvm_basic/Cargo.toml similarity index 100% rename from src/runtime/rust/tests/test_tvm_basic/Cargo.toml rename to rust/tests/test_tvm_basic/Cargo.toml diff --git a/src/runtime/rust/tests/test_tvm_basic/src/main.rs b/rust/tests/test_tvm_basic/src/main.rs similarity index 100% rename from src/runtime/rust/tests/test_tvm_basic/src/main.rs rename to rust/tests/test_tvm_basic/src/main.rs From f57d622d228cc16b7208a72eb33e6eb92d43f7be Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Tue, 14 Aug 2018 21:14:38 +0000 Subject: [PATCH 03/10] Re-add build script --- .gitignore | 1 - rust/build.rs | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 rust/build.rs diff --git a/.gitignore b/.gitignore index 833eee1a0774..5816bcfeabd6 100644 --- a/.gitignore +++ b/.gitignore @@ -187,7 +187,6 @@ tvm_u.* tvm_t.* # Mac OS X .DS_Store -build* # Jetbrain .idea diff --git a/rust/build.rs b/rust/build.rs new file mode 100644 index 000000000000..f21c9e0c2c1d --- /dev/null +++ b/rust/build.rs @@ -0,0 +1,47 @@ +extern crate bindgen; + +use std::{env, path::PathBuf}; + +fn parse_clang_ver(raw_v: String) -> Vec { + raw_v + .split_whitespace() + .nth(2) + .unwrap() + .split('.') + .map(|v| v.parse::().unwrap()) + .collect() +} + +fn main() { + let clang_ver = parse_clang_ver(bindgen::clang_version().full); + let bindings = bindgen::Builder::default() + .header(concat!( + env!("CARGO_MANIFEST_DIR"), + "/../include/tvm/runtime/c_runtime_api.h" + )) + .header(concat!( + env!("CARGO_MANIFEST_DIR"), + "/../include/tvm/runtime/c_backend_api.h" + )) + .rust_target(bindgen::RustTarget::Nightly) + .clang_arg(concat!( + "-I", + env!("CARGO_MANIFEST_DIR"), + "/../dlpack/include" + )) + .clang_arg(format!("--target={}", env::var("HOST").unwrap())) + .clang_arg("-I/usr/include") + .clang_arg("-I/usr/local/include") + .clang_arg(format!( + "-I/usr/local/lib/clang/{}.{}.{}/include", + clang_ver[0], clang_ver[1], clang_ver[2] + )) + .layout_tests(false) + .generate() + .expect("Unable to generate bindings."); + + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + bindings + .write_to_file(out_path.join("c_runtime_api.rs")) + .expect("Unable to write bindings."); +} From f01824473cd80c40fcddeab1c61b342285459b65 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Sun, 19 Aug 2018 23:15:22 +0000 Subject: [PATCH 04/10] Add now un-ignored files to tests --- .gitignore | 2 - rust/tests/build_model.py | 53 ++++++++++++++++ rust/tests/test_nnvm/build.rs | 28 +++++++++ rust/tests/test_nnvm/src/build_test_graph.py | 63 +++++++++++++++++++ rust/tests/test_tvm_basic/build.rs | 28 +++++++++ .../test_tvm_basic/src/build_test_lib.py | 21 +++++++ 6 files changed, 193 insertions(+), 2 deletions(-) create mode 100644 rust/tests/build_model.py create mode 100644 rust/tests/test_nnvm/build.rs create mode 100755 rust/tests/test_nnvm/src/build_test_graph.py create mode 100644 rust/tests/test_tvm_basic/build.rs create mode 100755 rust/tests/test_tvm_basic/src/build_test_lib.py diff --git a/.gitignore b/.gitignore index 5816bcfeabd6..368764941cec 100644 --- a/.gitignore +++ b/.gitignore @@ -91,10 +91,8 @@ ENV/ *~ *.pyc *~ -build config.mk config.cmake -build_* Win32 *.dir perf diff --git a/rust/tests/build_model.py b/rust/tests/build_model.py new file mode 100644 index 000000000000..e0b90495159f --- /dev/null +++ b/rust/tests/build_model.py @@ -0,0 +1,53 @@ +"""Builds a simple NNVM graph for testing.""" + +from os import path as osp + +import nnvm +from nnvm import sym +from nnvm.compiler import graph_util +from nnvm.testing import init +import numpy as np +import tvm + +CWD = osp.dirname(osp.abspath(osp.expanduser(__file__))) + + +def _get_model(dshape): + data = sym.Variable('data', shape=dshape) + fc1 = sym.dense(data, units=dshape[-1]*2, use_bias=True) + left, right = sym.split(fc1, indices_or_sections=2, axis=1) + return sym.Group(((left + 1), (right - 1))) + + +def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10): + if isinstance(graph, sym.Symbol): + graph = nnvm.graph.create(graph) + ishapes, _ = graph_util.infer_shape(graph, **input_shapes) + param_shapes = dict(zip(graph.index.input_names, ishapes)) + np.random.seed(seed) + params = {} + for param, shape in param_shapes.items(): + if param in {'data', 'label'} or not shape: + continue + init_value = np.empty(shape).astype('float32') + initializer(param, init_value) + params[param] = tvm.nd.array(init_value) + return params + +def main(): + dshape = (32, 16) + net = _get_model(dshape) + ishape_dict = {'data': dshape} + params = _init_params(net, ishape_dict) + graph, lib, params = nnvm.compiler.build(net, 'llvm', + shape=ishape_dict, + params=params, + dtype='float32') + + with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet: + f_resnet.write(graph.json()) + with open(osp.join(CWD, 'graph.params'), 'wb') as f_params: + f_params.write(nnvm.compiler.save_param_dict(params)) + +if __name__ == '__main__': + main() diff --git a/rust/tests/test_nnvm/build.rs b/rust/tests/test_nnvm/build.rs new file mode 100644 index 000000000000..cb3a4e0d574d --- /dev/null +++ b/rust/tests/test_nnvm/build.rs @@ -0,0 +1,28 @@ +extern crate ar; + +use std::{env, path::PathBuf, process::Command}; + +use ar::Builder; +use std::fs::File; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_graph.py" + )).arg(&out_dir) + .output() + .expect("Failed to execute command"); + if output.stderr.len() > 0 { + panic!(String::from_utf8(output.stderr).unwrap()); + } + + let in_path: PathBuf = [&out_dir, "graph.o"].iter().collect(); + let out_path: PathBuf = [&out_dir, "libgraph.a"].iter().collect(); + let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap()); + builder.append_path(in_path.to_str().unwrap()).unwrap(); + + println!("cargo:rustc-link-lib=static=graph"); + println!("cargo:rustc-link-search=native={}", out_dir); +} diff --git a/rust/tests/test_nnvm/src/build_test_graph.py b/rust/tests/test_nnvm/src/build_test_graph.py new file mode 100755 index 000000000000..429cc2128931 --- /dev/null +++ b/rust/tests/test_nnvm/src/build_test_graph.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 + +"""Builds a simple NNVM graph for testing.""" + +from os import path as osp +import sys + +import nnvm +from nnvm import sym +from nnvm.compiler import graph_util +from nnvm.testing import init +import numpy as np +import tvm + + +def _get_model(dshape): + data = sym.Variable('data', shape=dshape) + fc = sym.dense(data, units=dshape[-1]*2, use_bias=True) + left, right = sym.split(fc, indices_or_sections=2, axis=1) + return sym.Group(((left + 1), (right - 1), fc)) + + +def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10): + if isinstance(graph, sym.Symbol): + graph = nnvm.graph.create(graph) + ishapes, _ = graph_util.infer_shape(graph, **input_shapes) + param_shapes = dict(zip(graph.index.input_names, ishapes)) + np.random.seed(seed) + params = {} + for param, shape in param_shapes.items(): + if param in {'data', 'label'} or not shape: + continue + + init_value = np.arange(np.product(shape), 0, -1).reshape(*shape).astype('float32') + if param.endswith('_bias'): + params[param] = tvm.nd.array(init_value) + continue + + init_value = np.empty(shape).astype('float32') + initializer(param, init_value) + # init_value /= init_value.sum() + 1e-10 + params[param] = tvm.nd.array(init_value) + return params + +def main(): + dshape = (4, 8) + net = _get_model(dshape) + ishape_dict = {'data': dshape} + params = _init_params(net, ishape_dict) + graph, lib, params = nnvm.compiler.build(net, 'llvm --system-lib', + shape=ishape_dict, + params=params, + dtype='float32') + + out_dir = sys.argv[1] + lib.save(osp.join(sys.argv[1], 'graph.o')) + with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet: + f_resnet.write(graph.json()) + with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params: + f_params.write(nnvm.compiler.save_param_dict(params)) + +if __name__ == '__main__': + main() diff --git a/rust/tests/test_tvm_basic/build.rs b/rust/tests/test_tvm_basic/build.rs new file mode 100644 index 000000000000..778dd1cab1ca --- /dev/null +++ b/rust/tests/test_tvm_basic/build.rs @@ -0,0 +1,28 @@ +extern crate ar; + +use std::{env, path::PathBuf, process::Command}; + +use ar::Builder; +use std::fs::File; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_lib.py" + )).arg(&out_dir) + .output() + .expect("Failed to execute command"); + if output.stderr.len() > 0 { + panic!(String::from_utf8(output.stderr).unwrap()); + } + + let in_path: PathBuf = [&out_dir, "test.o"].iter().collect(); + let out_path: PathBuf = [&out_dir, "libtest.a"].iter().collect(); + let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap()); + builder.append_path(in_path.to_str().unwrap()).unwrap(); + + println!("cargo:rustc-link-lib=static=test"); + println!("cargo:rustc-link-search=native={}", out_dir); +} diff --git a/rust/tests/test_tvm_basic/src/build_test_lib.py b/rust/tests/test_tvm_basic/src/build_test_lib.py new file mode 100755 index 000000000000..7289a778fcec --- /dev/null +++ b/rust/tests/test_tvm_basic/src/build_test_lib.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +"""Prepares a simple TVM library for testing.""" + +from os import path as osp +import sys + +import tvm + +def main(): + n = tvm.var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.create_schedule(C.op) + s[C].parallel(s[C].op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o')) + +if __name__ == '__main__': + main() From 72aa38b93c72d75cd0b498af66e1c1883bdc1c03 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Sun, 19 Aug 2018 23:41:38 +0000 Subject: [PATCH 05/10] Add docs --- rust/src/runtime/allocator.rs | 3 ++ rust/src/runtime/array.rs | 64 ++++++++++++++++++++++++++----- rust/src/runtime/graph.rs | 67 ++++++++++++++++++++++++++++----- rust/src/runtime/packed_func.rs | 27 +++++++++++++ rust/src/runtime/threading.rs | 9 ++++- 5 files changed, 149 insertions(+), 21 deletions(-) diff --git a/rust/src/runtime/allocator.rs b/rust/src/runtime/allocator.rs index d3dc772cd4cb..9b029e80946f 100644 --- a/rust/src/runtime/allocator.rs +++ b/rust/src/runtime/allocator.rs @@ -11,6 +11,7 @@ pub struct Allocation { } impl Allocation { + /// Allocates a chunk of memory of `size` bytes with optional alignment. pub fn new(size: usize, align: Option) -> Result { let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); let layout = Layout::from_size_align(size, alignment)?; @@ -28,10 +29,12 @@ impl Allocation { self.ptr } + /// Returns the size of the Allocation in bytes. pub fn size(&self) -> usize { self.layout.size() } + /// Returns the byte alignment of the Allocation. pub fn align(&self) -> usize { self.layout.align() } diff --git a/rust/src/runtime/array.rs b/rust/src/runtime/array.rs index dff63fa92bb1..9fec9ef7b1ef 100644 --- a/rust/src/runtime/array.rs +++ b/rust/src/runtime/array.rs @@ -16,9 +16,13 @@ use ffi::runtime::{ DLDeviceType_kDLCPU, DLTensor, }; +/// A `Storage` is a container which holds `Tensor` data. #[derive(PartialEq)] pub enum Storage<'a> { + /// A `Storage` which owns its contained bytes. Owned(Allocation), + + /// A view of an existing `Storage`. View(&'a mut [u8], usize), // ptr, align } @@ -52,6 +56,7 @@ impl<'a> Storage<'a> { self.as_mut_ptr() as *const _ } + /// Returns a `Storage::View` which points to an owned `Storage::Owned`. pub fn view(&self) -> Storage<'a> { match self { Storage::Owned(alloc) => Storage::View( @@ -72,6 +77,7 @@ impl<'a> Storage<'a> { } } + /// Returns an owned version of this storage via cloning. pub fn to_owned(&self) -> Storage<'static> { let s = Storage::new(self.size(), Some(self.align())).unwrap(); unsafe { @@ -94,24 +100,48 @@ impl<'a, T> From<&'a [T]> for Storage<'a> { } } +/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`. +/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or +/// converted to `ndarray::Array` for non-TVM processing. +/// +/// # Examples +/// +/// ``` +/// extern crate ndarray; +/// +/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]); +/// let mut a: Tensor = a_nd.into(); +/// let mut a_dl: DLTensor = (&mut t).into(); +/// call_packed!(tvm_fn, &mut a_dl); +/// +/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs. +/// let mut a_nd = ndarray::Array::try_from(&a).unwrap(); +/// ``` #[derive(PartialEq)] pub struct Tensor<'a> { + /// The bytes which contain the data this `Tensor` represents. pub(super) data: Storage<'a>, pub(super) ctx: TVMContext, pub(super) dtype: DataType, - pub(super) shape: Vec, + pub(super) shape: Vec, // not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h + /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous. pub(super) strides: Option>, pub(super) byte_offset: isize, pub(super) numel: usize, - pub(super) dshape: Vec, } impl<'a> Tensor<'a> { - pub fn shape(&self) -> Vec { + pub fn shape(&self) -> Vec { self.shape.clone() } + /// Returns the data of this `Tensor` as a `Vec`. + /// + /// # Panics + /// + /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. pub fn to_vec(&self) -> Vec { + assert!(self.is_contiguous()); assert!(self.dtype.is_type::()); let mut vec: Vec = Vec::with_capacity(self.numel * self.dtype.itemsize()); unsafe { @@ -124,10 +154,12 @@ impl<'a> Tensor<'a> { vec } + /// Returns `true` iff this `Tensor` is represented by a contiguous region of memory. pub fn is_contiguous(&self) -> bool { match self.strides { None => true, Some(ref strides) => { + // check that stride for each dimension is the product of all trailing dimensons' shapes self .shape .iter() @@ -137,7 +169,7 @@ impl<'a> Tensor<'a> { |(is_contig, expected_stride), (shape, stride)| { ( is_contig && *stride == expected_stride, - expected_stride * shape, + expected_stride * (*shape as usize), ) }, ) @@ -146,6 +178,11 @@ impl<'a> Tensor<'a> { } } + /// Returns a clone of this `Tensor`. + /// + /// # Panics + /// + /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. pub fn copy(&mut self, other: &Tensor) { assert!( self.dtype == other.dtype && self.numel == other.numel, @@ -169,6 +206,7 @@ impl<'a> Tensor<'a> { } } + /// Returns an owned version of this `Tensor` via cloning. pub fn to_owned(&self) -> Tensor<'static> { let t = Tensor { data: self.data.to_owned(), @@ -178,7 +216,6 @@ impl<'a> Tensor<'a> { shape: self.shape.clone(), strides: None, byte_offset: 0, - dshape: self.dshape.clone(), }; unsafe { mem::transmute::, Tensor<'static>>(t) } } @@ -193,7 +230,7 @@ impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD { tensor.dtype ); Ok(ndarray::Array::from_shape_vec( - tensor.shape.clone(), + tensor.shape.iter().map(|s| *s as usize).collect::>(), tensor.to_vec::(), )?) } @@ -210,8 +247,7 @@ impl DLTensor { shape: if flatten { &tensor.numel as *const _ as *mut i64 } else { - // tensor.shape.as_ptr() - tensor.dshape.as_ptr() as *mut i64 + tensor.shape.as_ptr() } as *mut i64, strides: if flatten || tensor.is_contiguous() { ptr::null_mut() @@ -243,10 +279,12 @@ pub struct DataType { } impl DataType { + /// Returns the number of bytes occupied by an element of this `DataType`. fn itemsize(&self) -> usize { (self.bits * self.lanes) >> 3 } + /// Returns whether this `DataType` represents primitive type `T`. fn is_type(&self) -> bool { if self.lanes != 1 { return false; @@ -325,13 +363,17 @@ fn tensor_from_array_storage<'a, 's, T, D: ndarray::Dimension>( lanes: 1, }, numel: arr.len(), - shape: arr.shape().iter().map(|&v| v as usize).collect(), + shape: arr.shape().iter().map(|&v| v as i64).collect(), strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()), byte_offset: 0, - dshape: arr.shape().iter().map(|&v| v as i64).collect(), } } +/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`. +/// +/// # Panics +/// +/// Panics if the ndarray is not contiguous. macro_rules! impl_tensor_from_ndarray { ($type:ty, $typecode:expr) => { impl From> for Tensor<'static> { @@ -356,6 +398,8 @@ macro_rules! impl_tensor_from_ndarray { }; } +/// `From` conversions to `DLTensor` for `ndarray::Array`. +/// Takes a reference to the `ndarray` since `DLTensor` is not owned. macro_rules! impl_dltensor_from_ndarray { ($type:ty, $typecode:expr) => { impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { diff --git a/rust/src/runtime/graph.rs b/rust/src/runtime/graph.rs index 0289318d6574..edf1c77ebb52 100644 --- a/rust/src/runtime/graph.rs +++ b/rust/src/runtime/graph.rs @@ -10,9 +10,19 @@ use ffi::runtime::{ DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor, }; -const NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F; // Magic number for NDArray file -const NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7; // Magic number for NDArray list file - +// Magic number for NDArray file. @see `kTVMNDArrayMagic` in `ndarray.h` +const NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F; +// Magic number for NDArray list file. @see `kTVMNDArrayListMagic` in `graph_runtime.h` +const NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7; + +/// A TVM computation graph. +/// +/// # Examples +/// +/// ``` +/// let graph_json = fs::read_to_string("graph.json")).unwrap(); +/// let graph = Graph::try_from(&graph_json).unwrap(); +/// ``` #[derive(Serialize, Deserialize, Debug)] pub struct Graph { pub nodes: Vec, @@ -38,6 +48,7 @@ impl Graph { .ok_or("Missing node_row_ptr.".into()) } + /// Attempt to deserialize a JSON attribute to a type `T`. fn get_attr(&self, attr: &str) -> Result { Ok(serde_json::from_value::( self @@ -116,6 +127,31 @@ impl<'a> TryFrom<&'a str> for Graph { } } +/// A executor for a TVM computation graph. +/// +/// # Examples +/// +/// ``` +/// use ndarray::Array; + +/// let syslib = SystemLibModule::default(); // a provider of TVM functions +/// +/// let mut params_bytes = Vec::new(); +/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap(); +/// let params = tvm::runtime::load_param_dict(¶ms_bytes).unwrap(); +/// +/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap(); +/// +/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); +/// exec.load_params(params); + +/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]); +/// exec.set_input("data", x.into()); +/// exec.run(); +/// let output = exec.get_output(0).unwrap(); +/// +/// println!("{:#?}", Array::try_from(output).unwrap()); +/// ``` pub struct GraphExecutor<'m, 't> { graph: Graph, op_execs: Vec>, @@ -132,15 +168,17 @@ impl<'m, 't> GraphExecutor<'m, 't> { }) } + /// Runs the computation graph. pub fn run(&self) { self.op_execs.iter().for_each(|op_exec| { op_exec(); }); } + /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output. fn setup_storages<'a>(graph: &'a Graph) -> Result>> { let storage_ids = graph.get_attr::<(String, Vec)>("storage_id")?.1; - let shapes = graph.get_attr::<(String, Vec>)>("shape")?.1; + let shapes = graph.get_attr::<(String, Vec>)>("shape")?.1; let dtypes = graph .get_attr::<(String, Vec)>("dltype")? .1 @@ -158,7 +196,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1]; for (i, &storage_id) in storage_ids.iter().enumerate() { let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3; - let nbytes = dtype_size * shapes[i].iter().product::(); + let nbytes = dtype_size * shapes[i].iter().product::() as usize; storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]); } @@ -174,8 +212,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { data: mem::replace(&mut storages[storage_id], storage), ctx: TVMContext::default(), dtype: dtype, - numel: shape.iter().product(), - dshape: shape.iter().map(|&v| v as i64).collect(), + numel: shape.iter().product::() as usize, shape: shape, strides: None, byte_offset: 0, @@ -186,6 +223,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { Ok(tensors) } + /// Creates closures which represent the computation performed by this graph. fn setup_op_execs( graph: &Graph, lib: &'m M, @@ -262,12 +300,14 @@ impl<'m, 't> GraphExecutor<'m, 't> { } } + /// Returns the graph input with name `name`, if it exists. pub fn get_input>(&mut self, name: S) -> Option<&Tensor> { self .get_input_index(name.as_ref()) .and_then(move |idx| Some(&self.tensors[idx])) } + /// Returns the graph output with index `index`, if it exists. pub fn get_output(&self, idx: usize) -> Option<&Tensor> { let graph = &self.graph; graph.heads.get(idx).and_then(|entry| { @@ -278,6 +318,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { }) } + /// Returns the index for graph input with name `name`, if it exists. pub fn get_input_index>(&self, name: S) -> Option { let graph = &self.graph; (0..graph.nodes.len()) @@ -293,6 +334,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { } } +/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h named!( tvm_str_to_type, do_parse!( @@ -315,6 +357,7 @@ named!( ) ); +/// Converts a bytes to String. named!( name, map_res!(length_bytes!(le_u64), |b: &[u8]| { @@ -322,6 +365,7 @@ named!( }) ); +/// Parses a TVMContext named!( tvm_ctx<&[u8], TVMContext>, do_parse!( @@ -331,6 +375,7 @@ named!( ) ); +/// Parses a DataType named!( data_type<&[u8], DataType>, do_parse!( @@ -341,17 +386,17 @@ named!( ) ); +/// Parses a Tensor from a TVM array file. named!( tensor, do_parse!( take!(8) >> bits!(tag_bits!(u64, 64, 0)) >> ctx: tvm_ctx >> ndim: le_u32 >> dtype: data_type - >> shape: count!(map!(le_i64, |sz| sz as usize), ndim as usize) >> length: le_i64 + >> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize) >> length: le_i64 >> data: take!(length) >> (Tensor { data: Storage::from(data), ctx: ctx, dtype: dtype, - numel: shape.iter().product(), - dshape: shape.iter().map(|&v| v as i64).collect(), + numel: shape.iter().product::() as usize, shape: shape, strides: None, byte_offset: 0, @@ -359,6 +404,7 @@ named!( ) ); +/// Parses a graph params dict from a params binary file. named!( parse_param_dict>, do_parse!( @@ -368,6 +414,7 @@ named!( ) ); +/// Loads a param dict saved using `nnvm.compiler.save_param_dict`. pub fn load_param_dict(bytes: &[u8]) -> Result> { if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) { if remaining_bytes.len() > 0 { diff --git a/rust/src/runtime/packed_func.rs b/rust/src/runtime/packed_func.rs index 8d28a20863f7..2586f67643f1 100644 --- a/rust/src/runtime/packed_func.rs +++ b/rust/src/runtime/packed_func.rs @@ -9,6 +9,11 @@ use errors::*; pub type PackedFunc = Box TVMRetValue>; +/// Calls a packed function and returns a `TVMRetValue`. +/// +/// # Example +/// +/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` #[macro_export] macro_rules! call_packed { ($fn:expr, $($args:expr),+) => { @@ -16,6 +21,8 @@ macro_rules! call_packed { }; } +/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way +/// to obtain a `TVMArgValue` is automatically via `call_packed!`. #[derive(Clone, Copy)] pub struct TVMArgValue<'a> { _lifetime: PhantomData<&'a ()>, @@ -23,6 +30,7 @@ pub struct TVMArgValue<'a> { type_code: i64, } +/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode. macro_rules! impl_prim_tvm_arg { ($type:ty, $field:ident, $code:expr, $as:ty) => { impl<'a> From<$type> for TVMArgValue<'a> { @@ -56,6 +64,7 @@ impl_prim_tvm_arg!(i64, v_int64); impl_prim_tvm_arg!(u64, v_int64); impl_prim_tvm_arg!(bool, v_int64); +/// Creates a conversion to a `TVMArgValue` for an object handle. impl<'a, T> From<*const T> for TVMArgValue<'a> { fn from(ptr: *const T) -> Self { TVMArgValue { @@ -68,6 +77,7 @@ impl<'a, T> From<*const T> for TVMArgValue<'a> { } } +/// Creates a conversion to a `TVMArgValue` for a mutable object handle. impl<'a, T> From<*mut T> for TVMArgValue<'a> { fn from(ptr: *mut T) -> Self { TVMArgValue { @@ -104,9 +114,25 @@ impl<'a> From<&'a DLTensor> for TVMArgValue<'a> { } } +/// An owned TVMPODValue. Can be converted from a variety of primitive and object types. +/// Can be downcasted using `try_from` if it contains the desired type. +/// +/// # Example +/// +/// ``` +/// let a = 42u32; +/// let b: i64 = TVMRetValue::from(a).try_into().unwrap(); +/// +/// let s = "hello, world!"; +/// let t: TVMRetValue = s.into(); +/// assert_eq!(String::try_from(t).unwrap(), s); +/// ``` pub struct TVMRetValue { + /// A primitive return value, if any. prim_value: u64, + /// An object return value, if any. box_value: Box, + /// The DLDataTypeCode which determines whether `prim_value` or `box_value` is in use. type_code: i64, } @@ -179,6 +205,7 @@ impl_prim_ret_value!(u32, 1); impl_prim_ret_value!(f32, 2); impl_boxed_ret_value!(String, 11); +// @see `WrapPackedFunc` in `llvm_module.cc`. pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { box move |args: &[TVMArgValue]| { func( diff --git a/rust/src/runtime/threading.rs b/rust/src/runtime/threading.rs index 84b409aa6b3e..c5db5147632d 100644 --- a/rust/src/runtime/threading.rs +++ b/rust/src/runtime/threading.rs @@ -26,6 +26,7 @@ use ffi::runtime::TVMParallelGroupEnv; type FTVMParallelLambda = extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; +/// Holds a parallel job request made by a TVM library function. struct Job { cb: FTVMParallelLambda, cdata: *const c_void, @@ -34,6 +35,7 @@ struct Job { } impl Job { + /// Splits this job into a number of `Task`s which can be scheduled. fn tasks(&self, num_workers: usize) -> Vec { let num_tasks = if self.req_num_tasks == 0 { num_workers @@ -58,6 +60,7 @@ impl Job { .collect() } + /// Waits for all tasks in this `Job` to be completed. fn wait(&self) -> Result<()> { while self.pending.load(Ordering::Acquire) > 0 { #[cfg(not(target_env = "sgx"))] @@ -67,6 +70,7 @@ impl Job { } } +/// A chunk of work requested by a TVM function. struct Task { id: usize, flambda: FTVMParallelLambda, @@ -166,6 +170,7 @@ impl ThreadPool { } } +// Send + Sync wrapper for bounded_spsc_queue::Consumer struct Consumer { consumer: bounded_spsc_queue::Consumer, } @@ -184,6 +189,7 @@ unsafe impl Sync for Consumer {} #[cfg(target_env = "sgx")] lazy_static! { + /// Holds tasks for untrusted threads which re-enter the enclave to execute. static ref SGX_QUEUES: Mutex>> = Mutex::new(VecDeque::new()); } @@ -204,7 +210,7 @@ fn max_concurrency() -> usize { #[cfg(target_arch = "wasm32")] fn max_concurrency() -> usize { - 0 + 0 // wasm doesn't support threads yet } #[cfg(target_env = "sgx")] @@ -243,6 +249,7 @@ pub extern "C" fn TVMBackendParallelLaunch( return 0; } +// @see https://github.com/dmlc/tvm/issues/988 for information on why this function is used. #[no_mangle] pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) { let barrier: &Arc = unsafe { &*((*penv).sync_handle as *const Arc) }; From 0b515c98ba752c352f2d08abdda14581bb9bc9f8 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Sun, 19 Aug 2018 23:46:04 +0000 Subject: [PATCH 06/10] Rename numel to size --- rust/src/runtime/array.rs | 22 +++++++++++----------- rust/src/runtime/graph.rs | 4 ++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/rust/src/runtime/array.rs b/rust/src/runtime/array.rs index 9fec9ef7b1ef..722687eebbfc 100644 --- a/rust/src/runtime/array.rs +++ b/rust/src/runtime/array.rs @@ -127,7 +127,7 @@ pub struct Tensor<'a> { /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous. pub(super) strides: Option>, pub(super) byte_offset: isize, - pub(super) numel: usize, + pub(super) size: usize, } impl<'a> Tensor<'a> { @@ -143,13 +143,13 @@ impl<'a> Tensor<'a> { pub fn to_vec(&self) -> Vec { assert!(self.is_contiguous()); assert!(self.dtype.is_type::()); - let mut vec: Vec = Vec::with_capacity(self.numel * self.dtype.itemsize()); + let mut vec: Vec = Vec::with_capacity(self.size * self.dtype.itemsize()); unsafe { vec.as_mut_ptr().copy_from_nonoverlapping( self.data.as_ptr().offset(self.byte_offset) as *const T, - self.numel, + self.size, ); - vec.set_len(self.numel); + vec.set_len(self.size); } vec } @@ -185,7 +185,7 @@ impl<'a> Tensor<'a> { /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. pub fn copy(&mut self, other: &Tensor) { assert!( - self.dtype == other.dtype && self.numel == other.numel, + self.dtype == other.dtype && self.size == other.size, "Tensor shape/dtype mismatch." ); assert!( @@ -201,7 +201,7 @@ impl<'a> Tensor<'a> { .offset(self.byte_offset as isize) .copy_from_nonoverlapping( other.data.as_mut_ptr().offset(other.byte_offset), - other.numel * other.dtype.itemsize(), + other.size * other.dtype.itemsize(), ); } } @@ -212,7 +212,7 @@ impl<'a> Tensor<'a> { data: self.data.to_owned(), ctx: self.ctx.clone(), dtype: self.dtype.clone(), - numel: self.numel.clone(), + size: self.size.clone(), shape: self.shape.clone(), strides: None, byte_offset: 0, @@ -245,7 +245,7 @@ impl DLTensor { ndim: if flatten { 1 } else { tensor.shape.len() } as i32, dtype: DLDataType::from(&tensor.dtype), shape: if flatten { - &tensor.numel as *const _ as *mut i64 + &tensor.size as *const _ as *mut i64 } else { tensor.shape.as_ptr() } as *mut i64, @@ -362,7 +362,7 @@ fn tensor_from_array_storage<'a, 's, T, D: ndarray::Dimension>( bits: 8 * type_width, lanes: 1, }, - numel: arr.len(), + size: arr.len(), shape: arr.shape().iter().map(|&v| v as i64).collect(), strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()), byte_offset: 0, @@ -379,9 +379,9 @@ macro_rules! impl_tensor_from_ndarray { impl From> for Tensor<'static> { fn from(arr: ndarray::Array<$type, D>) -> Self { assert!(arr.is_standard_layout(), "Array must be contiguous."); - let numel = arr.len() * mem::size_of::<$type>() as usize; + let size = arr.len() * mem::size_of::<$type>() as usize; let storage = - Storage::from(unsafe { slice::from_raw_parts(arr.as_ptr() as *const u8, numel) }); + Storage::from(unsafe { slice::from_raw_parts(arr.as_ptr() as *const u8, size) }); tensor_from_array_storage(&arr, storage, $typecode as usize) } } diff --git a/rust/src/runtime/graph.rs b/rust/src/runtime/graph.rs index edf1c77ebb52..805a8b232e09 100644 --- a/rust/src/runtime/graph.rs +++ b/rust/src/runtime/graph.rs @@ -212,7 +212,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { data: mem::replace(&mut storages[storage_id], storage), ctx: TVMContext::default(), dtype: dtype, - numel: shape.iter().product::() as usize, + size: shape.iter().product::() as usize, shape: shape, strides: None, byte_offset: 0, @@ -396,7 +396,7 @@ named!( data: Storage::from(data), ctx: ctx, dtype: dtype, - numel: shape.iter().product::() as usize, + size: shape.iter().product::() as usize, shape: shape, strides: None, byte_offset: 0, From ce067bed38c9e31ace30fb3e37a845be6c05955e Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Mon, 20 Aug 2018 00:01:58 +0000 Subject: [PATCH 07/10] Address code review comments --- rust/src/runtime/array.rs | 104 +++++++++++++++++++------------- rust/src/runtime/module.rs | 2 +- rust/src/runtime/packed_func.rs | 9 +++ 3 files changed, 72 insertions(+), 43 deletions(-) diff --git a/rust/src/runtime/array.rs b/rust/src/runtime/array.rs index 722687eebbfc..9f1cc5d7e484 100644 --- a/rust/src/runtime/array.rs +++ b/rust/src/runtime/array.rs @@ -219,23 +219,54 @@ impl<'a> Tensor<'a> { }; unsafe { mem::transmute::, Tensor<'static>>(t) } } -} -impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD { - type Error = Error; - fn try_from(tensor: &'a Tensor) -> Result> { - ensure!( - tensor.dtype == DTYPE_FLOAT32, - "Cannot convert Tensor with dtype {:?} to ndarray", - tensor.dtype - ); - Ok(ndarray::Array::from_shape_vec( - tensor.shape.iter().map(|s| *s as usize).collect::>(), - tensor.to_vec::(), - )?) + fn from_array_storage<'s, T, D: ndarray::Dimension>( + arr: &ndarray::Array, + storage: Storage<'s>, + type_code: usize, + ) -> Tensor<'s> { + let type_width = mem::size_of::() as usize; + Tensor { + data: storage, + ctx: TVMContext::default(), + dtype: DataType { + code: type_code, + bits: 8 * type_width, + lanes: 1, + }, + size: arr.len(), + shape: arr.shape().iter().map(|&v| v as i64).collect(), + strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()), + byte_offset: 0, + } } } +/// Conversions to `ndarray::Array` from `Tensor`, if the types match. +macro_rules! impl_ndarray_try_from_tensor { + ($type:ty, $dtype:expr) => { + impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> { + type Error = Error; + fn try_from(tensor: &'a Tensor) -> Result> { + ensure!( + tensor.dtype == $dtype, + "Cannot convert Tensor with dtype {:?} to ndarray", + tensor.dtype + ); + Ok(ndarray::Array::from_shape_vec( + tensor.shape.iter().map(|s| *s as usize).collect::>(), + tensor.to_vec::<$type>(), + )?) + } + } + }; +} + +impl_ndarray_try_from_tensor!(i32, DTYPE_INT32); +impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32); +impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32); +impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64); + impl DLTensor { pub(super) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self { assert!(!flatten || tensor.is_contiguous()); @@ -299,12 +330,6 @@ impl DataType { } } -const DTYPE_FLOAT32: DataType = DataType { - code: DLDataTypeCode_kDLFloat as usize, - bits: 32, - lanes: 1, -}; - impl<'a> From<&'a DataType> for DLDataType { fn from(dtype: &'a DataType) -> Self { Self { @@ -315,6 +340,22 @@ impl<'a> From<&'a DataType> for DLDataType { } } +macro_rules! make_dtype_const { + ($name: ident, $code: ident, $bits: expr, $lanes: expr) => { + const $name: DataType = DataType { + code: $code as usize, + bits: $bits, + lanes: $lanes, + }; + } +} + +make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1); +make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1); +make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1); +make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1); +make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1); + impl Default for DLContext { fn default() -> Self { DLContext { @@ -348,27 +389,6 @@ impl Default for TVMContext { } } -fn tensor_from_array_storage<'a, 's, T, D: ndarray::Dimension>( - arr: &ndarray::Array, - storage: Storage<'s>, - type_code: usize, -) -> Tensor<'s> { - let type_width = mem::size_of::() as usize; - Tensor { - data: storage, - ctx: TVMContext::default(), - dtype: DataType { - code: type_code, - bits: 8 * type_width, - lanes: 1, - }, - size: arr.len(), - shape: arr.shape().iter().map(|&v| v as i64).collect(), - strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()), - byte_offset: 0, - } -} - /// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`. /// /// # Panics @@ -382,13 +402,13 @@ macro_rules! impl_tensor_from_ndarray { let size = arr.len() * mem::size_of::<$type>() as usize; let storage = Storage::from(unsafe { slice::from_raw_parts(arr.as_ptr() as *const u8, size) }); - tensor_from_array_storage(&arr, storage, $typecode as usize) + Tensor::from_array_storage(&arr, storage, $typecode as usize) } } impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> { fn from(arr: &'a ndarray::Array<$type, D>) -> Self { assert!(arr.is_standard_layout(), "Array must be contiguous."); - tensor_from_array_storage( + Tensor::from_array_storage( arr, Storage::from(arr.as_slice().unwrap()), $typecode as usize, diff --git a/rust/src/runtime/module.rs b/rust/src/runtime/module.rs index e23c21fcd611..2594756d9885 100644 --- a/rust/src/runtime/module.rs +++ b/rust/src/runtime/module.rs @@ -9,7 +9,7 @@ pub trait Module { fn get_function>(&self, name: S) -> Option; } -pub struct SystemLibModule {} +pub struct SystemLibModule; lazy_static! { static ref SYSTEM_LIB_FUNCTIONS: Mutex> = diff --git a/rust/src/runtime/packed_func.rs b/rust/src/runtime/packed_func.rs index 2586f67643f1..ff5fd667af12 100644 --- a/rust/src/runtime/packed_func.rs +++ b/rust/src/runtime/packed_func.rs @@ -200,9 +200,18 @@ macro_rules! impl_boxed_ret_value { }; } +impl_prim_ret_value!(i8, 0); +impl_prim_ret_value!(u8, 1); +impl_prim_ret_value!(i16, 0); +impl_prim_ret_value!(u16, 1); impl_prim_ret_value!(i32, 0); impl_prim_ret_value!(u32, 1); impl_prim_ret_value!(f32, 2); +impl_prim_ret_value!(i64, 0); +impl_prim_ret_value!(u64, 1); +impl_prim_ret_value!(f64, 2); +impl_prim_ret_value!(isize, 0); +impl_prim_ret_value!(usize, 1); impl_boxed_ret_value!(String, 11); // @see `WrapPackedFunc` in `llvm_module.cc`. From 706c8de81a7e0d0e27e1774f7fddad2b21c08b0d Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Wed, 26 Sep 2018 00:38:18 -0500 Subject: [PATCH 08/10] Allow building SGX modules --- rust/.rustfmt.toml | 26 +- rust/Cargo.toml | 14 +- rust/build.rs | 47 -- rust/src/errors.rs | 13 +- rust/src/lib.rs | 27 +- rust/src/runtime/allocator.rs | 3 + rust/src/runtime/array.rs | 16 +- rust/src/runtime/c_runtime_api.rs | 770 ++++++++++++++++++++++++++++++ rust/src/runtime/graph.rs | 141 +++--- rust/src/runtime/mod.rs | 10 +- rust/src/runtime/packed_func.rs | 38 +- rust/src/runtime/sgx.rs | 78 +++ rust/src/runtime/threading.rs | 76 ++- rust/src/runtime/workspace.rs | 58 ++- rust/tests/test_graph_serde.rs | 7 +- 15 files changed, 1124 insertions(+), 200 deletions(-) delete mode 100644 rust/build.rs create mode 100644 rust/src/runtime/c_runtime_api.rs create mode 100644 rust/src/runtime/sgx.rs diff --git a/rust/.rustfmt.toml b/rust/.rustfmt.toml index 9b2cf0e1007d..df9a65dacfaa 100644 --- a/rust/.rustfmt.toml +++ b/rust/.rustfmt.toml @@ -1,13 +1,15 @@ max_width = 100 hard_tabs = false tab_spaces = 2 -newline_style = "Unix" -use_small_heuristics = true +newline_style = "Auto" +use_small_heuristics = "Default" indent_style = "Block" wrap_comments = false comment_width = 80 normalize_comments = false -format_strings = true +format_strings = false +format_macro_matchers = false +format_macro_bodies = true empty_item_single_line = true struct_lit_single_line = true fn_single_line = false @@ -22,9 +24,8 @@ type_punctuation_density = "Wide" space_before_colon = false space_after_colon = true spaces_around_ranges = false -spaces_within_parens_and_brackets = false binop_separator = "Front" -remove_blank_lines_at_start_or_end_of_block = true +remove_nested_parens = true combine_control_expr = true struct_field_align_threshold = 0 match_arm_blocks = true @@ -37,21 +38,22 @@ trailing_comma = "Vertical" match_block_trailing_comma = false blank_lines_upper_bound = 1 blank_lines_lower_bound = 0 +edition = "Edition2015" merge_derives = true use_try_shorthand = true -condense_wildcard_suffixes = true -force_explicit_abi = true use_field_init_shorthand = false -write_mode = "Overwrite" +force_explicit_abi = true +condense_wildcard_suffixes = false color = "Auto" -required_version = "0.6.1" +required_version = "0.99.4" unstable_features = false disable_all_formatting = false skip_children = false hide_parse_errors = false -error_on_line_overflow = true -error_on_unformatted = true +error_on_line_overflow = false +error_on_unformatted = false report_todo = "Never" report_fixme = "Never" ignore = [] -verbose_diff = false +emit_mode = "Files" +make_backup = false diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 1283c24dd8f7..0819e0c70023 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -3,28 +3,26 @@ name = "tvm" version = "0.1.0" license = "Apache-2.0" description = "TVM Rust runtime" -repository = "https://github.com/nhynes/tvm-rs" +repository = "https://github.com/dmlc/tvm" readme = "README.md" keywords = ["tvm", "nnvm"] categories = ["api-bindings", "science"] authors = ["Nick Hynes "] [features] -par-launch-alloc = [] +default = ["nom/std"] +sgx = ["nom/alloc"] [dependencies] bounded-spsc-queue = "0.4.0" error-chain = { version = "0.12.0", default-features = false } itertools = "0.7.8" -lazy_static = "1.0.0" +lazy_static = "1.1.0" ndarray = "0.11.2" -nom = "4.0.0" +nom = {version = "4.0.0", default-features = false } serde = "1.0.59" -serde_derive = "1.0.59" +serde_derive = "1.0.79" serde_json = "1.0.17" [target.'cfg(not(target_env = "sgx"))'.dependencies] num_cpus = "1.8.0" - -[build-dependencies] -bindgen = "0.37" diff --git a/rust/build.rs b/rust/build.rs deleted file mode 100644 index f21c9e0c2c1d..000000000000 --- a/rust/build.rs +++ /dev/null @@ -1,47 +0,0 @@ -extern crate bindgen; - -use std::{env, path::PathBuf}; - -fn parse_clang_ver(raw_v: String) -> Vec { - raw_v - .split_whitespace() - .nth(2) - .unwrap() - .split('.') - .map(|v| v.parse::().unwrap()) - .collect() -} - -fn main() { - let clang_ver = parse_clang_ver(bindgen::clang_version().full); - let bindings = bindgen::Builder::default() - .header(concat!( - env!("CARGO_MANIFEST_DIR"), - "/../include/tvm/runtime/c_runtime_api.h" - )) - .header(concat!( - env!("CARGO_MANIFEST_DIR"), - "/../include/tvm/runtime/c_backend_api.h" - )) - .rust_target(bindgen::RustTarget::Nightly) - .clang_arg(concat!( - "-I", - env!("CARGO_MANIFEST_DIR"), - "/../dlpack/include" - )) - .clang_arg(format!("--target={}", env::var("HOST").unwrap())) - .clang_arg("-I/usr/include") - .clang_arg("-I/usr/local/include") - .clang_arg(format!( - "-I/usr/local/lib/clang/{}.{}.{}/include", - clang_ver[0], clang_ver[1], clang_ver[2] - )) - .layout_tests(false) - .generate() - .expect("Unable to generate bindings."); - - let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); - bindings - .write_to_file(out_path.join("c_runtime_api.rs")) - .expect("Unable to write bindings."); -} diff --git a/rust/src/errors.rs b/rust/src/errors.rs index df6ee8f3c4e2..f9da7180b8cc 100644 --- a/rust/src/errors.rs +++ b/rust/src/errors.rs @@ -1,4 +1,8 @@ -use std::{alloc, num}; +#[cfg(target_env = "sgx")] +use alloc::alloc; +#[cfg(not(target_env = "sgx"))] +use std::alloc; +use std::num; use ndarray; use serde_json; @@ -22,9 +26,14 @@ error_chain! { } foreign_links { Alloc(alloc::AllocErr); - Layout(alloc::LayoutErr); GraphDeserialize(serde_json::Error); ParseInt(num::ParseIntError); ShapeError(ndarray::ShapeError); } } + +impl From for Error { + fn from(_err: alloc::LayoutErr) -> Error { + Error::from_kind(ErrorKind::Msg("Layout error".to_string())) + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index b2801e1e8cc6..fd83d5ee6c6d 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,6 +1,19 @@ -#![feature(allocator_api, box_syntax, fn_traits, try_from, unboxed_closures)] +#![feature( + alloc, + allocator_api, + box_syntax, + extern_prelude, + fn_traits, + try_from, + unboxed_closures, + vec_remove_item +)] +#[cfg(target_env = "sgx")] +extern crate alloc; extern crate bounded_spsc_queue; +#[cfg(target_env = "sgx")] +extern crate core; #[macro_use] extern crate error_chain; #[macro_use] @@ -18,12 +31,20 @@ extern crate serde_derive; extern crate serde_json; pub mod ffi { - #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)] + #![allow( + non_camel_case_types, + non_snake_case, + non_upper_case_globals, + unused + )] pub mod runtime { use std::os::raw::{c_char, c_int, c_void}; - include!(concat!(env!("OUT_DIR"), "/c_runtime_api.rs")); + include!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/runtime/c_runtime_api.rs" + )); pub type BackendPackedCFunc = extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; diff --git a/rust/src/runtime/allocator.rs b/rust/src/runtime/allocator.rs index 9b029e80946f..d704336bff1f 100644 --- a/rust/src/runtime/allocator.rs +++ b/rust/src/runtime/allocator.rs @@ -1,3 +1,6 @@ +#[cfg(target_env = "sgx")] +use alloc::alloc::{self, Layout}; +#[cfg(not(target_env = "sgx"))] use std::alloc::{self, Layout}; use errors::*; diff --git a/rust/src/runtime/array.rs b/rust/src/runtime/array.rs index 9f1cc5d7e484..6a9a60c9bb14 100644 --- a/rust/src/runtime/array.rs +++ b/rust/src/runtime/array.rs @@ -3,8 +3,7 @@ use std::{ convert::TryFrom, mem, os::raw::{c_int, c_void}, - ptr, - slice, + ptr, slice, }; use ndarray; @@ -172,8 +171,7 @@ impl<'a> Tensor<'a> { expected_stride * (*shape as usize), ) }, - ) - .0 + ).0 } } } @@ -254,7 +252,11 @@ macro_rules! impl_ndarray_try_from_tensor { tensor.dtype ); Ok(ndarray::Array::from_shape_vec( - tensor.shape.iter().map(|s| *s as usize).collect::>(), + tensor + .shape + .iter() + .map(|s| *s as usize) + .collect::>(), tensor.to_vec::<$type>(), )?) } @@ -347,12 +349,12 @@ macro_rules! make_dtype_const { bits: $bits, lanes: $lanes, }; - } + }; } make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1); make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1); -make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1); +// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1); make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1); make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1); diff --git a/rust/src/runtime/c_runtime_api.rs b/rust/src/runtime/c_runtime_api.rs new file mode 100644 index 000000000000..62cfa0d15451 --- /dev/null +++ b/rust/src/runtime/c_runtime_api.rs @@ -0,0 +1,770 @@ +/* automatically generated by rust-bindgen for TVM revision 6292c78 */ + +pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0"; +pub const DLPACK_VERSION: u32 = 8; +pub const _STDINT_H: u32 = 1; +pub const _FEATURES_H: u32 = 1; +pub const _DEFAULT_SOURCE: u32 = 1; +pub const __USE_ISOC11: u32 = 1; +pub const __USE_ISOC99: u32 = 1; +pub const __USE_ISOC95: u32 = 1; +pub const __USE_POSIX_IMPLICITLY: u32 = 1; +pub const _POSIX_SOURCE: u32 = 1; +pub const _POSIX_C_SOURCE: u32 = 200809; +pub const __USE_POSIX: u32 = 1; +pub const __USE_POSIX2: u32 = 1; +pub const __USE_POSIX199309: u32 = 1; +pub const __USE_POSIX199506: u32 = 1; +pub const __USE_XOPEN2K: u32 = 1; +pub const __USE_XOPEN2K8: u32 = 1; +pub const _ATFILE_SOURCE: u32 = 1; +pub const __USE_MISC: u32 = 1; +pub const __USE_ATFILE: u32 = 1; +pub const __USE_FORTIFY_LEVEL: u32 = 0; +pub const _STDC_PREDEF_H: u32 = 1; +pub const __STDC_IEC_559__: u32 = 1; +pub const __STDC_IEC_559_COMPLEX__: u32 = 1; +pub const __STDC_ISO_10646__: u32 = 201505; +pub const __STDC_NO_THREADS__: u32 = 1; +pub const __GNU_LIBRARY__: u32 = 6; +pub const __GLIBC__: u32 = 2; +pub const __GLIBC_MINOR__: u32 = 23; +pub const _SYS_CDEFS_H: u32 = 1; +pub const __WORDSIZE: u32 = 64; +pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1; +pub const __SYSCALL_WORDSIZE: u32 = 64; +pub const _BITS_WCHAR_H: u32 = 1; +pub const INT8_MIN: i32 = -128; +pub const INT16_MIN: i32 = -32768; +pub const INT32_MIN: i32 = -2147483648; +pub const INT8_MAX: u32 = 127; +pub const INT16_MAX: u32 = 32767; +pub const INT32_MAX: u32 = 2147483647; +pub const UINT8_MAX: u32 = 255; +pub const UINT16_MAX: u32 = 65535; +pub const UINT32_MAX: u32 = 4294967295; +pub const INT_LEAST8_MIN: i32 = -128; +pub const INT_LEAST16_MIN: i32 = -32768; +pub const INT_LEAST32_MIN: i32 = -2147483648; +pub const INT_LEAST8_MAX: u32 = 127; +pub const INT_LEAST16_MAX: u32 = 32767; +pub const INT_LEAST32_MAX: u32 = 2147483647; +pub const UINT_LEAST8_MAX: u32 = 255; +pub const UINT_LEAST16_MAX: u32 = 65535; +pub const UINT_LEAST32_MAX: u32 = 4294967295; +pub const INT_FAST8_MIN: i32 = -128; +pub const INT_FAST16_MIN: i64 = -9223372036854775808; +pub const INT_FAST32_MIN: i64 = -9223372036854775808; +pub const INT_FAST8_MAX: u32 = 127; +pub const INT_FAST16_MAX: u64 = 9223372036854775807; +pub const INT_FAST32_MAX: u64 = 9223372036854775807; +pub const UINT_FAST8_MAX: u32 = 255; +pub const UINT_FAST16_MAX: i32 = -1; +pub const UINT_FAST32_MAX: i32 = -1; +pub const INTPTR_MIN: i64 = -9223372036854775808; +pub const INTPTR_MAX: u64 = 9223372036854775807; +pub const UINTPTR_MAX: i32 = -1; +pub const PTRDIFF_MIN: i64 = -9223372036854775808; +pub const PTRDIFF_MAX: u64 = 9223372036854775807; +pub const SIG_ATOMIC_MIN: i32 = -2147483648; +pub const SIG_ATOMIC_MAX: u32 = 2147483647; +pub const SIZE_MAX: i32 = -1; +pub const WINT_MIN: u32 = 0; +pub const WINT_MAX: u32 = 4294967295; +pub type int_least8_t = ::std::os::raw::c_schar; +pub type int_least16_t = ::std::os::raw::c_short; +pub type int_least32_t = ::std::os::raw::c_int; +pub type int_least64_t = ::std::os::raw::c_long; +pub type uint_least8_t = ::std::os::raw::c_uchar; +pub type uint_least16_t = ::std::os::raw::c_ushort; +pub type uint_least32_t = ::std::os::raw::c_uint; +pub type uint_least64_t = ::std::os::raw::c_ulong; +pub type int_fast8_t = ::std::os::raw::c_schar; +pub type int_fast16_t = ::std::os::raw::c_long; +pub type int_fast32_t = ::std::os::raw::c_long; +pub type int_fast64_t = ::std::os::raw::c_long; +pub type uint_fast8_t = ::std::os::raw::c_uchar; +pub type uint_fast16_t = ::std::os::raw::c_ulong; +pub type uint_fast32_t = ::std::os::raw::c_ulong; +pub type uint_fast64_t = ::std::os::raw::c_ulong; +pub type intmax_t = ::std::os::raw::c_long; +pub type uintmax_t = ::std::os::raw::c_ulong; +pub type wchar_t = ::std::os::raw::c_int; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct max_align_t { + pub __clang_max_align_nonce1: ::std::os::raw::c_longlong, + pub __bindgen_padding_0: u64, + pub __clang_max_align_nonce2: f64, +} +pub const DLDeviceType_kDLCPU: DLDeviceType = 1; +pub const DLDeviceType_kDLGPU: DLDeviceType = 2; +pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3; +pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4; +pub const DLDeviceType_kDLMetal: DLDeviceType = 8; +pub const DLDeviceType_kDLVPI: DLDeviceType = 9; +pub const DLDeviceType_kDLROCM: DLDeviceType = 10; +/// \brief The device type in DLContext. +pub type DLDeviceType = u32; +/// \brief A Device context for Tensor and operator. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLContext { + /// \brief The device type used in the device. + pub device_type: DLDeviceType, + /// \brief The device index + pub device_id: ::std::os::raw::c_int, +} +pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0; +pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1; +pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2; +/// \brief The type code options DLDataType. +pub type DLDataTypeCode = u32; +/// \brief The data type the tensor can hold. +/// +/// Examples +/// - float: type_code = 2, bits = 32, lanes=1 +/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 +/// - int8: type_code = 0, bits = 8, lanes=1 +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLDataType { + /// \brief Type code of base types. + /// We keep it uint8_t instead of DLDataTypeCode for minimal memory + /// footprint, but the value should be one of DLDataTypeCode enum values. + /// + pub code: u8, + /// \brief Number of bits, common choices are 8, 16, 32. + pub bits: u8, + /// \brief Number of lanes in the type, used for vector types. + pub lanes: u16, +} +/// \brief Plain C Tensor object, does not manage memory. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLTensor { + /// \brief The opaque data pointer points to the allocated data. + /// This will be CUDA device pointer or cl_mem handle in OpenCL. + /// This pointer is always aligns to 256 bytes as in CUDA. + pub data: *mut ::std::os::raw::c_void, + /// \brief The device context of the tensor + pub ctx: DLContext, + /// \brief Number of dimensions + pub ndim: ::std::os::raw::c_int, + /// \brief The data type of the pointer + pub dtype: DLDataType, + /// \brief The shape of the tensor + pub shape: *mut i64, + /// \brief strides of the tensor, + /// can be NULL, indicating tensor is compact. + pub strides: *mut i64, + /// \brief The offset in bytes to the beginning pointer to data + pub byte_offset: u64, +} +/// \brief C Tensor object, manage memory of DLTensor. This data structure is +/// intended to faciliate the borrowing of DLTensor by another framework. It is +/// not meant to transfer the tensor. When the borrowing framework doesn't need +/// the tensor, it should call the deleter to notify the host that the resource +/// is no longer needed. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLManagedTensor { + /// \brief DLTensor which is being memory managed + pub dl_tensor: DLTensor, + /// \brief the context of the original host framework of DLManagedTensor in + /// which DLManagedTensor is used in the framework. It can also be NULL. + pub manager_ctx: *mut ::std::os::raw::c_void, + /// \brief Destructor signature void (*)(void*) - this should be called + /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL + /// if there is no way for the caller to provide a reasonable destructor. + pub deleter: ::std::option::Option, +} +/// \brief type of array index. +pub type tvm_index_t = i64; +pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5; +pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6; +pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7; +pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11; +pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12; +/// \brief Extension device types in TVM +pub type TVMDeviceExtType = u32; +pub const TVMTypeCode_kHandle: TVMTypeCode = 3; +pub const TVMTypeCode_kNull: TVMTypeCode = 4; +pub const TVMTypeCode_kTVMType: TVMTypeCode = 5; +pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6; +pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7; +pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8; +pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9; +pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10; +pub const TVMTypeCode_kStr: TVMTypeCode = 11; +pub const TVMTypeCode_kBytes: TVMTypeCode = 12; +pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13; +pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15; +pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16; +pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20; +pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64; +pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128; +/// \brief The type code in TVMType +/// \note TVMType is used in two places. +pub type TVMTypeCode = u32; +/// \brief The data type used in TVM Runtime. +/// +/// Examples +/// - float: type_code = 2, bits = 32, lanes=1 +/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 +/// - int8: type_code = 0, bits = 8, lanes=1 +/// +/// \note Arguments TVM API function always takes bits=64 and lanes=1 +pub type TVMType = DLDataType; +/// \brief The Device information, abstract away common device types. +pub type TVMContext = DLContext; +/// \brief The tensor array stucture to TVM API. +pub type TVMArray = DLTensor; +/// \brief the array handle +pub type TVMArrayHandle = *mut TVMArray; +/// \brief Union type of values +/// being passed through API and function calls. +#[repr(C)] +#[derive(Copy, Clone)] +pub union TVMValue { + pub v_int64: i64, + pub v_float64: f64, + pub v_handle: *mut ::std::os::raw::c_void, + pub v_str: *const ::std::os::raw::c_char, + pub v_type: TVMType, + pub v_ctx: TVMContext, + _bindgen_union_align: u64, +} +/// \brief Byte array type used to pass in byte array +/// When kBytes is used as data type. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct TVMByteArray { + pub data: *const ::std::os::raw::c_char, + pub size: usize, +} +/// \brief Handle to TVM runtime modules. +pub type TVMModuleHandle = *mut ::std::os::raw::c_void; +/// \brief Handle to packed function handle. +pub type TVMFunctionHandle = *mut ::std::os::raw::c_void; +/// \brief Handle to hold return value. +pub type TVMRetValueHandle = *mut ::std::os::raw::c_void; +/// \brief The stream that is specific to device +/// can be NULL, which indicates the default one. +pub type TVMStreamHandle = *mut ::std::os::raw::c_void; +extern "C" { + /// \brief Used for implementing C API function. + /// Set last error message before return. + /// \param msg The error message to be set. + pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char); +} +extern "C" { + /// \brief return str message of the last error + /// all function in this file will return 0 when success + /// and -1 when an error occured, + /// TVMGetLastError can be called to retrieve the error + /// + /// this function is threadsafe and can be called by different thread + /// \return error info + pub fn TVMGetLastError() -> *const ::std::os::raw::c_char; +} +extern "C" { + /// \brief Load module from file. + /// \param file_name The file name to load the module from. + /// \param format The format of the module. + /// \param out The result module + /// + /// \return 0 when success, -1 when failure happens + /// \note The resulting module do not contain import relation. + /// It can be reconstructed by TVMModImport. + pub fn TVMModLoadFromFile( + file_name: *const ::std::os::raw::c_char, + format: *const ::std::os::raw::c_char, + out: *mut TVMModuleHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Add dep to mod's dependency. + /// This allows functions in this module to use modules. + /// + /// \param mod The module handle. + /// \param dep The dependent module to be imported. + /// \return 0 when success, -1 when failure happens + pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Get function from the module. + /// \param mod The module handle. + /// \param func_name The name of the function. + /// \param query_imports Whether to query imported modules + /// \param out The result function, can be NULL if it is not available. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMModGetFunction( + mod_: TVMModuleHandle, + func_name: *const ::std::os::raw::c_char, + query_imports: ::std::os::raw::c_int, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free front-end extension type resource. + /// \param handle The extension handle. + /// \param type_code The type of of the extension type. + /// \return 0 when success, -1 when failure happens + pub fn TVMExtTypeFree( + handle: *mut ::std::os::raw::c_void, + type_code: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free the Module + /// \param mod The module to be freed. + /// + /// \note This may not free up the module's resources. + /// If there is active TVMFunctionHandle uses the module + /// Or if this module is imported by another active module. + /// + /// The all functions remains valid until TVMFuncFree is called. + /// \return 0 when success, -1 when failure happens + pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free the function when it is no longer needed. + /// \param func The function handle + /// \return 0 when success, -1 when failure happens + pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Call a Packed TVM Function. + /// + /// \param func node handle of the function. + /// \param arg_values The arguments + /// \param type_codes The type codes of the arguments + /// \param num_args Number of arguments. + /// + /// \param ret_val The return value. + /// \param ret_type_code the type code of return value. + /// + /// \return 0 when success, -1 when failure happens + /// \note TVM calls always exchanges with type bits=64, lanes=1 + /// + /// \note API calls always exchanges with type bits=64, lanes=1 + /// If API call returns container handles (e.g. FunctionHandle) + /// these handles should be managed by the front-end. + /// The front-end need to call free function (e.g. TVMFuncFree) + /// to free these handles. + pub fn TVMFuncCall( + func: TVMFunctionHandle, + arg_values: *mut TVMValue, + type_codes: *mut ::std::os::raw::c_int, + num_args: ::std::os::raw::c_int, + ret_val: *mut TVMValue, + ret_type_code: *mut ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Set the return value of TVMPackedCFunc. + /// + /// This function is called by TVMPackedCFunc to set the return value. + /// When this function is not called, the function returns null by default. + /// + /// \param ret The return value handle, pass by ret in TVMPackedCFunc + /// \param value The value to be returned. + /// \param type_code The type of the value to be returned. + /// \param num_ret Number of return values, for now only 1 is supported. + pub fn TVMCFuncSetReturn( + ret: TVMRetValueHandle, + value: *mut TVMValue, + type_code: *mut ::std::os::raw::c_int, + num_ret: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Inplace translate callback argument value to return value. + /// This is only needed for non-POD arguments. + /// + /// \param value The value to be translated. + /// \param code The type code to be translated. + /// \note This function will do a shallow copy when necessary. + /// + /// \return 0 when success, -1 when failure happens. + pub fn TVMCbArgToReturn( + value: *mut TVMValue, + code: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +/// \brief C type of packed function. +/// +/// \param args The arguments +/// \param type_codes The type codes of the arguments +/// \param num_args Number of arguments. +/// \param ret The return value handle. +/// \param resource_handle The handle additional resouce handle from fron-end. +/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. +/// \sa TVMCFuncSetReturn +pub type TVMPackedCFunc = ::std::option::Option< + unsafe extern "C" fn( + args: *mut TVMValue, + type_codes: *mut ::std::os::raw::c_int, + num_args: ::std::os::raw::c_int, + ret: TVMRetValueHandle, + resource_handle: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int, +>; +/// \brief C callback to free the resource handle in C packed function. +/// \param resource_handle The handle additional resouce handle from fron-end. +pub type TVMPackedCFuncFinalizer = + ::std::option::Option; +/// \brief Signature for extension function declarer. +/// +/// TVM call this function to get the extension functions +/// The declarer will call register_func to register function and their name. +/// +/// \param register_func_handle The register function +/// \return 0 if success, -1 if failure happens +pub type TVMExtensionFuncDeclarer = ::std::option::Option< + unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int, +>; +extern "C" { + /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle. + /// + /// The resource_handle will be managed by TVM API, until the function is no longer used. + /// + /// \param func The packed C function. + /// \param resource_handle The resource handle from front-end, can be NULL. + /// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL + /// \param out the result function handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMFuncCreateFromCFunc( + func: TVMPackedCFunc, + resource_handle: *mut ::std::os::raw::c_void, + fin: TVMPackedCFuncFinalizer, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Register the function to runtime's global table. + /// + /// The registered function then can be pulled by the backend by the name. + /// + /// \param name The name of the function. + /// \param f The function to be registered. + /// \param override Whether allow override already registered function. + pub fn TVMFuncRegisterGlobal( + name: *const ::std::os::raw::c_char, + f: TVMFunctionHandle, + override_: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Get a global function. + /// + /// \param name The name of the function. + /// \param out the result function pointer, NULL if it does not exist. + /// + /// \note The function handle of global function is managed by TVM runtime, + /// So TVMFuncFree is should not be called when it get deleted. + pub fn TVMFuncGetGlobal( + name: *const ::std::os::raw::c_char, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief List all the globally registered function name + /// \param out_size The number of functions + /// \param out_array The array of function names. + /// \return 0 when success, -1 when failure happens + pub fn TVMFuncListGlobalNames( + out_size: *mut ::std::os::raw::c_int, + out_array: *mut *mut *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Allocate a nd-array's memory, + /// including space of shape, of given spec. + /// + /// \param shape The shape of the array, the data content will be copied to out + /// \param ndim The number of dimension of the array. + /// \param dtype_code The type code of the dtype + /// \param dtype_bits The number of bits of dtype + /// \param dtype_lanes The number of lanes in the dtype. + /// \param device_type The device type of context + /// \param device_id The device id of context. + /// \param out The output handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayAlloc( + shape: *const tvm_index_t, + ndim: ::std::os::raw::c_int, + dtype_code: ::std::os::raw::c_int, + dtype_bits: ::std::os::raw::c_int, + dtype_lanes: ::std::os::raw::c_int, + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + out: *mut TVMArrayHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free the TVM Array. + /// \param handle The array handle to be freed. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Copy array data from CPU byte array. + /// \param handle The array handle. + /// \param data the data pointer + /// \param nbytes The number of bytes to copy. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyFromBytes( + handle: TVMArrayHandle, + data: *mut ::std::os::raw::c_void, + nbytes: usize, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Copy array data to CPU byte array. + /// \param handle The array handle. + /// \param data the data pointer + /// \param nbytes The number of bytes to copy. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyToBytes( + handle: TVMArrayHandle, + data: *mut ::std::os::raw::c_void, + nbytes: usize, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Copy the array, both from and to must be valid during the copy. + /// \param from The array to be copied from. + /// \param to The target space. + /// \param stream The stream where the copy happens, can be NULL. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyFromTo( + from: TVMArrayHandle, + to: TVMArrayHandle, + stream: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Produce an array from the DLManagedTensor that shares data memory + /// with the DLManagedTensor. + /// \param from The source DLManagedTensor. + /// \param out The output array handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayFromDLPack( + from: *mut DLManagedTensor, + out: *mut TVMArrayHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Produce a DLMangedTensor from the array that shares data memory with + /// the array. + /// \param from The source array. + /// \param out The DLManagedTensor handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayToDLPack( + from: TVMArrayHandle, + out: *mut *mut DLManagedTensor, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Delete (free) a DLManagedTensor's data. + /// \param dltensor Pointer to the DLManagedTensor. + pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor); +} +extern "C" { + /// \brief Create a new runtime stream. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context + /// \param out The new stream handle + /// \return 0 when success, -1 when failure happens + pub fn TVMStreamCreate( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + out: *mut TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free a created stream handle. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context + /// \param stream The stream to be freed + /// \return 0 when success, -1 when failure happens + pub fn TVMStreamFree( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + stream: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Set the runtime stream of current thread to be stream. + /// The subsequent calls to the same device_type + /// will use the setted stream handle. + /// The specific type of stream is runtime device dependent. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context. + /// \param handle The stream handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMSetStream( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + handle: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Wait until all computations on stream completes. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context. + /// \param stream The stream to be synchronized. + /// \return 0 when success, -1 when failure happens + pub fn TVMSynchronize( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + stream: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Synchronize two streams of execution. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context + /// \param src The source stream to synchronize. + /// \param dst The destination stream to synchronize. + /// \return 0 when success, -1 when failure happens + pub fn TVMStreamStreamSynchronize( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + src: TVMStreamHandle, + dst: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Backend function for modules to get function + /// from its environment mod_node (its imports and global function). + /// The user do should not call TVMFuncFree on func. + /// + /// \param mod_node The module handle. + /// \param func_name The name of the function. + /// \param out The result function. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendGetFuncFromEnv( + mod_node: *mut ::std::os::raw::c_void, + func_name: *const ::std::os::raw::c_char, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Backend function to register system-wide library symbol. + /// + /// \param name The name of the symbol + /// \param ptr The symbol address. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendRegisterSystemLibSymbol( + name: *const ::std::os::raw::c_char, + ptr: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Backend function to allocate temporal workspace. + /// + /// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment. + /// + /// \param nbytes The size of the space requested. + /// \param device_type The device type which the space will be allocated. + /// \param device_id The device id which the space will be allocated. + /// \param dtype_code_hint The type code of the array elements. Only used in + /// certain backends such as OpenGL. + /// \param dtype_bits_hint The type bits of the array elements. Only used in + /// certain backends such as OpenGL. + /// \return nullptr when error is thrown, a valid ptr if success + pub fn TVMBackendAllocWorkspace( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + nbytes: u64, + dtype_code_hint: ::std::os::raw::c_int, + dtype_bits_hint: ::std::os::raw::c_int, + ) -> *mut ::std::os::raw::c_void; +} +extern "C" { + /// \brief Backend function to free temporal workspace. + /// + /// \param ptr The result allocated space pointer. + /// \param device_type The device type which the space will be allocated. + /// \param device_id The device id which the space will be allocated. + /// \return 0 when no error is thrown, -1 when failure happens + /// + /// \sa TVMBackendAllocWorkspace + pub fn TVMBackendFreeWorkspace( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + ptr: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int; +} +/// \brief Environment for TVM parallel task. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct TVMParallelGroupEnv { + /// \brief Auxiliary used for synchronization + pub sync_handle: *mut ::std::os::raw::c_void, + /// \brief total amount of task + pub num_task: i32, +} +/// \brief The callback function to execute a parallel lambda +/// \param task_id the task id of the function. +/// \param penv The parallel environment backs the execution. +/// \param cdata The supporting closure data. +pub type FTVMParallelLambda = ::std::option::Option< + unsafe extern "C" fn( + task_id: ::std::os::raw::c_int, + penv: *mut TVMParallelGroupEnv, + cdata: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int, +>; +extern "C" { + /// \brief Backend function for running parallel jobs. + /// + /// \param flambda The parallel function to be launched. + /// \param cdata The closure data. + /// \param num_task Number of tasks to launch, can be 0, means launch + /// with all available threads. + /// + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendParallelLaunch( + flambda: FTVMParallelLambda, + cdata: *mut ::std::os::raw::c_void, + num_task: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief BSP barrrier between parallel threads + /// \param task_id the task id of the function. + /// \param penv The parallel environment backs the execution. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendParallelBarrier( + task_id: ::std::os::raw::c_int, + penv: *mut TVMParallelGroupEnv, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Simple static initialization fucntion. + /// Run f once and set handle to be not null. + /// This function is mainly used for test purpose. + /// + /// \param handle An global address to indicate f + /// \param f The function to be ran + /// \param cdata The closure data to pass to the function. + /// \param nbytes Number of bytes in the closure data. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendRunOnce( + handle: *mut *mut ::std::os::raw::c_void, + f: ::std::option::Option< + unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int, + >, + cdata: *mut ::std::os::raw::c_void, + nbytes: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} diff --git a/rust/src/runtime/graph.rs b/rust/src/runtime/graph.rs index 805a8b232e09..a74c07f4fc3c 100644 --- a/rust/src/runtime/graph.rs +++ b/rust/src/runtime/graph.rs @@ -11,9 +11,9 @@ use ffi::runtime::{ }; // Magic number for NDArray file. @see `kTVMNDArrayMagic` in `ndarray.h` -const NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F; +const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F; // Magic number for NDArray list file. @see `kTVMNDArrayListMagic` in `graph_runtime.h` -const NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7; +const _NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7; /// A TVM computation graph. /// @@ -56,13 +56,11 @@ impl Graph { .as_ref() .ok_or(ErrorKind::GraphFormatError( "Missing graph attrs".to_string(), - ))? - .get(attr) + ))?.get(attr) .ok_or(ErrorKind::GraphFormatError(format!( "Missing {} attr", attr - )))? - .to_owned(), + )))?.to_owned(), )?) } } @@ -101,8 +99,8 @@ impl Node { .ok_or(format!( "Node `{}` is missing attrs.flatten_data", self.name - ))? - .parse::()? == 1; + ))?.parse::()? + == 1; Ok(NodeAttrs { func_name, num_outputs, @@ -189,10 +187,9 @@ impl<'m, 't> GraphExecutor<'m, 't> { } else { Err(ErrorKind::GraphFormatError(format!("Invalid dltype: {}", dltype).to_string()).into()) } - }) - .collect::>>()?; + }).collect::>>()?; - let align = dtypes.iter().map(|dtype| dtype.bits as usize >> 3).max(); + let align = dtypes.iter().map(|dtype| dtype.bits as usize).max(); let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1]; for (i, &storage_id) in storage_ids.iter().enumerate() { let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3; @@ -217,8 +214,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { strides: None, byte_offset: 0, } - }) - .collect(); + }).collect(); Ok(tensors) } @@ -231,45 +227,50 @@ impl<'m, 't> GraphExecutor<'m, 't> { ) -> Result>> { ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr."); let node_row_ptr = graph.node_row_ptr.as_ref().unwrap(); - graph - .nodes - .iter() - .enumerate() - .filter(|(_i, node)| node.op != "null") - .map(|(i, node)| { - ensure!(node.op == "tvm_op", "Only TVM ops are supported."); - ensure!(node.attrs.is_some(), "Missing node_row_ptr."); - let attrs = node.parse_attrs()?; - let func = lib - .get_function(&attrs.func_name) - .ok_or(format!("Missing function {}", attrs.func_name))?; - let arg_indices = node - .inputs - .iter() - .map(|entry| graph.entry_index(entry)) - .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi))); - - let dl_tensors = arg_indices - .map(|idx| { - let tensor = &tensors[idx?]; - Ok(if attrs.flatten_data { - DLTensor::from_tensor(tensor, true /* flatten */) - } else { - DLTensor::from(tensor) - }) + + let mut op_execs = Vec::new(); + for (i, node) in graph.nodes.iter().enumerate() { + if node.op == "null" { + continue; + } + ensure!(node.op == "tvm_op", "Only TVM ops are supported."); + ensure!(node.attrs.is_some(), "Missing node attrs."); + + let attrs = node.parse_attrs()?; + + if attrs.func_name == "__nop" { + continue; + } + + let func = lib + .get_function(&attrs.func_name) + .ok_or(format!("Missing function {}", attrs.func_name))?; + let arg_indices = node + .inputs + .iter() + .map(|entry| graph.entry_index(entry)) + .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi))); + + let dl_tensors = arg_indices + .map(|idx| { + let tensor = &tensors[idx?]; + Ok(if attrs.flatten_data { + DLTensor::from_tensor(tensor, true /* flatten */) + } else { + DLTensor::from(tensor) }) - .collect::>>() - .unwrap(); - let op: Box = box move || { - let args = dl_tensors - .iter() - .map(|t| t.into()) - .collect::>(); - func(args.as_slice()); - }; - Ok(op) - }) - .collect() + }).collect::>>() + .unwrap(); + let op: Box = box move || { + let args = dl_tensors + .iter() + .map(|t| t.into()) + .collect::>(); + func(args.as_slice()); + }; + op_execs.push(op); + } + Ok(op_execs) } pub fn load_params(&mut self, params: HashMap>) { @@ -360,9 +361,9 @@ named!( /// Converts a bytes to String. named!( name, - map_res!(length_bytes!(le_u64), |b: &[u8]| { - String::from_utf8(b.to_vec()) - }) + map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8( + b.to_vec() + )) ); /// Parses a TVMContext @@ -390,17 +391,23 @@ named!( named!( tensor, do_parse!( - take!(8) >> bits!(tag_bits!(u64, 64, 0)) >> ctx: tvm_ctx >> ndim: le_u32 >> dtype: data_type - >> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize) >> length: le_i64 - >> data: take!(length) >> (Tensor { - data: Storage::from(data), - ctx: ctx, - dtype: dtype, - size: shape.iter().product::() as usize, - shape: shape, - strides: None, - byte_offset: 0, - }) + take!(8) + >> bits!(tag_bits!(u64, 64, 0)) + >> ctx: tvm_ctx + >> ndim: le_u32 + >> dtype: data_type + >> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize) + >> length: le_i64 + >> data: take!(length) + >> (Tensor { + data: Storage::from(data), + ctx: ctx, + dtype: dtype, + size: shape.iter().product::() as usize, + shape: shape, + strides: None, + byte_offset: 0, + }) ) ); @@ -408,7 +415,9 @@ named!( named!( parse_param_dict>, do_parse!( - take!(8) >> bits!(tag_bits!(u64, 64, 0)) >> names: length_count!(le_u64, name) + take!(8) + >> bits!(tag_bits!(u64, 64, 0)) + >> names: length_count!(le_u64, name) >> tensors: length_count!(le_u64, tensor) >> (HashMap::from_iter(names.into_iter().zip(tensors.into_iter()))) ) diff --git a/rust/src/runtime/mod.rs b/rust/src/runtime/mod.rs index 871d7aff58d7..bdf7094113d8 100644 --- a/rust/src/runtime/mod.rs +++ b/rust/src/runtime/mod.rs @@ -4,16 +4,22 @@ mod module; #[macro_use] mod packed_func; mod graph; +#[cfg(target_env = "sgx")] +#[macro_use] +pub mod sgx; mod threading; mod workspace; -use std::{ffi::CStr, os::raw::c_char}; +use std::os::raw::c_char; pub use self::{array::*, graph::*, module::*, packed_func::*, threading::*, workspace::*}; #[no_mangle] pub extern "C" fn TVMAPISetLastError(cmsg: *const c_char) { + #[cfg(not(target_env = "sgx"))] unsafe { - panic!(CStr::from_ptr(cmsg).to_str().unwrap()); + panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap()); } + #[cfg(target_env = "sgx")] + ocall_packed!("__sgx_set_last_error__", cmsg); } diff --git a/rust/src/runtime/packed_func.rs b/rust/src/runtime/packed_func.rs index ff5fd667af12..0fd3cd8ef04b 100644 --- a/rust/src/runtime/packed_func.rs +++ b/rust/src/runtime/packed_func.rs @@ -7,7 +7,7 @@ use ffi::runtime::{ use errors::*; -pub type PackedFunc = Box TVMRetValue>; +pub type PackedFunc = Box TVMRetValue + Send + Sync>; /// Calls a packed function and returns a `TVMRetValue`. /// @@ -19,6 +19,9 @@ macro_rules! call_packed { ($fn:expr, $($args:expr),+) => { $fn(&[$($args.into(),)+]) }; + ($fn:expr) => { + $fn(&Vec::new()) + }; } /// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way @@ -26,8 +29,18 @@ macro_rules! call_packed { #[derive(Clone, Copy)] pub struct TVMArgValue<'a> { _lifetime: PhantomData<&'a ()>, - value: TVMValue, - type_code: i64, + pub(crate) value: TVMValue, + pub(crate) type_code: i64, +} + +impl<'a> TVMArgValue<'a> { + pub fn new(value: TVMValue, type_code: i64) -> Self { + TVMArgValue { + _lifetime: PhantomData, + value: value, + type_code: type_code, + } + } } /// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode. @@ -136,6 +149,25 @@ pub struct TVMRetValue { type_code: i64, } +impl TVMRetValue { + #[cfg(target_env = "sgx")] + pub(crate) fn from_tvm_value(value: TVMValue, type_code: i64) -> Self { + unsafe { + Self { + prim_value: match type_code { + 0 | 1 => value.v_int64 as u64, + 2 => value.v_float64 as u64, + 3 | 7 | 8 | 9 | 10 => value.v_handle as u64, + 11 | 12 => value.v_str as u64, + _ => 0, + } as u64, + box_value: box (), + type_code: type_code, + } + } + } +} + impl Default for TVMRetValue { fn default() -> Self { TVMRetValue { diff --git a/rust/src/runtime/sgx.rs b/rust/src/runtime/sgx.rs new file mode 100644 index 000000000000..7f3804b8407a --- /dev/null +++ b/rust/src/runtime/sgx.rs @@ -0,0 +1,78 @@ +use std::{ + ffi::CString, + os::raw::{c_char, c_int}, +}; + +use errors::Result; +use ffi::runtime::TVMValue; +use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue}; + +pub use runtime::threading::tvm_run_worker as run_worker; + +#[macro_export] +macro_rules! tvm_ocall { + ($func: expr) => { + match $func { + 0 => Ok(()), + err => Err(format!("SGX error: {}", err)), + } + }; +} + +pub type SgxStatus = u32; + +#[cfg(target_env = "sgx")] +extern "C" { + fn tvm_ocall_packed_func( + name: *const c_char, + arg_values: *const TVMValue, + type_codes: *const c_int, + num_args: c_int, + ret_val: *mut TVMValue, + ret_type_code: *mut c_int, + ) -> SgxStatus; +} + +pub fn ocall_packed_func>(fn_name: S, args: &[TVMArgValue]) -> Result { + let mut ret_val = TVMValue { v_int64: 0 }; + let ret_type_code = 0i64; + unsafe { + tvm_ocall!(tvm_ocall_packed_func( + CString::new(fn_name.as_ref()).unwrap().as_ptr(), + args + .iter() + .map(|ref arg| arg.value) + .collect::>() + .as_ptr(), + args + .iter() + .map(|ref arg| arg.type_code as i32) + .collect::>() + .as_ptr() as *const i32, + args.len() as i32, + &mut ret_val as *mut TVMValue, + &mut (ret_type_code as i32) as *mut c_int, + ))?; + } + Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64)) +} + +#[macro_export] +macro_rules! ocall_packed { + ($fn_name:expr, $($args:expr),+) => { + ::runtime::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+]) + .expect(concat!("Error calling `", $fn_name, "`")) + }; + ($fn_name:expr) => { + ::runtime::sgx::ocall_packed_func($fn_name, &Vec::new()) + .expect(concat!("Error calling `", $fn_name, "`")) + } +} + +impl Drop for SystemLibModule { + fn drop(&mut self) { + if env!("TVM_NUM_THREADS") != "0" { + sgx_join_threads() + } + } +} diff --git a/rust/src/runtime/threading.rs b/rust/src/runtime/threading.rs index c5db5147632d..6bbaa92349b4 100644 --- a/rust/src/runtime/threading.rs +++ b/rust/src/runtime/threading.rs @@ -2,8 +2,7 @@ use std::{ os::raw::{c_int, c_void}, sync::{ atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT}, - Arc, - Barrier, + Arc, Barrier, }, }; @@ -16,13 +15,16 @@ use std::{ }; #[cfg(target_env = "sgx")] -use std::{collections::VecDeque, sync::Mutex}; +use std::{collections::VecDeque, ptr, sync::Mutex}; use bounded_spsc_queue::{self, Producer}; use super::super::errors::*; use ffi::runtime::TVMParallelGroupEnv; +#[cfg(target_env = "sgx")] +use super::{TVMArgValue, TVMRetValue}; + type FTVMParallelLambda = extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; @@ -56,8 +58,7 @@ impl Job { }, cdata: self.cdata, pending: Arc::clone(&self.pending), - }) - .collect() + }).collect() } /// Waits for all tasks in this `Job` to be completed. @@ -109,8 +110,7 @@ impl<'a> Threads { let (p, c) = bounded_spsc_queue::make(2); let handle = thread::spawn(move || cb(c.into())); (handle, p) - }) - .unzip(); + }).unzip(); Threads { handles: handles, queues: queues, @@ -118,15 +118,18 @@ impl<'a> Threads { } #[cfg(target_env = "sgx")] - fn launch) + 'static + Copy>(num: usize, _cb: F) -> Self { + fn launch) + 'static + Copy>( + num_threads: usize, + _cb: F, + ) -> Self { let mut consumer_queues = SGX_QUEUES.lock().unwrap(); - let queues = (0..num) + let queues = (0..num_threads) .map(|_| { let (p, c) = bounded_spsc_queue::make(2); consumer_queues.push_back(c.into()); p - }) - .collect(); + }).collect(); + ocall_packed!("__sgx_thread_group_launch__", num_threads as u64); Threads { queues: queues } } } @@ -149,21 +152,23 @@ impl ThreadPool { } fn launch(&self, job: Job) { - let tasks = job.tasks(self.num_workers); + let mut tasks = job.tasks(self.num_workers + 1); - let _: Vec<()> = tasks - .into_iter() - .zip(self.threads.queues.iter()) - .map(|(task, q)| q.push(task)) - .collect(); + for (i, task) in tasks.split_off(1).into_iter().enumerate() { + self.threads.queues[i].push(task); + } + tasks.pop().unwrap()(); job.wait().unwrap(); } fn run_worker(queue: Consumer) { loop { let task = queue.pop(); - if task() != 0 { + let result = task(); + if result == ::min_value() { + break; + } else if result != 0 { panic!("Error running task."); } } @@ -214,11 +219,16 @@ fn max_concurrency() -> usize { } #[cfg(target_env = "sgx")] -#[no_mangle] -pub extern "C" fn tvm_ecall_run_worker() { - if let Some(q) = SGX_QUEUES.lock().unwrap().pop_front() { +pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue { + let q = { + let mut qs = SGX_QUEUES.lock().unwrap(); + qs.pop_front() + // `qs: MutexGuard` needs to be dropped here since `run_worker` won't return + }; + if let Some(q) = q { ThreadPool::run_worker(q); } + TVMRetValue::default() } #[no_mangle] @@ -233,9 +243,6 @@ pub extern "C" fn TVMBackendParallelLaunch( num_task: 1, }; cb(0, &penv as *const _, cdata); - - #[cfg(feature = "par-launch-alloc")] - let break_the_heap: Vec = Vec::new(); // TODO: why does allocating break? } else { THREAD_POOL.with(|pool| { pool.launch(Job { @@ -249,6 +256,27 @@ pub extern "C" fn TVMBackendParallelLaunch( return 0; } +#[cfg(target_env = "sgx")] +pub(crate) fn sgx_join_threads() -> () { + extern "C" fn poison_pill( + _task_id: usize, + _penv: *const TVMParallelGroupEnv, + _cdata: *const c_void, + ) -> i32 { + ::min_value() + } + + THREAD_POOL.with(|pool| { + pool.launch(Job { + cb: poison_pill, + cdata: ptr::null(), + req_num_tasks: 0, + pending: Arc::new(ATOMIC_USIZE_INIT), + }); + }); + ocall_packed!("__sgx_thread_group_join__", 0); +} + // @see https://github.com/dmlc/tvm/issues/988 for information on why this function is used. #[no_mangle] pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) { diff --git a/rust/src/runtime/workspace.rs b/rust/src/runtime/workspace.rs index fe9d8550a32c..979e4d44ea20 100644 --- a/rust/src/runtime/workspace.rs +++ b/rust/src/runtime/workspace.rs @@ -1,7 +1,6 @@ use std::{ cell::RefCell, os::raw::{c_int, c_void}, - ptr, }; use super::allocator::Allocation; @@ -22,24 +21,40 @@ impl WorkspacePool { } } - fn alloc(&mut self, size: usize) -> Result<*mut u8> { + fn alloc_new(&mut self, size: usize, align: usize) -> Result<*mut u8> { + self.workspaces.push(Allocation::new(size, Some(align))?); + self.in_use.push(self.workspaces.len() - 1); + Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr()) + } + + fn alloc(&mut self, size: usize, align: usize) -> Result<*mut u8> { if self.free.len() == 0 { - self.workspaces.push(Allocation::new(size, None)?); - self.free.push(self.workspaces.len() - 1); - Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr()) - } else { - let i = self.free.iter().fold(0, |cur_ws_idx, &idx| { - let cur_size = self.workspaces[cur_ws_idx].size(); + return self.alloc_new(size, align); + } + let idx = self + .free + .iter() + .fold(None, |cur_ws_idx: Option, &idx| { let ws_size = self.workspaces[idx].size(); - if ws_size < size || ws_size > cur_size { - cur_ws_idx - } else { - idx + let ws_ok = ws_size >= size && self.workspaces[idx].align() == align; + if !ws_ok { + return cur_ws_idx; } + cur_ws_idx.or(Some(idx)).and_then(|cur_idx| { + let cur_size = self.workspaces[cur_idx].size(); + Some(match ws_size <= cur_size { + true => idx, + false => cur_idx, + }) + }) }); - let idx = self.free.remove(i); - self.in_use.push(idx.clone()); - Ok(self.workspaces[idx].as_mut_ptr()) + match idx { + Some(idx) => { + self.free.remove_item(&idx).unwrap(); + self.in_use.push(idx); + Ok(self.workspaces[idx].as_mut_ptr()) + } + None => self.alloc_new(size, align), } } @@ -71,7 +86,7 @@ pub extern "C" fn TVMBackendAllocWorkspace( _device_id: c_int, size: u64, _dtype_code_hint: c_int, - _dtype_bits_hint: c_int, + dtype_bits_hint: c_int, ) -> *mut c_void { let nbytes = if size == 0 { WORKSPACE_PAGE_SIZE @@ -79,12 +94,11 @@ pub extern "C" fn TVMBackendAllocWorkspace( size as usize }; WORKSPACE_POOL.with(|pool_cell| { - (match pool_cell.borrow_mut().alloc(nbytes as usize) { - Ok(ptr) => ptr, - Err(_) => ptr::null_mut(), - }) as *mut c_void - }); - return ptr::null_mut(); + pool_cell + .borrow_mut() + .alloc(nbytes as usize, dtype_bits_hint as usize) + .unwrap() as *mut c_void + }) } #[no_mangle] diff --git a/rust/tests/test_graph_serde.rs b/rust/tests/test_graph_serde.rs index a3679812e3eb..b08469615d6b 100644 --- a/rust/tests/test_graph_serde.rs +++ b/rust/tests/test_graph_serde.rs @@ -18,10 +18,9 @@ fn test_load_graph() { .unwrap(); let params = tvm::runtime::load_param_dict(¶ms_bytes); - let graph = - Graph::try_from( - &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(), - ).unwrap(); + let graph = Graph::try_from( + &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(), + ).unwrap(); assert_eq!(graph.nodes[3].op, "tvm_op"); assert_eq!( From a0d3d413b61d7ce7dcf6d792e342aca59ec213c5 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Fri, 5 Oct 2018 00:06:12 +0000 Subject: [PATCH 09/10] Updates --- rust/src/lib.rs | 11 +++++++++ rust/src/runtime/graph.rs | 6 ++--- rust/src/runtime/workspace.rs | 27 +++++++++++--------- rust/tests/test_graph_serde.rs | 8 +++--- rust/tests/test_nnvm/Cargo.toml | 3 --- rust/tests/test_nnvm/src/main.rs | 42 +++++++++++++++----------------- 6 files changed, 52 insertions(+), 45 deletions(-) diff --git a/rust/src/lib.rs b/rust/src/lib.rs index fd83d5ee6c6d..4a70e428d37a 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,3 +1,14 @@ +//! This crate is an implementation of the TVM runtime for modules compiled with `--system-lib`. +//! It's mainly useful for compiling to WebAssembly and SGX, +//! but also native if you prefer Rust to C++. +//! +//! For TVM graphs, the entrypoint to this crate is `runtime::GraphExecutor`. +//! Single-function modules are used via the `packed_func!` macro after obtaining +//! the function from `runtime::SystemLibModule` +//! +//! The main entrypoints to this crate are `GraphExecutor` +//! For examples of use, please refer to the multi-file tests in the `tests` directory. + #![feature( alloc, allocator_api, diff --git a/rust/src/runtime/graph.rs b/rust/src/runtime/graph.rs index a74c07f4fc3c..e07e8370a459 100644 --- a/rust/src/runtime/graph.rs +++ b/rust/src/runtime/graph.rs @@ -131,7 +131,7 @@ impl<'a> TryFrom<&'a str> for Graph { /// /// ``` /// use ndarray::Array; - +/// /// let syslib = SystemLibModule::default(); // a provider of TVM functions /// /// let mut params_bytes = Vec::new(); @@ -142,7 +142,7 @@ impl<'a> TryFrom<&'a str> for Graph { /// /// let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); /// exec.load_params(params); - +/// /// let x = Array::from_vec(vec![1f32, 2., 3., 4.]); /// exec.set_input("data", x.into()); /// exec.run(); @@ -189,7 +189,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { } }).collect::>>()?; - let align = dtypes.iter().map(|dtype| dtype.bits as usize).max(); + let align = dtypes.iter().map(|dtype| dtype.bits as usize >> 3).max(); let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1]; for (i, &storage_id) in storage_ids.iter().enumerate() { let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3; diff --git a/rust/src/runtime/workspace.rs b/rust/src/runtime/workspace.rs index 979e4d44ea20..fa96651fac10 100644 --- a/rust/src/runtime/workspace.rs +++ b/rust/src/runtime/workspace.rs @@ -1,11 +1,14 @@ use std::{ cell::RefCell, os::raw::{c_int, c_void}, + ptr, }; use super::allocator::Allocation; use errors::*; +const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h` + struct WorkspacePool { workspaces: Vec, free: Vec, @@ -21,28 +24,28 @@ impl WorkspacePool { } } - fn alloc_new(&mut self, size: usize, align: usize) -> Result<*mut u8> { - self.workspaces.push(Allocation::new(size, Some(align))?); + fn alloc_new(&mut self, size: usize) -> Result<*mut u8> { + self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?); self.in_use.push(self.workspaces.len() - 1); Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr()) } - fn alloc(&mut self, size: usize, align: usize) -> Result<*mut u8> { + fn alloc(&mut self, size: usize) -> Result<*mut u8> { if self.free.len() == 0 { - return self.alloc_new(size, align); + return self.alloc_new(size); } let idx = self .free .iter() .fold(None, |cur_ws_idx: Option, &idx| { let ws_size = self.workspaces[idx].size(); - let ws_ok = ws_size >= size && self.workspaces[idx].align() == align; - if !ws_ok { + if !ws_size >= size { return cur_ws_idx; } - cur_ws_idx.or(Some(idx)).and_then(|cur_idx| { + cur_ws_idx.and_then(|cur_idx| { let cur_size = self.workspaces[cur_idx].size(); - Some(match ws_size <= cur_size { + Some(match ws_size < cur_size { + // is already ok true => idx, false => cur_idx, }) @@ -54,7 +57,7 @@ impl WorkspacePool { self.in_use.push(idx); Ok(self.workspaces[idx].as_mut_ptr()) } - None => self.alloc_new(size, align), + None => self.alloc_new(size), } } @@ -86,7 +89,7 @@ pub extern "C" fn TVMBackendAllocWorkspace( _device_id: c_int, size: u64, _dtype_code_hint: c_int, - dtype_bits_hint: c_int, + _dtype_bits_hint: c_int, ) -> *mut c_void { let nbytes = if size == 0 { WORKSPACE_PAGE_SIZE @@ -96,8 +99,8 @@ pub extern "C" fn TVMBackendAllocWorkspace( WORKSPACE_POOL.with(|pool_cell| { pool_cell .borrow_mut() - .alloc(nbytes as usize, dtype_bits_hint as usize) - .unwrap() as *mut c_void + .alloc(nbytes as usize) + .unwrap_or(ptr::null_mut()) as *mut c_void }) } diff --git a/rust/tests/test_graph_serde.rs b/rust/tests/test_graph_serde.rs index b08469615d6b..a596544212ca 100644 --- a/rust/tests/test_graph_serde.rs +++ b/rust/tests/test_graph_serde.rs @@ -1,11 +1,11 @@ -#![feature(fs_read_write, try_from)] +#![feature(try_from)] extern crate serde; extern crate serde_json; extern crate tvm; -use std::{collections::HashMap, convert::TryFrom, fs, io::Read}; +use std::{convert::TryFrom, fs, io::Read}; use tvm::runtime::Graph; @@ -13,10 +13,10 @@ use tvm::runtime::Graph; fn test_load_graph() { let mut params_bytes = Vec::new(); fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params")) - .unwrap() + .expect("Could not find TVM graph. Did you run `tests/build_model.py`?") .read_to_end(&mut params_bytes) .unwrap(); - let params = tvm::runtime::load_param_dict(¶ms_bytes); + let _params = tvm::runtime::load_param_dict(¶ms_bytes); let graph = Graph::try_from( &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(), diff --git a/rust/tests/test_nnvm/Cargo.toml b/rust/tests/test_nnvm/Cargo.toml index 978bdf9a428f..7e6ce5fb729c 100644 --- a/rust/tests/test_nnvm/Cargo.toml +++ b/rust/tests/test_nnvm/Cargo.toml @@ -4,9 +4,6 @@ version = "0.0.0" license = "Apache-2.0" authors = ["Nick Hynes "] -[features] -par-launch-alloc = ["tvm/par-launch-alloc"] - [dependencies] ndarray = "0.11.2" tvm = { path = "../../" } diff --git a/rust/tests/test_nnvm/src/main.rs b/rust/tests/test_nnvm/src/main.rs index 9fc5ba5d3537..0953ce2a2603 100644 --- a/rust/tests/test_nnvm/src/main.rs +++ b/rust/tests/test_nnvm/src/main.rs @@ -1,4 +1,4 @@ -#![feature(fs_read_write, try_from)] +#![feature(try_from)] #[macro_use] extern crate ndarray; @@ -14,10 +14,18 @@ use tvm::runtime::{Graph, GraphExecutor, SystemLibModule, Tensor}; const BATCH_SIZE: usize = 4; const IN_DIM: usize = 8; -macro_rules! assert_sum_eq { - ($a:expr, $b:expr) => { - let a_sum = $a.scalar_sum(); - let b_sum = $b.scalar_sum(); +macro_rules! check_sum { + ($e:expr, $a:ident, $b:ident) => { + let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap(); + check_sum!(a, $b); + }; + ($e:expr, $a:expr, $b:ident) => { + let a = Array::try_from($e.get_output($a).unwrap()).unwrap(); + check_sum!(a, $b); + }; + ($a:ident, $b:ident) => { + let a_sum: f32 = $a.scalar_sum(); + let b_sum: f32 = $b.scalar_sum(); assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum); }; } @@ -60,25 +68,13 @@ fn main() { exec.load_params(params); exec.set_input("data", x.clone().into()); - assert_sum_eq!(Array::try_from(exec.get_input("data").unwrap()).unwrap(), x); - assert_sum_eq!( - Array::try_from(exec.get_input("dense0_weight").unwrap()).unwrap(), - w - ); - assert_sum_eq!( - Array::try_from(exec.get_input("dense0_bias").unwrap()).unwrap(), - b - ); + check_sum!(exec, data, x); + check_sum!(exec, dense0_weight, w); + check_sum!(exec, dense0_bias, b); exec.run(); - assert_sum_eq!( - Array::try_from(exec.get_output(0).unwrap()).unwrap(), - expected_o0 - ); - assert_sum_eq!( - Array::try_from(exec.get_output(1).unwrap()).unwrap(), - expected_o1 - ); - assert_sum_eq!(Array::try_from(exec.get_output(2).unwrap()).unwrap(), dense); + check_sum!(exec, 0, expected_o0); + check_sum!(exec, 1, expected_o1); + check_sum!(exec, 2, dense); } From 88aff7ce3b62eec61d418237b09f02be30c54312 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Thu, 4 Oct 2018 23:40:08 -0500 Subject: [PATCH 10/10] Updates --- rust/src/runtime/array.rs | 2 ++ rust/src/runtime/graph.rs | 4 +++- rust/src/runtime/packed_func.rs | 21 ++++++++++++++++++++- rust/src/runtime/sgx.rs | 10 +++++++--- rust/src/runtime/threading.rs | 2 +- rust/src/runtime/workspace.rs | 5 ++--- 6 files changed, 35 insertions(+), 9 deletions(-) diff --git a/rust/src/runtime/array.rs b/rust/src/runtime/array.rs index 6a9a60c9bb14..79d22e400cff 100644 --- a/rust/src/runtime/array.rs +++ b/rust/src/runtime/array.rs @@ -129,6 +129,8 @@ pub struct Tensor<'a> { pub(super) size: usize, } +unsafe impl<'a> Send for Tensor<'a> {} + impl<'a> Tensor<'a> { pub fn shape(&self) -> Vec { self.shape.clone() diff --git a/rust/src/runtime/graph.rs b/rust/src/runtime/graph.rs index e07e8370a459..6c53aeb9f6e9 100644 --- a/rust/src/runtime/graph.rs +++ b/rust/src/runtime/graph.rs @@ -156,6 +156,8 @@ pub struct GraphExecutor<'m, 't> { tensors: Vec>, } +unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {} + impl<'m, 't> GraphExecutor<'m, 't> { pub fn new(graph: Graph, lib: &'m M) -> Result { let tensors = Self::setup_storages(&graph)?; @@ -189,7 +191,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { } }).collect::>>()?; - let align = dtypes.iter().map(|dtype| dtype.bits as usize >> 3).max(); + let align = dtypes.iter().map(|dtype| dtype.bits as usize).max(); let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1]; for (i, &storage_id) in storage_ids.iter().enumerate() { let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3; diff --git a/rust/src/runtime/packed_func.rs b/rust/src/runtime/packed_func.rs index 0fd3cd8ef04b..030d677329c0 100644 --- a/rust/src/runtime/packed_func.rs +++ b/rust/src/runtime/packed_func.rs @@ -149,8 +149,8 @@ pub struct TVMRetValue { type_code: i64, } +#[cfg(target_env = "sgx")] impl TVMRetValue { - #[cfg(target_env = "sgx")] pub(crate) fn from_tvm_value(value: TVMValue, type_code: i64) -> Self { unsafe { Self { @@ -166,6 +166,25 @@ impl TVMRetValue { } } } + + pub fn into_tvm_value(self) -> (TVMValue, i64) { + let val = match self.type_code { + 0 | 1 => TVMValue { + v_int64: self.prim_value.clone() as i64, + }, + 2 => TVMValue { + v_float64: self.prim_value.clone() as f64, + }, + 3 | 7 | 8 | 9 | 10 => TVMValue { + v_handle: Box::into_raw(self.box_value) as *mut c_void, + }, + 11 | 12 => TVMValue { + v_str: Box::into_raw(self.box_value) as *const _, + }, + _ => unreachable!(), + }; + (val, self.type_code) + } } impl Default for TVMRetValue { diff --git a/rust/src/runtime/sgx.rs b/rust/src/runtime/sgx.rs index 7f3804b8407a..bf9d54a4af65 100644 --- a/rust/src/runtime/sgx.rs +++ b/rust/src/runtime/sgx.rs @@ -69,10 +69,14 @@ macro_rules! ocall_packed { } } +pub fn shutdown() { + if env!("TVM_NUM_THREADS") != "0" { + sgx_join_threads() + } +} + impl Drop for SystemLibModule { fn drop(&mut self) { - if env!("TVM_NUM_THREADS") != "0" { - sgx_join_threads() - } + shutdown() } } diff --git a/rust/src/runtime/threading.rs b/rust/src/runtime/threading.rs index 6bbaa92349b4..c0d6221c91b7 100644 --- a/rust/src/runtime/threading.rs +++ b/rust/src/runtime/threading.rs @@ -257,7 +257,7 @@ pub extern "C" fn TVMBackendParallelLaunch( } #[cfg(target_env = "sgx")] -pub(crate) fn sgx_join_threads() -> () { +pub(crate) fn sgx_join_threads() { extern "C" fn poison_pill( _task_id: usize, _penv: *const TVMParallelGroupEnv, diff --git a/rust/src/runtime/workspace.rs b/rust/src/runtime/workspace.rs index fa96651fac10..d0e6d8c89255 100644 --- a/rust/src/runtime/workspace.rs +++ b/rust/src/runtime/workspace.rs @@ -42,10 +42,9 @@ impl WorkspacePool { if !ws_size >= size { return cur_ws_idx; } - cur_ws_idx.and_then(|cur_idx| { + cur_ws_idx.or(Some(idx)).and_then(|cur_idx| { let cur_size = self.workspaces[cur_idx].size(); - Some(match ws_size < cur_size { - // is already ok + Some(match ws_size <= cur_size { true => idx, false => cur_idx, })