diff --git a/crates/core/src/host_component.rs b/crates/core/src/host_component.rs index f476961c72..462e201dcc 100644 --- a/crates/core/src/host_component.rs +++ b/crates/core/src/host_component.rs @@ -1,6 +1,11 @@ -use std::{any::Any, marker::PhantomData, sync::Arc}; +use std::{ + any::{type_name, Any, TypeId}, + collections::HashMap, + marker::PhantomData, + sync::Arc, +}; -use anyhow::Result; +use anyhow::{bail, Result}; use super::{Data, Linker}; @@ -131,6 +136,7 @@ type BoxHostComponent = Box; #[derive(Default)] pub struct HostComponentsBuilder { + handles: HashMap, host_components: Vec, } @@ -140,7 +146,17 @@ impl HostComponentsBuilder { linker: &mut Linker, host_component: HC, ) -> Result> { + let type_id = TypeId::of::(); + if self.handles.contains_key(&type_id) { + bail!( + "already have a host component of type {}", + type_name::() + ) + } + let handle = AnyHostComponentDataHandle(self.host_components.len()); + self.handles.insert(type_id, handle); + self.host_components.push(Box::new(host_component)); HC::add_to_linker(linker, move |data| { data.host_components_data @@ -156,12 +172,14 @@ impl HostComponentsBuilder { pub fn build(self) -> HostComponents { HostComponents { + handles: self.handles, host_components: Arc::new(self.host_components), } } } pub struct HostComponents { + handles: HashMap, host_components: Arc>, } @@ -180,6 +198,12 @@ impl HostComponents { host_components: self.host_components.clone(), } } + + pub fn find_handle(&self) -> Option> { + self.handles + .get(&TypeId::of::()) + .map(|handle| HostComponentDataHandle::from_any(*handle)) + } } type AnyData = Box; @@ -265,4 +289,17 @@ mod tests { hc_data.set(handle2, 1); assert_eq!(hc_data.get_or_insert(handle2), &1); } + + #[test] + fn find_handle() { + let engine = wasmtime::Engine::default(); + let mut linker: crate::Linker<()> = crate::Linker::new(&engine); + + let mut builder = HostComponents::builder(); + builder.add_host_component(&mut linker, TestHC).unwrap(); + let host_components = builder.build(); + let handle = host_components.find_handle::().unwrap(); + let mut hc_data = host_components.new_data(); + assert_eq!(hc_data.get_or_insert(handle), &0); + } } diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index cfa20b72e0..993a9d8e2b 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -262,6 +262,13 @@ impl Engine { let inner = Arc::new(self.module_linker.instantiate_pre(module)?); Ok(ModuleInstancePre { inner }) } + + /// Find the [`HostComponentDataHandle`] for a [`HostComponent`] if configured for this engine. + pub fn find_host_component_handle( + &self, + ) -> Option> { + self.host_components.find_handle() + } } impl AsRef for Engine { diff --git a/crates/core/tests/integration_test.rs b/crates/core/tests/integration_test.rs index 07ab4c7842..953d7ed927 100644 --- a/crates/core/tests/integration_test.rs +++ b/crates/core/tests/integration_test.rs @@ -132,12 +132,10 @@ async fn test_host_component() { #[tokio::test(flavor = "multi_thread")] async fn test_host_component_data_update() { - // Need to build Engine separately to get the HostComponentDataHandle - let mut engine_builder = Engine::builder(&test_config()).unwrap(); - let factor_data_handle = engine_builder - .add_host_component(MultiplierHostComponent) + let engine = test_engine(); + let multiplier_handle = engine + .find_host_component_handle::() .unwrap(); - let engine: Engine<()> = engine_builder.build(); let stdout = run_core_wasi_test_engine( &engine, @@ -145,7 +143,7 @@ async fn test_host_component_data_update() { |store_builder| { store_builder .host_components_data() - .set(factor_data_handle, Multiplier(100)); + .set(multiplier_handle, Multiplier(100)); }, |_| {}, )