From 866ec46613a16f85979618b7d1291168a40a8ae3 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Mon, 8 Aug 2022 13:01:45 -0500 Subject: [PATCH] Implement roundtrip fuzzing of component adapters (#4640) * Improve the `component_api` fuzzer on a few dimensions * Update the generated component to use an adapter module. This involves two core wasm instances communicating with each other to test that data flows through everything correctly. The intention here is to fuzz the fused adapter compiler. String encoding options have been plumbed here to exercise differences in string encodings. * Use `Cow<'static, ...>` and `static` declarations for each static test case to try to cut down on rustc codegen time. * Add `Copy` to derivation of fuzzed enums to make `derive(Clone)` smaller. * Use `Store>` to try to cut down on codegen by monomorphizing fewer `Store` implementation. * Add debug logging to print out what's flowing in and what's flowing out for debugging failures. * Improve `Debug` representation of dynamic value types to more closely match their Rust counterparts. * Fix a variant issue with adapter trampolines Previously the offset of the payload was calculated as the discriminant aligned up to the alignment of a singular case, but instead this needs to be aligned up to the alignment of all cases to ensure all cases start at the same location. * Fix a copy/paste error when copying masked integers A 32-bit load was actually doing a 16-bit load by accident since it was copied from the 16-bit load-and-mask case. * Fix f32/i64 conversions in adapter modules The adapter previously erroneously converted the f32 to f64 and then to i64, where instead it should go from f32 to i32 to i64. * Fix zero-sized flags in adapter modules This commit corrects the size calculation for zero-sized flags in adapter modules. cc #4592 * Fix a variant size calculation bug in adapters This fixes the same issue found with variants during normal host-side fuzzing earlier where the size of a variant needs to align up the summation of the discriminant and the maximum case size. * Implement memory growth in libc bump realloc Some fuzz-generated test cases are copying lists large enough to exceed one page of memory so bake in a `memory.grow` to the bump allocator as well. * Avoid adapters of exponential size This commit is an attempt to avoid adapters being exponentially sized with respect to the type hierarchy of the input. Previously all adaptation was done inline within each adapter which meant that if something was structured as `tuple` the translation of `T` would be inlined N times. For very deeply nested types this can quickly create an exponentially sized adapter with types of the form: (type $t0 (list u8)) (type $t1 (tuple $t0 $t0)) (type $t2 (tuple $t1 $t1)) (type $t3 (tuple $t2 $t2)) ;; ... where the translation of `t4` has 8 different copies of translating `t0`. This commit changes the translation of types through memory to almost always go through a helper function. The hope here is that it doesn't lose too much performance because types already reside in memory. This can still lead to exponentially sized adapter modules to a lesser degree where if the translation all happens on the "stack", e.g. via `variant`s and their flat representation then many copies of one translation could still be made. For now this commit at least gets the problem under control for fuzzing where fuzzing doesn't trivially find type hierarchies that take over a minute to codegen the adapter module. One of the main tricky parts of this implementation is that when a function is generated the index that it will be placed at in the final module is not known at that time. To solve this the encoded form of the `Call` instruction is saved in a relocation-style format where the `Call` isn't encoded but instead saved into a different area for encoding later. When the entire adapter module is encoded to wasm these pseudo-`Call` instructions are encoded as real instructions at that time. * Fix some memory64 issues with string encodings Introduced just before #4623 I had a few mistakes related to 64-bit memories and mixing 32/64-bit memories. * Actually insert into the `translate_mem_funcs` map This... was the whole point of having the map! * Assert memory growth succeeds in bump allocator --- crates/component-util/src/lib.rs | 34 +- crates/environ/src/fact.rs | 348 +++++++--- crates/environ/src/fact/signature.rs | 69 +- crates/environ/src/fact/trampoline.rs | 611 ++++++++++++------ crates/environ/src/fact/transcode.rs | 36 +- .../fuzzing/src/generators/component_types.rs | 37 +- crates/fuzzing/src/oracles.rs | 19 +- crates/misc/component-fuzz-util/src/lib.rs | 127 +++- crates/wasmtime/src/component/values.rs | 63 +- fuzz/build.rs | 18 +- .../misc_testsuite/component-model/fused.wast | 8 +- 11 files changed, 918 insertions(+), 452 deletions(-) diff --git a/crates/component-util/src/lib.rs b/crates/component-util/src/lib.rs index 2c4a2a461dc9..827b168cb6cd 100644 --- a/crates/component-util/src/lib.rs +++ b/crates/component-util/src/lib.rs @@ -123,11 +123,35 @@ pub const REALLOC_AND_FREE: &str = r#" ;; save the current value of `$last` as the return value global.get $last - local.tee $ret + local.set $ret + + ;; bump our pointer + (global.set $last + (i32.add + (global.get $last) + (local.get $new_size))) + + ;; while `memory.size` is less than `$last`, grow memory + ;; by one page + (loop $loop + (if + (i32.lt_u + (i32.mul (memory.size) (i32.const 65536)) + (global.get $last)) + (then + i32.const 1 + memory.grow + ;; test to make sure growth succeeded + i32.const -1 + i32.eq + if unreachable end + + br $loop))) + ;; ensure anything necessary is set to valid data by spraying a bit ;; pattern that is invalid - global.get $last + local.get $ret i32.const 0xde local.get $new_size memory.fill @@ -142,10 +166,6 @@ pub const REALLOC_AND_FREE: &str = r#" memory.copy end - ;; bump our pointer - (global.set $last - (i32.add - (global.get $last) - (local.get $new_size))) + local.get $ret ) "#; diff --git a/crates/environ/src/fact.rs b/crates/environ/src/fact.rs index 3cb16a68c800..19f8a899be87 100644 --- a/crates/environ/src/fact.rs +++ b/crates/environ/src/fact.rs @@ -19,10 +19,13 @@ //! that. use crate::component::dfg::CoreDef; -use crate::component::{Adapter, AdapterOptions, ComponentTypes, StringEncoding, TypeFuncIndex}; -use crate::{FuncIndex, GlobalIndex, MemoryIndex, PrimaryMap}; +use crate::component::{ + Adapter, AdapterOptions as AdapterOptionsDfg, ComponentTypes, InterfaceType, StringEncoding, + TypeFuncIndex, +}; +use crate::fact::transcode::Transcoder; +use crate::{EntityRef, FuncIndex, GlobalIndex, MemoryIndex, PrimaryMap}; use std::collections::HashMap; -use std::mem; use wasm_encoder::*; mod core_types; @@ -50,26 +53,28 @@ pub struct Module<'a> { /// Final list of imports that this module ended up using, in the same order /// as the imports in the import section. imports: Vec, - /// Intern'd imports and what index they were assigned. - imported: HashMap, - imported_memories: PrimaryMap, + /// Intern'd imports and what index they were assigned. Note that this map + /// covers all the index spaces for imports, not just one. + imported: HashMap, + /// Intern'd transcoders and what index they were assigned. + imported_transcoders: HashMap, // Current status of index spaces from the imports generated so far. - core_funcs: u32, - core_memories: u32, - core_globals: u32, + imported_funcs: PrimaryMap>, + imported_memories: PrimaryMap, + imported_globals: PrimaryMap, - /// Adapters which will be compiled once they're all registered. - adapters: Vec, + funcs: PrimaryMap, + translate_mem_funcs: HashMap<(InterfaceType, InterfaceType, Options, Options), FunctionId>, } struct AdapterData { /// Export name of this adapter name: String, /// Options specified during the `canon lift` operation - lift: Options, + lift: AdapterOptions, /// Options specified during the `canon lower` operation - lower: Options, + lower: AdapterOptions, /// The core wasm function that this adapter will be calling (the original /// function that was `canon lift`'d) callee: FuncIndex, @@ -78,14 +83,38 @@ struct AdapterData { called_as_export: bool, } -struct Options { +/// Configuration options which apply at the "global adapter" level. +/// +/// These options are typically unique per-adapter and generally aren't needed +/// when translating recursive types within an adapter. +struct AdapterOptions { + /// The ascribed type of this adapter. ty: TypeFuncIndex, - string_encoding: StringEncoding, + /// The global that represents the instance flags for where this adapter + /// came from. flags: GlobalIndex, + /// The configured post-return function, if any. + post_return: Option, + /// Other, more general, options configured. + options: Options, +} + +/// This type is split out of `AdapterOptions` and is specifically used to +/// deduplicate translation functions within a module. Consequently this has +/// as few fields as possible to minimize the number of functions generated +/// within an adapter module. +#[derive(PartialEq, Eq, Hash, Copy, Clone)] +struct Options { + /// The encoding that strings use from this adapter. + string_encoding: StringEncoding, + /// Whether or not the `memory` field, if present, is a 64-bit memory. memory64: bool, + /// An optionally-specified memory where values may travel through for + /// types like lists. memory: Option, + /// An optionally-specified function to be used to allocate space for + /// types such as strings as they go into a module. realloc: Option, - post_return: Option, } enum Context { @@ -102,12 +131,13 @@ impl<'a> Module<'a> { core_types: Default::default(), core_imports: Default::default(), imported: Default::default(), - adapters: Default::default(), imports: Default::default(), + imported_transcoders: Default::default(), + imported_funcs: PrimaryMap::new(), imported_memories: PrimaryMap::new(), - core_funcs: 0, - core_memories: 0, - core_globals: 0, + imported_globals: PrimaryMap::new(), + funcs: PrimaryMap::new(), + translate_mem_funcs: HashMap::new(), } } @@ -128,7 +158,7 @@ impl<'a> Module<'a> { // Import the core wasm function which was lifted using its appropriate // signature since the exported function this adapter generates will // call the lifted function. - let signature = self.signature(&lift, Context::Lift); + let signature = self.types.signature(&lift, Context::Lift); let ty = self .core_types .function(&signature.params, &signature.results); @@ -141,19 +171,24 @@ impl<'a> Module<'a> { self.import_func("post_return", name, ty, func.clone()) }); - self.adapters.push(AdapterData { - name: name.to_string(), - lift, - lower, - callee, - // FIXME(#4185) should be plumbed and handled as part of the new - // reentrance rules not yet implemented here. - called_as_export: true, - }); + // This will internally create the adapter as specified and append + // anything necessary to `self.funcs`. + trampoline::compile( + self, + &AdapterData { + name: name.to_string(), + lift, + lower, + callee, + // FIXME(#4185) should be plumbed and handled as part of the new + // reentrance rules not yet implemented here. + called_as_export: true, + }, + ); } - fn import_options(&mut self, ty: TypeFuncIndex, options: &AdapterOptions) -> Options { - let AdapterOptions { + fn import_options(&mut self, ty: TypeFuncIndex, options: &AdapterOptionsDfg) -> AdapterOptions { + let AdapterOptionsDfg { instance, string_encoding, memory, @@ -192,23 +227,24 @@ impl<'a> Module<'a> { let ty = self.core_types.function(&[ptr, ptr, ptr, ptr], &[ptr]); self.import_func("realloc", "", ty, func.clone()) }); - Options { + + AdapterOptions { ty, - string_encoding: *string_encoding, flags, - memory64: *memory64, - memory, - realloc, post_return: None, + options: Options { + string_encoding: *string_encoding, + memory64: *memory64, + memory, + realloc, + }, } } fn import_func(&mut self, module: &str, name: &str, ty: u32, def: CoreDef) -> FuncIndex { - FuncIndex::from_u32( - self.import(module, name, EntityType::Function(ty), def, |m| { - &mut m.core_funcs - }), - ) + self.import(module, name, EntityType::Function(ty), def, |m| { + &mut m.imported_funcs + }) } fn import_global( @@ -218,9 +254,9 @@ impl<'a> Module<'a> { ty: GlobalType, def: CoreDef, ) -> GlobalIndex { - GlobalIndex::from_u32(self.import(module, name, EntityType::Global(ty), def, |m| { - &mut m.core_globals - })) + self.import(module, name, EntityType::Global(ty), def, |m| { + &mut m.imported_globals + }) } fn import_memory( @@ -230,82 +266,113 @@ impl<'a> Module<'a> { ty: MemoryType, def: CoreDef, ) -> MemoryIndex { - MemoryIndex::from_u32(self.import(module, name, EntityType::Memory(ty), def, |m| { - &mut m.core_memories - })) + self.import(module, name, EntityType::Memory(ty), def, |m| { + &mut m.imported_memories + }) } - fn import( + fn import>( &mut self, module: &str, name: &str, ty: EntityType, def: CoreDef, - new: impl FnOnce(&mut Self) -> &mut u32, - ) -> u32 { + map: impl FnOnce(&mut Self) -> &mut PrimaryMap, + ) -> K { if let Some(prev) = self.imported.get(&def) { - return *prev; + return K::new(*prev); } - let cnt = new(self); - *cnt += 1; - let ret = *cnt - 1; + let idx = map(self).push(def.clone().into()); self.core_imports.import(module, name, ty); - self.imported.insert(def.clone(), ret); - if let EntityType::Memory(_) = ty { - self.imported_memories.push(def.clone()); - } + self.imported.insert(def.clone(), idx.index()); self.imports.push(Import::CoreDef(def)); - ret + idx + } + + fn import_transcoder(&mut self, transcoder: transcode::Transcoder) -> FuncIndex { + *self + .imported_transcoders + .entry(transcoder) + .or_insert_with(|| { + // Add the import to the core wasm import section... + let name = transcoder.name(); + let ty = transcoder.ty(&mut self.core_types); + self.core_imports.import("transcode", &name, ty); + + // ... and also record the metadata for what this import + // corresponds to. + let from = self.imported_memories[transcoder.from_memory].clone(); + let to = self.imported_memories[transcoder.to_memory].clone(); + self.imports.push(Import::Transcode { + op: transcoder.op, + from, + from64: transcoder.from_memory64, + to, + to64: transcoder.to_memory64, + }); + + self.imported_funcs.push(None) + }) } /// Encodes this module into a WebAssembly binary. pub fn encode(&mut self) -> Vec { - let mut types = mem::take(&mut self.core_types); - let mut transcoders = transcode::Transcoders::new(self.core_funcs); - let mut adapter_funcs = Vec::new(); - for adapter in self.adapters.iter() { - adapter_funcs.push(trampoline::compile( - self, - &mut types, - &mut transcoders, - adapter, - )); - } - - // If any string transcoding imports were needed add imported items - // associated with them. - for (module, name, ty, transcoder) in transcoders.imports() { - self.core_imports.import(module, name, ty); - let from = self.imported_memories[transcoder.from_memory].clone(); - let to = self.imported_memories[transcoder.to_memory].clone(); - self.imports.push(Import::Transcode { - op: transcoder.op, - from, - from64: transcoder.from_memory64, - to, - to64: transcoder.to_memory64, - }); - self.core_funcs += 1; + // Build the function/export sections of the wasm module in a first pass + // which will assign a final `FuncIndex` to all functions defined in + // `self.funcs`. + let mut funcs = FunctionSection::new(); + let mut exports = ExportSection::new(); + let mut id_to_index = PrimaryMap::::new(); + for (id, func) in self.funcs.iter() { + assert!(func.filled_in); + let idx = FuncIndex::from_u32(self.imported_funcs.next_key().as_u32() + id.as_u32()); + let id2 = id_to_index.push(idx); + assert_eq!(id2, id); + + funcs.function(func.ty); + + if let Some(name) = &func.export { + exports.export(name, ExportKind::Func, idx.as_u32()); + } } - // Now that all functions are known as well as all imports the actual - // bodies of all adapters are assembled into a final module. - let mut funcs = FunctionSection::new(); + // With all functions numbered the fragments of the body of each + // function can be assigned into one final adapter function. let mut code = CodeSection::new(); - let mut exports = ExportSection::new(); let mut traps = traps::TrapSection::default(); - for (adapter, (function, func_traps)) in self.adapters.iter().zip(adapter_funcs) { - let idx = self.core_funcs + funcs.len(); - exports.export(&adapter.name, ExportKind::Func, idx); - - let signature = self.signature(&adapter.lower, Context::Lower); - let ty = types.function(&signature.params, &signature.results); - funcs.function(ty); - - code.raw(&function); - traps.append(idx, func_traps); + for (id, func) in self.funcs.iter() { + let mut func_traps = Vec::new(); + let mut body = Vec::new(); + + // Encode all locals used for this function + func.locals.len().encode(&mut body); + for (count, ty) in func.locals.iter() { + count.encode(&mut body); + ty.encode(&mut body); + } + + // Then encode each "chunk" of a body which may have optional traps + // specified within it. Traps get offset by the current length of + // the body and otherwise our `Call` instructions are "relocated" + // here to the final function index. + for chunk in func.body.iter() { + match chunk { + Body::Raw(code, traps) => { + let start = body.len(); + body.extend_from_slice(code); + for (offset, trap) in traps { + func_traps.push((start + offset, *trap)); + } + } + Body::Call(id) => { + Instruction::Call(id_to_index[*id].as_u32()).encode(&mut body); + } + } + } + code.raw(&body); + traps.append(id_to_index[id].as_u32(), func_traps); } - self.core_types = types; + let traps = traps.finish(); let mut result = wasm_encoder::Module::new(); @@ -367,3 +434,82 @@ impl Options { } } } + +/// Temporary index which is not the same as `FuncIndex`. +/// +/// This represents the nth generated function in the adapter module where the +/// final index of the function is not known at the time of generation since +/// more imports may be discovered (specifically string transcoders). +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +struct FunctionId(u32); +cranelift_entity::entity_impl!(FunctionId); + +/// A generated function to be added to an adapter module. +/// +/// At least one function is created per-adapter and dependeing on the type +/// hierarchy multiple functions may be generated per-adapter. +struct Function { + /// Whether or not the `body` has been finished. + /// + /// Functions are added to a `Module` before they're defined so this is used + /// to assert that the function was in fact actually filled in by the + /// time we reach `Module::encode`. + filled_in: bool, + + /// The type signature that this function has, as an index into the core + /// wasm type index space of the generated adapter module. + ty: u32, + + /// The locals that are used by this function, organized by the number of + /// types of each local. + locals: Vec<(u32, ValType)>, + + /// If specified, the export name of this function. + export: Option, + + /// The contents of the function. + /// + /// See `Body` for more information, and the `Vec` here represents the + /// concatentation of all the `Body` fragments. + body: Vec, +} + +/// Representation of a fragment of the body of a core wasm function generated +/// for adapters. +/// +/// This variant comes in one of two flavors: +/// +/// 1. First a `Raw` variant is used to contain general instructions for the +/// wasm function. This is populated by `Compiler::instruction` primarily. +/// This also comes with a list of traps. and the byte offset within the +/// first vector of where the trap information applies to. +/// +/// 2. A `Call` instruction variant for a `FunctionId` where the final +/// `FuncIndex` isn't known until emission time. +/// +/// The purpose of this representation is the `Body::Call` variant. This can't +/// be encoded as an instruction when it's generated due to not knowing the +/// final index of the function being called. During `Module::encode`, however, +/// all indices are known and `Body::Call` is turned into a final +/// `Instruction::Call`. +/// +/// One other possible representation in the future would be to encode a `Call` +/// instruction with a 5-byte leb to fill in later, but for now this felt +/// easier to represent. A 5-byte leb may be more efficient at compile-time if +/// necessary, however. +enum Body { + Raw(Vec, Vec<(usize, traps::Trap)>), + Call(FunctionId), +} + +impl Function { + fn new(export: Option, ty: u32) -> Function { + Function { + filled_in: false, + ty, + locals: Vec::new(), + export, + body: Vec::new(), + } + } +} diff --git a/crates/environ/src/fact/signature.rs b/crates/environ/src/fact/signature.rs index 27313f13c956..f6b8b3fb7345 100644 --- a/crates/environ/src/fact/signature.rs +++ b/crates/environ/src/fact/signature.rs @@ -1,8 +1,9 @@ //! Size, align, and flattening information about component model types. -use crate::component::{InterfaceType, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS}; -use crate::fact::{Context, Module, Options}; +use crate::component::{ComponentTypes, InterfaceType, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS}; +use crate::fact::{AdapterOptions, Context, Options}; use wasm_encoder::ValType; +use wasmtime_component_util::{DiscriminantSize, FlagsSize}; /// Metadata about a core wasm signature which is created for a component model /// signature. @@ -27,25 +28,25 @@ pub(crate) fn align_to(n: usize, align: usize) -> usize { (n + (align - 1)) & !(align - 1) } -impl Module<'_> { +impl ComponentTypes { /// Calculates the core wasm function signature for the component function /// type specified within `Context`. /// /// This is used to generate the core wasm signatures for functions that are /// imported (matching whatever was `canon lift`'d) and functions that are /// exported (matching the generated function from `canon lower`). - pub(super) fn signature(&self, options: &Options, context: Context) -> Signature { - let ty = &self.types[options.ty]; - let ptr_ty = options.ptr(); + pub(super) fn signature(&self, options: &AdapterOptions, context: Context) -> Signature { + let ty = &self[options.ty]; + let ptr_ty = options.options.ptr(); - let mut params = self.flatten_types(options, ty.params.iter().map(|(_, ty)| *ty)); + let mut params = self.flatten_types(&options.options, ty.params.iter().map(|(_, ty)| *ty)); let mut params_indirect = false; if params.len() > MAX_FLAT_PARAMS { params = vec![ptr_ty]; params_indirect = true; } - let mut results = self.flatten_types(options, [ty.result]); + let mut results = self.flatten_types(&options.options, [ty.result]); let mut results_indirect = false; if results.len() > MAX_FLAT_RESULTS { results_indirect = true; @@ -108,17 +109,17 @@ impl Module<'_> { dst.push(opts.ptr()); } InterfaceType::Record(r) => { - for field in self.types[*r].fields.iter() { + for field in self[*r].fields.iter() { self.push_flat(opts, &field.ty, dst); } } InterfaceType::Tuple(t) => { - for ty in self.types[*t].types.iter() { + for ty in self[*t].types.iter() { self.push_flat(opts, ty, dst); } } InterfaceType::Flags(f) => { - let flags = &self.types[*f]; + let flags = &self[*f]; let nflags = align_to(flags.names.len(), 32) / 32; for _ in 0..nflags { dst.push(ValType::I32); @@ -127,13 +128,13 @@ impl Module<'_> { InterfaceType::Enum(_) => dst.push(ValType::I32), InterfaceType::Option(t) => { dst.push(ValType::I32); - self.push_flat(opts, &self.types[*t], dst); + self.push_flat(opts, &self[*t], dst); } InterfaceType::Variant(t) => { dst.push(ValType::I32); let pos = dst.len(); let mut tmp = Vec::new(); - for case in self.types[*t].cases.iter() { + for case in self[*t].cases.iter() { self.push_flat_variant(opts, &case.ty, pos, &mut tmp, dst); } } @@ -141,13 +142,13 @@ impl Module<'_> { dst.push(ValType::I32); let pos = dst.len(); let mut tmp = Vec::new(); - for ty in self.types[*t].types.iter() { + for ty in self[*t].types.iter() { self.push_flat_variant(opts, ty, pos, &mut tmp, dst); } } InterfaceType::Expected(t) => { dst.push(ValType::I32); - let e = &self.types[*t]; + let e = &self[*t]; let pos = dst.len(); let mut tmp = Vec::new(); self.push_flat_variant(opts, &e.ok, pos, &mut tmp, dst); @@ -208,26 +209,26 @@ impl Module<'_> { } InterfaceType::Record(r) => { - self.record_size_align(opts, self.types[*r].fields.iter().map(|f| &f.ty)) + self.record_size_align(opts, self[*r].fields.iter().map(|f| &f.ty)) } - InterfaceType::Tuple(t) => self.record_size_align(opts, self.types[*t].types.iter()), - InterfaceType::Flags(f) => match self.types[*f].names.len() { - n if n <= 8 => (1, 1), - n if n <= 16 => (2, 2), - n if n <= 32 => (4, 4), - n => (4 * (align_to(n, 32) / 32), 4), + InterfaceType::Tuple(t) => self.record_size_align(opts, self[*t].types.iter()), + InterfaceType::Flags(f) => match FlagsSize::from_count(self[*f].names.len()) { + FlagsSize::Size0 => (0, 1), + FlagsSize::Size1 => (1, 1), + FlagsSize::Size2 => (2, 2), + FlagsSize::Size4Plus(n) => (n * 4, 4), }, - InterfaceType::Enum(t) => self.discrim_size_align(self.types[*t].names.len()), + InterfaceType::Enum(t) => self.discrim_size_align(self[*t].names.len()), InterfaceType::Option(t) => { - let ty = &self.types[*t]; + let ty = &self[*t]; self.variant_size_align(opts, [&InterfaceType::Unit, ty].into_iter()) } InterfaceType::Variant(t) => { - self.variant_size_align(opts, self.types[*t].cases.iter().map(|c| &c.ty)) + self.variant_size_align(opts, self[*t].cases.iter().map(|c| &c.ty)) } - InterfaceType::Union(t) => self.variant_size_align(opts, self.types[*t].types.iter()), + InterfaceType::Union(t) => self.variant_size_align(opts, self[*t].types.iter()), InterfaceType::Expected(t) => { - let e = &self.types[*t]; + let e = &self[*t]; self.variant_size_align(opts, [&e.ok, &e.err].into_iter()) } } @@ -260,14 +261,18 @@ impl Module<'_> { payload_size = payload_size.max(csize); align = align.max(calign); } - (align_to(discrim_size, align) + payload_size, align) + ( + align_to(align_to(discrim_size, align) + payload_size, align), + align, + ) } fn discrim_size_align<'a>(&self, cases: usize) -> (usize, usize) { - match cases { - n if n <= u8::MAX as usize => (1, 1), - n if n <= u16::MAX as usize => (2, 2), - _ => (4, 4), + match DiscriminantSize::from_count(cases) { + Some(DiscriminantSize::Size1) => (1, 1), + Some(DiscriminantSize::Size2) => (2, 2), + Some(DiscriminantSize::Size4) => (4, 4), + None => unreachable!(), } } } diff --git a/crates/environ/src/fact/trampoline.rs b/crates/environ/src/fact/trampoline.rs index b5337229a2f8..bc9095129f5d 100644 --- a/crates/environ/src/fact/trampoline.rs +++ b/crates/environ/src/fact/trampoline.rs @@ -16,16 +16,15 @@ //! can be somewhat arbitrary, an intentional decision. use crate::component::{ - InterfaceType, StringEncoding, TypeEnumIndex, TypeExpectedIndex, TypeFlagsIndex, - TypeInterfaceIndex, TypeRecordIndex, TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, - FLAG_MAY_ENTER, FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, + ComponentTypes, InterfaceType, StringEncoding, TypeEnumIndex, TypeExpectedIndex, + TypeFlagsIndex, TypeInterfaceIndex, TypeRecordIndex, TypeTupleIndex, TypeUnionIndex, + TypeVariantIndex, FLAG_MAY_ENTER, FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, }; -use crate::fact::core_types::CoreTypes; use crate::fact::signature::{align_to, Signature}; -use crate::fact::transcode::{FixedEncoding as FE, Transcode, Transcoder, Transcoders}; +use crate::fact::transcode::{FixedEncoding as FE, Transcode, Transcoder}; use crate::fact::traps::Trap; -use crate::fact::{AdapterData, Context, Module, Options}; -use crate::GlobalIndex; +use crate::fact::{AdapterData, Body, Context, Function, FunctionId, Module, Options}; +use crate::{FuncIndex, GlobalIndex}; use std::collections::HashMap; use std::mem; use std::ops::Range; @@ -36,28 +35,13 @@ const MAX_STRING_BYTE_LENGTH: u32 = 1 << 31; const UTF16_TAG: u32 = 1 << 31; struct Compiler<'a, 'b> { - /// The module that the adapter will eventually be inserted into. - module: &'a Module<'a>, - - /// The type section of `module` - types: &'b mut CoreTypes, - - /// Imported functions to transcode between various string encodings. - transcoders: &'b mut Transcoders, - - /// Metadata about the adapter that is being compiled. - adapter: &'a AdapterData, + types: &'a ComponentTypes, + module: &'b mut Module<'a>, + result: FunctionId, /// The encoded WebAssembly function body so far, not including locals. code: Vec, - /// Generated locals that this function will use. - /// - /// The first entry in the tuple is the number of locals and the second - /// entry is the type of those locals. This is pushed during compilation as - /// locals become necessary. - locals: Vec<(u32, ValType)>, - /// Total number of locals generated so far. nlocals: u32, @@ -66,36 +50,90 @@ struct Compiler<'a, 'b> { /// well. traps: Vec<(usize, Trap)>, - /// The function signature of the lowered half of this trampoline, or the - /// signature of the function that's being generated. - lower_sig: &'a Signature, - - /// The function signature of the lifted half of this trampoline, or the - /// signature of the function that's imported the trampoline will call. - lift_sig: &'a Signature, + /// Indicates whether this call to `translate` is a "top level" on where + /// it's the first call from the root of the generated function. This is + /// used as a heuristic to know when to split helpers out to a separate + /// function. + top_level_translate: bool, } -pub(super) fn compile( - module: &Module<'_>, - types: &mut CoreTypes, - transcoders: &mut Transcoders, - adapter: &AdapterData, -) -> (Vec, Vec<(usize, Trap)>) { - let lower_sig = &module.signature(&adapter.lower, Context::Lower); - let lift_sig = &module.signature(&adapter.lift, Context::Lift); +pub(super) fn compile(module: &mut Module<'_>, adapter: &AdapterData) { + let lower_sig = module.types.signature(&adapter.lower, Context::Lower); + let lift_sig = module.types.signature(&adapter.lift, Context::Lift); + let ty = module + .core_types + .function(&lower_sig.params, &lower_sig.results); + let result = module + .funcs + .push(Function::new(Some(adapter.name.clone()), ty)); Compiler { + types: module.types, module, - types, - adapter, - transcoders, code: Vec::new(), - locals: Vec::new(), nlocals: lower_sig.params.len() as u32, traps: Vec::new(), - lower_sig, - lift_sig, + result, + top_level_translate: true, } - .compile() + .compile_adapter(adapter, &lower_sig, &lift_sig) +} + +/// Compiles a helper function which is used to translate `src` to `dst` +/// in-memory. +/// +/// The generated function takes two arguments: the source pointer and +/// destination pointer. The conversion operation is configured by the +/// `src_opts` and `dst_opts` specified as well. +fn compile_translate_mem( + module: &mut Module<'_>, + src: InterfaceType, + src_opts: &Options, + dst: InterfaceType, + dst_opts: &Options, +) -> FunctionId { + // If a helper for this translation has already been generated then reuse + // that. Note that this is key to this function where by doing this it + // prevents an exponentially sized output given any particular input type. + let key = (src, dst, *src_opts, *dst_opts); + if module.translate_mem_funcs.contains_key(&key) { + return module.translate_mem_funcs[&key]; + } + + // Generate a fresh `Function` with a unique id for what we're about to + // generate. + let ty = module + .core_types + .function(&[src_opts.ptr(), dst_opts.ptr()], &[]); + let result = module.funcs.push(Function::new(None, ty)); + module.translate_mem_funcs.insert(key, result); + let mut compiler = Compiler { + types: module.types, + module, + code: Vec::new(), + nlocals: 2, + traps: Vec::new(), + result, + top_level_translate: true, + }; + // This function only does one thing which is to translate between memory, + // so only one call to `translate` is necessary. Note that the `addr_local` + // values come from the function arguments. + compiler.translate( + &src, + &Source::Memory(Memory { + opts: src_opts, + addr_local: 0, + offset: 0, + }), + &dst, + &Destination::Memory(Memory { + opts: dst_opts, + addr_local: 1, + offset: 0, + }), + ); + compiler.finish(); + result } /// Possible ways that a interface value is represented in the core wasm @@ -150,19 +188,24 @@ struct Memory<'a> { } impl Compiler<'_, '_> { - fn compile(&mut self) -> (Vec, Vec<(usize, Trap)>) { + fn compile_adapter( + mut self, + adapter: &AdapterData, + lower_sig: &Signature, + lift_sig: &Signature, + ) { // Check the instance flags required for this trampoline. // // This inserts the initial check required by `canon_lower` that the // caller instance can be left and additionally checks the // flags on the callee if necessary whether it can be entered. - self.trap_if_not_flag(self.adapter.lower.flags, FLAG_MAY_LEAVE, Trap::CannotLeave); - if self.adapter.called_as_export { - self.trap_if_not_flag(self.adapter.lift.flags, FLAG_MAY_ENTER, Trap::CannotEnter); - self.set_flag(self.adapter.lift.flags, FLAG_MAY_ENTER, false); + self.trap_if_not_flag(adapter.lower.flags, FLAG_MAY_LEAVE, Trap::CannotLeave); + if adapter.called_as_export { + self.trap_if_not_flag(adapter.lift.flags, FLAG_MAY_ENTER, Trap::CannotEnter); + self.set_flag(adapter.lift.flags, FLAG_MAY_ENTER, false); } else if self.module.debug { self.assert_not_flag( - self.adapter.lift.flags, + adapter.lift.flags, FLAG_MAY_ENTER, "may_enter should be unset", ); @@ -180,23 +223,22 @@ impl Compiler<'_, '_> { // TODO: if translation doesn't actually call any functions in either // instance then there's no need to set/clear the flag here and that can // be optimized away. - self.set_flag(self.adapter.lift.flags, FLAG_MAY_LEAVE, false); - let param_locals = self - .lower_sig + self.set_flag(adapter.lift.flags, FLAG_MAY_LEAVE, false); + let param_locals = lower_sig .params .iter() .enumerate() .map(|(i, ty)| (i as u32, *ty)) .collect::>(); - self.translate_params(¶m_locals); - self.set_flag(self.adapter.lift.flags, FLAG_MAY_LEAVE, true); + self.translate_params(adapter, ¶m_locals); + self.set_flag(adapter.lift.flags, FLAG_MAY_LEAVE, true); // With all the arguments on the stack the actual target function is // now invoked. The core wasm results of the function are then placed // into locals for result translation afterwards. - self.instruction(Call(self.adapter.callee.as_u32())); - let mut result_locals = Vec::with_capacity(self.lift_sig.results.len()); - for ty in self.lift_sig.results.iter().rev() { + self.instruction(Call(adapter.callee.as_u32())); + let mut result_locals = Vec::with_capacity(lift_sig.results.len()); + for ty in lift_sig.results.iter().rev() { let local = self.gen_local(*ty); self.instruction(LocalSet(local)); result_locals.push((local, *ty)); @@ -211,77 +253,75 @@ impl Compiler<'_, '_> { // // TODO: like above the management of the `MAY_LEAVE` flag can probably // be elided here for "simple" results. - self.set_flag(self.adapter.lower.flags, FLAG_MAY_LEAVE, false); - self.translate_results(¶m_locals, &result_locals); - self.set_flag(self.adapter.lower.flags, FLAG_MAY_LEAVE, true); + self.set_flag(adapter.lower.flags, FLAG_MAY_LEAVE, false); + self.translate_results(adapter, ¶m_locals, &result_locals); + self.set_flag(adapter.lower.flags, FLAG_MAY_LEAVE, true); // And finally post-return state is handled here once all results/etc // are all translated. - if let Some(func) = self.adapter.lift.post_return { + if let Some(func) = adapter.lift.post_return { for (result, _) in result_locals.iter() { self.instruction(LocalGet(*result)); } self.instruction(Call(func.as_u32())); } - if self.adapter.called_as_export { - self.set_flag(self.adapter.lift.flags, FLAG_MAY_ENTER, true); + if adapter.called_as_export { + self.set_flag(adapter.lift.flags, FLAG_MAY_ENTER, true); } self.finish() } - fn translate_params(&mut self, param_locals: &[(u32, ValType)]) { - let src_tys = &self.module.types[self.adapter.lower.ty].params; + fn translate_params(&mut self, adapter: &AdapterData, param_locals: &[(u32, ValType)]) { + let src_tys = &self.types[adapter.lower.ty].params; let src_tys = src_tys.iter().map(|(_, ty)| *ty).collect::>(); - let dst_tys = &self.module.types[self.adapter.lift.ty].params; + let dst_tys = &self.types[adapter.lift.ty].params; let dst_tys = dst_tys.iter().map(|(_, ty)| *ty).collect::>(); + let lift_opts = &adapter.lift.options; + let lower_opts = &adapter.lower.options; // TODO: handle subtyping assert_eq!(src_tys.len(), dst_tys.len()); let src_flat = self - .module - .flatten_types(&self.adapter.lower, src_tys.iter().copied()); - let dst_flat = self - .module - .flatten_types(&self.adapter.lift, dst_tys.iter().copied()); + .types + .flatten_types(lower_opts, src_tys.iter().copied()); + let dst_flat = self.types.flatten_types(lift_opts, dst_tys.iter().copied()); let src = if src_flat.len() <= MAX_FLAT_PARAMS { Source::Stack(Stack { locals: ¶m_locals[..src_flat.len()], - opts: &self.adapter.lower, + opts: lower_opts, }) } else { // If there are too many parameters then that means the parameters // are actually a tuple stored in linear memory addressed by the // first parameter local. let (addr, ty) = param_locals[0]; - assert_eq!(ty, self.adapter.lower.ptr()); + assert_eq!(ty, lower_opts.ptr()); let align = src_tys .iter() - .map(|t| self.module.align(&self.adapter.lower, t)) + .map(|t| self.types.align(lower_opts, t)) .max() .unwrap_or(1); - Source::Memory(self.memory_operand(&self.adapter.lower, addr, align)) + Source::Memory(self.memory_operand(lower_opts, addr, align)) }; let dst = if dst_flat.len() <= MAX_FLAT_PARAMS { - Destination::Stack(&dst_flat, &self.adapter.lift) + Destination::Stack(&dst_flat, lift_opts) } else { // If there are too many parameters then space is allocated in the // destination module for the parameters via its `realloc` function. - let (size, align) = self - .module - .record_size_align(&self.adapter.lift, dst_tys.iter()); + let (size, align) = self.types.record_size_align(lift_opts, dst_tys.iter()); let size = MallocSize::Const(size); - Destination::Memory(self.malloc(&self.adapter.lift, size, align)) + Destination::Memory(self.malloc(lift_opts, size, align)) }; let srcs = src - .record_field_srcs(self.module, src_tys.iter().copied()) + .record_field_srcs(self.types, src_tys.iter().copied()) .zip(src_tys.iter()); let dsts = dst - .record_field_dsts(self.module, dst_tys.iter().copied()) + .record_field_dsts(self.types, dst_tys.iter().copied()) .zip(dst_tys.iter()); for ((src, src_ty), (dst, dst_ty)) in srcs.zip(dsts) { self.translate(&src_ty, &src, &dst_ty, &dst); @@ -297,42 +337,45 @@ impl Compiler<'_, '_> { fn translate_results( &mut self, + adapter: &AdapterData, param_locals: &[(u32, ValType)], result_locals: &[(u32, ValType)], ) { - let src_ty = self.module.types[self.adapter.lift.ty].result; - let dst_ty = self.module.types[self.adapter.lower.ty].result; + let src_ty = self.types[adapter.lift.ty].result; + let dst_ty = self.types[adapter.lower.ty].result; + let lift_opts = &adapter.lift.options; + let lower_opts = &adapter.lower.options; - let src_flat = self.module.flatten_types(&self.adapter.lift, [src_ty]); - let dst_flat = self.module.flatten_types(&self.adapter.lower, [dst_ty]); + let src_flat = self.types.flatten_types(lift_opts, [src_ty]); + let dst_flat = self.types.flatten_types(lower_opts, [dst_ty]); let src = if src_flat.len() <= MAX_FLAT_RESULTS { Source::Stack(Stack { locals: result_locals, - opts: &self.adapter.lift, + opts: lift_opts, }) } else { // The original results to read from in this case come from the // return value of the function itself. The imported function will // return a linear memory address at which the values can be read // from. - let align = self.module.align(&self.adapter.lift, &src_ty); + let align = self.types.align(lift_opts, &src_ty); assert_eq!(result_locals.len(), 1); let (addr, ty) = result_locals[0]; - assert_eq!(ty, self.adapter.lift.ptr()); - Source::Memory(self.memory_operand(&self.adapter.lift, addr, align)) + assert_eq!(ty, lift_opts.ptr()); + Source::Memory(self.memory_operand(lift_opts, addr, align)) }; let dst = if dst_flat.len() <= MAX_FLAT_RESULTS { - Destination::Stack(&dst_flat, &self.adapter.lower) + Destination::Stack(&dst_flat, lower_opts) } else { // This is slightly different than `translate_params` where the // return pointer was provided by the caller of this function // meaning the last parameter local is a pointer into linear memory. - let align = self.module.align(&self.adapter.lower, &dst_ty); + let align = self.types.align(lower_opts, &dst_ty); let (addr, ty) = *param_locals.last().expect("no retptr"); - assert_eq!(ty, self.adapter.lower.ptr()); - Destination::Memory(self.memory_operand(&self.adapter.lower, addr, align)) + assert_eq!(ty, lower_opts.ptr()); + Destination::Memory(self.memory_operand(lower_opts, addr, align)) }; self.translate(&src_ty, &src, &dst_ty, &dst); @@ -351,6 +394,107 @@ impl Compiler<'_, '_> { if let Destination::Memory(mem) = dst { self.assert_aligned(dst_ty, mem); } + + // Classify the source type as "primitive" or not as a heuristic to + // whether the translation should be split out into a helper function. + let src_primitive = match src_ty { + InterfaceType::Unit + | InterfaceType::Bool + | InterfaceType::U8 + | InterfaceType::S8 + | InterfaceType::U16 + | InterfaceType::S16 + | InterfaceType::U32 + | InterfaceType::S32 + | InterfaceType::U64 + | InterfaceType::S64 + | InterfaceType::Float32 + | InterfaceType::Float64 + | InterfaceType::Char + | InterfaceType::Flags(_) => true, + + InterfaceType::String + | InterfaceType::List(_) + | InterfaceType::Record(_) + | InterfaceType::Tuple(_) + | InterfaceType::Variant(_) + | InterfaceType::Union(_) + | InterfaceType::Enum(_) + | InterfaceType::Option(_) + | InterfaceType::Expected(_) => false, + }; + let top_level = mem::replace(&mut self.top_level_translate, false); + + // Use a number of heuristics to determine whether this translation + // should be split out into a helper function rather than translated + // inline. The goal of this heuristic is to avoid a function that is + // exponential in the size of a type. For example if everything + // were translated inline then this could get arbitrarily large + // + // (type $level0 (list u8)) + // (type $level1 (expected $level0 $level0)) + // (type $level2 (expected $level1 $level1)) + // (type $level3 (expected $level2 $level2)) + // (type $level4 (expected $level3 $level3)) + // ;; ... + // + // If everything we inlined then translation of `$level0` would appear + // in 2^n different locations depending on the depth of the type. By + // splitting out the translation to a helper function, though, it + // means there could be one function for each level, keeping the size + // of translation on par with the size of the module itself. + // + // The heuristics which go into this splitting currently are: + // + // * Both the source and destination must be memory. This skips "top + // level" translation for adapters where arguments/results come from + // direct parameters or get placed on the stack. + // + // * Primitive types are skipped here since they have no need to be + // split out. This is for types like integers and floats. + // + // * The "top level" of a function is also skipped. That basically + // means that the first call to `translate` will never split out + // a helper function (since if we're already in a helper function + // that could cause infinite recursion in the wasm). Otherwise + // this keeps the top-level list of types in adapters nice and inline + // too while only possibly considering splitting out deeper types. + // + // This heuristic may need tweaking over time naturally as more modules + // in the wild are seen and performance measurements are taken. For now + // this keeps the fuzzers happy by avoiding exponentially-sized output + // given an input. + if let (Source::Memory(src), Destination::Memory(dst)) = (src, dst) { + if !src_primitive && !top_level { + // Compile the helper function which will translate the source + // type to the destination type. The two parameters to this + // function are the source/destination pointers which are + // calculated here to pass through. Our own function then + // grows a `Body::Call` to the function generated. Note that + // `Body::Call` is used here instead of `Instruction::Call` + // because we don't know the final index of the generated + // function yet. It's filled in at the end of adapter module + // translation. + let helper = + compile_translate_mem(self.module, *src_ty, src.opts, *dst_ty, dst.opts); + + // TODO: overflow checks? + self.instruction(LocalGet(src.addr_local)); + if src.offset != 0 { + self.ptr_uconst(src.opts, src.offset); + self.ptr_add(src.opts); + } + self.instruction(LocalGet(dst.addr_local)); + if dst.offset != 0 { + self.ptr_uconst(dst.opts, dst.offset); + self.ptr_add(dst.opts); + } + self.flush_code(); + self.module.funcs[self.result].body.push(Body::Call(helper)); + self.top_level_translate = true; + return; + } + } match src_ty { InterfaceType::Unit => self.translate_unit(src, dst_ty, dst), InterfaceType::Bool => self.translate_bool(src, dst_ty, dst), @@ -376,6 +520,8 @@ impl Compiler<'_, '_> { InterfaceType::Option(t) => self.translate_option(*t, src, dst_ty, dst), InterfaceType::Expected(t) => self.translate_expected(*t, src, dst_ty, dst), } + + self.top_level_translate = top_level; } fn translate_unit(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { @@ -504,7 +650,7 @@ impl Compiler<'_, '_> { fn convert_u32_mask(&mut self, src: &Source<'_>, dst: &Destination<'_>, mask: u32) { self.push_dst_addr(dst); match src { - Source::Memory(mem) => self.i32_load16u(mem), + Source::Memory(mem) => self.i32_load(mem), Source::Stack(stack) => self.stack_get(stack, ValType::I32), } if mask != 0xffffffff { @@ -854,7 +1000,7 @@ impl Compiler<'_, '_> { self.instruction(LocalGet(src.ptr)); self.instruction(LocalGet(src.len)); self.instruction(LocalGet(dst.ptr)); - self.instruction(Call(transcode)); + self.instruction(Call(transcode.as_u32())); dst } @@ -923,7 +1069,7 @@ impl Compiler<'_, '_> { self.instruction(LocalGet(src.len)); self.instruction(LocalGet(dst.ptr)); self.instruction(LocalGet(dst_byte_len)); - self.instruction(Call(transcode)); + self.instruction(Call(transcode.as_u32())); self.instruction(LocalSet(dst.len)); let src_len_tmp = self.gen_local(src.opts.ptr()); self.instruction(LocalSet(src_len_tmp)); @@ -978,7 +1124,7 @@ impl Compiler<'_, '_> { self.instruction(LocalGet(dst_byte_len)); self.instruction(LocalGet(dst.len)); self.ptr_sub(dst.opts); - self.instruction(Call(transcode)); + self.instruction(Call(transcode.as_u32())); // Add the second result, the amount of destination units encoded, // to `dst_len` so it's an accurate reflection of the final size of @@ -1075,7 +1221,7 @@ impl Compiler<'_, '_> { self.instruction(LocalGet(src.ptr)); self.instruction(LocalGet(src.len)); self.instruction(LocalGet(dst.ptr)); - self.instruction(Call(transcode)); + self.instruction(Call(transcode.as_u32())); self.instruction(LocalSet(dst.len)); // If the number of code units returned by transcode is not @@ -1137,14 +1283,18 @@ impl Compiler<'_, '_> { } }; - self.validate_string_inbounds(src, dst_byte_len); + let src_byte_len = self.gen_local(src.opts.ptr()); + self.convert_src_len_to_dst(dst_byte_len, dst.opts.ptr(), src.opts.ptr()); + self.instruction(LocalSet(src_byte_len)); + + self.validate_string_inbounds(src, src.len); self.validate_string_inbounds(&dst, dst_byte_len); let transcode = self.transcoder(src, &dst, Transcode::Utf16ToCompactProbablyUtf16); self.instruction(LocalGet(src.ptr)); self.instruction(LocalGet(src.len)); self.instruction(LocalGet(dst.ptr)); - self.instruction(Call(transcode)); + self.instruction(Call(transcode.as_u32())); self.instruction(LocalSet(dst.len)); // Assert that the untagged code unit length is the same as the @@ -1222,7 +1372,7 @@ impl Compiler<'_, '_> { self.instruction(LocalGet(src.ptr)); self.instruction(LocalGet(src.len)); self.instruction(LocalGet(dst.ptr)); - self.instruction(Call(transcode_latin1)); + self.instruction(Call(transcode_latin1.as_u32())); self.instruction(LocalSet(dst.len)); let src_len_tmp = self.gen_local(src.opts.ptr()); self.instruction(LocalSet(src_len_tmp)); @@ -1289,7 +1439,7 @@ impl Compiler<'_, '_> { self.instruction(LocalGet(dst.ptr)); self.convert_src_len_to_dst(src.len, src.opts.ptr(), dst.opts.ptr()); self.instruction(LocalGet(dst.len)); - self.instruction(Call(transcode_utf16)); + self.instruction(Call(transcode_utf16.as_u32())); self.instruction(LocalSet(dst.len)); // If the returned number of code units written to the destination @@ -1340,17 +1490,19 @@ impl Compiler<'_, '_> { self.instruction(End); } - fn transcoder(&mut self, src: &WasmString<'_>, dst: &WasmString<'_>, op: Transcode) -> u32 { - self.transcoders.import( - self.types, - Transcoder { - from_memory: src.opts.memory.unwrap(), - from_memory64: src.opts.memory64, - to_memory: dst.opts.memory.unwrap(), - to_memory64: dst.opts.memory64, - op, - }, - ) + fn transcoder( + &mut self, + src: &WasmString<'_>, + dst: &WasmString<'_>, + op: Transcode, + ) -> FuncIndex { + self.module.import_transcoder(Transcoder { + from_memory: src.opts.memory.unwrap(), + from_memory64: src.opts.memory64, + to_memory: dst.opts.memory.unwrap(), + to_memory64: dst.opts.memory64, + op, + }) } fn validate_string_inbounds(&mut self, s: &WasmString<'_>, byte_len: u32) { @@ -1386,7 +1538,7 @@ impl Compiler<'_, '_> { self.instruction(LocalTee(tmp)); self.instruction(LocalGet(s.ptr)); self.ptr_lt_u(s.opts); - self.ptr_br_if(s.opts, 0); + self.instruction(BrIf(0)); self.instruction(LocalGet(tmp)); } @@ -1408,15 +1560,15 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_element_ty = &self.module.types[src_ty]; + let src_element_ty = &self.types[src_ty]; let dst_element_ty = match dst_ty { - InterfaceType::List(r) => &self.module.types[*r], + InterfaceType::List(r) => &self.types[*r], _ => panic!("expected a list"), }; let src_opts = src.opts(); let dst_opts = dst.opts(); - let (src_size, src_align) = self.module.size_align(src_opts, src_element_ty); - let (dst_size, dst_align) = self.module.size_align(dst_opts, dst_element_ty); + let (src_size, src_align) = self.types.size_align(src_opts, src_element_ty); + let (dst_size, dst_align) = self.types.size_align(dst_opts, dst_element_ty); // Load the pointer/length of this list into temporary locals. These // will be referenced a good deal so this just makes it easier to deal @@ -1791,9 +1943,9 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Record(r) => &self.module.types[*r], + InterfaceType::Record(r) => &self.types[*r], _ => panic!("expected a record"), }; @@ -1805,7 +1957,7 @@ impl Compiler<'_, '_> { // fields' names let mut src_fields = HashMap::new(); for (i, src) in src - .record_field_srcs(self.module, src_ty.fields.iter().map(|f| f.ty)) + .record_field_srcs(self.types, src_ty.fields.iter().map(|f| f.ty)) .enumerate() { let field = &src_ty.fields[i]; @@ -1821,7 +1973,7 @@ impl Compiler<'_, '_> { // // TODO: should that lookup be fallible with subtyping? for (i, dst) in dst - .record_field_dsts(self.module, dst_ty.fields.iter().map(|f| f.ty)) + .record_field_dsts(self.types, dst_ty.fields.iter().map(|f| f.ty)) .enumerate() { let field = &dst_ty.fields[i]; @@ -1837,9 +1989,9 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Flags(r) => &self.module.types[*r], + InterfaceType::Flags(r) => &self.types[*r], _ => panic!("expected a record"), }; @@ -1863,8 +2015,8 @@ impl Compiler<'_, '_> { self.convert_u16_mask(src, dst, mask); } FlagsSize::Size4Plus(n) => { - let srcs = src.record_field_srcs(self.module, (0..n).map(|_| InterfaceType::U32)); - let dsts = dst.record_field_dsts(self.module, (0..n).map(|_| InterfaceType::U32)); + let srcs = src.record_field_srcs(self.types, (0..n).map(|_| InterfaceType::U32)); + let dsts = dst.record_field_dsts(self.types, (0..n).map(|_| InterfaceType::U32)); for (i, (src, dst)) in srcs.zip(dsts).enumerate() { let mask = if i == n - 1 && (cnt % 32 != 0) { (1 << (cnt % 32)) - 1 @@ -1884,9 +2036,9 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Tuple(t) => &self.module.types[*t], + InterfaceType::Tuple(t) => &self.types[*t], _ => panic!("expected a tuple"), }; @@ -1894,10 +2046,10 @@ impl Compiler<'_, '_> { assert_eq!(src_ty.types.len(), dst_ty.types.len()); let srcs = src - .record_field_srcs(self.module, src_ty.types.iter().copied()) + .record_field_srcs(self.types, src_ty.types.iter().copied()) .zip(src_ty.types.iter()); let dsts = dst - .record_field_dsts(self.module, dst_ty.types.iter().copied()) + .record_field_dsts(self.types, dst_ty.types.iter().copied()) .zip(dst_ty.types.iter()); for ((src, src_ty), (dst, dst_ty)) in srcs.zip(dsts) { self.translate(src_ty, &src, dst_ty, &dst); @@ -1911,14 +2063,14 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Variant(t) => &self.module.types[*t], + InterfaceType::Variant(t) => &self.types[*t], _ => panic!("expected a variant"), }; - let src_disc_size = DiscriminantSize::from_count(src_ty.cases.len()).unwrap(); - let dst_disc_size = DiscriminantSize::from_count(dst_ty.cases.len()).unwrap(); + let src_info = VariantInfo::new(self.types, src.opts(), src_ty.cases.iter().map(|c| c.ty)); + let dst_info = VariantInfo::new(self.types, dst.opts(), dst_ty.cases.iter().map(|c| c.ty)); let iter = src_ty.cases.iter().enumerate().map(|(src_i, src_case)| { let dst_i = dst_ty @@ -1936,7 +2088,7 @@ impl Compiler<'_, '_> { dst_ty: &dst_case.ty, } }); - self.convert_variant(src, src_disc_size, dst, dst_disc_size, iter); + self.convert_variant(src, &src_info, dst, &dst_info, iter); } fn translate_union( @@ -1946,18 +2098,20 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Union(t) => &self.module.types[*t], + InterfaceType::Union(t) => &self.types[*t], _ => panic!("expected an option"), }; assert_eq!(src_ty.types.len(), dst_ty.types.len()); + let src_info = VariantInfo::new(self.types, src.opts(), src_ty.types.iter().copied()); + let dst_info = VariantInfo::new(self.types, dst.opts(), dst_ty.types.iter().copied()); self.convert_variant( src, - DiscriminantSize::Size1, + &src_info, dst, - DiscriminantSize::Size1, + &dst_info, src_ty .types .iter() @@ -1982,18 +2136,28 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Enum(t) => &self.module.types[*t], + InterfaceType::Enum(t) => &self.types[*t], _ => panic!("expected an option"), }; + let src_info = VariantInfo::new( + self.types, + src.opts(), + src_ty.names.iter().map(|_| InterfaceType::Unit), + ); + let dst_info = VariantInfo::new( + self.types, + dst.opts(), + dst_ty.names.iter().map(|_| InterfaceType::Unit), + ); let unit = &InterfaceType::Unit; self.convert_variant( src, - DiscriminantSize::from_count(src_ty.names.len()).unwrap(), + &src_info, dst, - DiscriminantSize::from_count(dst_ty.names.len()).unwrap(), + &dst_info, src_ty.names.iter().enumerate().map(|(src_i, src_name)| { let dst_i = dst_ty.names.iter().position(|n| n == src_name).unwrap(); let src_i = u32::try_from(src_i).unwrap(); @@ -2015,17 +2179,20 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Option(t) => &self.module.types[*t], + InterfaceType::Option(t) => &self.types[*t], _ => panic!("expected an option"), }; + let src_info = VariantInfo::new(self.types, src.opts(), [InterfaceType::Unit, *src_ty]); + let dst_info = VariantInfo::new(self.types, dst.opts(), [InterfaceType::Unit, *dst_ty]); + self.convert_variant( src, - DiscriminantSize::Size1, + &src_info, dst, - DiscriminantSize::Size1, + &dst_info, [ VariantCase { src_i: 0, @@ -2051,17 +2218,20 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Expected(t) => &self.module.types[*t], + InterfaceType::Expected(t) => &self.types[*t], _ => panic!("expected an expected"), }; + let src_info = VariantInfo::new(self.types, src.opts(), [src_ty.ok, src_ty.err]); + let dst_info = VariantInfo::new(self.types, dst.opts(), [dst_ty.ok, dst_ty.err]); + self.convert_variant( src, - DiscriminantSize::Size1, + &src_info, dst, - DiscriminantSize::Size1, + &dst_info, [ VariantCase { src_i: 0, @@ -2083,9 +2253,9 @@ impl Compiler<'_, '_> { fn convert_variant<'a>( &mut self, src: &Source<'_>, - src_disc_size: DiscriminantSize, + src_info: &VariantInfo, dst: &Destination, - dst_disc_size: DiscriminantSize, + dst_info: &VariantInfo, src_cases: impl ExactSizeIterator>, ) { // The outermost block is special since it has the result type of the @@ -2095,7 +2265,7 @@ impl Compiler<'_, '_> { 0 => BlockType::Empty, 1 => BlockType::Result(dst_flat[0]), _ => { - let ty = self.types.function(&[], &dst_flat); + let ty = self.module.core_types.function(&[], &dst_flat); BlockType::FunctionType(ty) } }, @@ -2120,7 +2290,7 @@ impl Compiler<'_, '_> { // Load the discriminant match src { Source::Stack(s) => self.stack_get(&s.slice(0..1), ValType::I32), - Source::Memory(mem) => match src_disc_size { + Source::Memory(mem) => match src_info.size { DiscriminantSize::Size1 => self.i32_load8u(mem), DiscriminantSize::Size2 => self.i32_load16u(mem), DiscriminantSize::Size4 => self.i32_load(mem), @@ -2158,7 +2328,7 @@ impl Compiler<'_, '_> { self.instruction(I32Const(dst_i as i32)); match dst { Destination::Stack(stack, _) => self.stack_set(&stack[..1], ValType::I32), - Destination::Memory(mem) => match dst_disc_size { + Destination::Memory(mem) => match dst_info.size { DiscriminantSize::Size1 => self.i32_store8(mem), DiscriminantSize::Size2 => self.i32_store16(mem), DiscriminantSize::Size4 => self.i32_store(mem), @@ -2167,8 +2337,8 @@ impl Compiler<'_, '_> { // Translate the payload of this case using the various types from // the dst/src. - let src_payload = src.payload_src(self.module, src_disc_size, src_ty); - let dst_payload = dst.payload_dst(self.module, dst_disc_size, dst_ty); + let src_payload = src.payload_src(self.types, src_info, src_ty); + let dst_payload = dst.payload_dst(self.types, dst_info, dst_ty); self.translate(src_ty, &src_payload, dst_ty, &dst_payload); // If the results of this translation were placed on the stack then @@ -2251,7 +2421,7 @@ impl Compiler<'_, '_> { if !self.module.debug { return; } - let align = self.module.align(mem.opts, ty); + let align = self.types.align(mem.opts, ty); if align == 1 { return; } @@ -2299,9 +2469,10 @@ impl Compiler<'_, '_> { fn gen_local(&mut self, ty: ValType) -> u32 { // TODO: see if local reuse is necessary, right now this always // generates a new local. - match self.locals.last_mut() { + let locals = &mut self.module.funcs[self.result].locals; + match locals.last_mut() { Some((cnt, prev_ty)) if ty == *prev_ty => *cnt += 1, - _ => self.locals.push((1, ty)), + _ => locals.push((1, ty)), } self.nlocals += 1; self.nlocals - 1 @@ -2316,27 +2487,29 @@ impl Compiler<'_, '_> { self.instruction(Unreachable); } - fn finish(&mut self) -> (Vec, Vec<(usize, Trap)>) { - self.instruction(End); - - let mut bytes = Vec::new(); - - // Encode all locals used for this function - self.locals.len().encode(&mut bytes); - for (count, ty) in self.locals.iter() { - count.encode(&mut bytes); - ty.encode(&mut bytes); + /// Flushes out the current `code` instructions (and `traps` if there are + /// any) into the destination function. + /// + /// This is a noop if no instructions have been encoded yet. + fn flush_code(&mut self) { + if self.code.is_empty() { + return; } + self.module.funcs[self.result].body.push(Body::Raw( + mem::take(&mut self.code), + mem::take(&mut self.traps), + )); + } - // Factor in the size of the encodings of locals into the offsets of - // traps. - for (offset, _) in self.traps.iter_mut() { - *offset += bytes.len(); - } + fn finish(mut self) { + // Append the final `end` instruction which all functions require, and + // then empty out the temporary buffer in `Compiler`. + self.instruction(End); + self.flush_code(); - // Then append the function we built and return - bytes.extend_from_slice(&self.code); - (bytes, mem::take(&mut self.traps)) + // Flag the function as "done" which helps with an assert later on in + // emission that everything was eventually finished. + self.module.funcs[self.result].filled_in = true; } /// Fetches the value contained with the local specified by `stack` and @@ -2361,8 +2534,8 @@ impl Compiler<'_, '_> { (ValType::I64, ValType::F64) => self.instruction(F64ReinterpretI64), (ValType::F64, ValType::F32) => self.instruction(F32DemoteF64), (ValType::I64, ValType::F32) => { - self.instruction(F64ReinterpretI64); - self.instruction(F32DemoteF64); + self.instruction(I32WrapI64); + self.instruction(F32ReinterpretI32); } // should not be possible given the `join` function for variants @@ -2405,8 +2578,8 @@ impl Compiler<'_, '_> { (ValType::F64, ValType::I64) => self.instruction(I64ReinterpretF64), (ValType::F32, ValType::F64) => self.instruction(F64PromoteF32), (ValType::F32, ValType::I64) => { - self.instruction(F64PromoteF32); - self.instruction(I64ReinterpretF64); + self.instruction(I32ReinterpretF32); + self.instruction(I64ExtendI32U); } // should not be possible given the `join` function for variants @@ -2654,7 +2827,7 @@ impl<'a> Source<'a> { /// offset for each memory-based type. fn record_field_srcs<'b>( &'b self, - module: &'b Module, + types: &'b ComponentTypes, fields: impl IntoIterator + 'b, ) -> impl Iterator> + 'b where @@ -2663,11 +2836,11 @@ impl<'a> Source<'a> { let mut offset = 0; fields.into_iter().map(move |ty| match self { Source::Memory(mem) => { - let mem = next_field_offset(&mut offset, module, &ty, mem); + let mem = next_field_offset(&mut offset, types, &ty, mem); Source::Memory(mem) } Source::Stack(stack) => { - let cnt = module.flatten_types(stack.opts, [ty]).len(); + let cnt = types.flatten_types(stack.opts, [ty]).len(); offset += cnt; Source::Stack(stack.slice(offset - cnt..offset)) } @@ -2677,17 +2850,17 @@ impl<'a> Source<'a> { /// Returns the corresponding discriminant source and payload source f fn payload_src( &self, - module: &Module, - size: DiscriminantSize, + types: &ComponentTypes, + info: &VariantInfo, case: &InterfaceType, ) -> Source<'a> { match self { Source::Stack(s) => { - let flat_len = module.flatten_types(s.opts, [*case]).len(); + let flat_len = types.flatten_types(s.opts, [*case]).len(); Source::Stack(s.slice(1..s.locals.len()).slice(0..flat_len)) } Source::Memory(mem) => { - let mem = payload_offset(size, module, case, mem); + let mem = info.payload_offset(case, mem); Source::Memory(mem) } } @@ -2705,7 +2878,7 @@ impl<'a> Destination<'a> { /// Same as `Source::record_field_srcs` but for destinations. fn record_field_dsts<'b>( &'b self, - module: &'b Module, + types: &'b ComponentTypes, fields: impl IntoIterator + 'b, ) -> impl Iterator + 'b where @@ -2714,11 +2887,11 @@ impl<'a> Destination<'a> { let mut offset = 0; fields.into_iter().map(move |ty| match self { Destination::Memory(mem) => { - let mem = next_field_offset(&mut offset, module, &ty, mem); + let mem = next_field_offset(&mut offset, types, &ty, mem); Destination::Memory(mem) } Destination::Stack(s, opts) => { - let cnt = module.flatten_types(opts, [ty]).len(); + let cnt = types.flatten_types(opts, [ty]).len(); offset += cnt; Destination::Stack(&s[offset - cnt..offset], opts) } @@ -2728,17 +2901,17 @@ impl<'a> Destination<'a> { /// Returns the corresponding discriminant source and payload source f fn payload_dst( &self, - module: &Module, - size: DiscriminantSize, + types: &ComponentTypes, + info: &VariantInfo, case: &InterfaceType, ) -> Destination { match self { Destination::Stack(s, opts) => { - let flat_len = module.flatten_types(opts, [*case]).len(); + let flat_len = types.flatten_types(opts, [*case]).len(); Destination::Stack(&s[1..][..flat_len], opts) } Destination::Memory(mem) => { - let mem = payload_offset(size, module, case, mem); + let mem = info.payload_offset(case, mem); Destination::Memory(mem) } } @@ -2754,23 +2927,37 @@ impl<'a> Destination<'a> { fn next_field_offset<'a>( offset: &mut usize, - module: &Module, + types: &ComponentTypes, field: &InterfaceType, mem: &Memory<'a>, ) -> Memory<'a> { - let (size, align) = module.size_align(mem.opts, field); + let (size, align) = types.size_align(mem.opts, field); *offset = align_to(*offset, align) + size; mem.bump(*offset - size) } -fn payload_offset<'a>( - disc_size: DiscriminantSize, - module: &Module, - case: &InterfaceType, - mem: &Memory<'a>, -) -> Memory<'a> { - let align = module.align(mem.opts, case); - mem.bump(align_to(disc_size.into(), align)) +struct VariantInfo { + size: DiscriminantSize, + align: usize, +} + +impl VariantInfo { + fn new(types: &ComponentTypes, options: &Options, iter: I) -> VariantInfo + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let iter = iter.into_iter(); + let size = DiscriminantSize::from_count(iter.len()).unwrap(); + VariantInfo { + size, + align: usize::from(size).max(iter.map(|i| types.align(options, &i)).max().unwrap_or(1)), + } + } + + fn payload_offset<'a>(&self, _case: &InterfaceType, mem: &Memory<'a>) -> Memory<'a> { + mem.bump(align_to(self.size.into(), self.align)) + } } impl<'a> Memory<'a> { diff --git a/crates/environ/src/fact/transcode.rs b/crates/environ/src/fact/transcode.rs index 865fef316ed6..7d72413050f5 100644 --- a/crates/environ/src/fact/transcode.rs +++ b/crates/environ/src/fact/transcode.rs @@ -1,15 +1,8 @@ use crate::fact::core_types::CoreTypes; use crate::MemoryIndex; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use wasm_encoder::{EntityType, ValType}; -pub struct Transcoders { - imported: HashMap, - prev_func_imports: u32, - imports: Vec<(String, EntityType, Transcoder)>, -} - #[derive(Copy, Clone, Hash, Eq, PartialEq)] pub struct Transcoder { pub from_memory: MemoryIndex, @@ -46,33 +39,8 @@ pub enum FixedEncoding { Latin1, } -impl Transcoders { - pub fn new(prev_func_imports: u32) -> Transcoders { - Transcoders { - imported: HashMap::new(), - prev_func_imports, - imports: Vec::new(), - } - } - - pub fn import(&mut self, types: &mut CoreTypes, transcoder: Transcoder) -> u32 { - *self.imported.entry(transcoder).or_insert_with(|| { - let idx = self.prev_func_imports + (self.imports.len() as u32); - self.imports - .push((transcoder.name(), transcoder.ty(types), transcoder)); - idx - }) - } - - pub fn imports(&self) -> impl Iterator { - self.imports - .iter() - .map(|(name, ty, transcoder)| ("transcode", &name[..], *ty, transcoder)) - } -} - impl Transcoder { - fn name(&self) -> String { + pub fn name(&self) -> String { format!( "{} (mem{} => mem{})", self.op.desc(), @@ -81,7 +49,7 @@ impl Transcoder { ) } - fn ty(&self, types: &mut CoreTypes) -> EntityType { + pub fn ty(&self, types: &mut CoreTypes) -> EntityType { let from_ptr = if self.from_memory64 { ValType::I64 } else { diff --git a/crates/fuzzing/src/generators/component_types.rs b/crates/fuzzing/src/generators/component_types.rs index 2d93f29d726a..bd3c64755cbe 100644 --- a/crates/fuzzing/src/generators/component_types.rs +++ b/crates/fuzzing/src/generators/component_types.rs @@ -8,6 +8,7 @@ use arbitrary::{Arbitrary, Unstructured}; use component_fuzz_util::{Declarations, EXPORT_FUNCTION, IMPORT_FUNCTION}; +use std::any::Any; use std::fmt::Debug; use std::ops::ControlFlow; use wasmtime::component::{self, Component, Lift, Linker, Lower, Val}; @@ -141,25 +142,29 @@ macro_rules! define_static_api_test { let mut config = Config::new(); config.wasm_component_model(true); let engine = Engine::new(&config).unwrap(); - let component = Component::new( - &engine, - declarations.make_component().as_bytes() - ).unwrap(); + let wat = declarations.make_component(); + let wat = wat.as_bytes(); + crate::oracles::log_wasm(wat); + let component = Component::new(&engine, wat).unwrap(); let mut linker = Linker::new(&engine); linker .root() .func_wrap( IMPORT_FUNCTION, - |cx: StoreContextMut<'_, ($(Option<$param>,)* Option)>, + |cx: StoreContextMut<'_, Box>, $($param_name: $param,)*| { - let ($($param_expected_name,)* result) = cx.data(); - $(assert_eq!($param_name, *$param_expected_name.as_ref().unwrap());)* - Ok(result.as_ref().unwrap().clone()) + log::trace!("received parameters {:?}", ($(&$param_name,)*)); + let data: &($($param,)* R,) = + cx.data().downcast_ref().unwrap(); + let ($($param_expected_name,)* result,) = data; + $(assert_eq!($param_name, *$param_expected_name);)* + log::trace!("returning result {:?}", result); + Ok(result.clone()) }, ) .unwrap(); - let mut store = Store::new(&engine, Default::default()); + let mut store: Store> = Store::new(&engine, Box::new(())); let instance = linker.instantiate(&mut store, &component).unwrap(); let func = instance .get_typed_func::<($($param,)*), R, _>(&mut store, EXPORT_FUNCTION) @@ -168,9 +173,17 @@ macro_rules! define_static_api_test { while input.arbitrary()? { $(let $param_name = input.arbitrary::<$param>()?;)* let result = input.arbitrary::()?; - *store.data_mut() = ($(Some($param_name.clone()),)* Some(result.clone())); - - assert_eq!(func.call(&mut store, ($($param_name,)*)).unwrap(), result); + *store.data_mut() = Box::new(( + $($param_name.clone(),)* + result.clone(), + )); + log::trace!( + "passing in parameters {:?}", + ($(&$param_name,)*), + ); + let actual = func.call(&mut store, ($($param_name,)*)).unwrap(); + log::trace!("got result {:?}", actual); + assert_eq!(actual, result); func.post_return(&mut store).unwrap(); } diff --git a/crates/fuzzing/src/oracles.rs b/crates/fuzzing/src/oracles.rs index 4e7d090c4ff3..bdaca94ae51e 100644 --- a/crates/fuzzing/src/oracles.rs +++ b/crates/fuzzing/src/oracles.rs @@ -1089,20 +1089,25 @@ pub fn dynamic_component_api_target(input: &mut arbitrary::Unstructured) -> arbi let engine = component_test_util::engine(); let mut store = Store::new(&engine, (Box::new([]) as Box<[Val]>, None)); - let component = - Component::new(&engine, case.declarations().make_component().as_bytes()).unwrap(); + let wat = case.declarations().make_component(); + let wat = wat.as_bytes(); + log_wasm(wat); + let component = Component::new(&engine, wat).unwrap(); let mut linker = Linker::new(&engine); linker .root() .func_new(&component, IMPORT_FUNCTION, { move |cx: StoreContextMut<'_, (Box<[Val]>, Option)>, args: &[Val]| -> Result { + log::trace!("received arguments {args:?}"); let (expected_args, result) = cx.data(); assert_eq!(args.len(), expected_args.len()); for (expected, actual) in expected_args.iter().zip(args) { assert_eq!(expected, actual); } - Ok(result.as_ref().unwrap().clone()) + let result = result.as_ref().unwrap().clone(); + log::trace!("returning result {result:?}"); + Ok(result) } }) .unwrap(); @@ -1122,10 +1127,10 @@ pub fn dynamic_component_api_target(input: &mut arbitrary::Unstructured) -> arbi *store.data_mut() = (args.clone(), Some(result.clone())); - assert_eq!( - func.call_and_post_return(&mut store, &args).unwrap(), - result - ); + log::trace!("passing args {args:?}"); + let actual = func.call_and_post_return(&mut store, &args).unwrap(); + log::trace!("received return {actual:?}"); + assert_eq!(actual, result); } Ok(()) diff --git a/crates/misc/component-fuzz-util/src/lib.rs b/crates/misc/component-fuzz-util/src/lib.rs index 9b14266dcd92..2d5b75fd9254 100644 --- a/crates/misc/component-fuzz-util/src/lib.rs +++ b/crates/misc/component-fuzz-util/src/lib.rs @@ -8,7 +8,8 @@ use arbitrary::{Arbitrary, Unstructured}; use proc_macro2::{Ident, TokenStream}; -use quote::{format_ident, quote}; +use quote::{format_ident, quote, ToTokens}; +use std::borrow::Cow; use std::fmt::{self, Debug, Write}; use std::iter; use std::ops::Deref; @@ -328,7 +329,7 @@ fn variant_size_and_alignment<'a>( } } -fn make_import_and_export(params: &[Type], result: &Type) -> Box { +fn make_import_and_export(params: &[Type], result: &Type) -> String { let params_lowered = params .iter() .flat_map(|ty| ty.lowered()) @@ -400,7 +401,6 @@ fn make_import_and_export(params: &[Type], result: &Type) -> Box { )"# ) } - .into() } fn make_rust_name(name_counter: &mut u32) -> Ident { @@ -509,7 +509,7 @@ pub fn rust_type(ty: &Type, name_counter: &mut u32, declarations: &mut TokenStre let name = make_rust_name(name_counter); declarations.extend(quote! { - #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)] + #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Copy, Clone, Arbitrary)] #[component(enum)] enum #name { #cases @@ -677,13 +677,17 @@ fn write_component_type( #[derive(Debug)] pub struct Declarations { /// Type declarations (if any) referenced by `params` and/or `result` - pub types: Box, + pub types: Cow<'static, str>, /// Parameter declarations used for the imported and exported functions - pub params: Box, + pub params: Cow<'static, str>, /// Result declaration used for the imported and exported functions - pub result: Box, + pub result: Cow<'static, str>, /// A WAT fragment representing the core function import and export to use for testing - pub import_and_export: Box, + pub import_and_export: Cow<'static, str>, + /// String encoding to use for host -> component + pub encoding1: StringEncoding, + /// String encoding to use for component -> host + pub encoding2: StringEncoding, } impl Declarations { @@ -694,7 +698,44 @@ impl Declarations { params, result, import_and_export, + encoding1, + encoding2, } = self; + let mk_component = |name: &str, encoding: StringEncoding| { + format!( + r#" + (component ${name} + (import "echo" (func $f (type $sig))) + + (core instance $libc (instantiate $libc)) + + (core func $f_lower (canon lower + (func $f) + (memory $libc "memory") + (realloc (func $libc "realloc")) + string-encoding={encoding} + )) + + (core instance $i (instantiate $m + (with "libc" (instance $libc)) + (with "host" (instance (export "{IMPORT_FUNCTION}" (func $f_lower)))) + )) + + (func (export "echo") (type $sig) + (canon lift + (core func $i "echo") + (memory $libc "memory") + (realloc (func $libc "realloc")) + string-encoding={encoding} + ) + ) + ) + "# + ) + }; + + let c1 = mk_component("c1", *encoding2); + let c2 = mk_component("c2", *encoding1); format!( r#" @@ -704,18 +745,6 @@ impl Declarations { {REALLOC_AND_FREE} ) - (core instance $libc (instantiate $libc)) - - {types} - - (import "{IMPORT_FUNCTION}" (func $f {params} {result})) - - (core func $f_lower (canon lower - (func $f) - (memory $libc "memory") - (realloc (func $libc "realloc")) - )) - (core module $m (memory (import "libc" "memory") 1) (func $realloc (import "libc" "realloc") (param i32 i32 i32 i32) (result i32)) @@ -723,18 +752,16 @@ impl Declarations { {import_and_export} ) - (core instance $i (instantiate $m - (with "libc" (instance $libc)) - (with "host" (instance (export "{IMPORT_FUNCTION}" (func $f_lower)))) - )) + {types} - (func (export "echo") {params} {result} - (canon lift - (core func $i "echo") - (memory $libc "memory") - (realloc (func $libc "realloc")) - ) - ) + (type $sig (func {params} {result})) + (import "{IMPORT_FUNCTION}" (func $f (type $sig))) + + {c1} + {c2} + (instance $c1 (instantiate $c1 (with "echo" (func $f)))) + (instance $c2 (instantiate $c2 (with "echo" (func $c1 "echo")))) + (export "echo" (func $c2 "echo")) )"#, ) .into() @@ -748,6 +775,10 @@ pub struct TestCase { pub params: Box<[Type]>, /// The type of the result to be returned by the function pub result: Type, + /// String encoding to use from host-to-component. + pub encoding1: StringEncoding, + /// String encoding to use from component-to-host. + pub encoding2: StringEncoding, } impl TestCase { @@ -781,7 +812,9 @@ impl TestCase { types: types.into(), params, result, - import_and_export, + import_and_export: import_and_export.into(), + encoding1: self.encoding1, + encoding2: self.encoding2, } } } @@ -795,6 +828,36 @@ impl<'a> Arbitrary<'a> for TestCase { .take(MAX_ARITY) .collect::>>()?, result: input.arbitrary()?, + encoding1: input.arbitrary()?, + encoding2: input.arbitrary()?, }) } } + +#[derive(Copy, Clone, Debug, Arbitrary)] +pub enum StringEncoding { + Utf8, + Utf16, + Latin1OrUtf16, +} + +impl fmt::Display for StringEncoding { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + StringEncoding::Utf8 => fmt::Display::fmt(&"utf8", f), + StringEncoding::Utf16 => fmt::Display::fmt(&"utf16", f), + StringEncoding::Latin1OrUtf16 => fmt::Display::fmt(&"latin1+utf16", f), + } + } +} + +impl ToTokens for StringEncoding { + fn to_tokens(&self, tokens: &mut TokenStream) { + let me = match self { + StringEncoding::Utf8 => quote!(Utf8), + StringEncoding::Utf16 => quote!(Utf16), + StringEncoding::Latin1OrUtf16 => quote!(Latin1OrUtf16), + }; + tokens.extend(quote!(component_fuzz_util::StringEncoding::#me)); + } +} diff --git a/crates/wasmtime/src/component/values.rs b/crates/wasmtime/src/component/values.rs index 7c3f5501522a..65fd2280b986 100644 --- a/crates/wasmtime/src/component/values.rs +++ b/crates/wasmtime/src/component/values.rs @@ -4,12 +4,13 @@ use crate::store::StoreOpaque; use crate::{AsContextMut, StoreContextMut, ValRaw}; use anyhow::{anyhow, bail, Context, Error, Result}; use std::collections::HashMap; +use std::fmt; use std::iter; use std::mem::MaybeUninit; use std::ops::Deref; use wasmtime_component_util::{DiscriminantSize, FlagsSize}; -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(PartialEq, Eq, Clone)] pub struct List { ty: types::List, values: Box<[Val]>, @@ -45,7 +46,17 @@ impl Deref for List { } } -#[derive(Debug, PartialEq, Eq, Clone)] +impl fmt::Debug for List { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut f = f.debug_list(); + for val in self.iter() { + f.entry(val); + } + f.finish() + } +} + +#[derive(PartialEq, Eq, Clone)] pub struct Record { ty: types::Record, values: Box<[Val]>, @@ -105,6 +116,16 @@ impl Record { } } +impl fmt::Debug for Record { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut f = f.debug_struct("Record"); + for (name, val) in self.fields() { + f.field(name, val); + } + f.finish() + } +} + #[derive(Debug, PartialEq, Eq, Clone)] pub struct Tuple { ty: types::Tuple, @@ -144,7 +165,7 @@ impl Tuple { } } -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(PartialEq, Eq, Clone)] pub struct Variant { ty: types::Variant, discriminant: u32, @@ -197,6 +218,14 @@ impl Variant { } } +impl fmt::Debug for Variant { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple(self.discriminant()) + .field(self.payload()) + .finish() + } +} + #[derive(Debug, PartialEq, Eq, Clone)] pub struct Enum { ty: types::Enum, @@ -273,7 +302,7 @@ impl Union { } } -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(PartialEq, Eq, Clone)] pub struct Option { ty: types::Option, discriminant: u32, @@ -313,7 +342,13 @@ impl Option { } } -#[derive(Debug, PartialEq, Eq, Clone)] +impl fmt::Debug for Option { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.value().fmt(f) + } +} + +#[derive(PartialEq, Eq, Clone)] pub struct Expected { ty: types::Expected, discriminant: u32, @@ -358,7 +393,13 @@ impl Expected { } } -#[derive(Debug, PartialEq, Eq, Clone)] +impl fmt::Debug for Expected { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.value().fmt(f) + } +} + +#[derive(PartialEq, Eq, Clone)] pub struct Flags { ty: types::Flags, count: u32, @@ -408,6 +449,16 @@ impl Flags { } } +impl fmt::Debug for Flags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut set = f.debug_set(); + for flag in self.flags() { + set.entry(&flag); + } + set.finish() + } +} + /// Represents possible runtime values which a component function can either consume or produce #[derive(Debug, PartialEq, Eq, Clone)] pub enum Val { diff --git a/fuzz/build.rs b/fuzz/build.rs index b8b45a36f45f..97c6e64ddffd 100644 --- a/fuzz/build.rs +++ b/fuzz/build.rs @@ -77,6 +77,8 @@ mod component { params, result, import_and_export, + encoding1, + encoding2, } = case.declarations(); let test = format_ident!("static_api_test{}", case.params.len()); @@ -95,11 +97,16 @@ mod component { let test = quote!(#index => component_types::#test::<#rust_params #rust_result>( input, - &Declarations { - types: #types.into(), - params: #params.into(), - result: #result.into(), - import_and_export: #import_and_export.into() + { + static DECLS: Declarations = Declarations { + types: Cow::Borrowed(#types), + params: Cow::Borrowed(#params), + result: Cow::Borrowed(#result), + import_and_export: Cow::Borrowed(#import_and_export), + encoding1: #encoding1, + encoding2: #encoding2, + }; + &DECLS } ),); @@ -116,6 +123,7 @@ mod component { use std::sync::{Arc, Once}; use wasmtime::component::{ComponentType, Lift, Lower}; use wasmtime_fuzzing::generators::component_types; + use std::borrow::Cow; const SEED: u64 = #seed; diff --git a/tests/misc_testsuite/component-model/fused.wast b/tests/misc_testsuite/component-model/fused.wast index 6de762471ef2..77dca93edc3d 100644 --- a/tests/misc_testsuite/component-model/fused.wast +++ b/tests/misc_testsuite/component-model/fused.wast @@ -925,7 +925,7 @@ (i32.eqz (local.get 0)) if (if (i32.ne (local.get 1) (i32.const 0)) (unreachable)) - (if (f64.ne (f64.reinterpret_i64 (local.get 2)) (f64.const 8)) (unreachable)) + (if (f32.ne (f32.reinterpret_i32 (i32.wrap_i64 (local.get 2))) (f32.const 8)) (unreachable)) else (if (i32.ne (local.get 1) (i32.const 1)) (unreachable)) (if (f64.ne (f64.reinterpret_i64 (local.get 2)) (f64.const 9)) (unreachable)) @@ -935,7 +935,7 @@ (i32.eqz (local.get 0)) if (if (i32.ne (local.get 1) (i32.const 0)) (unreachable)) - (if (f64.ne (f64.reinterpret_i64 (local.get 2)) (f64.const 10)) (unreachable)) + (if (f32.ne (f32.reinterpret_i32 (i32.wrap_i64 (local.get 2))) (f32.const 10)) (unreachable)) else (if (i32.ne (local.get 1) (i32.const 1)) (unreachable)) (if (i64.ne (local.get 2) (i64.const 11)) (unreachable)) @@ -983,10 +983,10 @@ (call $c (i32.const 0) (i32.const 0) (i64.const 6)) (call $c (i32.const 1) (i32.const 1) (i64.reinterpret_f64 (f64.const 7))) - (call $d (i32.const 0) (i32.const 0) (i64.reinterpret_f64 (f64.const 8))) + (call $d (i32.const 0) (i32.const 0) (i64.extend_i32_u (i32.reinterpret_f32 (f32.const 8)))) (call $d (i32.const 1) (i32.const 1) (i64.reinterpret_f64 (f64.const 9))) - (call $e (i32.const 0) (i32.const 0) (i64.reinterpret_f64 (f64.const 10))) + (call $e (i32.const 0) (i32.const 0) (i64.extend_i32_u (i32.reinterpret_f32 (f32.const 10)))) (call $e (i32.const 1) (i32.const 1) (i64.const 11)) ) (start $start)