Skip to content

Commit

Permalink
Refactor version management
Browse files Browse the repository at this point in the history
We now check for the ROCm global version first, if there is a mismatch we
either panic (major version mismatch) or print a warning during the build
(minor or patch version mismatch).

Then we check for bindings compability using the patch number of the HIP
library.
  • Loading branch information
syl20bnr committed Dec 9, 2024
1 parent e889bad commit c78c696
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 8,208 deletions.
5 changes: 5 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions crates/build-script/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[package]
authors = ["Tracel Technologies Inc."]
name = "build-script"
edition.workspace = true
license.workspace = true
readme.workspace = true
version.workspace = true
rust-version = "1.81"

76 changes: 76 additions & 0 deletions crates/build-script/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use std::path::Path;
use std::fmt;

pub struct Version {
pub major: u8,
pub minor: u8,
pub patch: u32,
}

impl fmt::Display for Version {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
}
}

/// Reads the header inside the rocm folder that contains the ROCm global version
pub fn get_rocm_system_version(rocm_path: impl AsRef<Path>) -> std::io::Result<Version> {
let version_path = rocm_path.as_ref().join("include/rocm-core/rocm_version.h");
let version_file = std::fs::read_to_string(version_path)?;
let version_lines = version_file.lines().collect::<Vec<_>>();

let major = version_lines
.iter()
.find_map(|line| line.strip_prefix("#define ROCM_VERSION_MAJOR "))
.expect("Invalid rocm_version.h file structure: Major version line not found.")
.trim()
.parse::<u8>()
.expect("Invalid rocm_version.h file structure: Couldn't parse major version.");
let minor = version_lines
.iter()
.find_map(|line| line.strip_prefix("#define ROCM_VERSION_MINOR "))
.expect("Invalid rocm_version.h file structure: Minor version line not found.")
.trim()
.parse::<u8>()
.expect("Invalid rocm_version.h file structure: Couldn't parse minor version.");
let patch = version_lines
.iter()
.find_map(|line| line.strip_prefix("#define ROCM_VERSION_PATCH "))
.expect("Invalid rocm_version.h file structure: Patch version line not found.")
.trim()
.parse::<u32>()
.expect("Invalid rocm_version.h file structure: Couldn't parse patch version.");

Ok(Version { major, minor, patch })
}

/// Reads the HIP header inside the rocm folder that contains the HIP specific version
pub fn get_hip_system_version(rocm_path: impl AsRef<Path>) -> std::io::Result<Version> {
let version_path = rocm_path.as_ref().join("include/hip/hip_version.h");
let version_file = std::fs::read_to_string(version_path)?;
let version_lines = version_file.lines().collect::<Vec<_>>();

let major = version_lines
.iter()
.find_map(|line| line.strip_prefix("#define HIP_VERSION_MAJOR "))
.expect("Invalid hip_version.h file structure: Major version line not found.")
.trim()
.parse::<u8>()
.expect("Invalid hip_version.h file structure: Couldn't parse major version.");
let minor = version_lines
.iter()
.find_map(|line| line.strip_prefix("#define HIP_VERSION_MINOR "))
.expect("Invalid hip_version.h file structure: Minor version line not found.")
.trim()
.parse::<u8>()
.expect("Invalid hip_version.h file structure: Couldn't parse minor version.");
let patch = version_lines
.iter()
.find_map(|line| line.strip_prefix("#define HIP_VERSION_PATCH "))
.expect("Invalid hip_version.h file structure: Patch version line not found.")
.trim()
.parse::<u32>()
.expect("Invalid hip_version.h file structure: Couldn't parse patch version.");

Ok(Version { major, minor, patch })
}
11 changes: 9 additions & 2 deletions crates/cubecl-hip-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@ rust-version = "1.81"

[features]
default = ["rocm__6_2_2"]
rocm__6_2_2 = []
rocm__6_2_4 = []

# ROCm versions
rocm__6_2_2 = [ "hip_41134" ]
rocm__6_2_4 = [ "hip_41134" ]
rocm__6_3_0 = [ "hip_42131" ]

# HIP versions
hip_41134 = []
hip_42131 = []

[dependencies]
libc = { workspace = true }
Expand Down
161 changes: 64 additions & 97 deletions crates/cubecl-hip-sys/build.rs
Original file line number Diff line number Diff line change
@@ -1,60 +1,54 @@
use std::path::Path;
use std::{env, io};
use std::env;

const ROCM_FEATURE_PREFIX: &str = "CARGO_FEATURE_ROCM__";
const ROCM_HIP_FEATURE_PREFIX: &str = "CARGO_FEATURE_HIP_";

