Skip to content

Commit

Permalink
Improve DLL filtering on Windows (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleMayes committed May 29, 2024
1 parent d16b874 commit 74e4c3a
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 18 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## [1.8.1] - UNRELEASED

### Fixed
- Improve DLL search on Windows to take target architecture into account (e.g., ARM64 vs x86-64)

## [1.8.0] - 2024-05-26

### Changed
Expand Down
29 changes: 24 additions & 5 deletions build/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ fn parse_elf_header(path: &Path) -> io::Result<u8> {
}
}

/// Extracts the magic number from the PE header in a shared library.
fn parse_pe_header(path: &Path) -> io::Result<u16> {
/// Extracts the magic number and machine type from the PE header in a shared library.
fn parse_pe_header(path: &Path) -> io::Result<(u16, u16)> {
let mut file = File::open(path)?;

// Extract the header offset.
Expand All @@ -45,7 +45,15 @@ fn parse_pe_header(path: &Path) -> io::Result<u16> {
let mut buffer = [0; 2];
file.seek(SeekFrom::Current(20))?;
file.read_exact(&mut buffer)?;
Ok(u16::from_le_bytes(buffer))
let magic_number = u16::from_le_bytes(buffer);

// Extract the machine type.
let mut buffer = [0; 2];
file.seek(SeekFrom::Current(-22))?;
file.read_exact(&mut buffer)?;
let machine_type = u16::from_le_bytes(buffer);

return Ok((magic_number, machine_type));
}

/// Checks that a `libclang` shared library matches the target platform.
Expand All @@ -63,7 +71,7 @@ fn validate_library(path: &Path) -> Result<(), String> {

Ok(())
} else if target_os!("windows") {
let magic = parse_pe_header(path).map_err(|e| e.to_string())?;
let (magic, machine_type) = parse_pe_header(path).map_err(|e| e.to_string())?;

if target_pointer_width!("32") && magic != 267 {
return Err("invalid DLL (64-bit)".into());
Expand All @@ -73,7 +81,18 @@ fn validate_library(path: &Path) -> Result<(), String> {
return Err("invalid DLL (32-bit)".into());
}

Ok(())
let arch_mismatch = match machine_type {
0x014C if !target_arch!("x86") => Some("x86"),
0x8664 if !target_arch!("x86_64") => Some("x86-64"),
0xAA64 if !target_arch!("aarch64") => Some("ARM64"),
_ => None,
};

if let Some(arch) = arch_mismatch {
Err(format!("invalid DLL ({arch})"))
} else {
Ok(())
}
} else {
Ok(())
}
Expand Down
11 changes: 11 additions & 0 deletions build/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ macro_rules! target_os {
};
}

macro_rules! target_arch {
($arch:expr) => {
if cfg!(test) && ::std::env::var("_CLANG_SYS_TEST").is_ok() {
let var = ::std::env::var("_CLANG_SYS_TEST_ARCH");
var.map_or(false, |v| v == $arch)
} else {
cfg!(target_arch = $arch)
}
};
}

macro_rules! target_pointer_width {
($pointer_width:expr) => {
if cfg!(test) && ::std::env::var("_CLANG_SYS_TEST").is_ok() {
Expand Down
102 changes: 89 additions & 13 deletions tests/build.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(dead_code)]

use core::fmt;
use std::collections::HashMap;
use std::env;
use std::fs;
Expand All @@ -26,9 +27,38 @@ struct RunCommandMock {
responses: HashMap<Vec<String>, String>,
}


#[derive(Copy, Clone, Debug)]
enum Arch {
ARM64,
X86,
X86_64,
}

impl Arch {
fn pe_machine_type(self) -> u16 {
match self {
Arch::ARM64 => 0xAA64,
Arch::X86 => 0x014C,
Arch::X86_64 => 0x8664,
}
}
}

impl fmt::Display for Arch {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Arch::ARM64 => write!(f, "aarch64"),
Arch::X86 => write!(f, "x86"),
Arch::X86_64 => write!(f, "x86_64"),
}
}
}

#[derive(Debug)]
struct Env {
os: String,
arch: Arch,
pointer_width: String,
env: Option<String>,
vars: HashMap<String, (Option<String>, Option<String>)>,
Expand All @@ -39,9 +69,10 @@ struct Env {
}

impl Env {
fn new(os: &str, pointer_width: &str) -> Self {
fn new(os: &str, arch: Arch, pointer_width: &str) -> Self {
Env {
os: os.into(),
arch,
pointer_width: pointer_width.into(),
env: None,
vars: HashMap::new(),
Expand Down Expand Up @@ -84,11 +115,12 @@ impl Env {
self
}

fn dll(self, path: &str, pointer_width: &str) -> Self {
fn dll(self, path: &str, arch: Arch, pointer_width: &str) -> Self {
// PE header.
let mut contents = [0; 64];
contents[0x3C..0x3C + 4].copy_from_slice(&i32::to_le_bytes(10));
contents[10..14].copy_from_slice(&[b'P', b'E', 0, 0]);
contents[14..16].copy_from_slice(&u16::to_le_bytes(arch.pe_machine_type()));
let magic = if pointer_width == "64" { 523 } else { 267 };
contents[34..36].copy_from_slice(&u16::to_le_bytes(magic));

Expand Down Expand Up @@ -117,6 +149,7 @@ impl Env {
fn enable(self) -> Self {
env::set_var("_CLANG_SYS_TEST", "yep");
env::set_var("_CLANG_SYS_TEST_OS", &self.os);
env::set_var("_CLANG_SYS_TEST_ARCH", &format!("{}", self.arch));
env::set_var("_CLANG_SYS_TEST_POINTER_WIDTH", &self.pointer_width);
if let Some(env) = &self.env {
env::set_var("_CLANG_SYS_TEST_ENV", env);
Expand Down Expand Up @@ -155,6 +188,7 @@ impl Drop for Env {
fn drop(&mut self) {
env::remove_var("_CLANG_SYS_TEST");
env::remove_var("_CLANG_SYS_TEST_OS");
env::remove_var("_CLANG_SYS_TEST_ARCH");
env::remove_var("_CLANG_SYS_TEST_POINTER_WIDTH");
env::remove_var("_CLANG_SYS_TEST_ENV");

Expand Down Expand Up @@ -185,17 +219,31 @@ fn test_all() {
test_windows_bin_sibling();
test_windows_mingw_gnu();
test_windows_mingw_msvc();
test_windows_arm64_on_x86_64();
test_windows_x86_64_on_arm64();
}
}

macro_rules! assert_error {
($result:expr, $contents:expr $(,)?) => {
if let Err(error) = $result {
if !error.contains($contents) {
panic!("expected error to contain {:?}, received: {error:?}", $contents);
}
} else {
panic!("expected error, received: {:?}", $result);
}
};
}

//================================================
// Dynamic
//================================================

// Linux -----------------------------------------

fn test_linux_directory_preference() {
let _env = Env::new("linux", "64")
let _env = Env::new("linux", Arch::X86_64, "64")
.so("usr/lib/libclang.so.1", "64")
.so("usr/local/lib/libclang.so.1", "64")
.enable();
Expand All @@ -207,7 +255,7 @@ fn test_linux_directory_preference() {
}

fn test_linux_version_preference() {
let _env = Env::new("linux", "64")
let _env = Env::new("linux", Arch::X86_64, "64")
.so("usr/lib/libclang-3.so", "64")
.so("usr/lib/libclang-3.5.so", "64")
.so("usr/lib/libclang-3.5.0.so", "64")
Expand All @@ -220,7 +268,7 @@ fn test_linux_version_preference() {
}

fn test_linux_directory_and_version_preference() {
let _env = Env::new("linux", "64")
let _env = Env::new("linux", Arch::X86_64, "64")
.so("usr/local/llvm/lib/libclang-3.so", "64")
.so("usr/local/lib/libclang-3.5.so", "64")
.so("usr/lib/libclang-3.5.0.so", "64")
Expand All @@ -236,9 +284,9 @@ fn test_linux_directory_and_version_preference() {

#[cfg(target_os = "windows")]
fn test_windows_bin_sibling() {
let _env = Env::new("windows", "64")
let _env = Env::new("windows", Arch::X86_64, "64")
.dir("Program Files\\LLVM\\lib")
.dll("Program Files\\LLVM\\bin\\libclang.dll", "64")
.dll("Program Files\\LLVM\\bin\\libclang.dll", Arch::X86_64, "64")
.enable();

assert_eq!(
Expand All @@ -249,12 +297,12 @@ fn test_windows_bin_sibling() {

#[cfg(target_os = "windows")]
fn test_windows_mingw_gnu() {
let _env = Env::new("windows", "64")
let _env = Env::new("windows", Arch::X86_64, "64")
.env("gnu")
.dir("MSYS\\MinGW\\lib")
.dll("MSYS\\MinGW\\bin\\clang.dll", "64")
.dll("MSYS\\MinGW\\bin\\clang.dll", Arch::X86_64, "64")
.dir("Program Files\\LLVM\\lib")
.dll("Program Files\\LLVM\\bin\\libclang.dll", "64")
.dll("Program Files\\LLVM\\bin\\libclang.dll", Arch::X86_64, "64")
.enable();

assert_eq!(
Expand All @@ -265,16 +313,44 @@ fn test_windows_mingw_gnu() {

#[cfg(target_os = "windows")]
fn test_windows_mingw_msvc() {
let _env = Env::new("windows", "64")
let _env = Env::new("windows", Arch::X86_64, "64")
.env("msvc")
.dir("MSYS\\MinGW\\lib")
.dll("MSYS\\MinGW\\bin\\clang.dll", "64")
.dll("MSYS\\MinGW\\bin\\clang.dll", Arch::X86_64, "64")
.dir("Program Files\\LLVM\\lib")
.dll("Program Files\\LLVM\\bin\\libclang.dll", "64")
.dll("Program Files\\LLVM\\bin\\libclang.dll", Arch::X86_64, "64")
.enable();

assert_eq!(
dynamic::find(true),
Ok(("Program Files\\LLVM\\bin".into(), "libclang.dll".into())),
);
}

#[cfg(target_os = "windows")]
fn test_windows_arm64_on_x86_64() {
let _env = Env::new("windows", Arch::X86_64, "64")
.env("msvc")
.dir("Program Files\\LLVM\\lib")
.dll("Program Files\\LLVM\\bin\\libclang.dll", Arch::ARM64, "64")
.enable();

assert_error!(
dynamic::find(true),
"invalid: [(Program Files\\LLVM\\bin\\libclang.dll: invalid DLL (ARM64)",
);
}

#[cfg(target_os = "windows")]
fn test_windows_x86_64_on_arm64() {
let _env = Env::new("windows", Arch::ARM64, "64")
.env("msvc")
.dir("Program Files\\LLVM\\lib")
.dll("Program Files\\LLVM\\bin\\libclang.dll", Arch::X86_64, "64")
.enable();

assert_error!(
dynamic::find(true),
"invalid: [(Program Files\\LLVM\\bin\\libclang.dll: invalid DLL (x86-64)",
);
}

0 comments on commit 74e4c3a

Please sign in to comment.