diff --git a/Cargo.toml b/Cargo.toml index 2c9664af..d15f58b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ lua52 = ["ffi/lua52"] lua51 = ["ffi/lua51"] luajit = ["ffi/luajit"] luajit52 = ["luajit", "ffi/luajit52"] -luau = ["ffi/luau"] +luau = ["ffi/luau", "libloading"] luau-jit = ["luau", "ffi/luau-codegen"] luau-vector4 = ["luau", "ffi/luau-vector4"] vendored = ["ffi/vendored"] @@ -57,6 +57,9 @@ parking_lot = { version = "0.12", optional = true } ffi = { package = "mlua-sys", version = "0.3.2", path = "mlua-sys" } +[target.'cfg(unix)'.dependencies] +libloading = { version = "0.8", optional = true } + [dev-dependencies] rustyline = "12.0" criterion = { version = "0.5", features = ["async_tokio"] } diff --git a/mlua-sys/build/main_inner.rs b/mlua-sys/build/main_inner.rs index 9b8b15d5..668f40b3 100644 --- a/mlua-sys/build/main_inner.rs +++ b/mlua-sys/build/main_inner.rs @@ -9,8 +9,8 @@ cfg_if::cfg_if! { } fn main() { - #[cfg(all(feature = "luau", feature = "module"))] - compile_error!("Luau does not support `module` mode"); + #[cfg(all(feature = "luau", feature = "module", windows))] + compile_error!("Luau does not support `module` mode on Windows"); #[cfg(all(feature = "module", feature = "vendored"))] compile_error!("`vendored` and `module` features are mutually exclusive"); diff --git a/src/lua.rs b/src/lua.rs index 1f4f9a69..07051ebb 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -222,7 +222,6 @@ impl LuaOptions { #[cfg(feature = "async")] pub(crate) static ASYNC_POLL_PENDING: u8 = 0; -#[cfg(not(feature = "luau"))] pub(crate) static EXTRA_REGISTRY_KEY: u8 = 0; const WRAPPED_FAILURE_POOL_SIZE: usize = 64; @@ -359,23 +358,22 @@ impl Lua { /// /// [`StdLib`]: crate::StdLib pub unsafe fn unsafe_new_with(libs: StdLib, options: LuaOptions) -> Lua { + // Workaround to avoid stripping a few unused Lua symbols that could be imported + // by C modules in unsafe mode + let mut _symbols: Vec<*const extern "C-unwind" fn()> = + vec![ffi::lua_isuserdata as _, ffi::lua_tocfunction as _]; + #[cfg(not(feature = "luau"))] + _symbols.extend_from_slice(&[ + ffi::lua_atpanic as _, + ffi::luaL_loadstring as _, + ffi::luaL_openlibs as _, + ]); + #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] { - // Workaround to avoid stripping a few unused Lua symbols that could be imported - // by C modules in unsafe mode - let mut _symbols: Vec<*const extern "C-unwind" fn()> = vec![ - ffi::lua_atpanic as _, - ffi::lua_isuserdata as _, - ffi::lua_tocfunction as _, - ffi::luaL_loadstring as _, - ffi::luaL_openlibs as _, - ]; - #[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52"))] - { - _symbols.push(ffi::lua_getglobal as _); - _symbols.push(ffi::lua_setglobal as _); - _symbols.push(ffi::luaL_setfuncs as _); - } + _symbols.push(ffi::lua_getglobal as _); + _symbols.push(ffi::lua_setglobal as _); + _symbols.push(ffi::luaL_setfuncs as _); } Self::inner_new(libs, options) @@ -3232,22 +3230,13 @@ impl<'a> Drop for StateGuard<'a> { } } -#[cfg(feature = "luau")] unsafe fn extra_data(state: *mut ffi::lua_State) -> *mut ExtraData { - (*ffi::lua_callbacks(state)).userdata as *mut ExtraData -} - -#[cfg(feature = "luau")] -unsafe fn set_extra_data( - state: *mut ffi::lua_State, - extra: &Arc>, -) -> Result<()> { - (*ffi::lua_callbacks(state)).userdata = extra.get() as *mut _; - Ok(()) -} + #[cfg(feature = "luau")] + if cfg!(not(feature = "module")) { + // In the main app we can use `lua_callbacks` to access ExtraData + return (*ffi::lua_callbacks(state)).userdata as *mut _; + } -#[cfg(not(feature = "luau"))] -unsafe fn extra_data(state: *mut ffi::lua_State) -> *mut ExtraData { let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void; if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, extra_key) != ffi::LUA_TUSERDATA { // `ExtraData` can be null only when Lua state is foreign. @@ -3260,11 +3249,16 @@ unsafe fn extra_data(state: *mut ffi::lua_State) -> *mut ExtraData { (*extra_ptr).get() } -#[cfg(not(feature = "luau"))] unsafe fn set_extra_data( state: *mut ffi::lua_State, extra: &Arc>, ) -> Result<()> { + #[cfg(feature = "luau")] + if cfg!(not(feature = "module")) { + (*ffi::lua_callbacks(state)).userdata = extra.get() as *mut _; + return Ok(()); + } + push_gc_userdata(state, Arc::clone(extra), true)?; protect_lua!(state, 1, 0, fn(state) { let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void; diff --git a/src/luau.rs b/src/luau.rs index 61b846b1..c1f279d9 100644 --- a/src/luau.rs +++ b/src/luau.rs @@ -12,11 +12,42 @@ use crate::table::Table; use crate::types::RegistryKey; use crate::value::{IntoLua, Value}; +#[cfg(unix)] +use {libloading::Library, rustc_hash::FxHashMap}; + // Since Luau has some missing standard function, we re-implement them here +#[cfg(unix)] +const TARGET_MLUA_LUAU_ABI_VERSION: u32 = 1; + +#[cfg(all(unix, feature = "module"))] +#[no_mangle] +#[used] +pub static MLUA_LUAU_ABI_VERSION: u32 = TARGET_MLUA_LUAU_ABI_VERSION; + // We keep reference to the `package` table in registry under this key struct PackageKey(RegistryKey); +// We keep reference to the loaded dylibs in application data +#[cfg(unix)] +struct LoadedDylibs(FxHashMap); + +#[cfg(unix)] +impl std::ops::Deref for LoadedDylibs { + type Target = FxHashMap; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[cfg(unix)] +impl std::ops::DerefMut for LoadedDylibs { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + impl Lua { pub(crate) unsafe fn prepare_luau_state(&self) -> Result<()> { let globals = self.globals(); @@ -162,6 +193,22 @@ fn create_package_table(lua: &Lua) -> Result { } package.raw_set("path", search_path)?; + // Set `package.cpath` + #[cfg(unix)] + { + let mut search_cpath = env::var("LUAU_CPATH") + .or_else(|_| env::var("LUA_CPATH")) + .unwrap_or_default(); + if search_cpath.is_empty() { + if cfg!(any(target_os = "macos", target_os = "ios")) { + search_cpath = "?.dylib".to_string(); + } else { + search_cpath = "?.so".to_string(); + } + } + package.raw_set("cpath", search_cpath)?; + } + // Set `package.loaded` (table with a list of loaded modules) let loaded = lua.create_table()?; package.raw_set("loaded", loaded.clone())?; @@ -170,6 +217,12 @@ fn create_package_table(lua: &Lua) -> Result
{ // Set `package.loaders` let loaders = lua.create_sequence_from([lua.create_function(lua_loader)?])?; package.raw_set("loaders", loaders.clone())?; + #[cfg(unix)] + { + loaders.push(lua.create_function(dylib_loader)?)?; + let loaded_dylibs = LoadedDylibs(FxHashMap::default()); + lua.set_app_data(loaded_dylibs); + } lua.set_named_registry_value("_LOADERS", loaders)?; Ok(package) @@ -225,3 +278,54 @@ fn lua_loader(lua: &Lua, modname: StdString) -> Result { Ok(Value::Nil) } + +/// Tries to load a dynamic library +#[cfg(unix)] +fn dylib_loader(lua: &Lua, modname: StdString) -> Result { + let package = { + let key = lua.app_data_ref::().unwrap(); + lua.registry_value::
(&key.0) + }?; + let search_cpath = package.get::<_, StdString>("cpath").unwrap_or_default(); + + let find_symbol = |lib: &Library| unsafe { + if let Ok(entry) = lib.get::(format!("luaopen_{modname}\0").as_bytes()) + { + return lua.create_c_function(*entry).map(Value::Function); + } + // Try all in one mode + if let Ok(entry) = lib.get::( + format!("luaopen_{}\0", modname.replace('.', "_")).as_bytes(), + ) { + return lua.create_c_function(*entry).map(Value::Function); + } + "cannot find module entrypoint".into_lua(lua) + }; + + if let Some(file_path) = package_searchpath(&modname, &search_cpath, true) { + let file_path = file_path.canonicalize()?; + // Load the library and check for symbol + unsafe { + // Check if it's already loaded + if let Some(lib) = lua.app_data_ref::().unwrap().get(&file_path) { + return find_symbol(lib); + } + if let Ok(lib) = Library::new(&file_path) { + // Check version + let mod_version = lib.get::<*const u32>(b"MLUA_LUAU_ABI_VERSION"); + let mod_version = mod_version.map(|v| **v).unwrap_or_default(); + if mod_version != TARGET_MLUA_LUAU_ABI_VERSION { + let err = format!("wrong module ABI version (expected {TARGET_MLUA_LUAU_ABI_VERSION}, got {mod_version})"); + return err.into_lua(lua); + } + let symbol = find_symbol(&lib); + lua.app_data_mut::() + .unwrap() + .insert(file_path, lib); + return symbol; + } + } + } + + Ok(Value::Nil) +} diff --git a/src/memory.rs b/src/memory.rs index ed685147..672a8647 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -4,6 +4,7 @@ use std::ptr; pub(crate) static ALLOCATOR: ffi::lua_Alloc = allocator; +#[repr(C)] #[derive(Default)] pub(crate) struct MemoryState { used_memory: isize, diff --git a/tests/module/Cargo.toml b/tests/module/Cargo.toml index c2e0da8d..f107ad73 100644 --- a/tests/module/Cargo.toml +++ b/tests/module/Cargo.toml @@ -18,6 +18,7 @@ lua53 = ["mlua/lua53"] lua52 = ["mlua/lua52"] lua51 = ["mlua/lua51"] luajit = ["mlua/luajit"] +luau = ["mlua/luau"] [dependencies] mlua = { path = "../..", features = ["module"] } diff --git a/tests/module/loader/Cargo.toml b/tests/module/loader/Cargo.toml index b51f002c..64b196ff 100644 --- a/tests/module/loader/Cargo.toml +++ b/tests/module/loader/Cargo.toml @@ -10,6 +10,7 @@ lua53 = ["mlua/lua53"] lua52 = ["mlua/lua52"] lua51 = ["mlua/lua51"] luajit = ["mlua/luajit"] +luau = ["mlua/luau"] vendored = ["mlua/vendored"] [dependencies] diff --git a/tests/module/loader/tests/load.rs b/tests/module/loader/tests/load.rs index d06ece4f..25f85ab0 100644 --- a/tests/module/loader/tests/load.rs +++ b/tests/module/loader/tests/load.rs @@ -4,7 +4,7 @@ use std::path::PathBuf; use mlua::{Lua, Result}; #[test] -fn test_module() -> Result<()> { +fn test_module_simple() -> Result<()> { let lua = make_lua()?; lua.load( r#"