From 25e5b89f97d652dc1bb497105988ba4adfecdd18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Thei=C3=9Fen?= Date: Tue, 7 Sep 2021 02:00:22 +0200 Subject: [PATCH] Use linear time algorithm to inject stack height metering --- src/stack_height/mod.rs | 113 +++++++++++++++++++--------------------- 1 file changed, 54 insertions(+), 59 deletions(-) diff --git a/src/stack_height/mod.rs b/src/stack_height/mod.rs index d855cba..03c9da0 100644 --- a/src/stack_height/mod.rs +++ b/src/stack_height/mod.rs @@ -49,11 +49,11 @@ //! between the frames. //! - upon entry into the function entire stack frame is allocated. -use crate::std::{string::String, vec::Vec}; +use crate::std::{mem, string::String, vec::Vec}; use parity_wasm::{ builder, - elements::{self, Type}, + elements::{self, Instruction, Instructions, Type}, }; /// Macro to generate preamble and postamble. @@ -145,7 +145,7 @@ fn generate_stack_height_global(module: &mut elements::Module) -> u32 { .value_type() .i32() .mutable() - .init_expr(elements::Instruction::I32Const(0)) + .init_expr(Instruction::I32Const(0)) .build(); // Try to find an existing global section. @@ -253,75 +253,70 @@ fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Res /// /// drop /// ``` -fn instrument_function( - ctx: &mut Context, - instructions: &mut elements::Instructions, -) -> Result<(), Error> { - use parity_wasm::elements::Instruction::*; - - let mut cursor = 0; - loop { - if cursor >= instructions.elements().len() { - break - } +fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(), Error> { + use Instruction::*; - enum Action { - InstrumentCall { callee_idx: u32, callee_stack_cost: u32 }, - Nop, - } + struct InstrumentCall { + offset: usize, + callee: u32, + cost: u32, + } - let action: Action = { - let instruction = &instructions.elements()[cursor]; - match instruction { - Call(callee_idx) => { - let callee_stack_cost = ctx.stack_cost(*callee_idx).ok_or_else(|| { - Error(format!("Call to function that out-of-bounds: {}", callee_idx)) - })?; - - // Instrument only calls to a functions which stack_cost is - // non-zero. - if callee_stack_cost > 0 { - Action::InstrumentCall { callee_idx: *callee_idx, callee_stack_cost } + let calls: Vec<_> = func + .elements() + .iter() + .enumerate() + .filter_map(|(offset, instruction)| { + if let Call(callee) = instruction { + ctx.stack_cost(*callee).and_then(|cost| { + if cost > 0 { + Some(InstrumentCall { callee: *callee, offset, cost }) } else { - Action::Nop + None } - }, - _ => Action::Nop, + }) + } else { + None } - }; - - match action { - // We need to wrap a `call idx` instruction - // with a code that adjusts stack height counter - // and then restores it. - Action::InstrumentCall { callee_idx, callee_stack_cost } => { + }) + .collect(); + + // The `instrumented_call!` contains the call itself. This is why we need to subtract one. + let len = func.elements().len() + calls.len() * (instrument_call!(0, 0, 0, 0).len() - 1); + let original_instrs = mem::replace(func.elements_mut(), Vec::with_capacity(len)); + let new_instrs = func.elements_mut(); + + let mut calls = calls.into_iter().peekable(); + for (original_pos, instr) in original_instrs.into_iter().enumerate() { + // whether there is some call instruction at this position that needs to be instrumented + let did_instrument = if let Some(call) = calls.peek() { + if call.offset == original_pos { let new_seq = instrument_call!( - callee_idx, - callee_stack_cost as i32, + call.callee, + call.cost as i32, ctx.stack_height_global_idx(), ctx.stack_limit() ); + new_instrs.extend(new_seq); + true + } else { + false + } + } else { + false + }; - // Replace the original `call idx` instruction with - // a wrapped call sequence. - // - // To splice actually take a place, we need to consume iterator - // splice returns. So we just `count()` it. - let _ = instructions - .elements_mut() - .splice(cursor..(cursor + 1), new_seq.iter().cloned()) - .count(); - - // Advance cursor to be after the inserted sequence. - cursor += new_seq.len(); - }, - // Do nothing for other instructions. - _ => { - cursor += 1; - }, + if did_instrument { + calls.next(); + } else { + new_instrs.push(instr); } } + if calls.next().is_some() { + return Err(Error("Not all calls were used".into())) + } + Ok(()) }