Skip to content

Commit c78c696

Browse files
committed
Refactor version management
We now check for the ROCm global version first, if there is a mismatch we either panic (major version mismatch) or print a warning during the build (minor or patch version mismatch). Then we check for bindings compability using the patch number of the HIP library.
1 parent e889bad commit c78c696

File tree

10 files changed

+176
-8208
lines changed

10 files changed

+176
-8208
lines changed

Cargo.lock

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/build-script/Cargo.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[package]
2+
authors = ["Tracel Technologies Inc."]
3+
name = "build-script"
4+
edition.workspace = true
5+
license.workspace = true
6+
readme.workspace = true
7+
version.workspace = true
8+
rust-version = "1.81"
9+

crates/build-script/src/lib.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
use std::path::Path;
2+
use std::fmt;
3+
4+
pub struct Version {
5+
pub major: u8,
6+
pub minor: u8,
7+
pub patch: u32,
8+
}
9+
10+
impl fmt::Display for Version {
11+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12+
write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
13+
}
14+
}
15+
16+
/// Reads the header inside the rocm folder that contains the ROCm global version
17+
pub fn get_rocm_system_version(rocm_path: impl AsRef<Path>) -> std::io::Result<Version> {
18+
let version_path = rocm_path.as_ref().join("include/rocm-core/rocm_version.h");
19+
let version_file = std::fs::read_to_string(version_path)?;
20+
let version_lines = version_file.lines().collect::<Vec<_>>();
21+
22+
let major = version_lines
23+
.iter()
24+
.find_map(|line| line.strip_prefix("#define ROCM_VERSION_MAJOR "))
25+
.expect("Invalid rocm_version.h file structure: Major version line not found.")
26+
.trim()
27+
.parse::<u8>()
28+
.expect("Invalid rocm_version.h file structure: Couldn't parse major version.");
29+
let minor = version_lines
30+
.iter()
31+
.find_map(|line| line.strip_prefix("#define ROCM_VERSION_MINOR "))
32+
.expect("Invalid rocm_version.h file structure: Minor version line not found.")
33+
.trim()
34+
.parse::<u8>()
35+
.expect("Invalid rocm_version.h file structure: Couldn't parse minor version.");
36+
let patch = version_lines
37+
.iter()
38+
.find_map(|line| line.strip_prefix("#define ROCM_VERSION_PATCH "))
39+
.expect("Invalid rocm_version.h file structure: Patch version line not found.")
40+
.trim()
41+
.parse::<u32>()
42+
.expect("Invalid rocm_version.h file structure: Couldn't parse patch version.");
43+
44+
Ok(Version { major, minor, patch })
45+
}
46+
47+
/// Reads the HIP header inside the rocm folder that contains the HIP specific version
48+
pub fn get_hip_system_version(rocm_path: impl AsRef<Path>) -> std::io::Result<Version> {
49+
let version_path = rocm_path.as_ref().join("include/hip/hip_version.h");
50+
let version_file = std::fs::read_to_string(version_path)?;
51+
let version_lines = version_file.lines().collect::<Vec<_>>();
52+
53+
let major = version_lines
54+
.iter()
55+
.find_map(|line| line.strip_prefix("#define HIP_VERSION_MAJOR "))
56+
.expect("Invalid hip_version.h file structure: Major version line not found.")
57+
.trim()
58+
.parse::<u8>()
59+
.expect("Invalid hip_version.h file structure: Couldn't parse major version.");
60+
let minor = version_lines
61+
.iter()
62+
.find_map(|line| line.strip_prefix("#define HIP_VERSION_MINOR "))
63+
.expect("Invalid hip_version.h file structure: Minor version line not found.")
64+
.trim()
65+
.parse::<u8>()
66+
.expect("Invalid hip_version.h file structure: Couldn't parse minor version.");
67+
let patch = version_lines
68+
.iter()
69+
.find_map(|line| line.strip_prefix("#define HIP_VERSION_PATCH "))
70+
.expect("Invalid hip_version.h file structure: Patch version line not found.")
71+
.trim()
72+
.parse::<u32>()
73+
.expect("Invalid hip_version.h file structure: Couldn't parse patch version.");
74+
75+
Ok(Version { major, minor, patch })
76+
}

