diff --git a/src/agent/coverage/examples/block_coverage.rs b/src/agent/coverage/examples/block_coverage.rs index b490a3d3a2..7912f2e16d 100644 --- a/src/agent/coverage/examples/block_coverage.rs +++ b/src/agent/coverage/examples/block_coverage.rs @@ -14,6 +14,9 @@ struct Opt { #[structopt(min_values = 1)] cmd: Vec, + + #[structopt(short, long, default_value = "5")] + timeout: u64, } impl Opt { @@ -40,7 +43,8 @@ fn main() -> Result<()> { let mut cmd = Command::new(&opt.cmd[0]); cmd.args(&opt.cmd[1..]); - let coverage = coverage::block::windows::record(cmd, filter)?; + let timeout = std::time::Duration::from_secs(opt.timeout); + let coverage = coverage::block::windows::record(cmd, filter, timeout)?; for (module, cov) in coverage.iter() { let total = cov.blocks.len(); @@ -63,7 +67,7 @@ fn main() -> Result<()> { let filter = opt.load_filter_or_default()?; let mut cmd = Command::new(&opt.cmd[0]); - cmd.stdin(Stdio::null()).args(&opt.cmd[1..]); + cmd.stdin(std::process::Stdio::null()).args(&opt.cmd[1..]); let mut cache = ModuleCache::default(); let mut recorder = Recorder::new(&mut cache, filter); diff --git a/src/agent/coverage/src/block/windows.rs b/src/agent/coverage/src/block/windows.rs index e97b880bd4..c11461f575 100644 --- a/src/agent/coverage/src/block/windows.rs +++ b/src/agent/coverage/src/block/windows.rs @@ -6,19 +6,15 @@ use std::process::Command; use std::time::{Duration, Instant}; use anyhow::Result; -use debugger::{ - debugger::{BreakpointId, BreakpointType, DebugEventHandler, Debugger}, - target::Module, -}; +use debugger::{BreakpointId, BreakpointType, DebugEventHandler, Debugger, ModuleLoadInfo}; use crate::block::CommandBlockCov; use crate::cache::ModuleCache; use crate::code::{CmdFilter, ModulePath}; -pub fn record(cmd: Command, filter: CmdFilter) -> Result { +pub fn record(cmd: Command, filter: CmdFilter, timeout: Duration) -> Result { let mut cache = ModuleCache::default(); let mut recorder = Recorder::new(&mut cache, filter); - let timeout = Duration::from_secs(5); let mut handler = RecorderEventHandler::new(&mut recorder, timeout); handler.run(cmd)?; Ok(recorder.into_coverage()) @@ -107,10 +103,11 @@ impl<'c> Recorder<'c> { self.coverage } - pub fn on_create_process(&mut self, dbg: &mut Debugger, module: &Module) -> Result<()> { + pub fn on_create_process(&mut self, dbg: &mut Debugger, module: &ModuleLoadInfo) -> Result<()> { log::debug!("process created: {}", module.path().display()); - if let Err(err) = dbg.target().sym_initialize() { + // TODO: we should avoid loading symbols if the module is in the cache. + if let Err(err) = dbg.target().maybe_sym_initialize() { log::error!( "unable to initialize symbol handler for new process {}: {:?}", module.path().display(), @@ -121,9 +118,11 @@ impl<'c> Recorder<'c> { self.insert_module(dbg, module) } - pub fn on_load_dll(&mut self, dbg: &mut Debugger, module: &Module) -> Result<()> { + pub fn on_load_dll(&mut self, dbg: &mut Debugger, module: &ModuleLoadInfo) -> Result<()> { log::debug!("DLL loaded: {}", module.path().display()); + // TODO: we should load symbols if the module is not in the cache (see on_create_process). + self.insert_module(dbg, module) } @@ -163,7 +162,7 @@ impl<'c> Recorder<'c> { Ok(()) } - fn insert_module(&mut self, dbg: &mut Debugger, module: &Module) -> Result<()> { + fn insert_module(&mut self, dbg: &mut Debugger, module: &ModuleLoadInfo) -> Result<()> { let path = ModulePath::new(module.path().to_owned())?; if !self.filter.includes_module(&path) { @@ -197,13 +196,13 @@ impl<'c> Recorder<'c> { } impl<'r, 'c> DebugEventHandler for RecorderEventHandler<'r, 'c> { - fn on_create_process(&mut self, dbg: &mut Debugger, module: &Module) { + fn on_create_process(&mut self, dbg: &mut Debugger, module: &ModuleLoadInfo) { if self.recorder.on_create_process(dbg, module).is_err() { self.stop(dbg); } } - fn on_load_dll(&mut self, dbg: &mut Debugger, module: &Module) { + fn on_load_dll(&mut self, dbg: &mut Debugger, module: &ModuleLoadInfo) { if self.recorder.on_load_dll(dbg, module).is_err() { self.stop(dbg); } @@ -244,17 +243,18 @@ impl Breakpoints { pub fn set( &mut self, dbg: &mut Debugger, - module: &Module, + module: &ModuleLoadInfo, offsets: impl Iterator, ) -> Result<()> { - // From the `target::Module`, create and save a `ModulePath`. + // From the `debugger::ModuleLoadInfo`, create and save a `ModulePath`. let module_path = ModulePath::new(module.path().to_owned())?; let module_index = self.modules.len(); self.modules.push(module_path); for offset in offsets { // Register the breakpoint in the running target address space. - let id = dbg.register_breakpoint(module.name(), offset as u64, BreakpointType::OneTime); + let id = + dbg.new_rva_breakpoint(module.name(), offset as u64, BreakpointType::OneTime)?; // Associate the opaque `BreakpointId` with the module and offset. self.registered.insert(id, (module_index, offset)); diff --git a/src/agent/debugger/src/breakpoint.rs b/src/agent/debugger/src/breakpoint.rs new file mode 100644 index 0000000000..7f1eae25df --- /dev/null +++ b/src/agent/debugger/src/breakpoint.rs @@ -0,0 +1,290 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::{ + collections::{btree_map::Range, BTreeMap}, + ops::RangeBounds, +}; + +use anyhow::Result; +use win_util::process; +use winapi::um::winnt::HANDLE; + +use crate::debugger::{BreakpointId, BreakpointType}; + +pub(crate) enum ExtraInfo { + Rva(u64), + Function(String), +} + +pub(crate) struct UnresolvedBreakpoint { + id: BreakpointId, + kind: BreakpointType, + module: String, + extra_info: ExtraInfo, +} + +impl UnresolvedBreakpoint { + pub(crate) fn from_symbol( + id: BreakpointId, + kind: BreakpointType, + module: impl ToString, + function: impl ToString, + ) -> Self { + UnresolvedBreakpoint { + id, + kind, + module: module.to_string(), + extra_info: ExtraInfo::Function(function.to_string()), + } + } + + pub(crate) fn from_rva( + id: BreakpointId, + kind: BreakpointType, + module: impl ToString, + rva: u64, + ) -> Self { + UnresolvedBreakpoint { + id, + kind, + module: module.to_string(), + extra_info: ExtraInfo::Rva(rva), + } + } + + pub(crate) fn id(&self) -> BreakpointId { + self.id + } + + pub(crate) fn kind(&self) -> BreakpointType { + self.kind + } + + pub(crate) fn module(&self) -> &str { + &self.module + } + + pub(crate) fn extra_info(&self) -> &ExtraInfo { + &self.extra_info + } +} + +pub struct ResolvedBreakpoint { + id: BreakpointId, + kind: BreakpointType, + + // We use a counter to handle multiple threads hitting the breakpoint at the same time. + // Each thread will increment the disable count and the breakpoint won't be restored + // until an equivalent number of threads enable the breakpoint. + disabled: u32, + hit_count: u64, + address: u64, + original_byte: Option, +} + +impl ResolvedBreakpoint { + pub fn new(id: BreakpointId, kind: BreakpointType, address: u64) -> Self { + ResolvedBreakpoint { + id, + kind, + disabled: 0, + hit_count: 0, + address, + original_byte: None, + } + } + + pub fn id(&self) -> BreakpointId { + self.id + } + + pub fn kind(&self) -> BreakpointType { + self.kind + } + + #[allow(unused)] + pub fn is_enabled(&self) -> bool { + !self.is_disabled() + } + + pub fn is_disabled(&self) -> bool { + self.disabled > 0 + } + + fn is_applied(&self) -> bool { + self.original_byte.is_some() + } + + pub(crate) fn disable(&mut self, process_handle: HANDLE) -> Result<()> { + self.disabled = self.disabled.saturating_add(1); + + if self.is_disabled() { + if let Some(original_byte) = self.original_byte.take() { + write_instruction_byte(process_handle, self.address, original_byte)?; + } + } + + Ok(()) + } + + pub fn enable(&mut self, process_handle: HANDLE) -> Result<()> { + self.disabled = self.disabled.saturating_sub(1); + + let new_original_byte = process::read_memory(process_handle, self.address as _)?; + self.original_byte = Some(new_original_byte); + write_instruction_byte(process_handle, self.address, 0xcc)?; + + Ok(()) + } + + #[allow(unused)] + pub fn hit_count(&self) -> u64 { + self.hit_count + } + + pub fn increment_hit_count(&mut self) { + self.hit_count = self.hit_count.saturating_add(1); + } + + pub(crate) fn get_original_byte(&self) -> Option { + self.original_byte + } + + fn set_original_byte(&mut self, byte: u8) { + self.original_byte = Some(byte); + } +} + +pub(crate) struct BreakpointCollection { + breakpoints: BTreeMap, + min_breakpoint_addr: u64, + max_breakpoint_addr: u64, +} + +impl BreakpointCollection { + pub fn new() -> Self { + BreakpointCollection { + breakpoints: BTreeMap::default(), + min_breakpoint_addr: u64::MAX, + max_breakpoint_addr: u64::MIN, + } + } + + pub fn insert( + &mut self, + address: u64, + breakpoint: ResolvedBreakpoint, + ) -> Option { + self.min_breakpoint_addr = std::cmp::min(self.min_breakpoint_addr, address); + self.max_breakpoint_addr = std::cmp::max(self.max_breakpoint_addr, address); + self.breakpoints.insert(address, breakpoint) + } + + pub fn contains_key(&self, address: u64) -> bool { + self.breakpoints.contains_key(&address) + } + + pub fn get_mut(&mut self, address: u64) -> Option<&mut ResolvedBreakpoint> { + self.breakpoints.get_mut(&address) + } + + pub fn breakpoints_for_range( + &self, + range: impl RangeBounds, + ) -> Range { + self.breakpoints.range(range) + } + + #[allow(unused)] + pub fn remove_all(&mut self, process_handle: HANDLE) -> Result<()> { + for (address, breakpoint) in self.breakpoints.iter() { + if let Some(original_byte) = breakpoint.get_original_byte() { + write_instruction_byte(process_handle, *address, original_byte)?; + } + } + + Ok(()) + } + + #[allow(unused)] + pub fn bulk_remove_all(&mut self, process_handle: HANDLE) -> Result<()> { + if self.breakpoints.is_empty() { + return Ok(()); + } + + let mut buffer = self.bulk_read_process_memory(process_handle)?; + + for (address, breakpoint) in self.breakpoints.iter() { + if let Some(original_byte) = breakpoint.get_original_byte() { + let idx = (*address - self.min_breakpoint_addr) as usize; + buffer[idx] = original_byte; + } + } + + self.bulk_write_process_memory(process_handle, &buffer) + } + + pub fn apply_all(&mut self, process_handle: HANDLE) -> Result<()> { + // No module, so we can't use the trick of reading and writing + // a single large range of memory. + for (address, breakpoint) in self.breakpoints.iter_mut() { + if !breakpoint.is_applied() { + let original_byte = process::read_memory(process_handle, *address as _)?; + breakpoint.set_original_byte(original_byte); + write_instruction_byte(process_handle, *address, 0xcc)?; + } + } + + Ok(()) + } + + pub fn bulk_apply_all(&mut self, process_handle: HANDLE) -> Result<()> { + if self.breakpoints.is_empty() { + return Ok(()); + } + + let mut buffer = self.bulk_read_process_memory(process_handle)?; + + for (address, breakpoint) in self.breakpoints.iter_mut() { + if !breakpoint.is_applied() { + let idx = (*address - self.min_breakpoint_addr) as usize; + breakpoint.set_original_byte(buffer[idx]); + buffer[idx] = 0xcc; + } + } + + self.bulk_write_process_memory(process_handle, &buffer) + } + + fn bulk_region_size(&self) -> usize { + (self.max_breakpoint_addr - self.min_breakpoint_addr + 1) as usize + } + + fn bulk_read_process_memory(&self, process_handle: HANDLE) -> Result> { + let mut buffer: Vec = Vec::with_capacity(self.bulk_region_size()); + unsafe { + buffer.set_len(self.bulk_region_size()); + } + process::read_memory_array(process_handle, self.min_breakpoint_addr as _, &mut buffer)?; + Ok(buffer) + } + + fn bulk_write_process_memory(&self, process_handle: HANDLE, buffer: &[u8]) -> Result<()> { + process::write_memory_slice(process_handle, self.min_breakpoint_addr as _, &buffer)?; + process::flush_instruction_cache( + process_handle, + self.min_breakpoint_addr as _, + self.bulk_region_size(), + )?; + Ok(()) + } +} + +fn write_instruction_byte(process_handle: HANDLE, ip: u64, b: u8) -> Result<()> { + let orig_byte = [b; 1]; + let remote_address = ip as _; + process::write_memory_slice(process_handle, remote_address, &orig_byte)?; + process::flush_instruction_cache(process_handle, remote_address, orig_byte.len())?; + Ok(()) +} diff --git a/src/agent/debugger/src/dbghelp.rs b/src/agent/debugger/src/dbghelp.rs index beb6affd93..ff5ad14c09 100644 --- a/src/agent/debugger/src/dbghelp.rs +++ b/src/agent/debugger/src/dbghelp.rs @@ -14,6 +14,7 @@ use std::{ cmp, ffi::{OsStr, OsString}, mem::{size_of, MaybeUninit}, + num::NonZeroU64, path::{Path, PathBuf}, sync::Once, }; @@ -417,9 +418,9 @@ impl ModuleInfo { #[derive(Clone, Debug, Hash, PartialEq)] pub struct SymInfo { - symbol: String, - address: u64, - displacement: u64, + pub symbol: String, + pub address: u64, + pub displacement: u64, } impl SymInfo { @@ -523,6 +524,15 @@ impl DebugHelpGuard { } } + pub fn get_module_base(&self, process_handle: HANDLE, addr: DWORD64) -> Result { + if let Some(base) = NonZeroU64::new(unsafe { SymGetModuleBase64(process_handle, addr) }) { + Ok(base) + } else { + let last_error = std::io::Error::last_os_error(); + Err(last_error.into()) + } + } + pub fn stackwalk_ex bool>( &self, process_handle: HANDLE, diff --git a/src/agent/debugger/src/debugger.rs b/src/agent/debugger/src/debugger.rs index dd802fd0cc..4df5a40af6 100644 --- a/src/agent/debugger/src/debugger.rs +++ b/src/agent/debugger/src/debugger.rs @@ -9,7 +9,6 @@ #![allow(clippy::redundant_closure)] #![allow(clippy::redundant_clone)] use std::{ - collections::HashMap, ffi::OsString, mem::MaybeUninit, os::windows::process::CommandExt, @@ -18,7 +17,7 @@ use std::{ }; use anyhow::{Context, Result}; -use log::{debug, error, trace}; +use log::{error, trace}; use win_util::{check_winapi, last_os_error, process}; use winapi::{ shared::{ @@ -35,11 +34,11 @@ use winapi::{ }, }; -use crate::target::{Module, Target}; use crate::{ dbghelp::{self, ModuleInfo, SymInfo, SymLineInfo}, debug_event::{DebugEvent, DebugEventInfo}, stack, + target::Target, }; // When debugging a WoW64 process, we see STATUS_WX86_BREAKPOINT in addition to EXCEPTION_BREAKPOINT @@ -49,133 +48,12 @@ const STATUS_WX86_BREAKPOINT: u32 = ::winapi::shared::ntstatus::STATUS_WX86_BREA #[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct BreakpointId(pub u64); -#[derive(Copy, Clone)] -pub(crate) enum StepState { - RemoveBreakpoint { pc: u64, original_byte: u8 }, - SingleStep, -} - #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum BreakpointType { Counter, OneTime, } -pub(crate) struct ModuleBreakpoint { - rva: u64, - kind: BreakpointType, - id: BreakpointId, -} - -impl ModuleBreakpoint { - pub fn new(rva: u64, kind: BreakpointType, id: BreakpointId) -> Self { - ModuleBreakpoint { rva, kind, id } - } - - pub fn rva(&self) -> u64 { - self.rva - } - - pub fn kind(&self) -> BreakpointType { - self.kind - } - - pub fn id(&self) -> BreakpointId { - self.id - } -} - -#[allow(unused)] -struct UnresolvedBreakpoint { - sym: String, - kind: BreakpointType, - id: BreakpointId, -} - -/// A breakpoint for a specific target. We can say it is bound because we know exactly -/// where to set it, but it might disabled. -#[derive(Clone)] -pub struct Breakpoint { - /// The address of the breakpoint. - ip: u64, - - kind: BreakpointType, - - /// Currently active? - enabled: bool, - - /// Holds the original byte at the location. - original_byte: Option, - - hit_count: usize, - - id: BreakpointId, -} - -impl Breakpoint { - pub fn new( - ip: u64, - kind: BreakpointType, - enabled: bool, - original_byte: Option, - hit_count: usize, - id: BreakpointId, - ) -> Self { - Breakpoint { - ip, - kind, - enabled, - original_byte, - hit_count, - id, - } - } - - pub fn ip(&self) -> u64 { - self.ip - } - - pub fn kind(&self) -> BreakpointType { - self.kind - } - - pub(crate) fn set_kind(&mut self, kind: BreakpointType) { - self.kind = kind; - } - - pub fn enabled(&self) -> bool { - self.enabled - } - - pub fn set_enabled(&mut self, enabled: bool) { - self.enabled = enabled; - } - - pub fn original_byte(&self) -> Option { - self.original_byte - } - - pub(crate) fn set_original_byte(&mut self, original_byte: Option) { - self.original_byte = original_byte; - } - - pub fn hit_count(&self) -> usize { - self.hit_count - } - - pub(crate) fn increment_hit_count(&mut self) { - self.hit_count += 1; - } - - pub fn id(&self) -> BreakpointId { - self.id - } - - pub(crate) fn set_id(&mut self, id: BreakpointId) { - self.id = id; - } -} - pub struct StackFrame { return_address: u64, stack_pointer: u64, @@ -198,6 +76,33 @@ impl StackFrame { } } +pub struct ModuleLoadInfo { + path: PathBuf, + base_address: u64, +} + +impl ModuleLoadInfo { + pub fn new(path: impl AsRef, base_address: u64) -> Self { + ModuleLoadInfo { + path: path.as_ref().into(), + base_address, + } + } + + pub fn path(&self) -> &Path { + &self.path + } + + pub fn base_address(&self) -> u64 { + self.base_address + } + + pub fn name(&self) -> &Path { + // Unwrap guaranteed by construction, we always have a filename. + self.path.file_stem().unwrap().as_ref() + } +} + #[rustfmt::skip] #[allow(clippy::trivially_copy_pass_by_ref)] pub trait DebugEventHandler { @@ -205,11 +110,11 @@ pub trait DebugEventHandler { // Continue normal exception handling processing DBG_EXCEPTION_NOT_HANDLED } - fn on_create_process(&mut self, _debugger: &mut Debugger, _module: &Module) {} + fn on_create_process(&mut self, _debugger: &mut Debugger, _module: &ModuleLoadInfo) {} fn on_create_thread(&mut self, _debugger: &mut Debugger) {} fn on_exit_process(&mut self, _debugger: &mut Debugger, _exit_code: u32) {} fn on_exit_thread(&mut self, _debugger: &mut Debugger, _exit_code: u32) {} - fn on_load_dll(&mut self, _debugger: &mut Debugger, _module: &Module) {} + fn on_load_dll(&mut self, _debugger: &mut Debugger, _module: &ModuleLoadInfo) {} fn on_unload_dll(&mut self, _debugger: &mut Debugger, _base_address: u64) {} fn on_output_debug_string(&mut self, _debugger: &mut Debugger, _message: String) {} fn on_output_debug_os_string(&mut self, _debugger: &mut Debugger, _message: OsString) {} @@ -228,8 +133,6 @@ struct ContinueDebugEventArguments { pub struct Debugger { target: Target, continue_args: Option, - registered_breakpoints: HashMap>, - symbolic_breakpoints: HashMap>, breakpoint_count: u64, } @@ -276,8 +179,6 @@ impl Debugger { let mut debugger = Debugger { target, continue_args: None, - registered_breakpoints: HashMap::default(), - symbolic_breakpoints: HashMap::default(), breakpoint_count: 0, }; callbacks.on_create_process(&mut debugger, &module); @@ -303,7 +204,7 @@ impl Debugger { id } - pub fn register_symbolic_breakpoint( + pub fn new_symbolic_breakpoint( &mut self, sym: &str, kind: BreakpointType, @@ -313,57 +214,28 @@ impl Debugger { } else { anyhow::bail!("no module name specified for breakpoint {}", sym); }; - let id = self.next_breakpoint_id(); - - if self.target.saw_initial_bp() { - self.target - .set_symbolic_breakpoint(module, func, kind, id)?; - } else { - // Defer setting the breakpoint until seeing the initial breakpoint. - let values = self - .symbolic_breakpoints - .entry(module.into()) - .or_insert_with(|| Vec::new()); - - values.push(UnresolvedBreakpoint { - kind, - id, - sym: func.into(), - }); - } - - Ok(id) + self.target.new_symbolic_breakpoint(id, module, func, kind) } - pub fn register_breakpoint( + pub fn new_rva_breakpoint( &mut self, module: &Path, rva: u64, kind: BreakpointType, - ) -> BreakpointId { + ) -> Result { let id = self.next_breakpoint_id(); - - let module_breakpoints = self - .registered_breakpoints - .entry(module.into()) - .or_insert_with(|| vec![]); - - module_breakpoints.push(ModuleBreakpoint::new(rva, kind, id)); - id + let module = format!("{}", module.display()); + self.target.new_rva_breakpoint(id, module, rva, kind) } - pub fn register_absolute_breakpoint( + pub fn new_address_breakpoint( &mut self, address: u64, kind: BreakpointType, ) -> Result { let id = self.next_breakpoint_id(); - - self.target.apply_absolute_breakpoint(address, kind, id)?; - // TODO: find module the address belongs to and add to registered_breakpoints - - Ok(id) + self.target.new_absolute_breakpoint(id, address, kind) } /// Return true if an event was process, false if timing out, or an error. @@ -460,14 +332,6 @@ impl Debugger { match self.target.load_module(info.hFile, info.lpBaseOfDll as u64) { Ok(Some(module)) => { callbacks.on_load_dll(self, &module); - - // We must defer adding any breakpoints until we've seen the initial - // breakpoint notification from the OS. Otherwise we may set - // breakpoints in startup code before the debugger is properly - // initialized. - if self.target.saw_initial_bp() { - self.apply_module_breakpoints(module.name(), module.base_address()) - } } Ok(None) => {} Err(e) => { @@ -548,25 +412,14 @@ impl Debugger { match is_debugger_notification( info.ExceptionRecord.ExceptionCode, info.ExceptionRecord.ExceptionAddress as u64, - &self.target, + &mut self.target, ) { Some(DebuggerNotification::InitialBreak) => { - let modules = { - self.target.set_saw_initial_bp(); - let load_symbols = !self.symbolic_breakpoints.is_empty(); - self.target.initial_bp(load_symbols)?; - self.target - .modules() - .map(|(addr, module)| (*addr, (*module).name().to_owned())) - .collect::>() - }; - for (base_address, module) in modules { - self.apply_module_breakpoints(module, base_address) - } + self.target.initial_bp()?; Ok(DBG_CONTINUE) } Some(DebuggerNotification::InitialWow64Break) => { - self.target.set_saw_initial_wow64_bp(); + self.target.initial_wow64_bp(); Ok(DBG_CONTINUE) } Some(DebuggerNotification::Clr) => Ok(DBG_CONTINUE), @@ -577,7 +430,7 @@ impl Debugger { Ok(DBG_CONTINUE) } Some(DebuggerNotification::SingleStep { thread_id }) => { - self.target.complete_single_step(thread_id); + self.target.complete_single_step(thread_id)?; Ok(DBG_CONTINUE) } None => { @@ -587,60 +440,6 @@ impl Debugger { } } - fn apply_module_breakpoints(&mut self, module_name: impl AsRef, base_address: u64) { - // We remove because we only need to resolve the RVA once even if the dll is loaded - // multiple times (e.g. in the same process via LoadLibrary/FreeLibrary) or if the - // same dll is loaded in different processes. - if let Some(unresolved_breakpoints) = self.symbolic_breakpoints.remove(module_name.as_ref()) - { - let cloned_module_name = PathBuf::from(module_name.as_ref()).to_owned(); - let rva_breakpoints = self - .registered_breakpoints - .entry(cloned_module_name) - .or_insert_with(|| Vec::new()); - - match dbghelp::lock() { - Ok(dbghelp) => { - for bp in unresolved_breakpoints { - match dbghelp.sym_from_name( - self.target.process_handle(), - module_name.as_ref(), - &bp.sym, - ) { - Ok(sym) => { - rva_breakpoints.push(ModuleBreakpoint::new( - sym.address() - base_address, - bp.kind, - bp.id, - )); - } - Err(e) => { - debug!( - "Can't set symbolic breakpoint {}!{}: {}", - module_name.as_ref().display(), - bp.sym, - e - ); - } - } - } - } - Err(e) => { - error!("Can't set symbolic breakpoints: {}", e); - } - } - } - - if let Some(breakpoints) = self.registered_breakpoints.get(module_name.as_ref()) { - if let Err(e) = self - .target - .apply_module_breakpoints(base_address, breakpoints) - { - error!("Error applying breakpoints: {}", e); - } - } - } - pub fn get_current_stack(&mut self) -> Result { // If we fail to initialize symbols, we'll skip collecting symbols // when walking the stack. Note that if we see multiple exceptions @@ -648,7 +447,7 @@ impl Debugger { // We could retry in a loop (apparently it can fail but later // succeed), but symbols aren't strictly necessary, so we won't // be too aggressive in dealing with failures. - let resolve_symbols = self.target.sym_initialize().is_ok(); + let resolve_symbols = self.target.maybe_sym_initialize().is_ok(); return stack::get_stack( self.target.process_handle(), self.target.current_thread_handle(), @@ -687,13 +486,8 @@ impl Debugger { self.target.read_flags_register() } - pub fn get_current_target_memory( - &self, - remote_address: LPCVOID, - buf: &mut [T], - ) -> Result<()> { - process::read_memory_array(self.target.process_handle(), remote_address, buf)?; - Ok(()) + pub fn read_memory(&mut self, remote_address: LPCVOID, buf: &mut [impl Copy]) -> Result<()> { + self.target.read_memory(remote_address, buf) } pub fn get_current_frame(&self) -> Result { @@ -734,7 +528,7 @@ enum DebuggerNotification { fn is_debugger_notification( exception_code: u32, exception_address: u64, - target: &Target, + target: &mut Target, ) -> Option { // The CLR debugger notification exception is not a crash: // https://github.com/dotnet/coreclr/blob/9ee6b8a33741cc5f3eb82e990646dd3a81de996a/src/debug/inc/dbgipcevents.h#L37 diff --git a/src/agent/debugger/src/lib.rs b/src/agent/debugger/src/lib.rs index 4a440cb4b7..8ed5290eb2 100644 --- a/src/agent/debugger/src/lib.rs +++ b/src/agent/debugger/src/lib.rs @@ -3,8 +3,15 @@ #![cfg(windows)] +mod breakpoint; pub mod dbghelp; -pub mod debug_event; -pub mod debugger; +mod debug_event; +mod debugger; +mod module; pub mod stack; -pub mod target; +mod target; + +pub use self::{ + debug_event::DebugEvent, + debugger::{BreakpointId, BreakpointType, DebugEventHandler, Debugger, ModuleLoadInfo}, +}; diff --git a/src/agent/debugger/src/module.rs b/src/agent/debugger/src/module.rs new file mode 100644 index 0000000000..1af7bebec8 --- /dev/null +++ b/src/agent/debugger/src/module.rs @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::{ + collections::btree_map::Range, + fs, + ops::RangeBounds, + path::{Path, PathBuf}, +}; + +use anyhow::Result; +use log::error; +use win_util::{file, handle::Handle}; +use winapi::um::{ + handleapi::INVALID_HANDLE_VALUE, + winnt::{HANDLE, IMAGE_FILE_MACHINE_AMD64, IMAGE_FILE_MACHINE_I386}, +}; + +use crate::{ + breakpoint::{BreakpointCollection, ResolvedBreakpoint}, + dbghelp, + debugger::{BreakpointId, BreakpointType}, +}; + +pub const UNKNOWN_MODULE_BASE_ADDRESS: u64 = u64::MAX; +pub const UNKNOWN_MODULE_NAME: &str = "*unknown module*"; + +pub struct Module { + path: PathBuf, + file_handle: Handle, + base_address: u64, + image_size: u32, + machine: Machine, + + breakpoints: BreakpointCollection, + + // Track if we need to call SymLoadModule for the dll. + sym_module_loaded: bool, +} + +impl Module { + pub fn new(module_handle: HANDLE, base_address: u64) -> Result { + let path = file::get_path_from_handle(module_handle).unwrap_or_else(|e| { + error!("Error getting path from file handle: {}", e); + "???".into() + }); + + let image_details = get_image_details(&path)?; + + Ok(Module { + path, + file_handle: Handle(module_handle), + base_address, + image_size: image_details.image_size, + machine: image_details.machine, + sym_module_loaded: false, + breakpoints: BreakpointCollection::new(), + }) + } + + pub fn new_fake_module() -> Self { + Module { + path: UNKNOWN_MODULE_NAME.into(), + file_handle: Handle(INVALID_HANDLE_VALUE), + base_address: UNKNOWN_MODULE_BASE_ADDRESS, + image_size: 0, + machine: Machine::Unknown, + breakpoints: BreakpointCollection::new(), + sym_module_loaded: true, + } + } + + pub fn sym_load_module(&mut self, process_handle: HANDLE) -> Result<()> { + if !self.sym_module_loaded { + let dbghelp = dbghelp::lock()?; + + dbghelp.sym_load_module( + process_handle, + self.file_handle.0, + &self.path, + self.base_address, + self.image_size, + )?; + + self.sym_module_loaded = true; + } + + Ok(()) + } + + pub fn path(&self) -> &Path { + &self.path + } + + pub fn base_address(&self) -> u64 { + self.base_address + } + + pub fn machine(&self) -> Machine { + self.machine + } + + #[allow(unused)] + pub fn image_size(&self) -> u32 { + self.image_size + } + + pub fn name(&self) -> &Path { + // Unwrap guaranteed by construction, we always have a filename. + self.path.file_stem().unwrap().as_ref() + } + + pub fn new_breakpoint( + &mut self, + id: BreakpointId, + kind: BreakpointType, + address: u64, + process_handle: HANDLE, + ) -> Result<()> { + let mut breakpoint = ResolvedBreakpoint::new(id, kind, address); + breakpoint.enable(process_handle)?; + self.breakpoints.insert(address, breakpoint); + Ok(()) + } + + #[allow(unused)] + pub fn remove_breakpoints(&mut self, process_handle: HANDLE) -> Result<()> { + if self.base_address == UNKNOWN_MODULE_BASE_ADDRESS { + self.breakpoints.remove_all(process_handle) + } else { + self.breakpoints.bulk_remove_all(process_handle) + } + } + + pub fn apply_breakpoints(&mut self, process_handle: HANDLE) -> Result<()> { + if self.base_address == UNKNOWN_MODULE_BASE_ADDRESS { + self.breakpoints.apply_all(process_handle) + } else { + self.breakpoints.bulk_apply_all(process_handle) + } + } + + pub fn contains_breakpoint(&self, address: u64) -> bool { + self.breakpoints.contains_key(address) + } + + pub fn get_breakpoint_mut(&mut self, address: u64) -> Option<&mut ResolvedBreakpoint> { + self.breakpoints.get_mut(address) + } + + pub fn breakpoints_for_range( + &self, + range: impl RangeBounds, + ) -> Range { + self.breakpoints.breakpoints_for_range(range) + } +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum Machine { + Unknown, + X64, + X86, +} + +struct ImageDetails { + image_size: u32, + machine: Machine, +} + +fn get_image_details(path: &Path) -> Result { + let file = fs::File::open(path)?; + let map = unsafe { memmap2::Mmap::map(&file)? }; + + let header = goblin::pe::header::Header::parse(&map)?; + let image_size = header + .optional_header + .map(|h| h.windows_fields.size_of_image) + .ok_or_else(|| anyhow::anyhow!("Missing optional header in PE image"))?; + + let machine = match header.coff_header.machine { + IMAGE_FILE_MACHINE_AMD64 => Machine::X64, + IMAGE_FILE_MACHINE_I386 => Machine::X86, + _ => Machine::Unknown, + }; + + Ok(ImageDetails { + image_size, + machine, + }) +} diff --git a/src/agent/debugger/src/target.rs b/src/agent/debugger/src/target.rs index 38cc2084f2..7619ff361a 100644 --- a/src/agent/debugger/src/target.rs +++ b/src/agent/debugger/src/target.rs @@ -3,36 +3,36 @@ #![allow(clippy::single_match)] -use std::{ - collections::hash_map, - fs, - path::{Path, PathBuf}, -}; +use std::{num::NonZeroU64, path::Path}; use anyhow::Result; -use log::{error, trace}; +use log::{debug, error, trace}; use rand::{thread_rng, Rng}; -use win_util::{file, handle::Handle, last_os_error, process}; +use win_util::{last_os_error, process}; use winapi::{ - shared::minwindef::{DWORD, LPVOID}, + shared::minwindef::{DWORD, LPCVOID}, um::{ processthreadsapi::{ResumeThread, SuspendThread}, winbase::Wow64SuspendThread, - winnt::{HANDLE, IMAGE_FILE_MACHINE_AMD64, IMAGE_FILE_MACHINE_I386}, + winnt::HANDLE, }, }; use crate::{ + breakpoint::{self, ResolvedBreakpoint, UnresolvedBreakpoint}, dbghelp::{self, FrameContext}, - debugger::{Breakpoint, BreakpointId, BreakpointType, ModuleBreakpoint, StepState}, + debugger::{BreakpointId, BreakpointType, ModuleLoadInfo}, + module::{self, Machine, Module}, }; -struct CodeByteToUpdate { - address: u64, - byte: u8, +#[derive(Copy, Clone)] +pub(crate) enum StepState { + RestoreBreakpointAfterStep { pc: u64 }, + SingleStep, } struct ThreadInfo { + #[allow(unused)] id: u32, handle: HANDLE, suspended: bool, @@ -59,7 +59,6 @@ impl ThreadInfo { Err(last_os_error()) } else { self.suspended = false; - trace!("Resume {:x} - suspend_count: {}", self.id, suspend_count); Ok(()) } } @@ -79,77 +78,16 @@ impl ThreadInfo { Err(last_os_error()) } else { self.suspended = true; - trace!("Suspend {:x} - suspend_count: {}", self.id, suspend_count); Ok(()) } } } -#[derive(Clone)] -pub struct Module { - path: PathBuf, - file_handle: Handle, - base_address: u64, - image_size: u32, - machine: Machine, - - // Track if we need to call SymLoadModule for the dll. - sym_module_loaded: bool, -} - -impl Module { - fn new(module_handle: HANDLE, base_address: u64) -> Result { - let path = file::get_path_from_handle(module_handle).unwrap_or_else(|e| { - error!("Error getting path from file handle: {}", e); - "???".into() - }); - - let image_details = get_image_details(&path)?; - - Ok(Self { - path, - file_handle: Handle(module_handle), - base_address, - image_size: image_details.image_size, - machine: image_details.machine, - sym_module_loaded: false, - }) - } - - fn sym_load_module(&mut self, process_handle: HANDLE) -> Result<()> { - if !self.sym_module_loaded { - let dbghelp = dbghelp::lock()?; - - dbghelp.sym_load_module( - process_handle, - self.file_handle.0, - &self.path, - self.base_address, - self.image_size, - )?; - - self.sym_module_loaded = true; - } - - Ok(()) - } - - pub fn path(&self) -> &Path { - &self.path - } - - pub fn base_address(&self) -> u64 { - self.base_address - } - - pub fn image_size(&self) -> u32 { - self.image_size - } - - pub fn name(&self) -> &Path { - // Unwrap guaranteed by construction, we always have a filename. - self.path.file_stem().unwrap().as_ref() - } +#[derive(Copy, Clone, PartialEq)] +enum SymInitalizeState { + NotInitialized, + InitializeNeeded, + Initialized, } pub struct Target { @@ -163,7 +101,7 @@ pub struct Target { // Track if we need to call SymInitialize for the process and if we need to notify // dbghelp about loaded/unloaded dlls. - sym_initialized: bool, + sym_initialize_state: SymInitalizeState, exited: bool, thread_info: fnv::FnvHashMap, @@ -177,12 +115,16 @@ pub struct Target { // Key is base address (which also happens to be the HANDLE). modules: fnv::FnvHashMap, - breakpoints: fnv::FnvHashMap, + // Breakpoints that are not yet resolved to a virtual address, so either an RVA or symbol. + unresolved_breakpoints: Vec, // Map of thread to stepping state (e.g. breakpoint address to restore breakpoints) single_step: fnv::FnvHashMap, - code_byte_to_update: Option, + // When stepping after hitting a breakpoint, we need to restore the breakpoint. + // We track the address of the breakpoint to restore. 1 is sufficient because we + // can only hit a single breakpoint between calls to where we restore the breakpoint. + restore_breakpoint_pc: Option, } impl Target { @@ -204,14 +146,14 @@ impl Target { saw_initial_bp: false, saw_initial_wow64_bp: false, wow64, - sym_initialized: false, + sym_initialize_state: SymInitalizeState::NotInitialized, exited: false, thread_info: thread_handles, current_context: None, modules: fnv::FnvHashMap::default(), - breakpoints: fnv::FnvHashMap::default(), + unresolved_breakpoints: vec![], single_step: fnv::FnvHashMap::default(), - code_byte_to_update: None, + restore_breakpoint_pc: None, } } @@ -252,78 +194,231 @@ impl Target { self.saw_initial_wow64_bp } - pub fn set_saw_initial_wow64_bp(&mut self) { - self.saw_initial_wow64_bp = true; - } - pub fn saw_initial_bp(&self) -> bool { self.saw_initial_bp } - pub fn set_saw_initial_bp(&mut self) { - self.saw_initial_bp = true; - } - pub fn exited(&self) -> bool { self.exited } - pub fn modules(&self) -> hash_map::Iter { - self.modules.iter() + pub fn initial_bp(&mut self) -> Result<()> { + self.saw_initial_bp = true; + + if !self.unresolved_breakpoints.is_empty() { + self.maybe_sym_initialize()?; + self.try_resolve_all_unresolved_breakpoints(); + for module in self.modules.values_mut() { + module.apply_breakpoints(self.process_handle)?; + } + } + + Ok(()) } - pub fn initial_bp(&mut self, load_symbols: bool) -> Result<()> { - self.saw_initial_bp = true; + pub fn initial_wow64_bp(&mut self) { + self.saw_initial_wow64_bp = true; + } - if load_symbols || !self.breakpoints.is_empty() { - self.sym_initialize()?; + fn try_resolve_all_unresolved_breakpoints(&mut self) { + // borrowck - take ownership from self so we call `try_resolve_unresolved_breakpoint`. + let mut unresolved_breakpoints = std::mem::take(&mut self.unresolved_breakpoints); + unresolved_breakpoints.retain(|bp| match self.try_resolve_unresolved_breakpoint(bp) { + Ok(resolved) => !resolved, + Err(err) => { + debug!("Error resolving breakpoint: {:?}", err); + true + } + }); + assert!(self.unresolved_breakpoints.is_empty()); + self.unresolved_breakpoints = unresolved_breakpoints; + } + + /// Try to resolve a single unresolved breakpoint, returning true if the breakpoint + /// was successfully resolved. + fn try_resolve_unresolved_breakpoint( + &mut self, + breakpoint: &UnresolvedBreakpoint, + ) -> Result { + if !self.saw_initial_bp { + return Ok(false); + } - for (_, module) in self.modules.iter_mut() { - if let Err(e) = module.sym_load_module(self.process_handle) { - error!("Error loading symbols: {}", e); + let process_handle = self.process_handle; // borrowck + let mut resolved = false; + match breakpoint.extra_info() { + breakpoint::ExtraInfo::Rva(rva) => { + if let Some(module) = self.module_from_name_mut(breakpoint.module()) { + let address = module.base_address() + *rva; + module.new_breakpoint( + breakpoint.id(), + breakpoint.kind(), + address, + process_handle, + )?; + resolved = true; + } + } + breakpoint::ExtraInfo::Function(func) => { + if let Some(module) = self.module_from_name_mut(breakpoint.module()) { + match dbghelp::lock() { + Ok(dbghelp) => { + match dbghelp.sym_from_name(process_handle, module.name(), func) { + Ok(sym) => { + module.new_breakpoint( + breakpoint.id(), + breakpoint.kind(), + sym.address(), + process_handle, + )?; + resolved = true; + } + Err(_) => { + debug!("unknown symbol {}!{}", module.name().display(), func); + } + } + } + Err(e) => { + error!("Can't set symbolic breakpoints: {:?}", e); + } + } } } } - Ok(()) + Ok(resolved) + } + + pub fn new_symbolic_breakpoint( + &mut self, + id: BreakpointId, + module: impl AsRef, + func: impl AsRef, + kind: BreakpointType, + ) -> Result { + self.maybe_sym_initialize()?; + let bp = UnresolvedBreakpoint::from_symbol(id, kind, module.as_ref(), func.as_ref()); + if !self.try_resolve_unresolved_breakpoint(&bp)? { + self.unresolved_breakpoints.push(bp); + } + Ok(id) + } + + pub fn new_rva_breakpoint( + &mut self, + id: BreakpointId, + module: impl AsRef, + rva: u64, + kind: BreakpointType, + ) -> Result { + let bp = UnresolvedBreakpoint::from_rva(id, kind, module.as_ref(), rva); + if !self.try_resolve_unresolved_breakpoint(&bp)? { + self.unresolved_breakpoints.push(bp); + } + Ok(id) } - pub fn breakpoint_set_at_addr(&self, address: u64) -> bool { - self.breakpoints.contains_key(&address) + pub fn new_absolute_breakpoint( + &mut self, + id: BreakpointId, + address: u64, + kind: BreakpointType, + ) -> Result { + let process_handle = self.process_handle; // borrowck + self.module_from_address(address) + .new_breakpoint(id, kind, address, process_handle)?; + Ok(id) + } + + fn module_base_from_address(&self, address: u64) -> Result { + let dbghelp = dbghelp::lock().expect("can't lock dbghelp to find module"); + dbghelp.get_module_base(self.process_handle, address) + } + + fn module_from_address(&mut self, address: u64) -> &mut Module { + let module_base = self + .module_base_from_address(address) + .unwrap_or(unsafe { NonZeroU64::new_unchecked(module::UNKNOWN_MODULE_BASE_ADDRESS) }); + + self.modules + .entry(module_base.get()) + .or_insert_with(Module::new_fake_module) + } + + fn module_from_name_mut(&mut self, name: &str) -> Option<&mut Module> { + let name = Path::new(name); + self.modules + .values_mut() + .find(|module| module.name() == name) + } + + fn get_breakpoint_for_address(&mut self, address: u64) -> Option<&mut ResolvedBreakpoint> { + self.module_from_address(address) + .get_breakpoint_mut(address) + } + + pub fn breakpoint_set_at_addr(&mut self, address: u64) -> bool { + self.module_from_address(address) + .contains_breakpoint(address) } pub(crate) fn expecting_single_step(&self, thread_id: DWORD) -> bool { self.single_step.contains_key(&thread_id) } - pub(crate) fn complete_single_step(&mut self, thread_id: DWORD) { + pub(crate) fn complete_single_step(&mut self, thread_id: DWORD) -> Result<()> { + // We now re-enable the breakpoint so that the next time we step, the breakpoint + // will be restored. + if let Some(restore_breakpoint_pc) = self.restore_breakpoint_pc.take() { + let process_handle = self.process_handle; // borrowck + if let Some(breakpoint) = self.get_breakpoint_for_address(restore_breakpoint_pc) { + trace!("Restoring breakpoint at 0x{:x}", restore_breakpoint_pc); + breakpoint.enable(process_handle)?; + } + } + self.single_step.remove(&thread_id); + + Ok(()) } - pub fn sym_initialize(&mut self) -> Result<()> { - if !self.sym_initialized { - let dbghelp = dbghelp::lock()?; - if let Err(e) = dbghelp.sym_initialize(self.process_handle) { - error!("Error in SymInitializeW: {}", e); + pub fn maybe_sym_initialize(&mut self) -> Result<()> { + if self.sym_initialize_state == SymInitalizeState::Initialized { + return Ok(()); + } - if let Err(e) = dbghelp.sym_cleanup(self.process_handle) { - error!("Error in SymCleanup: {}", e); - } + if self.sym_initialize_state == SymInitalizeState::NotInitialized { + self.sym_initialize_state = SymInitalizeState::InitializeNeeded; + } - return Err(e); - } + if self.saw_initial_bp && self.sym_initialize_state == SymInitalizeState::InitializeNeeded { + self.sym_initialize()?; + self.sym_initialize_state = SymInitalizeState::Initialized; + } - for (_, module) in self.modules.iter_mut() { - if let Err(e) = module.sym_load_module(self.process_handle) { - error!( - "Error loading symbols for module {}: {}", - module.path.display(), - e - ); - } + Ok(()) + } + + fn sym_initialize(&mut self) -> Result<()> { + let dbghelp = dbghelp::lock()?; + if let Err(e) = dbghelp.sym_initialize(self.process_handle) { + error!("Error in SymInitializeW: {:?}", e); + + if let Err(e) = dbghelp.sym_cleanup(self.process_handle) { + error!("Error in SymCleanup: {:?}", e); } - self.sym_initialized = true; + return Err(e); + } + + for (_, module) in self.modules.iter_mut() { + if let Err(e) = module.sym_load_module(self.process_handle) { + error!( + "Error loading symbols for module {}: {:?}", + module.path().display(), + e + ); + } } Ok(()) @@ -335,7 +430,7 @@ impl Target { &mut self, module_handle: HANDLE, base_address: u64, - ) -> Result> { + ) -> Result> { let mut module = Module::new(module_handle, base_address)?; trace!( @@ -344,151 +439,32 @@ impl Target { base_address ); - if module.machine == Machine::X64 && process::is_wow64_process(self.process_handle) { + if module.machine() == Machine::X64 && process::is_wow64_process(self.process_handle) { // We ignore native dlls in wow64 processes. return Ok(None); } - if self.sym_initialized { + if self.sym_initialize_state == SymInitalizeState::Initialized { if let Err(e) = module.sym_load_module(self.process_handle) { - error!("Error loading symbols: {}", e); + error!("Error loading symbols: {:?}", e); } } - let base_address = module.base_address; - if let Some(old_value) = self.modules.insert(base_address, module.clone()) { + let module_load_info = ModuleLoadInfo::new(module.path(), base_address); + let base_address = module.base_address(); + if let Some(old_value) = self.modules.insert(base_address, module) { error!( "Existing module {} replace at base_address {}", - old_value.path.display(), + old_value.path().display(), base_address ); } - Ok(Some(module)) + Ok(Some(module_load_info)) } pub fn unload_module(&mut self, base_address: u64) { - // Drop the module and remove any breakpoints. - if let Some(module) = self.modules.remove(&base_address) { - let image_size = module.image_size as u64; - self.breakpoints - .retain(|&ip, _| ip < base_address || ip >= base_address + image_size); - } - } - - pub(crate) fn set_symbolic_breakpoint( - &mut self, - module_name: &str, - func: &str, - kind: BreakpointType, - id: BreakpointId, - ) -> Result<()> { - match dbghelp::lock() { - Ok(dbghelp) => match dbghelp.sym_from_name(self.process_handle, module_name, func) { - Ok(sym) => { - self.apply_absolute_breakpoint(sym.address(), kind, id)?; - } - Err(_) => { - anyhow::bail!("unknown symbol {}!{}", module_name, func); - } - }, - Err(e) => { - error!("Can't set symbolic breakpoints: {}", e); - } - } - - Ok(()) - } - - pub fn apply_absolute_breakpoint( - &mut self, - address: u64, - kind: BreakpointType, - id: BreakpointId, - ) -> Result<()> { - let original_byte: u8 = process::read_memory(self.process_handle, address as LPVOID)?; - - self.breakpoints - .entry(address) - .and_modify(|bp| { - bp.set_kind(kind); - bp.set_enabled(true); - bp.set_original_byte(Some(original_byte)); - bp.set_id(id); - }) - .or_insert_with(|| { - Breakpoint::new( - address, - kind, - /*enabled*/ true, - /*original_byte*/ Some(original_byte), - /*hit_count*/ 0, - id, - ) - }); - - write_instruction_byte(self.process_handle, address, 0xcc)?; - - Ok(()) - } - - pub(crate) fn apply_module_breakpoints( - &mut self, - base_address: u64, - breakpoints: &[ModuleBreakpoint], - ) -> Result<()> { - if breakpoints.is_empty() { - return Ok(()); - } - - // We want to set every breakpoint for the module at once. We'll read the just the - // memory we need to do that, so find the min and max rva to compute how much memory - // to read and update in the remote process. - let (min, max) = breakpoints - .iter() - .fold((u64::max_value(), u64::min_value()), |acc, bp| { - (acc.0.min(bp.rva()), acc.1.max(bp.rva())) - }); - - // Add 1 to include the final byte. - let region_size = (max - min) - .checked_add(1) - .ok_or_else(|| anyhow::anyhow!("overflow in region size trying to set breakpoints"))? - as usize; - let remote_address = base_address.checked_add(min).ok_or_else(|| { - anyhow::anyhow!("overflow in remote address calculation trying to set breakpoints") - })? as LPVOID; - - let mut buffer: Vec = Vec::with_capacity(region_size); - unsafe { - buffer.set_len(region_size); - } - process::read_memory_array(self.process_handle, remote_address, &mut buffer[..])?; - - for mbp in breakpoints { - let ip = base_address + mbp.rva(); - let offset = (mbp.rva() - min) as usize; - - trace!("Setting breakpoint at {:x}", ip); - - let bp = Breakpoint::new( - ip, - mbp.kind(), - /*enabled*/ true, - Some(buffer[offset]), - /*hit_count*/ 0, - mbp.id(), - ); - - buffer[offset] = 0xcc; - - self.breakpoints.insert(ip, bp); - } - - process::write_memory_slice(self.process_handle, remote_address, &buffer[..])?; - process::flush_instruction_cache(self.process_handle, remote_address, region_size)?; - - Ok(()) + self.modules.remove(&base_address); } pub fn prepare_to_resume(&mut self) -> Result<()> { @@ -504,18 +480,9 @@ impl Target { // // To avoid these possible races, when resuming, we only let a single thread go **if** // we're single stepping any threads. - // - // First, if we last stepped because of hitting a breakpoint, we restore the breakpoint - // so that whichever thread is resumed, it can't miss the breakpoint. - if let Some(CodeByteToUpdate { address, byte }) = self.code_byte_to_update.take() { - trace!("Updating breakpoint at 0x{:x}", address); - write_instruction_byte(self.process_handle, address, byte)?; - } if self.single_step.is_empty() { // Resume all threads if we aren't waiting for any threads to single step. - trace!("Resuming all threads"); - for thread_info in self.thread_info.values_mut() { thread_info.resume_thread()?; } @@ -534,17 +501,11 @@ impl Target { thread_info.resume_thread()?; // We may also need to remove a breakpoint. - if let StepState::RemoveBreakpoint { pc, original_byte } = step_state { - trace!("Restoring original byte at 0x{:x}", *pc); - write_instruction_byte(self.process_handle, *pc, *original_byte)?; - + if let StepState::RestoreBreakpointAfterStep { pc } = step_state { // We are stepping to remove the breakpoint. After we've stepped, // we must restore the breakpoint (which is done on the subsequent // call to this function). - self.code_byte_to_update = Some(CodeByteToUpdate { - address: *pc, - byte: 0xcc, - }); + self.restore_breakpoint_pc = Some(*pc); } } @@ -604,37 +565,47 @@ impl Target { Ok(current_context.get_flags()) } + pub fn read_memory(&mut self, remote_address: LPCVOID, buf: &mut [impl Copy]) -> Result<()> { + process::read_memory_array(self.process_handle, remote_address, buf)?; + + // We don't remove breakpoints when processing an event, so it's possible that the + // memory we read contains **our** breakpoints instead of the original code. + let remote_address = remote_address as u64; + let module = self.module_from_address(remote_address); + let range = remote_address..(remote_address + buf.len() as u64); + + let u8_buf = unsafe { + std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut u8, std::mem::size_of_val(buf)) + }; + for (address, breakpoint) in module.breakpoints_for_range(range) { + if let Some(original_byte) = breakpoint.get_original_byte() { + let idx = *address - remote_address; + u8_buf[idx as usize] = original_byte; + } + } + + Ok(()) + } + /// Handle a breakpoint that we set (as opposed to a breakpoint in user code, e.g. /// assertion.) /// /// Return the breakpoint id if it should be reported to the client. pub fn handle_breakpoint(&mut self, pc: u64) -> Result> { + let process_handle = self.process_handle; // borrowck let id; - let original_byte; - let mut restore_breakpoint = true; + let mut renable_after_step = true; { - let bp = self.breakpoints.get_mut(&pc).unwrap(); + let bp = self + .get_breakpoint_for_address(pc) + .expect("debugger should have checked already"); id = bp.id(); bp.increment_hit_count(); - - original_byte = bp.original_byte().unwrap(); + bp.disable(process_handle)?; if let BreakpointType::OneTime = bp.kind() { - // This assertion simplifies `prepare_to_resume` slightly. - // If it fires, we could restore the original code byte here, otherwise we - // would need an additional Option to test in `prepare_to_resume`. - assert!(self.code_byte_to_update.is_none()); - - self.code_byte_to_update = Some(CodeByteToUpdate { - address: bp.ip(), - byte: original_byte, - }); - - bp.set_enabled(false); - bp.set_original_byte(None); - - restore_breakpoint = false; + renable_after_step = false; } } @@ -646,18 +617,18 @@ impl Target { let context = self.get_current_context_mut()?; context.set_program_counter(pc); - if restore_breakpoint { + if renable_after_step { context.set_single_step(true); } context.set_thread_context(current_thread_handle)?; - if restore_breakpoint { + if renable_after_step { // Remember that on the current thread, we need to restore the original byte. // When resuming, if we pick the current thread, we'll remove the breakpoint. self.single_step.insert( self.current_thread_id, - StepState::RemoveBreakpoint { pc, original_byte }, + StepState::RestoreBreakpointAfterStep { pc }, ); } @@ -667,52 +638,10 @@ impl Target { pub fn set_exited(&mut self) -> Result<()> { self.exited = true; - if self.sym_initialized { + if self.sym_initialize_state == SymInitalizeState::Initialized { let dbghelp = dbghelp::lock()?; dbghelp.sym_cleanup(self.process_handle)?; } Ok(()) } } - -#[derive(Copy, Clone, Debug, PartialEq)] -enum Machine { - Unknown, - X64, - X86, -} - -struct ImageDetails { - image_size: u32, - machine: Machine, -} - -fn get_image_details(path: &Path) -> Result { - let file = fs::File::open(path)?; - let map = unsafe { memmap2::Mmap::map(&file)? }; - - let header = goblin::pe::header::Header::parse(&map)?; - let image_size = header - .optional_header - .map(|h| h.windows_fields.size_of_image) - .ok_or_else(|| anyhow::anyhow!("Missing optional header in PE image"))?; - - let machine = match header.coff_header.machine { - IMAGE_FILE_MACHINE_AMD64 => Machine::X64, - IMAGE_FILE_MACHINE_I386 => Machine::X86, - _ => Machine::Unknown, - }; - - Ok(ImageDetails { - image_size, - machine, - }) -} - -fn write_instruction_byte(process_handle: HANDLE, ip: u64, b: u8) -> Result<()> { - let orig_byte = [b; 1]; - let remote_address = ip as LPVOID; - process::write_memory_slice(process_handle, remote_address, &orig_byte)?; - process::flush_instruction_cache(process_handle, remote_address, orig_byte.len())?; - Ok(()) -} diff --git a/src/agent/input-tester/src/crash_detector.rs b/src/agent/input-tester/src/crash_detector.rs index bb4d84ae11..2d315f889e 100644 --- a/src/agent/input-tester/src/crash_detector.rs +++ b/src/agent/input-tester/src/crash_detector.rs @@ -15,10 +15,7 @@ use std::{ use anyhow::Result; use coverage::{block::windows::Recorder as BlockCoverageRecorder, cache::ModuleCache}; -use debugger::{ - debugger::{BreakpointId, DebugEventHandler, Debugger}, - target::Module, -}; +use debugger::{BreakpointId, DebugEventHandler, Debugger, ModuleLoadInfo}; use log::{debug, error, trace}; use win_util::{ pipe_handle::{pipe, PipeReaderNonBlocking}, @@ -270,7 +267,7 @@ impl<'a> DebugEventHandler for CrashDetectorEventHandler<'a> { } } - fn on_create_process(&mut self, dbg: &mut Debugger, module: &Module) { + fn on_create_process(&mut self, dbg: &mut Debugger, module: &ModuleLoadInfo) { if let Some(coverage) = &mut self.coverage { if let Err(err) = coverage.on_create_process(dbg, module) { error!("error recording coverage on create process: {:?}", err); @@ -279,7 +276,7 @@ impl<'a> DebugEventHandler for CrashDetectorEventHandler<'a> { } } - fn on_load_dll(&mut self, dbg: &mut Debugger, module: &Module) { + fn on_load_dll(&mut self, dbg: &mut Debugger, module: &ModuleLoadInfo) { if let Some(coverage) = &mut self.coverage { if let Err(err) = coverage.on_load_dll(dbg, module) { error!("error recording coverage on load DLL: {:?}", err);