Skip to content
This repository has been archived by the owner on Jan 17, 2022. It is now read-only.

Commit

Permalink
Use linear time algorithm to inject stack height metering (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
athei authored Sep 7, 2021
1 parent 2f88f49 commit 2293760
Showing 1 changed file with 54 additions and 59 deletions.
113 changes: 54 additions & 59 deletions src/stack_height/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@
//! between the frames.
//! - upon entry into the function entire stack frame is allocated.

use crate::std::{string::String, vec::Vec};
use crate::std::{mem, string::String, vec::Vec};

use parity_wasm::{
builder,
elements::{self, Type},
elements::{self, Instruction, Instructions, Type},
};

/// Macro to generate preamble and postamble.
Expand Down Expand Up @@ -145,7 +145,7 @@ fn generate_stack_height_global(module: &mut elements::Module) -> u32 {
.value_type()
.i32()
.mutable()
.init_expr(elements::Instruction::I32Const(0))
.init_expr(Instruction::I32Const(0))
.build();

// Try to find an existing global section.
Expand Down Expand Up @@ -253,75 +253,70 @@ fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Res
///
/// drop
/// ```
fn instrument_function(
ctx: &mut Context,
instructions: &mut elements::Instructions,
) -> Result<(), Error> {
use parity_wasm::elements::Instruction::*;

let mut cursor = 0;
loop {
if cursor >= instructions.elements().len() {
break
}
fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(), Error> {
use Instruction::*;

enum Action {
InstrumentCall { callee_idx: u32, callee_stack_cost: u32 },
Nop,
}
struct InstrumentCall {
offset: usize,
callee: u32,
cost: u32,
}

let action: Action = {
let instruction = &instructions.elements()[cursor];
match instruction {
Call(callee_idx) => {
let callee_stack_cost = ctx.stack_cost(*callee_idx).ok_or_else(|| {
Error(format!("Call to function that out-of-bounds: {}", callee_idx))
})?;

// Instrument only calls to a functions which stack_cost is
// non-zero.
if callee_stack_cost > 0 {
Action::InstrumentCall { callee_idx: *callee_idx, callee_stack_cost }
let calls: Vec<_> = func
.elements()
.iter()
.enumerate()
.filter_map(|(offset, instruction)| {
if let Call(callee) = instruction {
ctx.stack_cost(*callee).and_then(|cost| {
if cost > 0 {
Some(InstrumentCall { callee: *callee, offset, cost })
} else {
Action::Nop
None
}
},
_ => Action::Nop,
})
} else {
None
}
};

match action {
// We need to wrap a `call idx` instruction
// with a code that adjusts stack height counter
// and then restores it.
Action::InstrumentCall { callee_idx, callee_stack_cost } => {
})
.collect();

// The `instrumented_call!` contains the call itself. This is why we need to subtract one.
let len = func.elements().len() + calls.len() * (instrument_call!(0, 0, 0, 0).len() - 1);
let original_instrs = mem::replace(func.elements_mut(), Vec::with_capacity(len));
let new_instrs = func.elements_mut();

let mut calls = calls.into_iter().peekable();
for (original_pos, instr) in original_instrs.into_iter().enumerate() {
// whether there is some call instruction at this position that needs to be instrumented
let did_instrument = if let Some(call) = calls.peek() {
if call.offset == original_pos {
let new_seq = instrument_call!(
callee_idx,
callee_stack_cost as i32,
call.callee,
call.cost as i32,
ctx.stack_height_global_idx(),
ctx.stack_limit()
);
new_instrs.extend(new_seq);
true
} else {
false
}
} else {
false
};

// Replace the original `call idx` instruction with
// a wrapped call sequence.
//
// To splice actually take a place, we need to consume iterator
// splice returns. So we just `count()` it.
let _ = instructions
.elements_mut()
.splice(cursor..(cursor + 1), new_seq.iter().cloned())
.count();

// Advance cursor to be after the inserted sequence.
cursor += new_seq.len();
},
// Do nothing for other instructions.
_ => {
cursor += 1;
},
if did_instrument {
calls.next();
} else {
new_instrs.push(instr);
}
}

if calls.next().is_some() {
return Err(Error("Not all calls were used".into()))
}

Ok(())
}

Expand Down

0 comments on commit 2293760

Please sign in to comment.