crates/cubecl-hip-sys/Cargo.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,15 @@ rust-version = "1.81"
1313

1414
[features]
1515
default = ["rocm__6_2_2"]
16-
rocm__6_2_2 = []
17-
rocm__6_2_4 = []
16+
17+
# ROCm versions
18+
rocm__6_2_2 = [ "hip_41134" ]
19+
rocm__6_2_4 = [ "hip_41134" ]
20+
rocm__6_3_0 = [ "hip_42131" ]
21+
22+
# HIP versions
23+
hip_41134 = []
24+
hip_42131 = []
1825

1926
[dependencies]
2027
libc = { workspace = true }

crates/cubecl-hip-sys/build.rs

Lines changed: 64 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,54 @@
1-
use std::path::Path;
2-
use std::{env, io};
1+
use std::env;
32

43
const ROCM_FEATURE_PREFIX: &str = "CARGO_FEATURE_ROCM__";
4+
const ROCM_HIP_FEATURE_PREFIX: &str = "CARGO_FEATURE_HIP_";
55

6-
/// Reads a header inside the rocm folder, that contains the lib's version
7-
fn get_system_hip_version(rocm_path: impl AsRef<Path>) -> std::io::Result<(u8, u8, u32)> {
8-
let version_path = rocm_path.as_ref().join("include/hip/hip_version.h");
9-
let version_file = std::fs::read_to_string(version_path)?;
10-
let version_lines = version_file.lines().collect::<Vec<_>>();
6+
include!("../build-script/src/lib.rs");
117

12-
let system_major = version_lines
13-
.iter()
14-
.find_map(|line| line.strip_prefix("#define HIP_VERSION_MAJOR "))
15-
.expect("Invalid hip_version.h file structure: Major version line not found")
16-
.parse::<u8>()
17-
.expect("Invalid hip_version.h file structure: Couldn't parse major version");
18-
let system_minor = version_lines
19-
.iter()
20-
.find_map(|line| line.strip_prefix("#define HIP_VERSION_MINOR "))
21-
.expect("Invalid hip_version.h file structure: Minor version line not found")
22-
.parse::<u8>()
23-
.expect("Invalid hip_version.h file structure: Couldn't parse minor version");
24-
let system_patch = version_lines
25-
.iter()
26-
.find_map(|line| line.strip_prefix("#define HIP_VERSION_PATCH "))
27-
.expect("Invalid hip_version.h file structure: Patch version line not found")
28-
.parse::<u32>()
29-
.expect("Invalid hip_version.h file structure: Couldn't parse patch version");
30-
let release_patch = hip_header_patch_number_to_release_patch_number(system_patch);
31-
if release_patch.is_none() {
32-
println!("cargo::warning=Unknown release version for patch version {system_patch}. This patch does not correspond to an official release patch.");
8+
/// Make sure that at least one and only one hip feature is set
9+
fn ensure_single_rocm_hip_feature_set() {
10+
let mut enabled_features = Vec::new();
11+
12+
for (key, value) in env::vars() {
13+
if key.starts_with(ROCM_HIP_FEATURE_PREFIX) && value == "1" {
14+
enabled_features.push(format!(
15+
"rocm__{}",
16+
key.strip_prefix(ROCM_HIP_FEATURE_PREFIX).unwrap()
17+
));
18+
}
3319
}
3420

35-
Ok((
36-
system_major,
37-
system_minor,
38-
release_patch.unwrap_or(system_patch),
39-
))
21+
match enabled_features.len() {
22+
1 => {}
23+
0 => panic!("No ROCm HIP feature is enabled. One ROCm HIP feature must be set."),
24+
_ => panic!(
25+
"Multiple ROCm HIP features are enabled: {:?}. Only one can be set.",
26+
enabled_features
27+
),
28+
}
4029
}
4130

42-
/// The official patch number of a ROCm release is not the same of the patch number
43-
/// in the header files. In the header files the patch number is a monotonic build
44-
/// that changes only when there are actual changes in the HIP libraries.
45-
/// This function maps the header patch number to their official latest release number.
46-
/// For instance if both versions 6.2.2 and 6.2.4 have the same patch version in their
47-
/// header file then this function will return 4.
48-
fn hip_header_patch_number_to_release_patch_number(number: u32) -> Option<u32> {
49-
match number {
50-
41134 => Some(4), // 6.2.4
51-
42131 => Some(0), // 6.3.0
52-
_ => None,
31+
/// Checks if the version inside `rocm_path` matches crate version
32+
fn check_rocm_version(rocm_path: impl AsRef<Path>) -> std::io::Result<bool> {
33+
let rocm_system_version = get_rocm_system_version(rocm_path)?;
34+
let rocm_feature_version = get_rocm_feature_version();
35+
36+
if rocm_system_version.major == rocm_feature_version.major {
37+
let mismatches = match (rocm_system_version.minor == rocm_feature_version.minor, rocm_system_version.patch == rocm_feature_version.patch) {
38+
// Perfect match, don't need a warning
39+
(true, true) => return Ok(true),
40+
(true, false) => "Patch",
41+
(false, _) => "Minor",
42+
};
43+
println!("cargo::warning=ROCm {mismatches} version mismatch between cubecl-hip-sys expected version ({rocm_feature_version}) and found ROCm version on the system ({rocm_system_version}). Build process might fail due to incompatible library bindings.");
44+
Ok(true)
45+
} else {
46+
Ok(false)
5347
}
5448
}
5549

56-
/// Return the ROCm version corresponding to the enabled feature
57-
fn get_rocm_feature_version() -> io::Result<(u8, u8, u32)> {
50+
/// Return the ROCm version corresponding to the enabled rocm__<version> feature
51+
fn get_rocm_feature_version() -> Version {
5852
for (key, value) in env::vars() {
5953
if key.starts_with(ROCM_FEATURE_PREFIX) && value == "1" {
6054
if let Some(version) = key.strip_prefix(ROCM_FEATURE_PREFIX) {
@@ -65,66 +59,32 @@ fn get_rocm_feature_version() -> io::Result<(u8, u8, u32)> {
6559
parts[1].parse::<u8>(),
6660
parts[2].parse::<u32>(),
6761
) {
68-
return Ok((major, minor, patch));
62+
return Version {major, minor, patch};
6963
}
7064
}
7165
}
7266
}
7367
}
7468

75-
Err(io::Error::new(
76-
io::ErrorKind::NotFound,
77-
"No valid ROCm feature version found. One 'rocm__<version>' feature must be set. For instance for ROCm 6.2.2 the feature is rocm__6_2_2.",
78-
))
69+
panic!("No valid ROCm feature version found. One 'rocm__<version>' feature must be set. For instance for ROCm 6.2.2 the feature is rocm__6_2_2.")
7970
}
8071

81-
/// Make sure that feature is set correctly
82-
fn ensure_single_rocm_feature_set() {
83-
let mut enabled_features = Vec::new();
84-
72+
/// Return the ROCm HIP patch version corresponding to the enabled hip_<patch_version> feature
73+
fn get_hip_feature_patch_version() -> u32 {
8574
for (key, value) in env::vars() {
86-
if key.starts_with(ROCM_FEATURE_PREFIX) && value == "1" {
87-
enabled_features.push(format!(
88-
"rocm__{}",
89-
key.strip_prefix(ROCM_FEATURE_PREFIX).unwrap()
90-
));
75+
if key.starts_with(ROCM_HIP_FEATURE_PREFIX) && value == "1" {
76+
if let Some(patch) = key.strip_prefix(ROCM_HIP_FEATURE_PREFIX) {
77+
if let Ok(patch) = patch.parse::<u32>() {
78+
return patch;
79+
}
80+
}
9181
}
9282
}
9383

94-
match enabled_features.len() {
95-
1 => {}
96-
0 => panic!("No ROCm version features are enabled. One ROCm version feature must be set."),
97-
_ => panic!(
98-
"Multiple ROCm version features are enabled: {:?}. Only one can be set.",
99-
enabled_features
100-
),
101-
}
102-
}
103-
104-
/// Checks if the version inside `rocm_path` matches crate version
105-
fn check_version(rocm_path: impl AsRef<Path>) -> std::io::Result<bool> {
106-
let (system_major, system_minor, system_patch) = get_system_hip_version(rocm_path)?;
107-
let (crate_major, crate_minor, crate_patch) = get_rocm_feature_version()?;
108-
109-
if crate_major == system_major {
110-
let mismatches = match (crate_minor == system_minor, crate_patch == system_patch) {
111-
// Perfect match, don't need a warning
112-
(true, true) => return Ok(true),
113-
(false, true) => "Minor",
114-
(true, false) => "Patch",
115-
(false, false) => "Both minor and patch",
116-
};
117-
println!("cargo::warning={mismatches} version mismatch between cubecl-hip-sys bindings and system HIP version. Want {}, but found {}",
118-
[crate_major, crate_minor, crate_patch as u8].map(|el| el.to_string()).join("."),
119-
[system_major, system_minor, system_patch as u8].map(|el| el.to_string()).join("."));
120-
Ok(true)
121-
} else {
122-
Ok(false)
123-
}
84+
panic!("No valid ROCm HIP feature found. One 'hip_<patch>' feature must be set.")
12485
}
12586

12687
fn main() {
127-
ensure_single_rocm_feature_set();
12888

12989
println!("cargo::rerun-if-changed=build.rs");
13090
println!("cargo::rerun-if-env-changed=CUBECL_ROCM_PATH");
@@ -146,18 +106,25 @@ fn main() {
146106
})
147107
.peekable();
148108
let have_candidates = rocm_path_candidates.peek().is_some();
149-
let rocm_path = rocm_path_candidates.find(|path| check_version(path).unwrap_or_default());
109+
let rocm_path = rocm_path_candidates.find(|path| check_rocm_version(path).unwrap_or_default());
150110

151111
if let Some(valid_rocm_path) = rocm_path {
112+
ensure_single_rocm_hip_feature_set();
113+
// verify HIP compatbility
114+
let Version {patch: hip_system_patch_version, ..} = get_hip_system_version(valid_rocm_path).unwrap();
115+
let hip_feature_patch_version = get_hip_feature_patch_version();
116+
if hip_system_patch_version != hip_feature_patch_version {
117+
panic!("Imcompatible HIP bindings found. Expected to find HIP patch version {hip_feature_patch_version}, but found HIP patch version {hip_system_patch_version}.");
118+
}
119+
152120
println!("cargo::rustc-link-lib=dylib=hiprtc");
153121
println!("cargo::rustc-link-lib=dylib=amdhip64");
154122
println!("cargo::rustc-link-search=native={}/lib", valid_rocm_path);
155123
} else if have_candidates {
156-
panic!(
157-
"None of the found ROCm installations match crate version {}",
158-
env!("CARGO_PKG_VERSION")
159-
);
124+
let rocm_feature_version = get_rocm_feature_version();
125+
panic!("None of the found ROCm installations match version {rocm_feature_version}.");
160126
} else if paths.len() > 1 {
161-
panic!("HIP headers not found in any of the defined CUBECL_ROCM_PATH, ROCM_PATH or HIP_PATH directories.");
127+
panic!("HIP headers not found in any of the directories set in CUBECL_ROCM_PATH, ROCM_PATH or HIP_PATH environment variable.");
162128
}
163129
}
130+

0 commit comments

Comments
 (0)