diff --git a/cas/scheduler/action_messages.rs b/cas/scheduler/action_messages.rs index fe1f21d98..dfe756e60 100644 --- a/cas/scheduler/action_messages.rs +++ b/cas/scheduler/action_messages.rs @@ -71,7 +71,13 @@ impl ActionInfoHashKey { /// Returns the salt used for cache busting/hashing. #[inline] pub fn action_name(&self) -> String { - format!("{}/{}/{:X}", self.instance_name, self.digest.str(), self.salt) + format!( + "{}/{}-{}/{:X}", + self.instance_name, + self.digest.str(), + self.digest.size_bytes, + self.salt + ) } } @@ -758,18 +764,49 @@ impl TryFrom for ActionStage { } } +// TODO: Should be able to remove this after tokio-rs/prost#299 +trait TypeUrl: Message { + const TYPE_URL: &'static str; +} + +impl TypeUrl for ExecuteResponse { + const TYPE_URL: &'static str = "type.googleapis.com/build.bazel.remote.execution.v2.ExecuteResponse"; +} + +impl TypeUrl for ExecuteOperationMetadata { + const TYPE_URL: &'static str = "type.googleapis.com/build.bazel.remote.execution.v2.ExecuteOperationMetadata"; +} + +fn from_any(message: &Any) -> Result +where + T: TypeUrl + Default, +{ + error_if!( + message.type_url != T::TYPE_URL, + "Incorrect type when decoding Any. {} != {}", + message.type_url, + T::TYPE_URL.to_string() + ); + Ok(T::decode(message.value.as_slice())?) +} + +fn to_any(message: &T) -> Any +where + T: TypeUrl, +{ + Any { + type_url: T::TYPE_URL.to_string(), + value: message.encode_to_vec(), + } +} + impl TryFrom for ActionState { type Error = Error; fn try_from(operation: Operation) -> Result { - let metadata_data = operation.metadata.err_tip(|| "No metadata in upstream operation")?; - error_if!( - metadata_data.type_url != "type.googleapis.com/build.bazel.remote.execution.v2.ExecuteResponse", - "Incorrect metadata structure in upstream operation. {} != type.googleapis.com/build.bazel.remote.execution.v2.ExecuteResponse", - metadata_data.type_url - ); - let metadata = ExecuteOperationMetadata::decode(metadata_data.value.as_slice()) - .err_tip(|| "Could not decode metadata in upstream operation")?; + let metadata = + from_any::(&operation.metadata.err_tip(|| "No metadata in upstream operation")?) + .err_tip(|| "Could not decode metadata in upstream operation")?; let action_digest = metadata .action_digest @@ -792,17 +829,15 @@ impl TryFrom for ActionState { LongRunningResult::Error(error) => ActionStage::Error((error.into(), ActionResult::default())), LongRunningResult::Response(response) => { // Could be Completed, CompletedFromCache or Error. - error_if!( - response.type_url != "type.googleapis.com/build.bazel.remote.execution.v2.ExecuteResponse", - "Incorrect result structure for completed upstream action. {} != type.googleapis.com/build.bazel.remote.execution.v2.ExecuteResponse", - response.type_url - ); - ExecuteResponse::decode(response.value.as_slice())?.try_into()? + from_any::(&response) + .err_tip(|| "Could not decode result structure for completed upstream action")? + .try_into()? } } } }; + println!("Operation name: {}", operation.name); let unique_qualifier = if let Ok(v) = operation.name.as_str().try_into() { v } else { @@ -842,14 +877,13 @@ impl ActionState { impl From for Operation { fn from(val: ActionState) -> Self { - let has_action_result = val.stage.has_action_result(); let stage = Into::::into(&val.stage) as i32; - let execute_response: ExecuteResponse = val.stage.into(); - let serialized_response = if has_action_result { - execute_response.encode_to_vec() + let result = if val.stage.has_action_result() { + let execute_response: ExecuteResponse = val.stage.into(); + Some(LongRunningResult::Response(to_any(&execute_response))) } else { - vec![] + None }; let metadata = ExecuteOperationMetadata { @@ -862,15 +896,9 @@ impl From for Operation { Self { name: val.unique_qualifier.action_name(), - metadata: Some(Any { - type_url: "type.googleapis.com/build.bazel.remote.execution.v2.ExecuteOperationMetadata".to_string(), - value: metadata.encode_to_vec(), - }), - done: has_action_result, - result: Some(LongRunningResult::Response(Any { - type_url: "type.googleapis.com/build.bazel.remote.execution.v2.ExecuteResponse".to_string(), - value: serialized_response, - })), + metadata: Some(to_any(&metadata)), + done: result.is_some(), + result, } } } diff --git a/cas/scheduler/tests/action_messages_test.rs b/cas/scheduler/tests/action_messages_test.rs index b36fcebd2..035aab368 100644 --- a/cas/scheduler/tests/action_messages_test.rs +++ b/cas/scheduler/tests/action_messages_test.rs @@ -39,17 +39,18 @@ mod action_messages_tests { #[tokio::test] async fn action_state_any_url_test() -> Result<(), Error> { - let operation: Operation = ActionState { + let action_state = ActionState { unique_qualifier: ActionInfoHashKey { instance_name: "foo_instance".to_string(), digest: DigestInfo::new([1u8; 32], 5), salt: 0, }, - stage: ActionStage::Unknown, - } - .into(); + // Result is only populated if has_action_result. + stage: ActionStage::Completed(ActionResult::default()), + }; + let operation: Operation = action_state.clone().into(); - match operation.result { + match &operation.result { Some(operation::Result::Response(any)) => assert_eq!( any.type_url, "type.googleapis.com/build.bazel.remote.execution.v2.ExecuteResponse" @@ -57,6 +58,9 @@ mod action_messages_tests { other => panic!("Expected Some(Result(Any)), got: {other:?}"), } + let action_state_round_trip: ActionState = operation.try_into()?; + assert_eq!(action_state, action_state_round_trip); + Ok(()) }