diff --git a/crates/libs/interface/src/lib.rs b/crates/libs/interface/src/lib.rs index 090581fd37..908e38a137 100644 --- a/crates/libs/interface/src/lib.rs +++ b/crates/libs/interface/src/lib.rs @@ -211,6 +211,17 @@ impl Interface { let parent_vtable_generics = if self.parent_is_iunknown() { quote!(Identity, OFFSET) } else { quote!(Identity, Impl, OFFSET) }; let parent_vtable = self.parent_vtable(); + // or_parent_matches will be `|| parent::matches(iid)` if this interface inherits from another + // interface (except for IUnknown) or will be empty if this is not applicable. This is what allows + // QueryInterface to work correctly for all interfaces in an inheritance chain, e.g. + // IFoo3 derives from IFoo2 derives from IFoo. + // + // We avoid matching IUnknown because object identity depends on the uniqueness of the IUnknown pointer. + let or_parent_matches = match parent_vtable.as_ref() { + Some(parent) if !self.parent_is_iunknown() => quote! (|| <#parent>::matches(iid)), + _ => quote!(), + }; + let functions = self .methods .iter() @@ -287,8 +298,10 @@ impl Interface { Self { base__: #parent_vtable::new::<#parent_vtable_generics>(), #(#entries),* } } - pub fn matches(iid: &windows_core::GUID) -> bool { - iid == &<#name as ::windows_core::Interface>::IID + #[inline(always)] + pub fn matches(iid: &::windows_core::GUID) -> bool { + *iid == <#name as ::windows_core::Interface>::IID + #or_parent_matches } } } diff --git a/crates/tests/implement_core/src/com_chain.rs b/crates/tests/implement_core/src/com_chain.rs new file mode 100644 index 0000000000..5124b390e4 --- /dev/null +++ b/crates/tests/implement_core/src/com_chain.rs @@ -0,0 +1,50 @@ +use windows_core::*; + +#[interface("cccccccc-0000-0000-0000-000000000001")] +unsafe trait IFoo: IUnknown {} + +#[interface("cccccccc-0000-0000-0000-000000000002")] +unsafe trait IFoo2: IFoo {} + +#[interface("cccccccc-0000-0000-0000-000000000003")] +unsafe trait IFoo3: IFoo2 {} + +// ObjectA implements a single interface chain, which consists of 3 different +// interfaces: IFoo3, IFoo2, and IFoo. You do not need to explicitly list all +// of the interfaces in the interface chain. Listing all of the interfaces is +// less efficient because it generates redundant interface chains (pointer +// fields in the the generated ObjectA_Impl type), which will never be used. +#[implement(IFoo3)] +struct ObjectWithChains {} + +impl IFoo_Impl for ObjectWithChains {} +impl IFoo2_Impl for ObjectWithChains {} +impl IFoo3_Impl for ObjectWithChains {} + +#[test] +fn interface_chain_query() { + let object = ComObject::new(ObjectWithChains {}); + let unknown: IUnknown = object.to_interface(); + let _foo: IFoo = unknown.cast().expect("QueryInterface for IFoo"); + let _foo2: IFoo2 = unknown.cast().expect("QueryInterface for IFoo2"); + let _foo3: IFoo3 = unknown.cast().expect("QueryInterface for IFoo3"); +} + +// ObjectRedundantChains implements the same interfaces as ObjectWithChains, +// but it defines more than one interface chain. This is unnecessary because it +// is redundant, but we are verifying that this works. +#[implement(IFoo3, IFoo2, IFoo)] +struct ObjectRedundantChains {} + +impl IFoo_Impl for ObjectRedundantChains {} +impl IFoo2_Impl for ObjectRedundantChains {} +impl IFoo3_Impl for ObjectRedundantChains {} + +#[test] +fn redundant_interface_chains() { + let object = ComObject::new(ObjectRedundantChains {}); + let unknown: IUnknown = object.to_interface(); + let _foo: IFoo = unknown.cast().expect("QueryInterface for IFoo"); + let _foo2: IFoo2 = unknown.cast().expect("QueryInterface for IFoo2"); + let _foo3: IFoo3 = unknown.cast().expect("QueryInterface for IFoo3"); +} diff --git a/crates/tests/implement_core/src/lib.rs b/crates/tests/implement_core/src/lib.rs index e083fca649..aa8f3bec53 100644 --- a/crates/tests/implement_core/src/lib.rs +++ b/crates/tests/implement_core/src/lib.rs @@ -3,4 +3,5 @@ #![cfg(test)] +mod com_chain; mod com_object;