diff --git a/devolutions-gateway/src/interceptor/pcap_recording.rs b/devolutions-gateway/src/interceptor/pcap_recording.rs index eed620418..5474c7c86 100644 --- a/devolutions-gateway/src/interceptor/pcap_recording.rs +++ b/devolutions-gateway/src/interceptor/pcap_recording.rs @@ -3,15 +3,22 @@ use crate::plugin_manager::{PacketsParser, Recorder, PLUGIN_MANAGER}; use slog_scope::debug; use std::{ net::SocketAddr, - sync::{Arc, Mutex}, + sync::{Arc, Condvar, Mutex}, }; +#[derive(Debug, Clone, Copy)] +enum State { + Update, + Finish, +} + #[derive(Clone)] pub struct PcapRecordingInterceptor { server_info: Arc>, client_info: Arc>, packets_parser: Arc>>, recorder: Arc>>, + condition_timeout: Arc<(Mutex, Condvar)>, } impl PcapRecordingInterceptor { @@ -28,8 +35,40 @@ impl PcapRecordingInterceptor { client_info: Arc::new(Mutex::new(PeerInfo::new(client_addr))), packets_parser: Arc::new(Mutex::new(PLUGIN_MANAGER.lock().unwrap().get_parsing_packets_plugin())), recorder: Arc::new(Mutex::new(recording_plugin)), + condition_timeout: Arc::new((Mutex::new(State::Update), Condvar::new())), }; + let recorder = interceptor.recorder.clone(); + let condition_timeout = interceptor.condition_timeout.clone(); + std::thread::spawn(move || loop { + let mut timeout: u32 = 0; + + { + if let Some(recorder) = recorder.lock().unwrap().as_ref() { + timeout = recorder.get_timeout(); + } + } + + let (state, cond_var) = &*condition_timeout; + let result = cond_var.wait_timeout(state.lock().unwrap(), std::time::Duration::from_millis(timeout as u64)); + + match result { + Ok((state_result, timeout_result)) => match *state_result { + State::Update => { + if timeout_result.timed_out() { + if let Some(recorder) = recorder.lock().unwrap().as_ref() { + recorder.timeout(); + } + } + } + State::Finish => break, + }, + Err(e) => { + slog_scope::error!("Wait timeout failed with error! {}", e); + } + } + }); + interceptor } @@ -46,15 +85,26 @@ impl PacketInterceptor for PcapRecordingInterceptor { debug!("New packet intercepted. Packet size = {}", data.len()); let server_info = self.server_info.lock().unwrap(); + let is_from_server = source_addr.unwrap() == server_info.addr; + + if is_from_server { + let (state, cond_var) = &*self.condition_timeout.clone(); + let mut pending = state.lock().unwrap(); + *pending = State::Update; + cond_var.notify_one(); + } let option_parser = self.packets_parser.lock().unwrap(); let option_recorder = self.recorder.lock().unwrap(); - let is_from_server = source_addr.unwrap() == server_info.addr; if let Some(parser) = option_parser.as_ref() { let (status, message_id) = parser.parse_message(data, data.len(), is_from_server); + debug!( + "Returned from parse message with status: {} and message_id: {}", + status, message_id + ); - if !parser.is_message_constructed() { + if !parser.is_message_constructed(is_from_server) { return; } else if message_id == PacketsParser::NOW_UPDATE_MSG_ID { let size = parser.get_size(); @@ -78,3 +128,12 @@ impl PacketInterceptor for PcapRecordingInterceptor { Box::new(self.clone()) } } + +impl Drop for PcapRecordingInterceptor { + fn drop(&mut self) { + let (state, cond_var) = &*self.condition_timeout.clone(); + let mut pending = state.lock().unwrap(); + *pending = State::Finish; + cond_var.notify_one(); + } +} diff --git a/devolutions-gateway/src/plugin_manager/packets_parsing.rs b/devolutions-gateway/src/plugin_manager/packets_parsing.rs index 1e0d12b6c..58f2a7f41 100644 --- a/devolutions-gateway/src/plugin_manager/packets_parsing.rs +++ b/devolutions-gateway/src/plugin_manager/packets_parsing.rs @@ -46,7 +46,8 @@ pub struct PacketsParsingApi<'a> { surfaceSize: *mut u32, ) -> *mut u8, >, - NowPacketParser_IsMessageConstructed: Symbol<'a, unsafe extern "C" fn(ctx: NowPacketParser) -> bool>, + NowPacketParser_IsMessageConstructed: + Symbol<'a, unsafe extern "C" fn(ctx: NowPacketParser, isFromServer: bool) -> bool>, NowPacketParser_Free: Symbol<'a, unsafe extern "C" fn(ctx: NowPacketParser)>, } @@ -101,8 +102,8 @@ impl PacketsParser { (res, message_id) } - pub fn is_message_constructed(&self) -> bool { - unsafe { (self.api.NowPacketParser_IsMessageConstructed)(self.ctx) } + pub fn is_message_constructed(&self, is_from_server: bool) -> bool { + unsafe { (self.api.NowPacketParser_IsMessageConstructed)(self.ctx, is_from_server) } } pub fn get_image_data(&self) -> ImageUpdate { diff --git a/devolutions-gateway/src/plugin_manager/recording.rs b/devolutions-gateway/src/plugin_manager/recording.rs index 5d18f31ac..97f0a80dd 100644 --- a/devolutions-gateway/src/plugin_manager/recording.rs +++ b/devolutions-gateway/src/plugin_manager/recording.rs @@ -25,6 +25,9 @@ pub struct RecordingApi<'a> { surfaceStep: *const u32, ), >, + NowRecording_Timeout: Symbol<'a, unsafe extern "C" fn(ctx: RecordingContext)>, + NowRecording_GetTimeout: Symbol<'a, unsafe extern "C" fn(ctx: RecordingContext) -> u32>, + NowRecording_GetPath: Symbol<'a, unsafe extern "C" fn(ctx: RecordingContext, path: *mut c_char)>, NowRecording_Free: Symbol<'a, unsafe extern "C" fn(ctx: RecordingContext)>, } @@ -86,6 +89,25 @@ impl Recorder { } } } + + pub fn timeout(&self) { + unsafe { + (self.api.NowRecording_Timeout)(self.ctx); + } + } + + pub fn get_timeout(&self) -> u32 { + unsafe { (self.api.NowRecording_GetTimeout)(self.ctx) } + } + + pub fn get_filepath(&self) -> String { + let mut path_array = [0i8; 512]; + unsafe { + (self.api.NowRecording_GetPath)(self.ctx, path_array.as_mut_ptr()); + } + return String::from_utf8(path_array.iter().map(|element| *element as u8).collect()) + .map_or("".to_string(), |path| path); + } } impl Drop for Recorder {