Skip to content

Commit 3f08f4f

Browse files
Add safe API for querying CUDA function attributes (#479)
* Add safe API for querying CUDA function attributes * get_attribute now binds to context before taking action Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Formatting --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 12cbbab commit 3f08f4f

File tree

3 files changed

+135
-0
lines changed

3 files changed

+135
-0
lines changed

examples/10-function-attributes.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
use cudarc::{
2+
driver::{CudaContext, DriverError},
3+
nvrtc::Ptx,
4+
};
5+
6+
fn main() -> Result<(), DriverError> {
7+
let ctx = CudaContext::new(0)?;
8+
9+
println!("Device: {}", ctx.name()?);
10+
println!();
11+
12+
// Load the module with the sin_kernel
13+
let module = ctx.load_module(Ptx::from_file("./examples/sin.ptx"))?;
14+
let sin_kernel = module.load_function("sin_kernel")?;
15+
16+
// Query function attributes
17+
println!("=== Function Attributes for 'sin_kernel' ===");
18+
println!();
19+
20+
println!("Resource Usage:");
21+
println!(" Registers per thread: {}", sin_kernel.num_regs()?);
22+
println!(
23+
" Static shared memory: {} bytes",
24+
sin_kernel.shared_size_bytes()?
25+
);
26+
println!(
27+
" Constant memory: {} bytes",
28+
sin_kernel.const_size_bytes()?
29+
);
30+
println!(
31+
" Local memory per thread: {} bytes",
32+
sin_kernel.local_size_bytes()?
33+
);
34+
println!();
35+
36+
println!("Limits:");
37+
println!(
38+
" Max threads per block: {}",
39+
sin_kernel.max_threads_per_block()?
40+
);
41+
println!();
42+
43+
println!("Compilation Info:");
44+
let ptx_ver = sin_kernel.ptx_version()?;
45+
let bin_ver = sin_kernel.binary_version()?;
46+
println!(
47+
" PTX version: {}.{}",
48+
ptx_ver / 10,
49+
ptx_ver % 10
50+
);
51+
println!(
52+
" Binary version: {}.{}",
53+
bin_ver / 10,
54+
bin_ver % 10
55+
);
56+
println!();
57+
58+
// Use occupancy API to get optimal launch configuration
59+
extern "C" fn no_dynamic_smem(_block_size: std::ffi::c_int) -> usize {
60+
0
61+
}
62+
let (min_grid_size, block_size) =
63+
sin_kernel.occupancy_max_potential_block_size(no_dynamic_smem, 0, 0, None)?;
64+
65+
println!("=== Optimal Launch Configuration (sin_kernel) ===");
66+
println!(" Suggested block size: {}", block_size);
67+
println!(" Min grid size: {}", min_grid_size);
68+
println!(" Total threads per grid: {}", min_grid_size * block_size);
69+
70+
Ok(())
71+
}

src/driver/result.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,24 @@ pub mod device {
180180

181181
pub mod function {
182182
use super::sys::{self, CUfunc_cache_enum, CUfunction_attribute_enum};
183+
use std::mem::MaybeUninit;
184+
185+
/// Gets a specific attribute of a CUDA function.
186+
///
187+
/// See [cuda docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g5e92a1b0d8d1b82cb00dcfb2de15961b)
188+
///
189+
/// # Safety
190+
/// Function must exist.
191+
pub unsafe fn get_function_attribute(
192+
f: sys::CUfunction,
193+
attribute: CUfunction_attribute_enum,
194+
) -> Result<i32, super::DriverError> {
195+
let mut value = MaybeUninit::uninit();
196+
unsafe {
197+
sys::cuFuncGetAttribute(value.as_mut_ptr(), attribute, f).result()?;
198+
Ok(value.assume_init())
199+
}
200+
}
183201

184202
/// Sets the specific attribute of a cuda function.
185203
///

src/driver/safe/core.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,6 +1933,52 @@ impl CudaFunction {
19331933
Ok(cluster_size as u32)
19341934
}
19351935

1936+
/// Get the value of a specific attribute of this [CudaFunction].
1937+
///
1938+
/// See [CUDA docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g5e92a1b0d8d1b82cb00dcfb2de15961b)
1939+
pub fn get_attribute(
1940+
&self,
1941+
attribute: CUfunction_attribute_enum,
1942+
) -> Result<i32, result::DriverError> {
1943+
self.module.ctx.bind_to_thread()?;
1944+
unsafe { result::function::get_function_attribute(self.cu_function, attribute) }
1945+
}
1946+
1947+
/// Get the number of registers used per thread.
1948+
pub fn num_regs(&self) -> Result<i32, result::DriverError> {
1949+
self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_NUM_REGS)
1950+
}
1951+
1952+
/// Get the size of statically-allocated shared memory in bytes.
1953+
pub fn shared_size_bytes(&self) -> Result<i32, result::DriverError> {
1954+
self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES)
1955+
}
1956+
1957+
/// Get the size of constant memory in bytes used by this function.
1958+
pub fn const_size_bytes(&self) -> Result<i32, result::DriverError> {
1959+
self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES)
1960+
}
1961+
1962+
/// Get the size of local memory in bytes used per thread.
1963+
pub fn local_size_bytes(&self) -> Result<i32, result::DriverError> {
1964+
self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES)
1965+
}
1966+
1967+
/// Get the maximum number of threads per block for this function.
1968+
pub fn max_threads_per_block(&self) -> Result<i32, result::DriverError> {
1969+
self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
1970+
}
1971+
1972+
/// Get the PTX virtual architecture version for which the function was compiled.
1973+
pub fn ptx_version(&self) -> Result<i32, result::DriverError> {
1974+
self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_PTX_VERSION)
1975+
}
1976+
1977+
/// Get the binary architecture version for which the function was compiled.
1978+
pub fn binary_version(&self) -> Result<i32, result::DriverError> {
1979+
self.get_attribute(CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_BINARY_VERSION)
1980+
}
1981+
19361982
/// Set the value of a specific attribute of this [CudaFunction].
19371983
pub fn set_attribute(
19381984
&self,

0 commit comments

Comments
 (0)