Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
nhynes committed Oct 5, 2018
1 parent 56e8531 commit a1e4d02
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 45 deletions.
11 changes: 11 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
//! This crate is an implementation of the TVM runtime for modules compiled with `--system-lib`.
//! It's mainly useful for compiling to WebAssembly and SGX,
//! but also native if you prefer Rust to C++.
//!
//! For TVM graphs, the entrypoint to this crate is `runtime::GraphExecutor`.
//! Single-function modules are used via the `packed_func!` macro after obtaining
//! the function from `runtime::SystemLibModule`
//!
//! The main entrypoints to this crate are `GraphExecutor`
//! For examples of use, please refer to the multi-file tests in the `tests` directory.

#![feature(
alloc,
allocator_api,
Expand Down
6 changes: 3 additions & 3 deletions rust/src/runtime/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ impl<'a> TryFrom<&'a str> for Graph {
///
/// ```
/// use ndarray::Array;

///
/// let syslib = SystemLibModule::default(); // a provider of TVM functions
///
/// let mut params_bytes = Vec::new();
Expand All @@ -142,7 +142,7 @@ impl<'a> TryFrom<&'a str> for Graph {
///
/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap();
/// exec.load_params(params);

///
/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]);
/// exec.set_input("data", x.into());
/// exec.run();
Expand Down Expand Up @@ -189,7 +189,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
}
}).collect::<Result<Vec<DataType>>>()?;

