Skip to content

Commit

Permalink
Add HIP bindings patch version 42131
Browse files Browse the repository at this point in the history
Fix CI
  • Loading branch information
syl20bnr committed Dec 10, 2024
1 parent c78c696 commit 75bcfaf
Show file tree
Hide file tree
Showing 6 changed files with 8,215 additions and 23 deletions.
20 changes: 10 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 11 additions & 3 deletions crates/build-script/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::path::Path;
use std::fmt;
use std::path::Path;

pub struct Version {
pub major: u8,
Expand Down Expand Up @@ -41,7 +41,11 @@ pub fn get_rocm_system_version(rocm_path: impl AsRef<Path>) -> std::io::Result<V
.parse::<u32>()
.expect("Invalid rocm_version.h file structure: Couldn't parse patch version.");

Ok(Version { major, minor, patch })
Ok(Version {
major,
minor,
patch,
})
}

/// Reads the HIP header inside the rocm folder that contains the HIP specific version
Expand Down Expand Up @@ -72,5 +76,9 @@ pub fn get_hip_system_version(rocm_path: impl AsRef<Path>) -> std::io::Result<Ve
.parse::<u32>()
.expect("Invalid hip_version.h file structure: Couldn't parse patch version.");

Ok(Version { major, minor, patch })
Ok(Version {
major,
minor,
patch,
})
}
24 changes: 16 additions & 8 deletions crates/cubecl-hip-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ fn check_rocm_version(rocm_path: impl AsRef<Path>) -> std::io::Result<bool> {
let rocm_system_version = get_rocm_system_version(rocm_path)?;
let rocm_feature_version = get_rocm_feature_version();

if rocm_system_version.major == rocm_feature_version.major {
let mismatches = match (rocm_system_version.minor == rocm_feature_version.minor, rocm_system_version.patch == rocm_feature_version.patch) {
if rocm_system_version.major == rocm_feature_version.major {
let mismatches = match (
rocm_system_version.minor == rocm_feature_version.minor,
rocm_system_version.patch == rocm_feature_version.patch,
) {
// Perfect match, don't need a warning
(true, true) => return Ok(true),
(true, false) => "Patch",
Expand All @@ -59,7 +62,11 @@ fn get_rocm_feature_version() -> Version {
parts[1].parse::<u8>(),
parts[2].parse::<u32>(),
) {
return Version {major, minor, patch};
return Version {
major,
minor,
patch,
};
}
}
}
Expand All @@ -85,7 +92,6 @@ fn get_hip_feature_patch_version() -> u32 {
}

fn main() {

println!("cargo::rerun-if-changed=build.rs");
println!("cargo::rerun-if-env-changed=CUBECL_ROCM_PATH");
println!("cargo::rerun-if-env-changed=ROCM_PATH");
Expand All @@ -110,11 +116,14 @@ fn main() {

if let Some(valid_rocm_path) = rocm_path {
ensure_single_rocm_hip_feature_set();
// verify HIP compatbility
let Version {patch: hip_system_patch_version, ..} = get_hip_system_version(valid_rocm_path).unwrap();
// verify HIP compatibility
let Version {
patch: hip_system_patch_version,
..
} = get_hip_system_version(valid_rocm_path).unwrap();
let hip_feature_patch_version = get_hip_feature_patch_version();
if hip_system_patch_version != hip_feature_patch_version {
panic!("Imcompatible HIP bindings found. Expected to find HIP patch version {hip_feature_patch_version}, but found HIP patch version {hip_system_patch_version}.");
panic!("Incompatible HIP bindings found. Expected to find HIP patch version {hip_feature_patch_version}, but found HIP patch version {hip_system_patch_version}.");
}

println!("cargo::rustc-link-lib=dylib=hiprtc");
Expand All @@ -127,4 +136,3 @@ fn main() {
panic!("HIP headers not found in any of the directories set in CUBECL_ROCM_PATH, ROCM_PATH or HIP_PATH environment variable.");
}
}

Loading

0 comments on commit 75bcfaf

Please sign in to comment.