Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
nhynes committed Oct 6, 2018
1 parent a0d3d41 commit 88aff7c
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 9 deletions.
2 changes: 2 additions & 0 deletions rust/src/runtime/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ pub struct Tensor<'a> {
pub(super) size: usize,
}

unsafe impl<'a> Send for Tensor<'a> {}

impl<'a> Tensor<'a> {
pub fn shape(&self) -> Vec<i64> {
self.shape.clone()
Expand Down
4 changes: 3 additions & 1 deletion rust/src/runtime/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ pub struct GraphExecutor<'m, 't> {
tensors: Vec<Tensor<'t>>,
}

unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {}

impl<'m, 't> GraphExecutor<'m, 't> {
pub fn new<M: 'm + Module>(graph: Graph, lib: &'m M) -> Result<Self> {
let tensors = Self::setup_storages(&graph)?;
Expand Down Expand Up @@ -189,7 +191,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
}
}).collect::<Result<Vec<DataType>>>()?;

let align = dtypes.iter().map(|dtype| dtype.bits as usize >> 3).max();
let align = dtypes.iter().map(|dtype| dtype.bits as usize).max();
let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1];
for (i, &storage_id) in storage_ids.iter().enumerate() {
let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3;
Expand Down
21 changes: 20 additions & 1 deletion rust/src/runtime/packed_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ pub struct TVMRetValue {
type_code: i64,
}

#[cfg(target_env = "sgx")]
impl TVMRetValue {
#[cfg(target_env = "sgx")]
pub(crate) fn from_tvm_value(value: TVMValue, type_code: i64) -> Self {
unsafe {
Self {
Expand All @@ -166,6 +166,25 @@ impl TVMRetValue {
}
}
}

pub fn into_tvm_value(self) -> (TVMValue, i64) {
let val = match self.type_code {
0 | 1 => TVMValue {
v_int64: self.prim_value.clone() as i64,
},
2 => TVMValue {
v_float64: self.prim_value.clone() as f64,
},
3 | 7 | 8 | 9 | 10 => TVMValue {
v_handle: Box::into_raw(self.box_value) as *mut c_void,
},
11 | 12 => TVMValue {
v_str: Box::into_raw(self.box_value) as *const _,
},
_ => unreachable!(),
};
(val, self.type_code)
}
}

impl Default for TVMRetValue {
Expand Down
10 changes: 7 additions & 3 deletions rust/src/runtime/sgx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,14 @@ macro_rules! ocall_packed {
}
}

pub fn shutdown() {
if env!("TVM_NUM_THREADS") != "0" {
sgx_join_threads()
}
}

impl Drop for SystemLibModule {
fn drop(&mut self) {
if env!("TVM_NUM_THREADS") != "0" {
sgx_join_threads()
}
shutdown()
}
}
2 changes: 1 addition & 1 deletion rust/src/runtime/threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ pub extern "C" fn TVMBackendParallelLaunch(
}

#[cfg(target_env = "sgx")]
pub(crate) fn sgx_join_threads() -> () {
pub(crate) fn sgx_join_threads() {
extern "C" fn poison_pill(
_task_id: usize,
_penv: *const TVMParallelGroupEnv,
Expand Down
5 changes: 2 additions & 3 deletions rust/src/runtime/workspace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@ impl WorkspacePool {
if !ws_size >= size {
return cur_ws_idx;
}
cur_ws_idx.and_then(|cur_idx| {
cur_ws_idx.or(Some(idx)).and_then(|cur_idx| {
let cur_size = self.workspaces[cur_idx].size();
Some(match ws_size < cur_size {
// is already ok
Some(match ws_size <= cur_size {
true => idx,
false => cur_idx,
})
Expand Down

0 comments on commit 88aff7c

Please sign in to comment.