let align = dtypes.iter().map(|dtype| dtype.bits as usize).max();
let align = dtypes.iter().map(|dtype| dtype.bits as usize >> 3).max();
let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
for (i, &storage_id) in storage_ids.iter().enumerate() {
let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3;
Expand Down
27 changes: 15 additions & 12 deletions rust/src/runtime/workspace.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::{
cell::RefCell,
os::raw::{c_int, c_void},
ptr,
};

use super::allocator::Allocation;
use errors::*;

const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h`

struct WorkspacePool {
workspaces: Vec<Allocation>,
free: Vec<usize>,
Expand All @@ -21,28 +24,28 @@ impl WorkspacePool {
}
}

fn alloc_new(&mut self, size: usize, align: usize) -> Result<*mut u8> {
self.workspaces.push(Allocation::new(size, Some(align))?);
fn alloc_new(&mut self, size: usize) -> Result<*mut u8> {
self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?);
self.in_use.push(self.workspaces.len() - 1);
Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr())
}

fn alloc(&mut self, size: usize, align: usize) -> Result<*mut u8> {
fn alloc(&mut self, size: usize) -> Result<*mut u8> {
if self.free.len() == 0 {
return self.alloc_new(size, align);
return self.alloc_new(size);
}
let idx = self
.free
.iter()
.fold(None, |cur_ws_idx: Option<usize>, &idx| {
let ws_size = self.workspaces[idx].size();
let ws_ok = ws_size >= size && self.workspaces[idx].align() == align;
if !ws_ok {
if !ws_size >= size {
return cur_ws_idx;
}
cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
cur_ws_idx.and_then(|cur_idx| {
let cur_size = self.workspaces[cur_idx].size();
Some(match ws_size <= cur_size {
Some(match ws_size < cur_size {
// is already ok
true => idx,
false => cur_idx,
})
Expand All @@ -54,7 +57,7 @@ impl WorkspacePool {
self.in_use.push(idx);
Ok(self.workspaces[idx].as_mut_ptr())
}
None => self.alloc_new(size, align),
None => self.alloc_new(size),
}
}

Expand Down Expand Up @@ -86,7 +89,7 @@ pub extern "C" fn TVMBackendAllocWorkspace(
_device_id: c_int,
size: u64,
_dtype_code_hint: c_int,
dtype_bits_hint: c_int,
_dtype_bits_hint: c_int,
) -> *mut c_void {
let nbytes = if size == 0 {
WORKSPACE_PAGE_SIZE
Expand All @@ -96,8 +99,8 @@ pub extern "C" fn TVMBackendAllocWorkspace(
WORKSPACE_POOL.with(|pool_cell| {
pool_cell
.borrow_mut()
.alloc(nbytes as usize, dtype_bits_hint as usize)
.unwrap() as *mut c_void
.alloc(nbytes as usize)
.unwrap_or(ptr::null_mut()) as *mut c_void
})
}

Expand Down
8 changes: 4 additions & 4 deletions rust/tests/test_graph_serde.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
#![feature(fs_read_write, try_from)]
#![feature(try_from)]

extern crate serde;
extern crate serde_json;

extern crate tvm;

use std::{collections::HashMap, convert::TryFrom, fs, io::Read};
use std::{convert::TryFrom, fs, io::Read};

use tvm::runtime::Graph;

#[test]
fn test_load_graph() {
let mut params_bytes = Vec::new();
fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params"))
.unwrap()
.expect("Could not find TVM graph. Did you run `tests/build_model.py`?")
.read_to_end(&mut params_bytes)
.unwrap();
let params = tvm::runtime::load_param_dict(&params_bytes);
let _params = tvm::runtime::load_param_dict(&params_bytes);

let graph = Graph::try_from(
&fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(),
Expand Down
3 changes: 0 additions & 3 deletions rust/tests/test_nnvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ version = "0.0.0"
license = "Apache-2.0"
authors = ["Nick Hynes <nhynes@berkeley.edu>"]

[features]
par-launch-alloc = ["tvm/par-launch-alloc"]

[dependencies]
ndarray = "0.11.2"
tvm = { path = "../../" }
Expand Down
42 changes: 19 additions & 23 deletions rust/tests/test_nnvm/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#![feature(fs_read_write, try_from)]
#![feature(try_from)]

#[macro_use]
extern crate ndarray;
Expand All @@ -14,10 +14,18 @@ use tvm::runtime::{Graph, GraphExecutor, SystemLibModule, Tensor};
const BATCH_SIZE: usize = 4;
const IN_DIM: usize = 8;

macro_rules! assert_sum_eq {
($a:expr, $b:expr) => {
let a_sum = $a.scalar_sum();
let b_sum = $b.scalar_sum();
macro_rules! check_sum {
($e:expr, $a:ident, $b:ident) => {
let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap();
check_sum!(a, $b);
};
($e:expr, $a:expr, $b:ident) => {
let a = Array::try_from($e.get_output($a).unwrap()).unwrap();
check_sum!(a, $b);
};
($a:ident, $b:ident) => {
let a_sum: f32 = $a.scalar_sum();
let b_sum: f32 = $b.scalar_sum();
assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum);
};
}
Expand Down Expand Up @@ -60,25 +68,13 @@ fn main() {
exec.load_params(params);
exec.set_input("data", x.clone().into());

assert_sum_eq!(Array::try_from(exec.get_input("data").unwrap()).unwrap(), x);
assert_sum_eq!(
Array::try_from(exec.get_input("dense0_weight").unwrap()).unwrap(),
w
);
assert_sum_eq!(
Array::try_from(exec.get_input("dense0_bias").unwrap()).unwrap(),
b
);
check_sum!(exec, data, x);
check_sum!(exec, dense0_weight, w);
check_sum!(exec, dense0_bias, b);

exec.run();

assert_sum_eq!(
Array::try_from(exec.get_output(0).unwrap()).unwrap(),
expected_o0
);
assert_sum_eq!(
Array::try_from(exec.get_output(1).unwrap()).unwrap(),
expected_o1
);
assert_sum_eq!(Array::try_from(exec.get_output(2).unwrap()).unwrap(), dense);
check_sum!(exec, 0, expected_o0);
check_sum!(exec, 1, expected_o1);
check_sum!(exec, 2, dense);
}

0 comments on commit a1e4d02

Please sign in to comment.