From 4f6662e0aa497982c51c68f79919b0fe059e9b19 Mon Sep 17 00:00:00 2001 From: KCarretto Date: Sat, 10 Feb 2024 16:53:59 -0500 Subject: [PATCH] [bug] Fix file download (#567) * fix file download * fix nits * fix tests * fix tests --- implants/imix/src/task.rs | 48 ++--- implants/lib/c2/Cargo.toml | 1 + implants/lib/c2/src/grpc.rs | 39 +++- implants/lib/eldritch/src/assets/copy_impl.rs | 198 +++++++++++------- tavern/internal/c2/c2test/grpc.go | 7 +- 5 files changed, 174 insertions(+), 119 deletions(-) diff --git a/implants/imix/src/task.rs b/implants/imix/src/task.rs index 7b3babe91..55085a5b4 100644 --- a/implants/imix/src/task.rs +++ b/implants/imix/src/task.rs @@ -147,56 +147,32 @@ impl TaskHandle { tavern: &mut impl Transport, req: FileRequest, ) -> Result<()> { - let (ch_file_chunk, file_chunk) = channel::(); + let (tx, rx) = channel::(); + tavern - .download_file(DownloadFileRequest { name: req.name() }, ch_file_chunk) + .download_file(DownloadFileRequest { name: req.name() }, tx) .await?; - let task_id = self.id; - let handle = tokio::task::spawn(async move { - loop { - let resp = match file_chunk.recv() { - Ok(r) => r, - Err(_err) => { - match _err.to_string().as_str() { - "receiving on a closed channel" => {} - _ => { - #[cfg(debug_assertions)] - log::error!( - "failed to download file chunk: task_id={}, name={}: {}", - task_id, - req.name(), - _err - ); - } - } - return; - } - }; - - #[cfg(debug_assertions)] - log::info!( - "downloaded file chunk: task_id={}, name={}, size={}", - task_id, - req.name(), - resp.chunk.len() - ); - - match req.send_chunk(resp.chunk) { + let handle = tokio::task::spawn_blocking(move || { + for r in rx { + match req.send_chunk(r.chunk) { Ok(_) => {} Err(_err) => { #[cfg(debug_assertions)] log::error!( - "failed to send downloaded file chunk: task_id={}, name={}: {}", - task_id, + "failed to send downloaded file chunk: {}: {}", req.name(), _err ); + return; } - }; + } } + #[cfg(debug_assertions)] + log::info!("file download completed: {}", req.name()); }); + self.download_handles.push(handle); Ok(()) } diff --git a/implants/lib/c2/Cargo.toml b/implants/lib/c2/Cargo.toml index bce39588c..d01e0aa86 100644 --- a/implants/lib/c2/Cargo.toml +++ b/implants/lib/c2/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] eldritch = { workspace = true } +log = { workspace = true } tonic = { workspace = true, features = ["tls-roots"] } prost = { workspace = true} prost-types = { workspace = true } diff --git a/implants/lib/c2/src/grpc.rs b/implants/lib/c2/src/grpc.rs index 6b6caa70f..5bb6abee5 100644 --- a/implants/lib/c2/src/grpc.rs +++ b/implants/lib/c2/src/grpc.rs @@ -34,13 +34,44 @@ impl crate::Transport for GRPC { async fn download_file( &mut self, request: crate::pb::DownloadFileRequest, - sender: Sender, + tx: Sender, ) -> Result<()> { + #[cfg(debug_assertions)] + let filename = request.name.clone(); + let resp = self.download_file_impl(request).await?; let mut stream = resp.into_inner(); - while let Some(file_chunk) = stream.message().await? { - sender.send(file_chunk)?; - } + tokio::spawn(async move { + loop { + let msg = match stream.message().await { + Ok(maybe_msg) => match maybe_msg { + Some(msg) => msg, + None => { + break; + } + }, + Err(_err) => { + #[cfg(debug_assertions)] + log::error!("failed to download file: {}: {}", filename, _err); + + return; + } + }; + match tx.send(msg) { + Ok(_) => {} + Err(_err) => { + #[cfg(debug_assertions)] + log::error!( + "failed to send downloaded file chunk: {}: {}", + filename, + _err + ); + + return; + } + } + } + }); Ok(()) } diff --git a/implants/lib/eldritch/src/assets/copy_impl.rs b/implants/lib/eldritch/src/assets/copy_impl.rs index e50beabcc..36c178db9 100644 --- a/implants/lib/eldritch/src/assets/copy_impl.rs +++ b/implants/lib/eldritch/src/assets/copy_impl.rs @@ -1,6 +1,8 @@ use crate::runtime::Client; use anyhow::{Context, Result}; use starlark::{eval::Evaluator, values::list::ListRef}; +use std::fs::OpenOptions; +use std::io::Write; use std::{fs, sync::mpsc::Receiver}; fn copy_local(src: String, dst: String) -> Result<()> { @@ -15,36 +17,42 @@ fn copy_local(src: String, dst: String) -> Result<()> { } } -fn copy_remote(file_reciever: Receiver>, dst: String) -> Result<()> { - loop { - let val = match file_reciever.recv() { - Ok(v) => v, - Err(err) => { - match err.to_string().as_str() { - "channel is empty and sending half is closed" => { - break; - } - "timed out waiting on channel" => { - continue; - } - _ => { - #[cfg(debug_assertions)] - log::debug!("failed to drain channel: {}", err) - } - } - break; - } - }; - match fs::write(dst.clone(), val) { - Ok(_) => {} - Err(local_err) => return Err(local_err.try_into()?), - }; +fn copy_remote(rx: Receiver>, dst_path: String) -> Result<()> { + // Truncate file + let mut dst = OpenOptions::new() + .create(true) + .truncate(true) + .write(true) + .open(&dst_path) + .context(format!( + "failed to truncate destination file: {}", + &dst_path + ))?; + dst.flush() + .context(format!("failed to flush file truncation: {}", &dst_path))?; + + // Reopen file for writing + let mut dst = OpenOptions::new() + .create(true) + .append(true) + .open(&dst_path) + .context(format!("failed to open file for writing: {}", &dst_path))?; + + // Listen for downloaded chunks and write them + for chunk in rx { + dst.write_all(&chunk) + .context(format!("failed to write file chunk: {}", &dst_path))?; } + // Ensure all chunks gets written + dst.flush() + .context(format!("failed to flush file: {}", &dst_path))?; + Ok(()) } -pub fn copy(starlark_eval: &mut Evaluator<'_, '_>, src: String, dst: String) -> Result<()> { +// #[allow(clippy::needless_pass_by_ref_mut)] +pub fn copy(starlark_eval: &Evaluator<'_, '_>, src: String, dst: String) -> Result<()> { let remote_assets = starlark_eval.module().get("remote_assets"); if let Some(assets) = remote_assets { @@ -63,64 +71,98 @@ pub fn copy(starlark_eval: &mut Evaluator<'_, '_>, src: String, dst: String) -> #[cfg(test)] mod tests { + use crate::assets::copy_impl::copy_remote; use crate::Runtime; + use std::sync::mpsc::channel; use std::{collections::HashMap, io::prelude::*}; use tempfile::NamedTempFile; - // fn init_log() { - // pretty_env_logger::formatted_timed_builder() - // .filter_level(log::LevelFilter::Info) - // .parse_env("IMIX_LOG") - // .init(); - // } - - // #[tokio::test] - // async fn test_remote_copy() -> anyhow::Result<()> { - // // Create files - // let mut tmp_file_dst = NamedTempFile::new()?; - // let path_dst = String::from(tmp_file_dst.path().to_str().unwrap()); - - // let (sender, reciver) = channel::>(); - // sender.send("Hello from a remote asset".as_bytes().to_vec())?; - - // copy_remote(reciver, path_dst)?; - - // let mut contents = String::new(); - // tmp_file_dst.read_to_string(&mut contents)?; - // assert!(contents.contains("Hello from a remote asset")); - // Ok(()) - // } - - // #[tokio::test] - // async fn test_remote_copy_full() -> anyhow::Result<()> { - // init_log(); - // log::debug!("Testing123"); - - // // Create files - // let mut tmp_file_dst = NamedTempFile::new()?; - // let path_dst = String::from(tmp_file_dst.path().to_str().unwrap()); - - // let (runtime, broker) = Runtime::new(); - // let handle = tokio::task::spawn_blocking(move || { - // runtime.run(crate::pb::Tome { - // eldritch: r#"assets.copy("test_tome/test_file.txt", input_params['test_output'])"# - // .to_owned(), - // parameters: HashMap::from([("test_output".to_string(), path_dst)]), - // file_names: Vec::from(["test_tome/test_file.txt".to_string()]), - // }) - // }); - // handle.await?; - // println!("{:?}", broker.collect_file_requests().len()); - // assert!(broker.collect_errors().is_empty()); // No errors even though the remote asset is inaccessible - - // let mut contents = String::new(); - // tmp_file_dst.read_to_string(&mut contents)?; - // // Compare - Should be empty basically just didn't error - // assert!(contents.contains("")); - - // Ok(()) - // } + #[tokio::test] + async fn test_remote_copy() -> anyhow::Result<()> { + // Create files + let mut tmp_file_dst = NamedTempFile::new()?; + let path_dst = String::from(tmp_file_dst.path().to_str().unwrap()); + + let (ch_data, data) = channel::>(); + let handle = tokio::task::spawn_blocking(|| { + copy_remote(data, path_dst).expect("copy_remote failed") + }); + + ch_data.send("Hello from a remote asset".as_bytes().to_vec())?; + ch_data.send("Goodbye from a remote asset".as_bytes().to_vec())?; + + // Drop the Sender, to indicate no more data will be sent (channel closed) + drop(ch_data); + + handle.await?; + + let mut contents = String::new(); + tmp_file_dst.read_to_string(&mut contents)?; + assert!(contents.contains("Hello from a remote asset")); + assert!(contents.contains("Goodbye from a remote asset")); + Ok(()) + } + + #[tokio::test] + async fn test_remote_copy_full() -> anyhow::Result<()> { + // Create files + let mut tmp_file_dst = NamedTempFile::new()?; + let path_dst = String::from(tmp_file_dst.path().to_str().unwrap()); + + // Create a runtime + let (runtime, broker) = Runtime::new(); + + // Execute eldritch in it's own thread + let handle = tokio::task::spawn_blocking(move || { + runtime.run(crate::pb::Tome { + eldritch: r#"assets.copy("test_tome/test_file.txt", input_params['test_output'])"# + .to_owned(), + parameters: HashMap::from([("test_output".to_string(), path_dst)]), + file_names: Vec::from(["test_tome/test_file.txt".to_string()]), + }) + }); + + // We now mock the agent, looping until eldritch requests a file + // We omit the sleep performed by the agent, just to save test time + loop { + // The broker only returns the data that is currently available + // So this may return an empty vec if our eldritch tokio task has not yet been scheduled + let mut reqs = broker.collect_file_requests(); + + // If no file request is yet available, just continue looping + if reqs.is_empty() { + continue; + } + + // Ensure the right file was requested + assert!(reqs.len() == 1); + let req = reqs.pop().expect("no file request received!"); + assert!(req.name() == "test_tome/test_file.txt"); + + // Now, we provide the file to eldritch (as a series of chunks) + req.send_chunk("chunk1\n".as_bytes().to_vec()) + .expect("failed to send file chunk to eldritch"); + req.send_chunk("chunk2\n".as_bytes().to_vec()) + .expect("failed to send file chunk to eldritch"); + + // We've finished providing the file, so we stop looping + // This will drop `req`, which consequently drops the underlying `Sender` for the file channel + // This will cause the next `recv()` to error with "channel is empty and sending half is closed" + // which is what tells eldritch that there are no more file chunks to wait for + break; + } + + // Now that we've finished writing data, we wait for eldritch to finish + handle.await?; + + // Lastly, assert the file was written correctly + let mut contents = String::new(); + tmp_file_dst.read_to_string(&mut contents)?; + assert_eq!("chunk1\nchunk2\n", contents.as_str()); + + Ok(()) + } #[test] fn test_embedded_copy() -> anyhow::Result<()> { diff --git a/tavern/internal/c2/c2test/grpc.go b/tavern/internal/c2/c2test/grpc.go index 150d3ebaa..f3ffe6e96 100644 --- a/tavern/internal/c2/c2test/grpc.go +++ b/tavern/internal/c2/c2test/grpc.go @@ -2,6 +2,7 @@ package c2test import ( "context" + "errors" "net" "testing" @@ -34,8 +35,9 @@ func New(t *testing.T) (c2pb.C2Client, *ent.Client, func()) { baseSrv := grpc.NewServer() c2pb.RegisterC2Server(baseSrv, c2.New(graph)) + grpcErrCh := make(chan error, 1) go func() { - require.NoError(t, baseSrv.Serve(lis), "failed to serve grpc") + grpcErrCh <- baseSrv.Serve(lis) }() conn, err := grpc.DialContext( @@ -52,5 +54,8 @@ func New(t *testing.T) (c2pb.C2Client, *ent.Client, func()) { assert.NoError(t, lis.Close()) baseSrv.Stop() assert.NoError(t, graph.Close()) + if err := <-grpcErrCh; err != nil && !errors.Is(err, grpc.ErrServerStopped) { + t.Fatalf("failed to serve grpc") + } } }