From 9a7a5814ed76e9b610e834c03b16ce76b544e37d Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Tue, 19 Sep 2023 23:25:22 +0200 Subject: [PATCH] Do proper type checking for type handles. Instead of relying purely on the assumption that type handles can be compared cheaply by pointer equality, fallback to a more expensive walk of the type tree that recursively compares types structurally. This allows different components to call into each other as long as their types are structurally equivalent. --- crates/wasmtime/src/component/types.rs | 285 +++++++++++++++++++++++-- tests/all/component_model/import.rs | 61 ++++++ 2 files changed, 326 insertions(+), 20 deletions(-) diff --git a/crates/wasmtime/src/component/types.rs b/crates/wasmtime/src/component/types.rs index 43b25a13c388..462d2ade43bb 100644 --- a/crates/wasmtime/src/component/types.rs +++ b/crates/wasmtime/src/component/types.rs @@ -3,14 +3,15 @@ use crate::component::matching::InstanceType; use crate::component::values::{self, Val}; use anyhow::{anyhow, Result}; +use std::collections::{HashMap, HashSet}; use std::fmt; use std::mem; use std::ops::Deref; use std::sync::Arc; use wasmtime_environ::component::{ CanonicalAbiInfo, ComponentTypes, InterfaceType, ResourceIndex, TypeEnumIndex, TypeFlagsIndex, - TypeListIndex, TypeOptionIndex, TypeRecordIndex, TypeResultIndex, TypeTupleIndex, - TypeVariantIndex, + TypeListIndex, TypeOptionIndex, TypeRecordIndex, TypeResourceTableIndex, TypeResultIndex, + TypeTupleIndex, TypeVariantIndex, }; use wasmtime_environ::PrimaryMap; @@ -56,6 +57,29 @@ impl Handle { resources: &self.resources, } } + + fn equivalent<'a>( + &'a self, + other: &'a Self, + type_check: fn(&TypeChecker<'a>, T, T) -> bool, + ) -> bool + where + T: PartialEq + Copy, + { + (self.index == other.index + && Arc::ptr_eq(&self.types, &other.types) + && Arc::ptr_eq(&self.resources, &other.resources)) + || type_check( + &TypeChecker { + a_types: &self.types, + b_types: &other.types, + a_resource: &self.resources, + b_resource: &self.resources, + }, + self.index, + other.index, + ) + } } impl fmt::Debug for Handle { @@ -66,23 +90,188 @@ impl fmt::Debug for Handle { } } -impl PartialEq for Handle { - fn eq(&self, other: &Self) -> bool { - // FIXME: This is an overly-restrictive definition of equality in that it doesn't consider types to be - // equal unless they refer to the same declaration in the same component. It's a good shortcut for the - // common case, but we should also do a recursive structural equality test if the shortcut test fails. - self.index == other.index - && Arc::ptr_eq(&self.types, &other.types) - && Arc::ptr_eq(&self.resources, &other.resources) - } +/// Type checker for two +struct TypeChecker<'a> { + a_types: &'a ComponentTypes, + a_resource: &'a PrimaryMap, + b_types: &'a ComponentTypes, + b_resource: &'a PrimaryMap, } -impl Eq for Handle {} +impl TypeChecker<'_> { + fn interface_types_equal(&self, a: InterfaceType, b: InterfaceType) -> bool { + match (a, b) { + (InterfaceType::Own(o1), InterfaceType::Own(o2)) => self.resources_equal(o1, o2), + (InterfaceType::Own(_), _) => false, + (InterfaceType::Borrow(b1), InterfaceType::Borrow(b2)) => self.resources_equal(b1, b2), + (InterfaceType::Borrow(_), _) => false, + (InterfaceType::List(l1), InterfaceType::List(l2)) => self.lists_equal(l1, l2), + (InterfaceType::List(_), _) => false, + (InterfaceType::Record(r1), InterfaceType::Record(r2)) => self.records_equal(r1, r2), + (InterfaceType::Record(_), _) => false, + (InterfaceType::Variant(v1), InterfaceType::Variant(v2)) => self.variants_equal(v1, v2), + (InterfaceType::Variant(_), _) => false, + (InterfaceType::Result(r1), InterfaceType::Result(r2)) => self.results_equal(r1, r2), + (InterfaceType::Result(_), _) => false, + (InterfaceType::Option(o1), InterfaceType::Option(o2)) => self.options_equal(o1, o2), + (InterfaceType::Option(_), _) => false, + (InterfaceType::Enum(e1), InterfaceType::Enum(e2)) => self.enums_equal(e1, e2), + (InterfaceType::Enum(_), _) => false, + (InterfaceType::Tuple(t1), InterfaceType::Tuple(t2)) => self.tuples_equal(t1, t2), + (InterfaceType::Tuple(_), _) => false, + (InterfaceType::Flags(f1), InterfaceType::Flags(f2)) => self.flags_equal(f1, f2), + (InterfaceType::Flags(_), _) => false, + (InterfaceType::Bool, InterfaceType::Bool) => true, + (InterfaceType::Bool, _) => false, + (InterfaceType::U8, InterfaceType::U8) => true, + (InterfaceType::U8, _) => false, + (InterfaceType::U16, InterfaceType::U16) => true, + (InterfaceType::U16, _) => false, + (InterfaceType::U32, InterfaceType::U32) => true, + (InterfaceType::U32, _) => false, + (InterfaceType::U64, InterfaceType::U64) => true, + (InterfaceType::U64, _) => false, + (InterfaceType::S8, InterfaceType::S8) => true, + (InterfaceType::S8, _) => false, + (InterfaceType::S16, InterfaceType::S16) => true, + (InterfaceType::S16, _) => false, + (InterfaceType::S32, InterfaceType::S32) => true, + (InterfaceType::S32, _) => false, + (InterfaceType::S64, InterfaceType::S64) => true, + (InterfaceType::S64, _) => false, + (InterfaceType::Float32, InterfaceType::Float32) => true, + (InterfaceType::Float32, _) => false, + (InterfaceType::Float64, InterfaceType::Float64) => true, + (InterfaceType::Float64, _) => false, + (InterfaceType::String, InterfaceType::String) => true, + (InterfaceType::String, _) => false, + (InterfaceType::Char, InterfaceType::Char) => true, + (InterfaceType::Char, _) => false, + } + } + + fn lists_equal(&self, l1: TypeListIndex, l2: TypeListIndex) -> bool { + let a = &self.a_types[l1]; + let b = &self.b_types[l2]; + self.interface_types_equal(a.element, b.element) + } + + fn resources_equal(&self, o1: TypeResourceTableIndex, o2: TypeResourceTableIndex) -> bool { + let a = &self.a_types[o1]; + let b = &self.b_types[o2]; + self.a_resource[a.ty] == self.b_resource[b.ty] + } + + fn records_equal(&self, r1: TypeRecordIndex, r2: TypeRecordIndex) -> bool { + let a = &self.a_types[r1]; + let b = &self.b_types[r2]; + if a.fields.len() != b.fields.len() { + return false; + } + let b_fields = b + .fields + .iter() + .map(|f| (&f.name, &f.ty)) + .collect::>(); + a.fields.iter().all(|a_field| { + let Some(&&b_field_ty) = b_fields.get(&a_field.name) else { + return false; + }; + self.interface_types_equal(a_field.ty, b_field_ty) + }) + } + + fn variants_equal(&self, v1: TypeVariantIndex, v2: TypeVariantIndex) -> bool { + let a = &self.a_types[v1]; + let b = &self.b_types[v2]; + if a.cases.len() != b.cases.len() { + return false; + } + let b_cases = b + .cases + .iter() + .map(|f| (&f.name, &f.ty)) + .collect::>>(); + a.cases.iter().all(|a_case| { + let Some(&&b_case_ty) = b_cases.get(&a_case.name) else { + return false; + }; + match (a_case.ty, b_case_ty) { + (Some(a_case_ty), Some(b_case_ty)) => { + self.interface_types_equal(a_case_ty, b_case_ty) + } + (None, None) => true, + _ => false, + } + }) + } + + fn results_equal(&self, r1: TypeResultIndex, r2: TypeResultIndex) -> bool { + let a = &self.a_types[r1]; + let b = &self.b_types[r2]; + let oks = match (a.ok, b.ok) { + (Some(ok1), Some(ok2)) => self.interface_types_equal(ok1, ok2), + (None, None) => true, + _ => false, + }; + if !oks { + return false; + } + match (a.err, b.err) { + (Some(err1), Some(err2)) => self.interface_types_equal(err1, err2), + (None, None) => true, + _ => false, + } + } + + fn options_equal(&self, o1: TypeOptionIndex, o2: TypeOptionIndex) -> bool { + let a = &self.a_types[o1]; + let b = &self.b_types[o2]; + self.interface_types_equal(a.ty, b.ty) + } + + fn enums_equal(&self, e1: TypeEnumIndex, e2: TypeEnumIndex) -> bool { + let a = &self.a_types[e1]; + let b = &self.b_types[e2]; + if a.names.len() != b.names.len() { + return false; + } + + let b_names = b.names.iter().collect::>(); + a.names.iter().all(|a_name| b_names.contains(a_name)) + } + + fn tuples_equal(&self, t1: TypeTupleIndex, t2: TypeTupleIndex) -> bool { + let a = &self.a_types[t1]; + let b = &self.b_types[t2]; + if a.types.len() != b.types.len() { + return false; + } + a.types + .iter() + .zip(b.types.iter()) + .all(|(&a, &b)| self.interface_types_equal(a, b)) + } + + fn flags_equal(&self, f1: TypeFlagsIndex, f2: TypeFlagsIndex) -> bool { + let a = &self.a_types[f1]; + let b = &self.b_types[f2]; + a.names == b.names + } +} /// A `list` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct List(Handle); +impl PartialEq for List { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::lists_equal) + } +} + +impl Eq for List {} + impl List { /// Instantiate this type with the specified `values`. pub fn new_val(&self, values: Box<[Val]>) -> Result { @@ -108,7 +297,7 @@ pub struct Field<'a> { } /// A `record` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct Record(Handle); impl Record { @@ -130,8 +319,16 @@ impl Record { } } +impl PartialEq for Record { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::records_equal) + } +} + +impl Eq for Record {} + /// A `tuple` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct Tuple(Handle); impl Tuple { @@ -153,6 +350,14 @@ impl Tuple { } } +impl PartialEq for Tuple { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::tuples_equal) + } +} + +impl Eq for Tuple {} + /// A case declaration belonging to a `variant` pub struct Case<'a> { /// The name of the case @@ -162,7 +367,7 @@ pub struct Case<'a> { } /// A `variant` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct Variant(Handle); impl Variant { @@ -187,8 +392,16 @@ impl Variant { } } +impl PartialEq for Variant { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::variants_equal) + } +} + +impl Eq for Variant {} + /// An `enum` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct Enum(Handle); impl Enum { @@ -210,8 +423,16 @@ impl Enum { } } +impl PartialEq for Enum { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::enums_equal) + } +} + +impl Eq for Enum {} + /// An `option` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct OptionType(Handle); impl OptionType { @@ -230,8 +451,16 @@ impl OptionType { } } +impl PartialEq for OptionType { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::options_equal) + } +} + +impl Eq for OptionType {} + /// An `expected` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct ResultType(Handle); impl ResultType { @@ -261,8 +490,16 @@ impl ResultType { } } +impl PartialEq for ResultType { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::results_equal) + } +} + +impl Eq for ResultType {} + /// A `flags` interface type -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, Debug)] pub struct Flags(Handle); impl Flags { @@ -288,6 +525,14 @@ impl Flags { } } +impl PartialEq for Flags { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::flags_equal) + } +} + +impl Eq for Flags {} + /// Represents a component model interface type #[derive(Clone, PartialEq, Eq, Debug)] #[allow(missing_docs)] diff --git a/tests/all/component_model/import.rs b/tests/all/component_model/import.rs index 38e23edd4f95..e77b9a73433b 100644 --- a/tests/all/component_model/import.rs +++ b/tests/all/component_model/import.rs @@ -924,3 +924,64 @@ fn no_actual_wasm_code() -> Result<()> { Ok(()) } + +#[test] +fn use_types_across_component_boundaries() -> Result<()> { + // Create a component that exports a function that returns a record + let engine = super::engine(); + let component = Component::new( + &engine, + r#"(component + (type (;0;) (record (field "a" u8) (field "b" string))) + (import "my-record" (type $my-record (eq 0))) + (core module $m + (memory $memory 17) + (export "memory" (memory $memory)) + (func (export "my-func") (result i32) + i32.const 4 + return)) + (core instance $instance (instantiate $m)) + (type $func-type (func (result $my-record))) + (alias core export $instance "my-func" (core func $my-func)) + (alias core export $instance "memory" (core memory $memory)) + (func $my-func (type $func-type) (canon lift (core func $my-func) (memory $memory) string-encoding=utf8)) + (export $export "my-func" (func $my-func)) + )"#, + )?; + let mut store = Store::new(&engine, 0); + let linker = Linker::new(&engine); + let instance = linker.instantiate(&mut store, &component)?; + let my_func = instance.get_func(&mut store, "my-func").unwrap(); + let mut results = vec![Val::Bool(false)]; + my_func.call(&mut store, &[], &mut results)?; + + // Create another component that exports a function that takes that record as an argument + let component = Component::new( + &engine, + format!( + r#"(component + (type (;0;) (record (field "a" u8) (field "b" string))) + (import "my-record" (type $my-record (eq 0))) + (core module $m + (memory $memory 17) + (export "memory" (memory $memory)) + {REALLOC_AND_FREE} + (func (export "my-func") (param i32 i32 i32))) + (core instance $instance (instantiate $m)) + (type $func-type (func (param "my-record" $my-record))) + (alias core export $instance "my-func" (core func $my-func)) + (alias core export $instance "memory" (core memory $memory)) + (func $my-func (type $func-type) (canon lift (core func $my-func) (memory $memory) string-encoding=utf8 (realloc (func $instance "realloc")))) + (export $export "my-func" (func $my-func)) + )"# + ), + )?; + let mut store = Store::new(&engine, 0); + let linker = Linker::new(&engine); + let instance = linker.instantiate(&mut store, &component)?; + let my_func = instance.get_func(&mut store, "my-func").unwrap(); + // Call the exported function with the return values of the call to the previous component's exported function + my_func.call(&mut store, &results, &mut [])?; + + Ok(()) +}