Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support multi CU workers #2345

Merged
merged 13 commits into from
Aug 20, 2024
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 9 additions & 10 deletions crates/chain-connector/src/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
use crate::types::{SubnetResolveResult, TxReceiptResult, TxStatus, Worker};
use crate::types::{OnChainWorkerId, SubnetResolveResult, SubnetWorker, TxReceiptResult, TxStatus};
use crate::{ChainConnector, HttpChainConnector};
use ccp_shared::types::CUID;
use futures::FutureExt;
use particle_args::{Args, JError};
use particle_builtins::{wrap, CustomService};
Expand Down Expand Up @@ -111,14 +110,14 @@ async fn register_worker_builtin(
let mut args = args.function_args.into_iter();
let deal_id: DealId = Args::next("deal_id", &mut args)?;
let worker_id: WorkerId = Args::next("worker_id", &mut args)?;
let cu_ids: Vec<CUID> = Args::next("cu_id", &mut args)?;
let onchain_worker_id: OnChainWorkerId = Args::next("onchain_worker_id", &mut args)?;

if cu_ids.len() != 1 {
return Err(JError::new("Only one cu_id is allowed"));
if onchain_worker_id.is_empty() {
return Err(JError::new("Invalid onchain_worker_id: empty"));
}

let tx_hash = connector
.register_worker(&deal_id, worker_id, cu_ids[0])
.register_worker(&deal_id, worker_id, onchain_worker_id)
.await
.map_err(|err| JError::new(format!("Failed to register worker: {err}")))?;
Ok(json!(tx_hash))
Expand Down Expand Up @@ -164,18 +163,18 @@ async fn resolve_subnet_builtin(
let deal_id: String = Args::next("deal_id", &mut args.function_args.into_iter())?;
let deal_id = DealId::from(deal_id);

let workers: eyre::Result<Vec<Worker>> = try {
let workers: eyre::Result<Vec<SubnetWorker>> = try {
if !deal_id.is_valid() {
Err(eyre::eyre!(
"Invalid deal id '{}': invalid length",
deal_id.as_str()
))?;
}

let units = connector.get_deal_compute_units(&deal_id).await?;
let workers: Result<Vec<Worker>, _> = units
let workers = connector.get_deal_workers(&deal_id).await?;
let workers: Result<Vec<SubnetWorker>, _> = workers
.into_iter()
.map(|unit| Worker::try_from(unit))
.map(|worker| SubnetWorker::try_from(worker))
.collect();
workers?
};
Expand Down
179 changes: 126 additions & 53 deletions crates/chain-connector/src/connector.rs

Large diffs are not rendered by default.

21 changes: 13 additions & 8 deletions crates/chain-connector/src/function/deal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

use alloy_primitives::FixedBytes;
use alloy_sol_types::{sol, SolType};
use hex_utils::decode_hex;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -44,12 +44,13 @@ sol! {
SMALL_BALANCE
}

struct ComputeUnit {
bytes32 id;
bytes32 workerId;
struct Worker {
bytes32 offchainId;
bytes32 onchainId;
bytes32 peerId;
address provider;
uint256 joinedEpoch;
bytes32[] computeUnitIds;
}

/// @dev Returns the status of the deal
Expand All @@ -58,14 +59,18 @@ sol! {
/// @dev Returns the app CID
function appCID() external view returns (CIDV1 memory);

/// @dev Set worker ID for a compute unit. Compute unit can have only one worker ID
function setWorker(bytes32 computeUnitId, bytes32 workerId) external;
/// @dev Set offchain worker ID for a corresponding onchain worker for a deal
function activateWorker(bytes32 onchainId, bytes32 offchainId);

/// @dev Returns the compute units info by provider
function getComputeUnits() public view returns (ComputeUnit[] memory);
/// @dev Removes worker from the deal
function removeWorker(bytes32 onchainId) external;
/// @dev Returns workers
function getWorkers() external view returns (Worker[] memory);
}
}

pub type OnChainWorkerID = FixedBytes<32>;

impl CIDV1 {
pub fn from_hex(hex: &str) -> Result<Self, ConnectorError> {
let bytes = decode_hex(hex)?;
Expand Down
31 changes: 17 additions & 14 deletions crates/chain-connector/src/function/offer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@ sol! {
bytes32 id;
address deal;
uint256 startEpoch;
bytes32 onchainWorkerId;
}

/// @dev Returns the compute peer info
function getComputePeer(bytes32 peerId) external view returns (ComputePeer memory);
/// @dev Returns the compute units info of a peer
function getComputeUnits(bytes32 peerId) external view returns (ComputeUnit[] memory);

/// @dev Return the compute unit from a deal
function returnComputeUnitFromDeal(bytes32 unitId) external;
}
}

Expand Down Expand Up @@ -75,12 +73,12 @@ impl From<ComputeUnit> for PendingUnit {
mod tests {
use crate::Offer::ComputePeer;
use alloy_primitives::{hex, U256};
use alloy_sol_types::SolType;
use alloy_sol_types::SolValue;
use hex_utils::decode_hex;

#[tokio::test]
async fn decode_compute_unit() {
let data = "aa3046a12a1aac6e840625e6329d70b427328fec36dc8d273e5e6454b85633d50000000000000000000000005e3d0fde6f793b3115a9e7f5ebc195bbeed35d6c00000000000000000000000000000000000000000000000000000000000003e8";
#[test]
fn decode_compute_unit() {
justprosh marked this conversation as resolved.
Show resolved Hide resolved
let data = "aa3046a12a1aac6e840625e6329d70b427328fec36dc8d273e5e6454b85633d50000000000000000000000005e3d0fde6f793b3115a9e7f5ebc195bbeed35d6c00000000000000000000000000000000000000000000000000000000000003e8bb3046a12a1aac6e840625e6329d70b427328fec36dc8d273e5e6454b85633dd";
let compute_unit = super::ComputeUnit::abi_decode(&decode_hex(data).unwrap(), true);
assert!(compute_unit.is_ok());
let compute_unit = compute_unit.unwrap();
Expand All @@ -95,11 +93,15 @@ mod tests {
"0x5e3d0fde6f793b3115a9e7f5ebc195bbeed35d6c"
);
assert_eq!(compute_unit.startEpoch, U256::from(1000));
assert_eq!(
compute_unit.onchainWorkerId,
hex!("bb3046a12a1aac6e840625e6329d70b427328fec36dc8d273e5e6454b85633dd")
)
}

#[tokio::test]
async fn decode_compute_unit_no_deal() {
let data = "aa3046a12a1aac6e840625e6329d70b427328fec36dc8d273e5e6454b85633d5000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003e8";
#[test]
fn decode_compute_unit_no_deal_no_worker() {
let data = "aa3046a12a1aac6e840625e6329d70b427328fec36dc8d273e5e6454b85633d5000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003e80000000000000000000000000000000000000000000000000000000000000000";
let compute_unit = super::ComputeUnit::abi_decode(&decode_hex(data).unwrap(), true);
assert!(compute_unit.is_ok());
let compute_unit = compute_unit.unwrap();
Expand All @@ -109,10 +111,11 @@ mod tests {
);
assert!(compute_unit.deal.is_zero());
assert_eq!(compute_unit.startEpoch, U256::from(1000));
assert!(compute_unit.onchainWorkerId.is_zero())
}

#[tokio::test]
async fn decode_compute_peer_no_commitment() {
#[test]
fn decode_compute_peer_no_commitment() {
let data = "0xaa3046a12a1aac6e840625e6329d70b427328fec36dc8d273e5e6454b85633d5000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000020000000000000000000000005b73c5498c1e3b4dba84de0f1833c4a029d90519";
let compute_peer = ComputePeer::abi_decode(&decode_hex(data).unwrap(), true);
assert!(compute_peer.is_ok());
Expand All @@ -129,8 +132,8 @@ mod tests {
);
}

#[tokio::test]
async fn decode_compute_peer() {
#[test]
fn decode_compute_peer() {
let data = "0xaa3046a12a1aac6e840625e6329d70b427328fec36dc8d273e5e6454b85633d5aa3046a12a1aac6e840625e6329d70b427328feceedc8d273e5e6454b85633b5000000000000000000000000000000000000000000000000000000000000000a0000000000000000000000005b73c5498c1e3b4dba84de0f1833c4a029d90519";
let compute_peer = ComputePeer::abi_decode(&decode_hex(data).unwrap(), true);
assert!(compute_peer.is_ok());
Expand Down
32 changes: 19 additions & 13 deletions crates/chain-connector/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/
use crate::error::ConnectorError;
use crate::function::Deal;
use crate::Deal::ComputeUnit;
use alloy_primitives::U256;
use ccp_shared::types::{Difficulty, GlobalNonce, CUID};
use chain_data::parse_peer_id;
Expand Down Expand Up @@ -56,11 +55,14 @@ impl DealResult {
}
}

pub type OnChainWorkerId = Vec<u8>;

#[derive(Debug, Serialize, Deserialize)]
pub struct DealInfo {
pub status: Deal::Status,
pub unit_ids: Vec<CUID>,
pub cu_ids: Vec<CUID>,
pub app_cid: String,
pub onchain_worker_id: OnChainWorkerId,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -142,28 +144,32 @@ impl RawTxReceipt {
}

#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct Worker {
pub struct SubnetWorker {
pub cu_ids: Vec<String>,
pub host_id: String,
pub worker_id: Vec<String>,
}

impl TryFrom<ComputeUnit> for Worker {
impl TryFrom<Deal::Worker> for SubnetWorker {
type Error = Report;
fn try_from(unit: ComputeUnit) -> eyre::Result<Self> {
fn try_from(deal_worker: Deal::Worker) -> eyre::Result<Self> {
let mut worker_id = vec![];
if !unit.workerId.is_zero() {
let w_id = parse_peer_id(&unit.workerId.0)
.map_err(|err| eyre!("Failed to parse unit.workerId: {err}"))?
if !deal_worker.offchainId.is_zero() {
let w_id = parse_peer_id(&deal_worker.offchainId.0)
.map_err(|err| eyre!("Failed to parse worker.offchainId: {err}"))?
.to_base58();
worker_id.push(w_id)
}
let cu_id = unit.id.to_string();
let peer_id = parse_peer_id(&unit.peerId.0)
.map_err(|err| eyre!("Failed to parse unit.peerId: {err}"))?;
let cu_ids = deal_worker
.computeUnitIds
.into_iter()
.map(|id| id.to_string())
.collect();
let peer_id = parse_peer_id(&deal_worker.peerId.0)
.map_err(|err| eyre!("Failed to parse worker.peerId: {err}"))?;

Ok(Self {
cu_ids: vec![cu_id],
cu_ids,
host_id: peer_id.to_base58(),
worker_id,
})
Expand All @@ -173,7 +179,7 @@ impl TryFrom<ComputeUnit> for Worker {
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct SubnetResolveResult {
pub success: bool,
pub workers: Vec<Worker>,
pub workers: Vec<SubnetWorker>,
pub error: Vec<String>,
}

Expand Down
2 changes: 1 addition & 1 deletion crates/chain-listener/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ serde_with = { workspace = true }
[dev-dependencies]
jsonrpsee = { workspace = true, features = ["server"] }
tempfile = { workspace = true }

bs58 = { workspace = true }
Loading
Loading