1
- use std:: path:: Path ;
2
- use std:: { env, io} ;
1
+ use std:: env;
3
2
4
3
const ROCM_FEATURE_PREFIX : & str = "CARGO_FEATURE_ROCM__" ;
4
+ const ROCM_HIP_FEATURE_PREFIX : & str = "CARGO_FEATURE_HIP_" ;
5
5
6
- /// Reads a header inside the rocm folder, that contains the lib's version
7
- fn get_system_hip_version ( rocm_path : impl AsRef < Path > ) -> std:: io:: Result < ( u8 , u8 , u32 ) > {
8
- let version_path = rocm_path. as_ref ( ) . join ( "include/hip/hip_version.h" ) ;
9
- let version_file = std:: fs:: read_to_string ( version_path) ?;
10
- let version_lines = version_file. lines ( ) . collect :: < Vec < _ > > ( ) ;
6
+ include ! ( "../build-script/src/lib.rs" ) ;
11
7
12
- let system_major = version_lines
13
- . iter ( )
14
- . find_map ( |line| line. strip_prefix ( "#define HIP_VERSION_MAJOR " ) )
15
- . expect ( "Invalid hip_version.h file structure: Major version line not found" )
16
- . parse :: < u8 > ( )
17
- . expect ( "Invalid hip_version.h file structure: Couldn't parse major version" ) ;
18
- let system_minor = version_lines
19
- . iter ( )
20
- . find_map ( |line| line. strip_prefix ( "#define HIP_VERSION_MINOR " ) )
21
- . expect ( "Invalid hip_version.h file structure: Minor version line not found" )
22
- . parse :: < u8 > ( )
23
- . expect ( "Invalid hip_version.h file structure: Couldn't parse minor version" ) ;
24
- let system_patch = version_lines
25
- . iter ( )
26
- . find_map ( |line| line. strip_prefix ( "#define HIP_VERSION_PATCH " ) )
27
- . expect ( "Invalid hip_version.h file structure: Patch version line not found" )
28
- . parse :: < u32 > ( )
29
- . expect ( "Invalid hip_version.h file structure: Couldn't parse patch version" ) ;
30
- let release_patch = hip_header_patch_number_to_release_patch_number ( system_patch) ;
31
- if release_patch. is_none ( ) {
32
- println ! ( "cargo::warning=Unknown release version for patch version {system_patch}. This patch does not correspond to an official release patch." ) ;
8
+ /// Make sure that at least one and only one hip feature is set
9
+ fn ensure_single_rocm_hip_feature_set ( ) {
10
+ let mut enabled_features = Vec :: new ( ) ;
11
+
12
+ for ( key, value) in env:: vars ( ) {
13
+ if key. starts_with ( ROCM_HIP_FEATURE_PREFIX ) && value == "1" {
14
+ enabled_features. push ( format ! (
15
+ "rocm__{}" ,
16
+ key. strip_prefix( ROCM_HIP_FEATURE_PREFIX ) . unwrap( )
17
+ ) ) ;
18
+ }
33
19
}
34
20
35
- Ok ( (
36
- system_major,
37
- system_minor,
38
- release_patch. unwrap_or ( system_patch) ,
39
- ) )
21
+ match enabled_features. len ( ) {
22
+ 1 => { }
23
+ 0 => panic ! ( "No ROCm HIP feature is enabled. One ROCm HIP feature must be set." ) ,
24
+ _ => panic ! (
25
+ "Multiple ROCm HIP features are enabled: {:?}. Only one can be set." ,
26
+ enabled_features
27
+ ) ,
28
+ }
40
29
}
41
30
42
- /// The official patch number of a ROCm release is not the same of the patch number
43
- /// in the header files. In the header files the patch number is a monotonic build
44
- /// that changes only when there are actual changes in the HIP libraries.
45
- /// This function maps the header patch number to their official latest release number.
46
- /// For instance if both versions 6.2.2 and 6.2.4 have the same patch version in their
47
- /// header file then this function will return 4.
48
- fn hip_header_patch_number_to_release_patch_number ( number : u32 ) -> Option < u32 > {
49
- match number {
50
- 41134 => Some ( 4 ) , // 6.2.4
51
- 42131 => Some ( 0 ) , // 6.3.0
52
- _ => None ,
31
+ /// Checks if the version inside `rocm_path` matches crate version
32
+ fn check_rocm_version ( rocm_path : impl AsRef < Path > ) -> std:: io:: Result < bool > {
33
+ let rocm_system_version = get_rocm_system_version ( rocm_path) ?;
34
+ let rocm_feature_version = get_rocm_feature_version ( ) ;
35
+
36
+ if rocm_system_version. major == rocm_feature_version. major {
37
+ let mismatches = match ( rocm_system_version. minor == rocm_feature_version. minor , rocm_system_version. patch == rocm_feature_version. patch ) {
38
+ // Perfect match, don't need a warning
39
+ ( true , true ) => return Ok ( true ) ,
40
+ ( true , false ) => "Patch" ,
41
+ ( false , _) => "Minor" ,
42
+ } ;
43
+ println ! ( "cargo::warning=ROCm {mismatches} version mismatch between cubecl-hip-sys expected version ({rocm_feature_version}) and found ROCm version on the system ({rocm_system_version}). Build process might fail due to incompatible library bindings." ) ;
44
+ Ok ( true )
45
+ } else {
46
+ Ok ( false )
53
47
}
54
48
}
55
49
56
- /// Return the ROCm version corresponding to the enabled feature
57
- fn get_rocm_feature_version ( ) -> io :: Result < ( u8 , u8 , u32 ) > {
50
+ /// Return the ROCm version corresponding to the enabled rocm__<version> feature
51
+ fn get_rocm_feature_version ( ) -> Version {
58
52
for ( key, value) in env:: vars ( ) {
59
53
if key. starts_with ( ROCM_FEATURE_PREFIX ) && value == "1" {
60
54
if let Some ( version) = key. strip_prefix ( ROCM_FEATURE_PREFIX ) {
@@ -65,66 +59,32 @@ fn get_rocm_feature_version() -> io::Result<(u8, u8, u32)> {
65
59
parts[ 1 ] . parse :: < u8 > ( ) ,
66
60
parts[ 2 ] . parse :: < u32 > ( ) ,
67
61
) {
68
- return Ok ( ( major, minor, patch) ) ;
62
+ return Version { major, minor, patch} ;
69
63
}
70
64
}
71
65
}
72
66
}
73
67
}
74
68
75
- Err ( io:: Error :: new (
76
- io:: ErrorKind :: NotFound ,
77
- "No valid ROCm feature version found. One 'rocm__<version>' feature must be set. For instance for ROCm 6.2.2 the feature is rocm__6_2_2." ,
78
- ) )
69
+ panic ! ( "No valid ROCm feature version found. One 'rocm__<version>' feature must be set. For instance for ROCm 6.2.2 the feature is rocm__6_2_2." )
79
70
}
80
71
81
- /// Make sure that feature is set correctly
82
- fn ensure_single_rocm_feature_set ( ) {
83
- let mut enabled_features = Vec :: new ( ) ;
84
-
72
+ /// Return the ROCm HIP patch version corresponding to the enabled hip_<patch_version> feature
73
+ fn get_hip_feature_patch_version ( ) -> u32 {
85
74
for ( key, value) in env:: vars ( ) {
86
- if key. starts_with ( ROCM_FEATURE_PREFIX ) && value == "1" {
87
- enabled_features. push ( format ! (
88
- "rocm__{}" ,
89
- key. strip_prefix( ROCM_FEATURE_PREFIX ) . unwrap( )
90
- ) ) ;
75
+ if key. starts_with ( ROCM_HIP_FEATURE_PREFIX ) && value == "1" {
76
+ if let Some ( patch) = key. strip_prefix ( ROCM_HIP_FEATURE_PREFIX ) {
77
+ if let Ok ( patch) = patch. parse :: < u32 > ( ) {
78
+ return patch;
79
+ }
80
+ }
91
81
}
92
82
}
93
83
94
- match enabled_features. len ( ) {
95
- 1 => { }
96
- 0 => panic ! ( "No ROCm version features are enabled. One ROCm version feature must be set." ) ,
97
- _ => panic ! (
98
- "Multiple ROCm version features are enabled: {:?}. Only one can be set." ,
99
- enabled_features
100
- ) ,
101
- }
102
- }
103
-
104
- /// Checks if the version inside `rocm_path` matches crate version
105
- fn check_version ( rocm_path : impl AsRef < Path > ) -> std:: io:: Result < bool > {
106
- let ( system_major, system_minor, system_patch) = get_system_hip_version ( rocm_path) ?;
107
- let ( crate_major, crate_minor, crate_patch) = get_rocm_feature_version ( ) ?;
108
-
109
- if crate_major == system_major {
110
- let mismatches = match ( crate_minor == system_minor, crate_patch == system_patch) {
111
- // Perfect match, don't need a warning
112
- ( true , true ) => return Ok ( true ) ,
113
- ( false , true ) => "Minor" ,
114
- ( true , false ) => "Patch" ,
115
- ( false , false ) => "Both minor and patch" ,
116
- } ;
117
- println ! ( "cargo::warning={mismatches} version mismatch between cubecl-hip-sys bindings and system HIP version. Want {}, but found {}" ,
118
- [ crate_major, crate_minor, crate_patch as u8 ] . map( |el| el. to_string( ) ) . join( "." ) ,
119
- [ system_major, system_minor, system_patch as u8 ] . map( |el| el. to_string( ) ) . join( "." ) ) ;
120
- Ok ( true )
121
- } else {
122
- Ok ( false )
123
- }
84
+ panic ! ( "No valid ROCm HIP feature found. One 'hip_<patch>' feature must be set." )
124
85
}
125
86
126
87
fn main ( ) {
127
- ensure_single_rocm_feature_set ( ) ;
128
88
129
89
println ! ( "cargo::rerun-if-changed=build.rs" ) ;
130
90
println ! ( "cargo::rerun-if-env-changed=CUBECL_ROCM_PATH" ) ;
@@ -146,18 +106,25 @@ fn main() {
146
106
} )
147
107
. peekable ( ) ;
148
108
let have_candidates = rocm_path_candidates. peek ( ) . is_some ( ) ;
149
- let rocm_path = rocm_path_candidates. find ( |path| check_version ( path) . unwrap_or_default ( ) ) ;
109
+ let rocm_path = rocm_path_candidates. find ( |path| check_rocm_version ( path) . unwrap_or_default ( ) ) ;
150
110
151
111
if let Some ( valid_rocm_path) = rocm_path {
112
+ ensure_single_rocm_hip_feature_set ( ) ;
113
+ // verify HIP compatbility
114
+ let Version { patch : hip_system_patch_version, ..} = get_hip_system_version ( valid_rocm_path) . unwrap ( ) ;
115
+ let hip_feature_patch_version = get_hip_feature_patch_version ( ) ;
116
+ if hip_system_patch_version != hip_feature_patch_version {
117
+ panic ! ( "Imcompatible HIP bindings found. Expected to find HIP patch version {hip_feature_patch_version}, but found HIP patch version {hip_system_patch_version}." ) ;
118
+ }
119
+
152
120
println ! ( "cargo::rustc-link-lib=dylib=hiprtc" ) ;
153
121
println ! ( "cargo::rustc-link-lib=dylib=amdhip64" ) ;
154
122
println ! ( "cargo::rustc-link-search=native={}/lib" , valid_rocm_path) ;
155
123
} else if have_candidates {
156
- panic ! (
157
- "None of the found ROCm installations match crate version {}" ,
158
- env!( "CARGO_PKG_VERSION" )
159
- ) ;
124
+ let rocm_feature_version = get_rocm_feature_version ( ) ;
125
+ panic ! ( "None of the found ROCm installations match version {rocm_feature_version}." ) ;
160
126
} else if paths. len ( ) > 1 {
161
- panic ! ( "HIP headers not found in any of the defined CUBECL_ROCM_PATH, ROCM_PATH or HIP_PATH directories ." ) ;
127
+ panic ! ( "HIP headers not found in any of the directories set in CUBECL_ROCM_PATH, ROCM_PATH or HIP_PATH environment variable ." ) ;
162
128
}
163
129
}
130
+
0 commit comments