diff --git a/crates/component-macro/tests/codegen/simple-wasi.wit b/crates/component-macro/tests/codegen/simple-wasi.wit index e2c2cee3514d..02c47dc9b8f4 100644 --- a/crates/component-macro/tests/codegen/simple-wasi.wit +++ b/crates/component-macro/tests/codegen/simple-wasi.wit @@ -10,6 +10,12 @@ interface wasi-filesystem { stat: func() -> result } +interface wall-clock { + record wall-clock { + } +} + default world wasi { import wasi-filesystem: self.wasi-filesystem + import wall-clock: self.wall-clock } diff --git a/crates/wit-bindgen/src/lib.rs b/crates/wit-bindgen/src/lib.rs index 21f02961318a..eb4753f0f627 100644 --- a/crates/wit-bindgen/src/lib.rs +++ b/crates/wit-bindgen/src/lib.rs @@ -1,4 +1,4 @@ -use crate::rust::{to_rust_ident, RustGenerator, TypeMode}; +use crate::rust::{to_rust_ident, to_rust_upper_camel_case, RustGenerator, TypeMode}; use crate::types::{TypeInfo, Types}; use heck::*; use std::collections::BTreeMap; @@ -162,7 +162,7 @@ impl Wasmtime { gen.generate_trappable_error_types(TypeOwner::Interface(*id)); let iface = &resolve.interfaces[*id]; - let camel = name.to_upper_camel_case(); + let camel = to_rust_upper_camel_case(name); uwriteln!(gen.src, "pub struct {camel} {{"); for (_, func) in iface.functions.iter() { uwriteln!( @@ -237,7 +237,7 @@ impl Wasmtime { } fn finish(&mut self, resolve: &Resolve, world: WorldId) -> String { - let camel = resolve.worlds[world].name.to_upper_camel_case(); + let camel = to_rust_upper_camel_case(&resolve.worlds[world].name); uwriteln!(self.src, "pub struct {camel} {{"); for (name, (ty, _)) in self.exports.fields.iter() { uwriteln!(self.src, "{name}: {ty},"); @@ -363,7 +363,7 @@ impl Wasmtime { return; } - let world_camel = resolve.worlds[world].name.to_upper_camel_case(); + let world_camel = to_rust_upper_camel_case(&resolve.worlds[world].name); if self.opts.async_ { uwriteln!(self.src, "#[wasmtime::component::__internal::async_trait]") } @@ -401,11 +401,11 @@ impl Wasmtime { where U: \ " ); - let world_camel = resolve.worlds[world].name.to_upper_camel_case(); + let world_camel = to_rust_upper_camel_case(&resolve.worlds[world].name); let world_trait = format!("{world_camel}Imports"); for (i, name) in interfaces .iter() - .map(|n| format!("{n}::{}", n.to_upper_camel_case())) + .map(|n| format!("{n}::Host")) .chain(if functions.is_empty() { None } else { @@ -591,7 +591,7 @@ impl<'a> InterfaceGenerator<'a> { self.rustdoc(docs); self.src.push_str("wasmtime::component::flags!(\n"); self.src - .push_str(&format!("{} {{\n", name.to_upper_camel_case())); + .push_str(&format!("{} {{\n", to_rust_upper_camel_case(name))); for flag in flags.flags.iter() { // TODO wasmtime-component-macro doesnt support docs for flags rn uwrite!( @@ -657,7 +657,7 @@ impl<'a> InterfaceGenerator<'a> { let info = self.info(id); for (name, mode) in self.modes_of(id) { - let name = name.to_upper_camel_case(); + let name = to_rust_upper_camel_case(&name); self.rustdoc(docs); let lt = self.lifetime_for(&info, mode); @@ -783,14 +783,14 @@ impl<'a> InterfaceGenerator<'a> { fn type_enum(&mut self, id: TypeId, name: &str, enum_: &Enum, docs: &Docs) { let info = self.info(id); - let name = name.to_upper_camel_case(); + let name = to_rust_upper_camel_case(name); self.rustdoc(docs); self.push_str("#[derive(wasmtime::component::ComponentType)]\n"); self.push_str("#[derive(wasmtime::component::Lift)]\n"); self.push_str("#[derive(wasmtime::component::Lower)]\n"); self.push_str("#[component(enum)]\n"); self.push_str("#[derive(Clone, Copy, PartialEq, Eq)]\n"); - self.push_str(&format!("pub enum {} {{\n", name.to_upper_camel_case())); + self.push_str(&format!("pub enum {} {{\n", name)); for case in enum_.cases.iter() { self.rustdoc(&case.docs); self.push_str(&format!("#[component(name = \"{}\")]", case.name)); @@ -953,7 +953,6 @@ impl<'a> InterfaceGenerator<'a> { fn generate_add_to_linker(&mut self, id: InterfaceId, name: &str) { let iface = &self.resolve.interfaces[id]; - let camel = name.to_upper_camel_case(); let owner = TypeOwner::Interface(id); if self.gen.opts.async_ { @@ -961,16 +960,16 @@ impl<'a> InterfaceGenerator<'a> { } // Generate the `pub trait` which represents the host functionality for // this import. - uwriteln!(self.src, "pub trait {camel}: Sized {{"); + uwriteln!(self.src, "pub trait Host: Sized {{"); for (_, func) in iface.functions.iter() { self.generate_function_trait_sig(owner, func); } uwriteln!(self.src, "}}"); let where_clause = if self.gen.opts.async_ { - format!("T: Send, U: {camel} + Send") + format!("T: Send, U: Host + Send") } else { - format!("U: {camel}") + format!("U: Host") }; uwriteln!( self.src, diff --git a/crates/wit-bindgen/src/rust.rs b/crates/wit-bindgen/src/rust.rs index 7253106b3bed..6a8c4a9b1aad 100644 --- a/crates/wit-bindgen/src/rust.rs +++ b/crates/wit-bindgen/src/rust.rs @@ -368,6 +368,7 @@ pub trait RustGenerator<'a> { } } +/// Translate `name` to a Rust `snake_case` identifier. pub fn to_rust_ident(name: &str) -> String { match name { // Escape Rust keywords. @@ -425,3 +426,13 @@ pub fn to_rust_ident(name: &str) -> String { s => s.to_snake_case(), } } + +/// Translate `name` to a Rust `UpperCamelCase` identifier. +pub fn to_rust_upper_camel_case(name: &str) -> String { + match name { + // We use `Host` as the name of the trait for host implementations + // to fill in, so rename it if "Host" is used as a regular identifier. + "host" => "Host_".into(), + s => s.to_upper_camel_case(), + } +} diff --git a/tests/all/component_model/bindgen.rs b/tests/all/component_model/bindgen.rs index ddadec8918e9..70681b6bfb22 100644 --- a/tests/all/component_model/bindgen.rs +++ b/tests/all/component_model/bindgen.rs @@ -97,7 +97,7 @@ mod one_import { hit: bool, } - impl foo::Foo for MyImports { + impl foo::Host for MyImports { fn foo(&mut self) -> Result<()> { self.hit = true; Ok(()) diff --git a/tests/all/component_model/bindgen/results.rs b/tests/all/component_model/bindgen/results.rs index 321eed9fdccf..0626af86f81c 100644 --- a/tests/all/component_model/bindgen/results.rs +++ b/tests/all/component_model/bindgen/results.rs @@ -60,7 +60,7 @@ mod empty_error { #[derive(Default)] struct MyImports {} - impl imports::Imports for MyImports { + impl imports::Host for MyImports { fn empty_error(&mut self, a: f64) -> Result, Error> { if a == 0.0 { Ok(Ok(a)) @@ -171,7 +171,7 @@ mod string_error { #[derive(Default)] struct MyImports {} - impl imports::Imports for MyImports { + impl imports::Host for MyImports { fn string_error(&mut self, a: f64) -> Result, Error> { if a == 0.0 { Ok(Ok(a)) @@ -313,7 +313,7 @@ mod enum_error { #[derive(Default)] struct MyImports {} - impl imports::Imports for MyImports { + impl imports::Host for MyImports { fn enum_error(&mut self, a: f64) -> Result { if a == 0.0 { Ok(a) @@ -440,7 +440,7 @@ mod record_error { #[derive(Default)] struct MyImports {} - impl imports::Imports for MyImports { + impl imports::Host for MyImports { fn record_error(&mut self, a: f64) -> Result { if a == 0.0 { Ok(a) @@ -576,7 +576,7 @@ mod variant_error { #[derive(Default)] struct MyImports {} - impl imports::Imports for MyImports { + impl imports::Host for MyImports { fn variant_error(&mut self, a: f64) -> Result { if a == 0.0 { Ok(a)