diff --git a/src/agent/onefuzz-agent/src/agent.rs b/src/agent/onefuzz-agent/src/agent.rs index d26ecb732a..17059332ae 100644 --- a/src/agent/onefuzz-agent/src/agent.rs +++ b/src/agent/onefuzz-agent/src/agent.rs @@ -12,7 +12,7 @@ use crate::reboot::*; use crate::scheduler::*; use crate::setup::*; use crate::work::IWorkQueue; -use crate::worker::IWorkerRunner; +use crate::worker::{IWorkerRunner, WorkerEvent}; const PENDING_COMMANDS_DELAY: time::Duration = time::Duration::from_secs(10); const BUSY_DELAY: time::Duration = time::Duration::from_secs(1); @@ -62,7 +62,7 @@ impl Agent { } } - pub async fn run(&mut self) -> Result<()> { + pub async fn run(self) -> Result<()> { let mut instant = time::Instant::now(); // Tell the service that the agent has started. @@ -78,42 +78,39 @@ impl Agent { let event = StateUpdateEvent::Init.into(); self.coordinator.emit_event(event).await?; } - - loop { - self.heartbeat.alive(); + let mut state = self; + let mut done = false; + while !done { + state.heartbeat.alive(); if instant.elapsed() >= PENDING_COMMANDS_DELAY { - self.execute_pending_commands().await?; + state = state.execute_pending_commands().await?; instant = time::Instant::now(); } - let done = self.update().await?; - - if done { - debug!("agent done, exiting loop"); - break; - } + (state, done) = state.update().await?; } + info!("agent done, exiting loop"); Ok(()) } - async fn update(&mut self) -> Result { + async fn update(mut self) -> Result<(Self, bool)> { let last = self.scheduler.take().ok_or_else(scheduler_error)?; let previous_state = NodeState::from(&last); let (next, done) = match last { - Scheduler::Free(s) => (self.free(s).await?, false), - Scheduler::SettingUp(s) => (self.setting_up(s).await?, false), - Scheduler::PendingReboot(s) => (self.pending_reboot(s).await?, false), - Scheduler::Ready(s) => (self.ready(s).await?, false), - Scheduler::Busy(s) => (self.busy(s).await?, false), - Scheduler::Done(s) => (self.done(s).await?, true), + Scheduler::Free(s) => (self.free(s, previous_state).await?, false), + Scheduler::SettingUp(s) => (self.setting_up(s, previous_state).await?, false), + Scheduler::PendingReboot(s) => (self.pending_reboot(s, previous_state).await?, false), + Scheduler::Ready(s) => (self.ready(s, previous_state).await?, false), + Scheduler::Busy(s) => (self.busy(s, previous_state).await?, false), + //todo: introduce a new prameter to allow the agent to restart after this point + Scheduler::Done(s) => (self.done(s, previous_state).await?, true), }; - self.previous_state = previous_state; - self.scheduler = Some(next); - Ok(done) + + Ok((next, done)) } - async fn emit_state_update_if_changed(&mut self, event: StateUpdateEvent) -> Result<()> { + async fn emit_state_update_if_changed(&self, event: StateUpdateEvent) -> Result<()> { match (&event, self.previous_state) { (StateUpdateEvent::Free, NodeState::Free) | (StateUpdateEvent::Busy, NodeState::Busy) @@ -129,7 +126,7 @@ impl Agent { Ok(()) } - async fn free(&mut self, state: State) -> Result { + async fn free(mut self, state: State, previous: NodeState) -> Result { self.emit_state_update_if_changed(StateUpdateEvent::Free) .await?; @@ -190,7 +187,7 @@ impl Agent { // Otherwise, the work was not stopped, but we still should not execute it. This is likely // our because agent version is out of date. Do nothing, so another node can see the work. // The service will eventually send us a stop command and reimage our node, if appropriate. - debug!( + info!( "not scheduling active work set, not dropping: {:?}", msg.work_set ); @@ -205,11 +202,15 @@ impl Agent { state.into() }; - Ok(next) + Ok(Self { + previous_state: previous, + scheduler: Some(next), + ..self + }) } - async fn setting_up(&mut self, state: State) -> Result { - debug!("agent setting up"); + async fn setting_up(mut self, state: State, previous: NodeState) -> Result { + info!("agent setting up"); let tasks = state.work_set().task_ids(); self.emit_state_update_if_changed(StateUpdateEvent::SettingUp { tasks }) @@ -221,11 +222,19 @@ impl Agent { SetupDone::Done(s) => s.into(), }; - Ok(scheduler) + Ok(Self { + previous_state: previous, + scheduler: Some(scheduler), + ..self + }) } - async fn pending_reboot(&mut self, state: State) -> Result { - debug!("agent pending reboot"); + async fn pending_reboot( + self, + state: State, + _previous: NodeState, + ) -> Result { + info!("agent pending reboot"); self.emit_state_update_if_changed(StateUpdateEvent::Rebooting) .await?; @@ -236,14 +245,18 @@ impl Agent { unreachable!() } - async fn ready(&mut self, state: State) -> Result { - debug!("agent ready"); + async fn ready(self, state: State, previous: NodeState) -> Result { + info!("agent ready"); self.emit_state_update_if_changed(StateUpdateEvent::Ready) .await?; - Ok(state.run().await?.into()) + Ok(Self { + previous_state: previous, + scheduler: Some(state.run().await?.into()), + ..self + }) } - async fn busy(&mut self, state: State) -> Result { + async fn busy(mut self, state: State, previous: NodeState) -> Result { self.emit_state_update_if_changed(StateUpdateEvent::Busy) .await?; @@ -255,7 +268,7 @@ impl Agent { // that is done, this sleep should be removed. time::sleep(BUSY_DELAY).await; - let mut events = vec![]; + let mut events: Vec = vec![]; let updated = state .update(&mut events, self.worker_runner.as_mut()) .await?; @@ -264,11 +277,15 @@ impl Agent { self.coordinator.emit_event(event.into()).await?; } - Ok(updated.into()) + Ok(Self { + previous_state: previous, + scheduler: Some(updated.into()), + ..self + }) } - async fn done(&mut self, state: State) -> Result { - debug!("agent done"); + async fn done(self, state: State, previous: NodeState) -> Result { + info!("agent done"); set_done_lock(self.machine_id).await?; let event = match state.cause() { @@ -287,23 +304,41 @@ impl Agent { self.emit_state_update_if_changed(event).await?; // `Done` is a final state. - Ok(state.into()) + Ok(Self { + previous_state: previous, + scheduler: Some(state.into()), + ..self + }) } - async fn execute_pending_commands(&mut self) -> Result<()> { + async fn execute_pending_commands(mut self) -> Result { let result = self.coordinator.poll_commands().await; match &result { - Ok(None) => {} + Ok(None) => Ok(Self { + last_poll_command: result, + ..self + }), Ok(Some(cmd)) => { info!("agent received node command: {:?}", cmd); let managed = self.managed; - self.scheduler()?.execute_command(cmd, managed).await?; + let scheduler = self.scheduler.take().ok_or_else(scheduler_error)?; + let new_scheduler = scheduler.execute_command(cmd.clone(), managed).await?; + + Ok(Self { + last_poll_command: result, + scheduler: Some(new_scheduler), + ..self + }) } Err(PollCommandError::RequestFailed(err)) => { // If we failed to request commands, this could be the service // could be down. Log it, but keep going. error!("error polling the service for commands: {:?}", err); + Ok(Self { + last_poll_command: result, + ..self + }) } Err(PollCommandError::RequestParseFailed(err)) => { bail!("poll commands failed: {:?}", err); @@ -321,22 +356,18 @@ impl Agent { bail!("repeated command claim attempt failures: {:?}", err); } error!("error claiming command from the service: {:?}", err); + Ok(Self { + last_poll_command: result, + ..self + }) } } - - self.last_poll_command = result; - - Ok(()) } - async fn sleep(&mut self) { + async fn sleep(&self) { let delay = time::Duration::from_secs(30); time::sleep(delay).await; } - - fn scheduler(&mut self) -> Result<&mut Scheduler> { - self.scheduler.as_mut().ok_or_else(scheduler_error) - } } // The agent owns a `Scheduler`, which it must consume when driving its state diff --git a/src/agent/onefuzz-agent/src/agent/tests.rs b/src/agent/onefuzz-agent/src/agent/tests.rs index 13929aad55..f94c40f6fe 100644 --- a/src/agent/onefuzz-agent/src/agent/tests.rs +++ b/src/agent/onefuzz-agent/src/agent/tests.rs @@ -83,12 +83,12 @@ impl Fixture { #[tokio::test] async fn test_update_free_no_work() { - let mut agent = Fixture.agent(); + let agent = Fixture.agent(); - let done = agent.update().await.unwrap(); + let (agent, done) = agent.update().await.unwrap(); assert!(!done); - assert!(matches!(agent.scheduler().unwrap(), Scheduler::Free(..))); + assert!(matches!(agent.scheduler.unwrap(), Scheduler::Free(..))); let double: &WorkQueueDouble = agent.work_queue.downcast_ref().unwrap(); let claimed_worksets = double @@ -109,13 +109,9 @@ async fn test_update_free_has_work() { .available .push(Fixture.message()); - let done = agent.update().await.unwrap(); + let (agent, done) = agent.update().await.unwrap(); assert!(!done); - - assert!(matches!( - agent.scheduler().unwrap(), - Scheduler::SettingUp(..) - )); + assert!(matches!(agent.scheduler.unwrap(), Scheduler::SettingUp(..))); let double: &WorkQueueDouble = agent.work_queue.downcast_ref().unwrap(); let claimed_worksets = double @@ -149,8 +145,10 @@ async fn test_emitted_state() { .available .push(Fixture.message()); + let mut done; for _i in 0..10 { - if agent.update().await.unwrap() { + (agent, done) = agent.update().await.unwrap(); + if done { break; } } @@ -181,8 +179,8 @@ async fn test_emitted_state() { }), ]; let coordinator: &CoordinatorDouble = agent.coordinator.downcast_ref().unwrap(); - let events = &coordinator.events; - assert_eq!(events, &expected_events); + let events = &coordinator.events.read().await; + assert_eq!(&events.to_vec(), &expected_events); } #[tokio::test] @@ -206,8 +204,10 @@ async fn test_emitted_state_failed_setup() { .available .push(Fixture.message()); + let mut done; for _i in 0..10 { - if agent.update().await.unwrap() { + (agent, done) = agent.update().await.unwrap(); + if done { break; } } @@ -223,7 +223,7 @@ async fn test_emitted_state_failed_setup() { }), ]; let coordinator: &CoordinatorDouble = agent.coordinator.downcast_ref().unwrap(); - let events = &coordinator.events; + let events = &coordinator.events.read().await.to_vec(); assert_eq!(events, &expected_events); // TODO: at some point, the underlying tests should be updated to not write diff --git a/src/agent/onefuzz-agent/src/commands.rs b/src/agent/onefuzz-agent/src/commands.rs index a613c8550a..8a62748c1a 100644 --- a/src/agent/onefuzz-agent/src/commands.rs +++ b/src/agent/onefuzz-agent/src/commands.rs @@ -25,7 +25,7 @@ const ONEFUZZ_SERVICE_USER: &str = "onefuzz"; #[cfg(target_family = "windows")] static SET_PERMISSION_ONCE: OnceCell<()> = OnceCell::const_new(); -#[derive(Debug, Deserialize, Eq, PartialEq, Serialize)] +#[derive(Debug, Deserialize, Eq, PartialEq, Serialize, Clone)] pub struct SshKeyInfo { pub public_key: Secret, } diff --git a/src/agent/onefuzz-agent/src/config.rs b/src/agent/onefuzz-agent/src/config.rs index d26d54e7f0..a9a8c650dd 100644 --- a/src/agent/onefuzz-agent/src/config.rs +++ b/src/agent/onefuzz-agent/src/config.rs @@ -319,13 +319,12 @@ impl Registration { pub async fn load_existing(config: StaticConfig) -> Result { let dynamic_config = DynamicConfig::load().await?; let machine_id = config.machine_identity.machine_id; - let mut registration = Self { + let registration = Self { config, dynamic_config, machine_id, }; - registration.renew().await?; - Ok(registration) + registration.renew().await } pub async fn create_managed(config: StaticConfig) -> Result { @@ -336,7 +335,7 @@ impl Registration { Self::create(config, false, DEFAULT_REGISTRATION_CREATE_TIMEOUT).await } - pub async fn renew(&mut self) -> Result<()> { + pub async fn renew(&self) -> Result { info!("renewing registration"); let token = self.config.credentials.access_token().await?; @@ -355,9 +354,13 @@ impl Registration { .await .context("Registration.renew request body")?; - self.dynamic_config = response.json().await?; - self.dynamic_config.save().await?; + let dynamic_config: DynamicConfig = response.json().await?; + dynamic_config.save().await?; - Ok(()) + Ok(Self { + dynamic_config, + config: self.config.clone(), + machine_id: self.machine_id, + }) } } diff --git a/src/agent/onefuzz-agent/src/coordinator.rs b/src/agent/onefuzz-agent/src/coordinator.rs index d940d46d8a..e63f250b11 100644 --- a/src/agent/onefuzz-agent/src/coordinator.rs +++ b/src/agent/onefuzz-agent/src/coordinator.rs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +use std::sync::Arc; + use anyhow::{Context, Error, Result}; use downcast_rs::Downcast; use onefuzz::{auth::AccessToken, http::ResponseExt, process::Output}; @@ -9,6 +11,7 @@ use reqwest_retry::{ is_auth_failure, RetryCheck, SendRetry, DEFAULT_RETRY_PERIOD, MAX_RETRY_ATTEMPTS, }; use serde::Serialize; +use tokio::sync::RwLock; use uuid::Uuid; use crate::commands::SshKeyInfo; @@ -16,12 +19,12 @@ use crate::config::Registration; use crate::work::{TaskId, WorkSet}; use crate::worker::WorkerEvent; -#[derive(Debug, Deserialize, Eq, PartialEq, Serialize)] +#[derive(Debug, Deserialize, Eq, PartialEq, Serialize, Clone)] pub struct StopTask { pub task_id: TaskId, } -#[derive(Debug, Deserialize, Eq, PartialEq, Serialize)] +#[derive(Debug, Deserialize, Eq, PartialEq, Serialize, Clone)] #[serde(rename_all = "snake_case")] pub enum NodeCommand { AddSshKey(SshKeyInfo), @@ -153,9 +156,9 @@ pub struct TaskInfo { pub trait ICoordinator: Downcast { async fn poll_commands(&mut self) -> Result, PollCommandError>; - async fn emit_event(&mut self, event: NodeEvent) -> Result<()>; + async fn emit_event(&self, event: NodeEvent) -> Result<()>; - async fn can_schedule(&mut self, work: &WorkSet) -> Result; + async fn can_schedule(&self, work: &WorkSet) -> Result; } impl_downcast!(ICoordinator); @@ -163,15 +166,15 @@ impl_downcast!(ICoordinator); #[async_trait] impl ICoordinator for Coordinator { async fn poll_commands(&mut self) -> Result, PollCommandError> { - self.poll_commands().await + Coordinator::poll_commands(self).await } - async fn emit_event(&mut self, event: NodeEvent) -> Result<()> { - self.emit_event(event).await + async fn emit_event(&self, event: NodeEvent) -> Result<()> { + Coordinator::emit_event(self, event).await } - async fn can_schedule(&mut self, work_set: &WorkSet) -> Result { - self.can_schedule(work_set).await + async fn can_schedule(&self, work_set: &WorkSet) -> Result { + Coordinator::can_schedule(self, work_set).await } } @@ -184,7 +187,7 @@ pub enum PollCommandError { pub struct Coordinator { client: Client, registration: Registration, - token: AccessToken, + token: Arc>, } impl Coordinator { @@ -195,7 +198,7 @@ impl Coordinator { Ok(Self { client, registration, - token, + token: Arc::new(RwLock::new(token)), }) } @@ -203,7 +206,7 @@ impl Coordinator { /// /// If the request fails due to an expired access token, we will retry once /// with a fresh one. - pub async fn poll_commands(&mut self) -> Result, PollCommandError> { + pub async fn poll_commands(&self) -> Result, PollCommandError> { let request = PollCommandsRequest { machine_id: self.registration.machine_id, }; @@ -241,7 +244,7 @@ impl Coordinator { } } - pub async fn emit_event(&mut self, event: NodeEvent) -> Result<()> { + pub async fn emit_event(&self, event: NodeEvent) -> Result<()> { let envelope = NodeEventEnvelope { event, machine_id: self.registration.machine_id, @@ -255,7 +258,7 @@ impl Coordinator { Ok(()) } - async fn can_schedule(&mut self, work_set: &WorkSet) -> Result { + async fn can_schedule(&self, work_set: &WorkSet) -> Result { // Temporary: assume one work unit per work set. // // In the future, we will probably want the same behavior, but we will @@ -283,11 +286,23 @@ impl Coordinator { Ok(can_schedule) } - async fn send_request(&mut self, request: RequestBuilder) -> Result { + async fn get_token(&self) -> Result { + let token = self.token.read().await; + Ok(token.clone()) + } + + async fn refresh_token(&self) -> Result { + let mut token = self.token.write().await; + *token = self.registration.config.credentials.access_token().await?; + Ok(token.clone()) + } + + async fn send_request(&self, request: RequestBuilder) -> Result { + let token = self.get_token().await?; let mut response = request .try_clone() .ok_or_else(|| anyhow!("unable to clone request"))? - .bearer_auth(self.token.secret().expose_ref()) + .bearer_auth(token.secret().expose_ref()) .send_retry( |code| match code { StatusCode::UNAUTHORIZED => RetryCheck::Fail, @@ -303,13 +318,13 @@ impl Coordinator { debug!("access token expired, renewing"); // If we didn't succeed due to authorization, refresh our token, - self.token = self.registration.config.credentials.access_token().await?; + let token = self.refresh_token().await?; debug!("retrying request after refreshing access token"); // And try one more time. response = request - .bearer_auth(self.token.secret().expose_ref()) + .bearer_auth(token.secret().expose_ref()) .send_retry_default() .await .context("Coordinator.send after refreshing access token"); diff --git a/src/agent/onefuzz-agent/src/coordinator/double.rs b/src/agent/onefuzz-agent/src/coordinator/double.rs index 6abbe77140..f7c488f2c2 100644 --- a/src/agent/onefuzz-agent/src/coordinator/double.rs +++ b/src/agent/onefuzz-agent/src/coordinator/double.rs @@ -5,22 +5,24 @@ use super::*; #[derive(Debug, Default)] pub struct CoordinatorDouble { - pub commands: Vec, - pub events: Vec, + pub commands: Arc>>, + pub events: Arc>>, } #[async_trait] impl ICoordinator for CoordinatorDouble { async fn poll_commands(&mut self) -> Result, PollCommandError> { - Ok(self.commands.pop()) + let mut commands = self.commands.write().await; + Ok(commands.pop()) } - async fn emit_event(&mut self, event: NodeEvent) -> Result<()> { - self.events.push(event); + async fn emit_event(&self, event: NodeEvent) -> Result<()> { + let mut events = self.events.write().await; + events.push(event); Ok(()) } - async fn can_schedule(&mut self, _work: &WorkSet) -> Result { + async fn can_schedule(&self, _work: &WorkSet) -> Result { Ok(CanSchedule { allowed: true, work_stopped: true, diff --git a/src/agent/onefuzz-agent/src/debug.rs b/src/agent/onefuzz-agent/src/debug.rs index 82072e7825..9df0d48592 100644 --- a/src/agent/onefuzz-agent/src/debug.rs +++ b/src/agent/onefuzz-agent/src/debug.rs @@ -180,7 +180,7 @@ fn debug_run_worker(opt: RunWorkerOpt) -> Result<()> { async fn run_worker(mut work_set: WorkSet) -> Result> { use crate::setup::SetupRunner; - let mut setup_runner = SetupRunner { + let setup_runner = SetupRunner { machine_id: Uuid::new_v4(), }; setup_runner.run(&work_set).await?; diff --git a/src/agent/onefuzz-agent/src/main.rs b/src/agent/onefuzz-agent/src/main.rs index 38ca6ef621..c2eda46675 100644 --- a/src/agent/onefuzz-agent/src/main.rs +++ b/src/agent/onefuzz-agent/src/main.rs @@ -306,7 +306,7 @@ async fn run_agent(config: StaticConfig, reset_node: bool) -> Result<()> { let mut coordinator = coordinator::Coordinator::new(registration.clone()).await?; debug!("initialized coordinator"); - let mut reboot = reboot::Reboot; + let reboot = reboot::Reboot; let reboot_context = reboot.load_context().await?; if reset_node { WorkSet::remove_context(config.machine_identity.machine_id).await?; @@ -331,7 +331,7 @@ async fn run_agent(config: StaticConfig, reset_node: bool) -> Result<()> { ), None => None, }; - let mut agent = agent::Agent::new( + let agent = agent::Agent::new( Box::new(coordinator), Box::new(reboot), scheduler, diff --git a/src/agent/onefuzz-agent/src/reboot.rs b/src/agent/onefuzz-agent/src/reboot.rs index 0ffd4ef0ac..7584cea858 100644 --- a/src/agent/onefuzz-agent/src/reboot.rs +++ b/src/agent/onefuzz-agent/src/reboot.rs @@ -12,26 +12,26 @@ use crate::work::*; #[async_trait] pub trait IReboot: Downcast { - async fn save_context(&mut self, ctx: RebootContext) -> Result<()>; + async fn save_context(&self, ctx: RebootContext) -> Result<()>; - async fn load_context(&mut self) -> Result>; + async fn load_context(&self) -> Result>; - fn invoke(&mut self) -> Result<()>; + fn invoke(&self) -> Result<()>; } impl_downcast!(IReboot); #[async_trait] impl IReboot for Reboot { - async fn save_context(&mut self, ctx: RebootContext) -> Result<()> { + async fn save_context(&self, ctx: RebootContext) -> Result<()> { self.save_context(ctx).await } - async fn load_context(&mut self) -> Result> { + async fn load_context(&self) -> Result> { self.load_context().await } - fn invoke(&mut self) -> Result<()> { + fn invoke(&self) -> Result<()> { self.invoke() } } @@ -39,7 +39,7 @@ impl IReboot for Reboot { pub struct Reboot; impl Reboot { - pub async fn save_context(&mut self, ctx: RebootContext) -> Result<()> { + pub async fn save_context(&self, ctx: RebootContext) -> Result<()> { let path = reboot_context_path()?; info!("saving reboot context to: {}", path.display()); @@ -54,7 +54,7 @@ impl Reboot { Ok(()) } - pub async fn load_context(&mut self) -> Result> { + pub async fn load_context(&self) -> Result> { use std::io::ErrorKind; let path = reboot_context_path()?; @@ -82,7 +82,7 @@ impl Reboot { } #[cfg(target_family = "unix")] - pub fn invoke(&mut self) -> Result<()> { + pub fn invoke(&self) -> Result<()> { info!("invoking local reboot command"); Command::new("reboot").arg("-f").status()?; @@ -91,7 +91,7 @@ impl Reboot { } #[cfg(target_family = "windows")] - pub fn invoke(&mut self) -> Result<()> { + pub fn invoke(&self) -> Result<()> { info!("invoking local reboot command"); Command::new("powershell.exe") diff --git a/src/agent/onefuzz-agent/src/reboot/double.rs b/src/agent/onefuzz-agent/src/reboot/double.rs index e63e474172..ab5027b948 100644 --- a/src/agent/onefuzz-agent/src/reboot/double.rs +++ b/src/agent/onefuzz-agent/src/reboot/double.rs @@ -1,27 +1,45 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +use std::sync::{ + atomic::{self, AtomicBool}, + Arc, +}; + +use tokio::sync::RwLock; + use super::*; -#[derive(Clone, Debug, Default)] +#[derive(Debug, Default)] pub struct RebootDouble { - pub saved: Vec, - pub invoked: bool, + pub saved: Arc>>, + pub invoked: AtomicBool, +} + +impl Clone for RebootDouble { + fn clone(&self) -> Self { + Self { + saved: self.saved.clone(), + invoked: AtomicBool::new(self.invoked.load(atomic::Ordering::SeqCst)), + } + } } #[async_trait] impl IReboot for RebootDouble { - async fn save_context(&mut self, ctx: RebootContext) -> Result<()> { - self.saved.push(ctx); + async fn save_context(&self, ctx: RebootContext) -> Result<()> { + let mut saved = self.saved.write().await; + saved.push(ctx); Ok(()) } - async fn load_context(&mut self) -> Result> { - Ok(self.saved.pop()) + async fn load_context(&self) -> Result> { + let mut saved = self.saved.write().await; + Ok(saved.pop()) } - fn invoke(&mut self) -> Result<()> { - self.invoked = true; + fn invoke(&self) -> Result<()> { + self.invoked.swap(true, atomic::Ordering::SeqCst); Ok(()) } } diff --git a/src/agent/onefuzz-agent/src/scheduler.rs b/src/agent/onefuzz-agent/src/scheduler.rs index 0283c06f2a..b7cca0da45 100644 --- a/src/agent/onefuzz-agent/src/scheduler.rs +++ b/src/agent/onefuzz-agent/src/scheduler.rs @@ -55,18 +55,22 @@ impl Scheduler { Self::default() } - pub async fn execute_command(&mut self, cmd: &NodeCommand, managed: bool) -> Result<()> { + pub async fn execute_command(self, cmd: NodeCommand, managed: bool) -> Result { match cmd { NodeCommand::AddSshKey(ssh_key_info) => { if managed { - add_ssh_key(ssh_key_info).await?; + add_ssh_key(&ssh_key_info).await?; } else { warn!("adding ssh keys only supported on managed nodes"); } + Ok(self) } NodeCommand::StopTask(stop_task) => { if let Scheduler::Busy(state) = self { - state.stop(stop_task.task_id)?; + let state = state.stop(stop_task.task_id)?; + Ok(state.into()) + } else { + Ok(self) } } NodeCommand::Stop {} => { @@ -74,7 +78,7 @@ impl Scheduler { let state = State { ctx: Done { cause }, }; - *self = state.into(); + Ok(state.into()) } NodeCommand::StopIfFree {} => { if let Scheduler::Free(_) = self { @@ -82,12 +86,12 @@ impl Scheduler { let state = State { ctx: Done { cause }, }; - *self = state.into(); + Ok(state.into()) + } else { + Ok(self) } } } - - Ok(()) } } @@ -187,7 +191,7 @@ pub enum SetupDone { } impl State { - pub async fn finish(self, runner: &mut dyn ISetupRunner) -> Result { + pub async fn finish(self, runner: &dyn ISetupRunner) -> Result { let work_set = self.ctx.work_set; let output = runner.run(&work_set).await; @@ -289,7 +293,7 @@ impl State { .all(|worker| worker.as_ref().unwrap().is_done()) } - pub fn stop(&mut self, task_id: TaskId) -> Result<()> { + pub fn stop(mut self, task_id: TaskId) -> Result { for worker in &mut self.ctx.workers { let worker = worker.as_mut().unwrap(); @@ -300,7 +304,7 @@ impl State { } } - Ok(()) + Ok(self) } } diff --git a/src/agent/onefuzz-agent/src/setup.rs b/src/agent/onefuzz-agent/src/setup.rs index fe9dc7bab8..a1caac3abc 100644 --- a/src/agent/onefuzz-agent/src/setup.rs +++ b/src/agent/onefuzz-agent/src/setup.rs @@ -24,14 +24,14 @@ pub type SetupOutput = Option; #[async_trait] pub trait ISetupRunner: Downcast { - async fn run(&mut self, work_set: &WorkSet) -> Result; + async fn run(&self, work_set: &WorkSet) -> Result; } impl_downcast!(ISetupRunner); #[async_trait] impl ISetupRunner for SetupRunner { - async fn run(&mut self, work_set: &WorkSet) -> Result { + async fn run(&self, work_set: &WorkSet) -> Result { self.run(work_set).await } } @@ -42,7 +42,7 @@ pub struct SetupRunner { } impl SetupRunner { - pub async fn run(&mut self, work_set: &WorkSet) -> Result { + pub async fn run(&self, work_set: &WorkSet) -> Result { info!("running setup for work set"); work_set.save_context(self.machine_id).await?; // Download the setup container. diff --git a/src/agent/onefuzz-agent/src/setup/double.rs b/src/agent/onefuzz-agent/src/setup/double.rs index 6fac515508..33dc1d0fe5 100644 --- a/src/agent/onefuzz-agent/src/setup/double.rs +++ b/src/agent/onefuzz-agent/src/setup/double.rs @@ -1,19 +1,24 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +use std::sync::Arc; + +use tokio::sync::RwLock; + use super::*; #[derive(Clone, Debug, Default)] pub struct SetupRunnerDouble { - pub ran: Vec, + pub ran: Arc>>, pub script: SetupOutput, pub error_message: Option, } #[async_trait] impl ISetupRunner for SetupRunnerDouble { - async fn run(&mut self, work_set: &WorkSet) -> Result { - self.ran.push(work_set.clone()); + async fn run(&self, work_set: &WorkSet) -> Result { + let mut ran = self.ran.write().await; + ran.push(work_set.clone()); if let Some(error) = self.error_message.clone() { anyhow::bail!(error); } diff --git a/src/agent/onefuzz-agent/src/work.rs b/src/agent/onefuzz-agent/src/work.rs index 34b93af494..c89693d950 100644 --- a/src/agent/onefuzz-agent/src/work.rs +++ b/src/agent/onefuzz-agent/src/work.rs @@ -1,14 +1,15 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -use std::io::ErrorKind; use std::path::PathBuf; +use std::{io::ErrorKind, sync::Arc}; use anyhow::{Context, Result}; use downcast_rs::Downcast; use onefuzz::{auth::Secret, blob::BlobContainerUrl, http::is_auth_error}; use storage_queue::{Message as QueueMessage, QueueClient}; use tokio::fs; +use tokio::sync::RwLock; use uuid::Uuid; use crate::config::Registration; @@ -145,7 +146,7 @@ pub struct Message { pub struct WorkQueue { queue: QueueClient, - registration: Registration, + registration: Arc>, } impl WorkQueue { @@ -155,16 +156,17 @@ impl WorkQueue { Ok(Self { queue, - registration, + registration: Arc::new(RwLock::new(registration)), }) } async fn renew(&mut self) -> Result<()> { - self.registration + let mut registration = self.registration.write().await; + *registration = registration .renew() .await .context("unable to renew registration in workqueue")?; - let url = self.registration.dynamic_config.work_queue.clone(); + let url = registration.dynamic_config.work_queue.clone(); self.queue = QueueClient::new(url)?; Ok(()) } @@ -207,7 +209,13 @@ impl WorkQueue { Err(err) => { if is_auth_error(&err) { self.renew().await.context("unable to renew registration")?; - let url = self.registration.dynamic_config.work_queue.clone(); + let url = self + .registration + .read() + .await + .dynamic_config + .work_queue + .clone(); queue_message .update_url(url) .delete() diff --git a/src/agent/onefuzz-agent/src/worker.rs b/src/agent/onefuzz-agent/src/worker.rs index c7f556fb8e..3e12523a6e 100644 --- a/src/agent/onefuzz-agent/src/worker.rs +++ b/src/agent/onefuzz-agent/src/worker.rs @@ -189,7 +189,7 @@ impl_from_state_for_worker!(Done); #[async_trait] pub trait IWorkerRunner: Downcast { - async fn run(&mut self, setup_dir: &Path, work: &WorkUnit) -> Result>; + async fn run(&self, setup_dir: &Path, work: &WorkUnit) -> Result>; } impl_downcast!(IWorkerRunner); @@ -214,7 +214,7 @@ impl WorkerRunner { #[async_trait] impl IWorkerRunner for WorkerRunner { - async fn run(&mut self, setup_dir: &Path, work: &WorkUnit) -> Result> { + async fn run(&self, setup_dir: &Path, work: &WorkUnit) -> Result> { let working_dir = work.working_dir(self.machine_identity.machine_id)?; debug!("worker working dir = {}", working_dir.display()); @@ -268,12 +268,12 @@ impl IWorkerRunner for WorkerRunner { } trait SuspendableChild { - fn suspend(&mut self) -> Result<()>; + fn suspend(&self) -> Result<()>; } #[cfg(target_os = "windows")] impl SuspendableChild for Child { - fn suspend(&mut self) -> Result<()> { + fn suspend(&self) -> Result<()> { // DebugActiveProcess suspends all threads in the process. // https://docs.microsoft.com/en-us/windows/win32/api/debugapi/nf-debugapi-debugactiveprocess#remarks let result = unsafe { winapi::um::debugapi::DebugActiveProcess(self.id()) }; @@ -286,7 +286,7 @@ impl SuspendableChild for Child { #[cfg(target_os = "linux")] impl SuspendableChild for Child { - fn suspend(&mut self) -> Result<()> { + fn suspend(&self) -> Result<()> { use nix::sys::signal; signal::kill( nix::unistd::Pid::from_raw(self.id() as _), diff --git a/src/agent/onefuzz-agent/src/worker/double.rs b/src/agent/onefuzz-agent/src/worker/double.rs index 8fa6a918a7..fb64555a76 100644 --- a/src/agent/onefuzz-agent/src/worker/double.rs +++ b/src/agent/onefuzz-agent/src/worker/double.rs @@ -10,7 +10,7 @@ pub struct WorkerRunnerDouble { #[async_trait] impl IWorkerRunner for WorkerRunnerDouble { - async fn run(&mut self, _setup_dir: &Path, _work: &WorkUnit) -> Result> { + async fn run(&self, _setup_dir: &Path, _work: &WorkUnit) -> Result> { Ok(Box::new(self.child.clone())) } } diff --git a/src/agent/onefuzz-agent/src/worker/tests.rs b/src/agent/onefuzz-agent/src/worker/tests.rs index 822e30e805..20f857858c 100644 --- a/src/agent/onefuzz-agent/src/worker/tests.rs +++ b/src/agent/onefuzz-agent/src/worker/tests.rs @@ -55,7 +55,7 @@ struct RunnerDouble { #[async_trait] impl IWorkerRunner for RunnerDouble { - async fn run(&mut self, _setup_dir: &Path, _work: &WorkUnit) -> Result> { + async fn run(&self, _setup_dir: &Path, _work: &WorkUnit) -> Result> { Ok(Box::new(self.child.clone())) } }