diff --git a/cosmwasm/ucs01-relay-api/src/middleware.rs b/cosmwasm/ucs01-relay-api/src/middleware.rs index 9720f778f3..49ffe2d672 100644 --- a/cosmwasm/ucs01-relay-api/src/middleware.rs +++ b/cosmwasm/ucs01-relay-api/src/middleware.rs @@ -74,6 +74,7 @@ pub struct InFlightPfmPacket { pub src_packet_timeout: IbcTimeout, pub forward_channel_id: String, pub forward_port_id: String, + pub original_protocol_version: String, } impl InFlightPfmPacket { @@ -83,6 +84,7 @@ impl InFlightPfmPacket { timeout: u64, forward_channel_id: String, forward_port_id: String, + original_protocol_version: String, ) -> Self { Self { original_sender_addr, @@ -96,6 +98,7 @@ impl InFlightPfmPacket { packet_sequence: original_packet.sequence, forward_channel_id, forward_port_id, + original_protocol_version, } } diff --git a/cosmwasm/ucs01-relay-api/src/protocol.rs b/cosmwasm/ucs01-relay-api/src/protocol.rs index 237b72c687..9205f8b72a 100644 --- a/cosmwasm/ucs01-relay-api/src/protocol.rs +++ b/cosmwasm/ucs01-relay-api/src/protocol.rs @@ -401,6 +401,12 @@ pub trait TransferProtocol { tokens: Vec, sequence: u64, ) -> Result>, Vec<(&str, String)>)>, Self::Error>; + + fn convert_foreign_protocol_ack( + &self, + foreign_protocol: &str, + ack: Binary, + ) -> Result; } #[cfg(test)] diff --git a/cosmwasm/ucs01-relay/src/error.rs b/cosmwasm/ucs01-relay/src/error.rs index 9b94dcce7c..7cbbbd389b 100644 --- a/cosmwasm/ucs01-relay/src/error.rs +++ b/cosmwasm/ucs01-relay/src/error.rs @@ -1,6 +1,6 @@ use std::string::FromUtf8Error; -use cosmwasm_std::{IbcOrder, OverflowError, StdError, SubMsgResult}; +use cosmwasm_std::{Binary, IbcOrder, OverflowError, StdError, SubMsgResult}; use cw_controllers::AdminError; use thiserror::Error; use ucs01_relay_api::{middleware::MiddlewareError, protocol::ProtocolError, types::EncodingError}; @@ -65,6 +65,9 @@ pub enum ContractError { #[error("{0}")] MiddlewareError(#[from] MiddlewareError), + + #[error("invalid ack ({0})")] + InvalidAck(Binary), } impl From for ContractError { diff --git a/cosmwasm/ucs01-relay/src/protocol.rs b/cosmwasm/ucs01-relay/src/protocol.rs index 00137dcffd..01bb4b4238 100644 --- a/cosmwasm/ucs01-relay/src/protocol.rs +++ b/cosmwasm/ucs01-relay/src/protocol.rs @@ -610,6 +610,7 @@ impl<'a> TransferProtocol for Ics20Protocol<'a> { timeout, forward.channel.value(), forward.port.value(), + Self::VERSION.to_string(), ); if let Some(reply_sub) = transfer @@ -653,6 +654,8 @@ impl<'a> TransferProtocol for Ics20Protocol<'a> { let (mut ack_msgs, mut ack_attr, ack_def) = match ack { Ok(value) => { + let value = self + .convert_foreign_protocol_ack(&refund_info.original_protocol_version, value)?; let value_string = value.to_string(); ( self.send_tokens_success(sender, &String::new(), tokens)?, @@ -713,6 +716,30 @@ impl<'a> TransferProtocol for Ics20Protocol<'a> { Ok(Some((ack_msgs, ack_attr))) } + + fn convert_foreign_protocol_ack( + &self, + foreign_protocol: &str, + ack: Binary, + ) -> Result { + match foreign_protocol { + Ucs01Protocol::VERSION => { + let ack: Ucs01Ack = ack + .clone() + .try_into() + .map_err(|_| ContractError::InvalidAck(ack))?; + + match ack { + Ucs01Ack::Failure => Self::ack_success(), + Ucs01Ack::Success => Self::ack_failure("ucs01 ack failure".to_string()), + } + .try_into() + .map_err(Into::into) + } + Ics20Protocol::VERSION => Ok(ack), + _ => Err(ContractError::InvalidAck(ack)), + } + } } pub struct Ucs01Protocol<'a> { @@ -977,6 +1004,7 @@ impl<'a> TransferProtocol for Ucs01Protocol<'a> { timeout, forward.channel.value(), forward.port.value(), + Self::VERSION.to_string(), ); if let Some(reply_sub) = transfer @@ -1021,6 +1049,8 @@ impl<'a> TransferProtocol for Ucs01Protocol<'a> { let (mut ack_msgs, mut ack_attr, ack_def) = match ack { Ok(value) => { let value_string = value.to_string(); + let value = self + .convert_foreign_protocol_ack(&refund_info.original_protocol_version, value)?; ( self.send_tokens_success(sender, &String::new().as_bytes().into(), tokens)?, Vec::from_iter( @@ -1080,6 +1110,30 @@ impl<'a> TransferProtocol for Ucs01Protocol<'a> { Ok(Some((ack_msgs, ack_attr))) } + + fn convert_foreign_protocol_ack( + &self, + foreign_protocol: &str, + ack: Binary, + ) -> Result { + match foreign_protocol { + Ucs01Protocol::VERSION => Ok(ack), + Ics20Protocol::VERSION => { + let ack: Ics20Ack = ack + .clone() + .try_into() + .map_err(|_| ContractError::InvalidAck(ack))?; + + match ack { + Ics20Ack::Result(_) => Self::ack_success(), + Ics20Ack::Error(_) => Self::ack_failure(String::new()), + } + .try_into() + .map_err(Into::into) + } + _ => Err(ContractError::InvalidAck(ack)), + } + } } #[cfg(test)]