/// Reads a header inside the rocm folder, that contains the lib's version
fn get_system_hip_version(rocm_path: impl AsRef<Path>) -> std::io::Result<(u8, u8, u32)> {
let version_path = rocm_path.as_ref().join("include/hip/hip_version.h");
let version_file = std::fs::read_to_string(version_path)?;
let version_lines = version_file.lines().collect::<Vec<_>>();
include!("../build-script/src/lib.rs");

let system_major = version_lines
.iter()
.find_map(|line| line.strip_prefix("#define HIP_VERSION_MAJOR "))
.expect("Invalid hip_version.h file structure: Major version line not found")
.parse::<u8>()
.expect("Invalid hip_version.h file structure: Couldn't parse major version");
let system_minor = version_lines
.iter()
.find_map(|line| line.strip_prefix("#define HIP_VERSION_MINOR "))
.expect("Invalid hip_version.h file structure: Minor version line not found")
.parse::<u8>()
.expect("Invalid hip_version.h file structure: Couldn't parse minor version");
let system_patch = version_lines
.iter()
.find_map(|line| line.strip_prefix("#define HIP_VERSION_PATCH "))
.expect("Invalid hip_version.h file structure: Patch version line not found")
.parse::<u32>()
.expect("Invalid hip_version.h file structure: Couldn't parse patch version");
let release_patch = hip_header_patch_number_to_release_patch_number(system_patch);
if release_patch.is_none() {
println!("cargo::warning=Unknown release version for patch version {system_patch}. This patch does not correspond to an official release patch.");
/// Make sure that at least one and only one hip feature is set
fn ensure_single_rocm_hip_feature_set() {
let mut enabled_features = Vec::new();

for (key, value) in env::vars() {
if key.starts_with(ROCM_HIP_FEATURE_PREFIX) && value == "1" {
enabled_features.push(format!(
"rocm__{}",
key.strip_prefix(ROCM_HIP_FEATURE_PREFIX).unwrap()
));
}
}

Ok((
system_major,
system_minor,
release_patch.unwrap_or(system_patch),
))
match enabled_features.len() {
1 => {}
0 => panic!("No ROCm HIP feature is enabled. One ROCm HIP feature must be set."),
_ => panic!(
"Multiple ROCm HIP features are enabled: {:?}. Only one can be set.",
enabled_features
),
}
}

/// The official patch number of a ROCm release is not the same of the patch number
/// in the header files. In the header files the patch number is a monotonic build
/// that changes only when there are actual changes in the HIP libraries.
/// This function maps the header patch number to their official latest release number.
/// For instance if both versions 6.2.2 and 6.2.4 have the same patch version in their
/// header file then this function will return 4.
fn hip_header_patch_number_to_release_patch_number(number: u32) -> Option<u32> {
match number {
41134 => Some(4), // 6.2.4
42131 => Some(0), // 6.3.0
_ => None,
/// Checks if the version inside `rocm_path` matches crate version
fn check_rocm_version(rocm_path: impl AsRef<Path>) -> std::io::Result<bool> {
let rocm_system_version = get_rocm_system_version(rocm_path)?;
let rocm_feature_version = get_rocm_feature_version();

if rocm_system_version.major == rocm_feature_version.major {
let mismatches = match (rocm_system_version.minor == rocm_feature_version.minor, rocm_system_version.patch == rocm_feature_version.patch) {
// Perfect match, don't need a warning
(true, true) => return Ok(true),
(true, false) => "Patch",
(false, _) => "Minor",
};
println!("cargo::warning=ROCm {mismatches} version mismatch between cubecl-hip-sys expected version ({rocm_feature_version}) and found ROCm version on the system ({rocm_system_version}). Build process might fail due to incompatible library bindings.");
Ok(true)
} else {
Ok(false)
}
}

/// Return the ROCm version corresponding to the enabled feature
fn get_rocm_feature_version() -> io::Result<(u8, u8, u32)> {
/// Return the ROCm version corresponding to the enabled rocm__<version> feature
fn get_rocm_feature_version() -> Version {
for (key, value) in env::vars() {
if key.starts_with(ROCM_FEATURE_PREFIX) && value == "1" {
if let Some(version) = key.strip_prefix(ROCM_FEATURE_PREFIX) {
Expand All @@ -65,66 +59,32 @@ fn get_rocm_feature_version() -> io::Result<(u8, u8, u32)> {
parts[1].parse::<u8>(),
parts[2].parse::<u32>(),
) {
return Ok((major, minor, patch));
return Version {major, minor, patch};
}
}
}
}
}

Err(io::Error::new(
io::ErrorKind::NotFound,
"No valid ROCm feature version found. One 'rocm__<version>' feature must be set. For instance for ROCm 6.2.2 the feature is rocm__6_2_2.",
))
panic!("No valid ROCm feature version found. One 'rocm__<version>' feature must be set. For instance for ROCm 6.2.2 the feature is rocm__6_2_2.")
}

/// Make sure that feature is set correctly
fn ensure_single_rocm_feature_set() {
let mut enabled_features = Vec::new();

/// Return the ROCm HIP patch version corresponding to the enabled hip_<patch_version> feature
fn get_hip_feature_patch_version() -> u32 {
for (key, value) in env::vars() {
if key.starts_with(ROCM_FEATURE_PREFIX) && value == "1" {
enabled_features.push(format!(
"rocm__{}",
key.strip_prefix(ROCM_FEATURE_PREFIX).unwrap()
));
if key.starts_with(ROCM_HIP_FEATURE_PREFIX) && value == "1" {
if let Some(patch) = key.strip_prefix(ROCM_HIP_FEATURE_PREFIX) {
if let Ok(patch) = patch.parse::<u32>() {
return patch;
}
}
}
}

match enabled_features.len() {
1 => {}
0 => panic!("No ROCm version features are enabled. One ROCm version feature must be set."),
_ => panic!(
"Multiple ROCm version features are enabled: {:?}. Only one can be set.",
enabled_features
),
}
}

/// Checks if the version inside `rocm_path` matches crate version
fn check_version(rocm_path: impl AsRef<Path>) -> std::io::Result<bool> {
let (system_major, system_minor, system_patch) = get_system_hip_version(rocm_path)?;
let (crate_major, crate_minor, crate_patch) = get_rocm_feature_version()?;

if crate_major == system_major {
let mismatches = match (crate_minor == system_minor, crate_patch == system_patch) {
// Perfect match, don't need a warning
(true, true) => return Ok(true),
(false, true) => "Minor",
(true, false) => "Patch",
(false, false) => "Both minor and patch",
};
println!("cargo::warning={mismatches} version mismatch between cubecl-hip-sys bindings and system HIP version. Want {}, but found {}",
[crate_major, crate_minor, crate_patch as u8].map(|el| el.to_string()).join("."),
[system_major, system_minor, system_patch as u8].map(|el| el.to_string()).join("."));
Ok(true)
} else {
Ok(false)
}
panic!("No valid ROCm HIP feature found. One 'hip_<patch>' feature must be set.")
}

fn main() {
ensure_single_rocm_feature_set();

println!("cargo::rerun-if-changed=build.rs");
println!("cargo::rerun-if-env-changed=CUBECL_ROCM_PATH");
Expand All @@ -146,18 +106,25 @@ fn main() {
})
.peekable();
let have_candidates = rocm_path_candidates.peek().is_some();
let rocm_path = rocm_path_candidates.find(|path| check_version(path).unwrap_or_default());
let rocm_path = rocm_path_candidates.find(|path| check_rocm_version(path).unwrap_or_default());

if let Some(valid_rocm_path) = rocm_path {
ensure_single_rocm_hip_feature_set();
// verify HIP compatbility
let Version {patch: hip_system_patch_version, ..} = get_hip_system_version(valid_rocm_path).unwrap();
let hip_feature_patch_version = get_hip_feature_patch_version();
if hip_system_patch_version != hip_feature_patch_version {
panic!("Imcompatible HIP bindings found. Expected to find HIP patch version {hip_feature_patch_version}, but found HIP patch version {hip_system_patch_version}.");
}

println!("cargo::rustc-link-lib=dylib=hiprtc");
println!("cargo::rustc-link-lib=dylib=amdhip64");
println!("cargo::rustc-link-search=native={}/lib", valid_rocm_path);
} else if have_candidates {
panic!(
"None of the found ROCm installations match crate version {}",
env!("CARGO_PKG_VERSION")
);
let rocm_feature_version = get_rocm_feature_version();
panic!("None of the found ROCm installations match version {rocm_feature_version}.");
} else if paths.len() > 1 {
panic!("HIP headers not found in any of the defined CUBECL_ROCM_PATH, ROCM_PATH or HIP_PATH directories.");
panic!("HIP headers not found in any of the directories set in CUBECL_ROCM_PATH, ROCM_PATH or HIP_PATH environment variable.");
}
}

Loading

0 comments on commit c78c696

Please sign in to comment.