From aef416862945ee169b50ffbaf50a44474fcfc79a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 4 Sep 2025 18:21:00 +0200 Subject: [PATCH 01/23] Add ArcSin, ArcCos, ArcTan and ArcTan2 to float operations This just makes the compiler happy and is not yet tested! --- .../src/frontend/operation/binary.rs | 13 +++++ .../src/frontend/operation/unary.rs | 41 ++++++++++++++++ crates/cubecl-cpp/src/metal/dialect.rs | 4 ++ crates/cubecl-cpp/src/shared/base.rs | 20 ++++++++ crates/cubecl-cpp/src/shared/binary.rs | 48 +++++++++++++++++++ crates/cubecl-cpp/src/shared/dialect.rs | 4 ++ crates/cubecl-cpp/src/shared/instruction.rs | 8 ++++ crates/cubecl-cpp/src/shared/unary.rs | 3 ++ crates/cubecl-ir/src/arithmetic.rs | 8 ++++ crates/cubecl-ir/src/processing.rs | 13 +++++ crates/cubecl-opt/src/instructions.rs | 6 ++- crates/cubecl-opt/src/passes/constant_prop.rs | 17 +++++++ crates/cubecl-spirv/src/arithmetic.rs | 32 +++++++++++++ crates/cubecl-spirv/src/extensions.rs | 20 ++++++++ .../cubecl-wgpu/src/compiler/wgsl/compiler.rs | 17 +++++++ .../src/compiler/wgsl/instructions.rs | 33 +++++++++++++ 16 files changed, 286 insertions(+), 1 deletion(-) diff --git a/crates/cubecl-core/src/frontend/operation/binary.rs b/crates/cubecl-core/src/frontend/operation/binary.rs index 081069eab..04ec52336 100644 --- a/crates/cubecl-core/src/frontend/operation/binary.rs +++ b/crates/cubecl-core/src/frontend/operation/binary.rs @@ -223,6 +223,19 @@ impl_binary_func!( f32, f64 ); +impl_binary_func!( + ArcTan2, + atan2, + __expand_atan2, + __expand_atan2_method, + Arithmetic::ArcTan2, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); impl_binary_func!( Max, max, diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index afa5b333c..9443e2f09 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -187,6 +187,47 @@ impl_unary_func!( f32, f64 ); +// TODO: Missing: SinH, ArcSinH, CosH, ArcCosH, ArcTanH, Tan +// TODO: Add function for converting between degree and radiants +// Open Questions: +// - When to use metal safe / atomic stuff and when not +// - When do I need to check for Bfloats and stuff? +impl_unary_func!( + ArcCos, + acos, + __exapnd_acos, + Arithmetic::ArcCos, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + ArcSin, + asin, + __exapnd_asin, + Arithmetic::ArcSin, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + ArcTan, + atan, + __exapnd_atan, + Arithmetic::ArcTan, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); impl_unary_func!( Sqrt, sqrt, diff --git a/crates/cubecl-cpp/src/metal/dialect.rs b/crates/cubecl-cpp/src/metal/dialect.rs index ba5698a2a..c783af450 100644 --- a/crates/cubecl-cpp/src/metal/dialect.rs +++ b/crates/cubecl-cpp/src/metal/dialect.rs @@ -803,6 +803,10 @@ impl DialectInstructions for MslDialect { write!(f, "safe_tanh_scalar({input})") } + fn compile_instruction_atan2(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "atan2") + } + // unary fn compile_instruction_find_first_set>( f: &mut std::fmt::Formatter<'_>, diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index 054f6bd51..5ae4b1129 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -932,6 +932,26 @@ impl CppCompiler { D::register_instruction_extension(&mut self.extensions, &instruction); instructions.push(instruction) } + gpu::Arithmetic::ArcCos(op) => { + let instruction = Instruction::ArcCos(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::ArcSin(op) => { + let instruction = Instruction::ArcSin(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::ArcTan(op) => { + let instruction = Instruction::ArcTan(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::ArcTan2(op) => { + let instruction = Instruction::ArcTan2(self.compile_binary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } gpu::Arithmetic::Powf(op) => { instructions.push(Instruction::Powf(self.compile_binary(op, out))) } diff --git a/crates/cubecl-cpp/src/shared/binary.rs b/crates/cubecl-cpp/src/shared/binary.rs index 449be4f30..84cfd39b7 100644 --- a/crates/cubecl-cpp/src/shared/binary.rs +++ b/crates/cubecl-cpp/src/shared/binary.rs @@ -221,6 +221,54 @@ impl Binary for Powf { } } +pub struct ArcTan2; + +impl Binary for ArcTan2 { + // ArcTan2 doesn't support half and no half equivalent exists + fn format_scalar( + f: &mut std::fmt::Formatter<'_>, + lhs: Lhs, + rhs: Rhs, + item: Item, + ) -> std::fmt::Result { + let elem = item.elem; + match elem { + Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => { + write!(f, "{elem}(")?; + D::compile_instruction_atan2(f)?; + write!(f, "(float({lhs}), float({rhs})))") + } + _ => { + D::compile_instruction_atan2(f)?; + write!(f, "({lhs}, {rhs})") + } + } + } + + // ArcTan2 doesn't support half and no half equivalent exists + fn unroll_vec( + f: &mut Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + ) -> core::fmt::Result { + let item_out = out.item(); + let index = out.item().vectorization; + + let out = out.fmt_left(); + writeln!(f, "{out} = {item_out}{{")?; + for i in 0..index { + let lhsi = lhs.index(i); + let rhsi = rhs.index(i); + + Self::format_scalar(f, lhsi, rhsi, item_out)?; + f.write_str(", ")?; + } + + f.write_str("};\n") + } +} + pub struct Max; impl Binary for Max { diff --git a/crates/cubecl-cpp/src/shared/dialect.rs b/crates/cubecl-cpp/src/shared/dialect.rs index 08f179895..b6e70857a 100644 --- a/crates/cubecl-cpp/src/shared/dialect.rs +++ b/crates/cubecl-cpp/src/shared/dialect.rs @@ -603,6 +603,10 @@ pub trait DialectInstructions { write!(f, "powf") } + fn compile_instruction_atan2(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "atan2") + } + fn compile_instruction_half_function_name_prefix() -> &'static str { "h" } diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index 0b386dc62..550a1423c 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -163,6 +163,10 @@ pub enum Instruction { Cos(UnaryInstruction), Sin(UnaryInstruction), Tanh(UnaryInstruction), + ArcCos(UnaryInstruction), + ArcSin(UnaryInstruction), + ArcTan(UnaryInstruction), + ArcTan2(BinaryInstruction), Powf(BinaryInstruction), Sqrt(UnaryInstruction), Min(BinaryInstruction), @@ -503,6 +507,10 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ Instruction::Cos(it) => Cos::format(f, &it.input, &it.out), Instruction::Sin(it) => Sin::format(f, &it.input, &it.out), Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out), + Instruction::ArcCos(it) => ArcCos::format(f, &it.input, &it.out), + Instruction::ArcSin(it) => ArcSin::format(f, &it.input, &it.out), + Instruction::ArcTan(it) => ArcTan::format(f, &it.input, &it.out), + Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out), Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out), diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index 749bb9184..c0976439e 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -151,6 +151,9 @@ macro_rules! function { function!(Log, "log"); function!(Cos, "cos"); function!(Sin, "sin"); +function!(ArcCos, "acos"); +function!(ArcSin, "asin"); +function!(ArcTan, "atan"); function!(Sqrt, "sqrt"); function!(Exp, "exp"); function!(Ceil, "ceil"); diff --git a/crates/cubecl-ir/src/arithmetic.rs b/crates/cubecl-ir/src/arithmetic.rs index 03343f450..e798a01ef 100644 --- a/crates/cubecl-ir/src/arithmetic.rs +++ b/crates/cubecl-ir/src/arithmetic.rs @@ -23,6 +23,10 @@ pub enum Arithmetic { Cos(UnaryOperator), Sin(UnaryOperator), Tanh(UnaryOperator), + ArcCos(UnaryOperator), + ArcSin(UnaryOperator), + ArcTan(UnaryOperator), + ArcTan2(BinaryOperator), Powf(BinaryOperator), Sqrt(UnaryOperator), Round(UnaryOperator), @@ -61,6 +65,10 @@ impl Display for Arithmetic { Arithmetic::Cos(op) => write!(f, "{}.cos()", op.input), Arithmetic::Sin(op) => write!(f, "{}.sin()", op.input), Arithmetic::Tanh(op) => write!(f, "{}.tanh()", op.input), + Arithmetic::ArcCos(op) => write!(f, "{}.acos()", op.input), + Arithmetic::ArcSin(op) => write!(f, "{}.asin()", op.input), + Arithmetic::ArcTan(op) => write!(f, "{}.atan()", op.input), + Arithmetic::ArcTan2(op) => write!(f, "{}.atan2({})", op.lhs, op.rhs), Arithmetic::Powf(op) => write!(f, "{}.pow({})", op.lhs, op.rhs), Arithmetic::Sqrt(op) => write!(f, "{}.sqrt()", op.input), Arithmetic::Round(op) => write!(f, "{}.round()", op.input), diff --git a/crates/cubecl-ir/src/processing.rs b/crates/cubecl-ir/src/processing.rs index f734a165c..0a2ad59fe 100644 --- a/crates/cubecl-ir/src/processing.rs +++ b/crates/cubecl-ir/src/processing.rs @@ -104,6 +104,19 @@ impl ScopeProcessing { Arithmetic::Tanh(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } + Arithmetic::ArcCos(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::ArcSin(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::ArcTan(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::ArcTan2(op) => { + sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap()); + sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap()); + } Arithmetic::Powf(op) => { sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap()); sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap()); diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index 0d4e99f9a..56e5ab2bb 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -80,7 +80,8 @@ impl Optimizer { | Arithmetic::Min(binary_operator) | Arithmetic::Remainder(binary_operator) | Arithmetic::Dot(binary_operator) - | Arithmetic::MulHi(binary_operator) => self.visit_binop(binary_operator, visit_read), + | Arithmetic::MulHi(binary_operator) + | Arithmetic::ArcTan2(binary_operator) => self.visit_binop(binary_operator, visit_read), Arithmetic::Abs(unary_operator) | Arithmetic::Exp(unary_operator) @@ -89,6 +90,9 @@ impl Optimizer { | Arithmetic::Cos(unary_operator) | Arithmetic::Sin(unary_operator) | Arithmetic::Tanh(unary_operator) + | Arithmetic::ArcCos(unary_operator) + | Arithmetic::ArcSin(unary_operator) + | Arithmetic::ArcTan(unary_operator) | Arithmetic::Sqrt(unary_operator) | Arithmetic::Round(unary_operator) | Arithmetic::Floor(unary_operator) diff --git a/crates/cubecl-opt/src/passes/constant_prop.rs b/crates/cubecl-opt/src/passes/constant_prop.rs index 181ba7729..25a87dd9d 100644 --- a/crates/cubecl-opt/src/passes/constant_prop.rs +++ b/crates/cubecl-opt/src/passes/constant_prop.rs @@ -369,6 +369,23 @@ fn try_const_eval_arithmetic(op: &mut Arithmetic) -> Option Arithmetic::Cos(op) => const_eval_float!(op.input; num::Float::cos), Arithmetic::Sin(op) => const_eval_float!(op.input; num::Float::sin), Arithmetic::Tanh(op) => const_eval_float!(op.input; num::Float::tanh), + Arithmetic::ArcCos(op) => const_eval_float!(op.input; num::Float::acos), + Arithmetic::ArcSin(op) => const_eval_float!(op.input; num::Float::asin), + Arithmetic::ArcTan(op) => const_eval_float!(op.input; num::Float::atan), + Arithmetic::ArcTan2(op) => { + use ConstantScalarValue::*; + if let (Some(lhs), Some(rhs)) = (op.lhs.as_const(), op.rhs.as_const()) { + let rhs = rhs.cast_to(lhs.storage_type()); + Some(match (lhs, rhs) { + (Float(lhs, kind), Float(rhs, _)) => { + ConstantScalarValue::Float(lhs.atan2(rhs), kind) + } + _ => unreachable!(), + }) + } else { + None + } + } Arithmetic::Sqrt(op) => const_eval_float!(op.input; num::Float::sqrt), Arithmetic::Round(op) => const_eval_float!(op.input; num::Float::round), Arithmetic::Floor(op) => const_eval_float!(op.input; num::Float::floor), diff --git a/crates/cubecl-spirv/src/arithmetic.rs b/crates/cubecl-spirv/src/arithmetic.rs index 92f44404b..e20c20919 100644 --- a/crates/cubecl-spirv/src/arithmetic.rs +++ b/crates/cubecl-spirv/src/arithmetic.rs @@ -301,6 +301,38 @@ impl SpirvCompiler { } }) } + Arithmetic::ArcCos(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::acos(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::ArcSin(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::asin(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::ArcTan(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::atan(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::ArcTan2(op) => { + self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| { + T::atan2(b, ty, lhs, rhs, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } Arithmetic::Powf(op) => { self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| { let bool = match out_ty { diff --git a/crates/cubecl-spirv/src/extensions.rs b/crates/cubecl-spirv/src/extensions.rs index 6f2dab234..3191b62a1 100644 --- a/crates/cubecl-spirv/src/extensions.rs +++ b/crates/cubecl-spirv/src/extensions.rs @@ -13,6 +13,10 @@ pub trait TargetExtensions { fn sin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn cos(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn tanh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn asin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn acos(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn atan(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn atan2(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word); fn pow(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word); fn exp(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn log(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); @@ -71,6 +75,22 @@ pub mod glcompute { b.tanh_id(ty, Some(out), input).unwrap(); } + fn asin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.asin_id(ty, Some(out), input).unwrap(); + } + + fn acos(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.acos_id(ty, Some(out), input).unwrap(); + } + + fn atan(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.atan_id(ty, Some(out), input).unwrap(); + } + + fn atan2(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word) { + b.atan2_id(ty, Some(out), lhs, rhs).unwrap(); + } + fn pow(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word) { b.pow_id(ty, Some(out), lhs, rhs).unwrap(); } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index a9a8132ba..fc165d7fd 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -752,6 +752,23 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(out), }), + cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(out), + }), cube::Arithmetic::Powf(op) => instructions.push(wgsl::Instruction::Powf { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index be3718260..dd0a6bd6d 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -121,6 +121,23 @@ pub enum Instruction { input: Variable, out: Variable, }, + ArcCos { + input: Variable, + out: Variable, + }, + ArcSin { + input: Variable, + out: Variable, + }, + ArcTan { + input: Variable, + out: Variable, + }, + ArcTan2 { + lhs: Variable, + rhs: Variable, + out: Variable, + }, Powf { lhs: Variable, rhs: Variable, @@ -602,6 +619,22 @@ impl Display for Instruction { result } + Instruction::ArcCos { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = acos({input});") + } + Instruction::ArcSin { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = asin({input});") + } + Instruction::ArcTan { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = atan({input});") + } + Instruction::ArcTan2 { lhs, rhs, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = atan2({lhs}, {rhs});") + } Instruction::Recip { input, out } => { let out = out.fmt_left(); write!(f, "{out} = 1.0 / {input};") From 416fa3d9e24def07c5a335b1c5cffe06ca1f34d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 4 Sep 2025 20:22:47 +0200 Subject: [PATCH 02/23] Add tests for acos, asin, atan and atan2 --- .../cubecl-core/src/frontend/element/float.rs | 4 + .../src/frontend/element/float/typemap.rs | 4 + .../src/frontend/operation/unary.rs | 7 +- .../cubecl-core/src/runtime_tests/binary.rs | 30 +++++ crates/cubecl-core/src/runtime_tests/unary.rs | 110 ++++++++++++++++++ .../compiler/visitor/operation/arithmetic.rs | 41 +++++++ 6 files changed, 193 insertions(+), 3 deletions(-) diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs index 1b1ee5605..0036b0463 100644 --- a/crates/cubecl-core/src/frontend/element/float.rs +++ b/crates/cubecl-core/src/frontend/element/float.rs @@ -26,6 +26,10 @@ pub trait Float: + Cos + Sin + Tanh + + ArcCos + + ArcSin + + ArcTan + + ArcTan2 + Powf + Sqrt + Round diff --git a/crates/cubecl-core/src/frontend/element/float/typemap.rs b/crates/cubecl-core/src/frontend/element/float/typemap.rs index 1d4aec3a6..6b19efab8 100644 --- a/crates/cubecl-core/src/frontend/element/float/typemap.rs +++ b/crates/cubecl-core/src/frontend/element/float/typemap.rs @@ -244,6 +244,10 @@ impl Log1p for ElemExpand {} impl Cos for ElemExpand {} impl Sin for ElemExpand {} impl Tanh for ElemExpand {} +impl ArcCos for ElemExpand {} +impl ArcSin for ElemExpand {} +impl ArcTan for ElemExpand {} +impl ArcTan2 for ElemExpand {} impl Powf for ElemExpand {} impl Sqrt for ElemExpand {} impl Round for ElemExpand {} diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index 9443e2f09..3c23fe719 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -192,10 +192,11 @@ impl_unary_func!( // Open Questions: // - When to use metal safe / atomic stuff and when not // - When do I need to check for Bfloats and stuff? +// - How to add support for LLVM MIR (in cube-cpu) impl_unary_func!( ArcCos, acos, - __exapnd_acos, + __expand_acos, Arithmetic::ArcCos, f16, bf16, @@ -207,7 +208,7 @@ impl_unary_func!( impl_unary_func!( ArcSin, asin, - __exapnd_asin, + __expand_asin, Arithmetic::ArcSin, f16, bf16, @@ -219,7 +220,7 @@ impl_unary_func!( impl_unary_func!( ArcTan, atan, - __exapnd_atan, + __expand_atan, Arithmetic::ArcTan, f16, bf16, diff --git a/crates/cubecl-core/src/runtime_tests/binary.rs b/crates/cubecl-core/src/runtime_tests/binary.rs index 4d82b4b1d..00e4e81cb 100644 --- a/crates/cubecl-core/src/runtime_tests/binary.rs +++ b/crates/cubecl-core/src/runtime_tests/binary.rs @@ -150,6 +150,35 @@ test_binary_impl!( ] ); +test_binary_impl!( + test_atan2, + F, + F::atan2, + [ + { + input_vectorization: 1, + out_vectorization: 1, + lhs: as_type![F: 0., 1., -1., 1., -1.], + rhs: as_type![F: 1., 0., 0., 1., -1.], + expected: as_type![F: 0., 1.57079632679, -1.57079632679, 0.78539816339, -2.35619449019] + }, + { + input_vectorization: 2, + out_vectorization: 2, + lhs: as_type![F: 0., 1., -1., 1.], + rhs: as_type![F: 1., 0., 0., 1.], + expected: as_type![F: 0., 1.57079632679, -1.57079632679, 0.78539816339] + }, + { + input_vectorization: 4, + out_vectorization: 4, + lhs: as_type![F: 0., 1., -1., 1.], + rhs: as_type![F: 1., 0., 0., 1.], + expected: as_type![F: 0., 1.57079632679, -1.57079632679, 0.78539816339] + } + ] +); + #[cube(launch_unchecked)] fn test_mulhi_kernel( lhs: &Array>, @@ -250,6 +279,7 @@ macro_rules! testgen_binary { add_test!(test_dot); add_test!(test_powf); + add_test!(test_atan2); } }; } diff --git a/crates/cubecl-core/src/runtime_tests/unary.rs b/crates/cubecl-core/src/runtime_tests/unary.rs index 87672c6e4..9aeb3aaf6 100644 --- a/crates/cubecl-core/src/runtime_tests/unary.rs +++ b/crates/cubecl-core/src/runtime_tests/unary.rs @@ -168,6 +168,111 @@ macro_rules! test_unary_impl_int_fixed { }; } +test_unary_impl!(test_sin, F, F::sin, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1.57079632679, 3.14159265359, -1.57079632679], + expected: as_type![F: 0., 1., 0., -1.] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1.57079632679, 3.14159265359, -1.57079632679], + expected: as_type![F: 0., 1., 0., -1.] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1.57079632679, 3.14159265359, -1.57079632679], + expected: as_type![F: 0., 1., 0., -1.] + } +]); + +test_unary_impl!(test_cos, F, F::cos, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1.57079632679, 3.14159265359, -1.57079632679], + expected: as_type![F: 1., 0., -1., 0.] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1.57079632679, 3.14159265359, -1.57079632679], + expected: as_type![F: 1., 0., -1., 0.] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1.57079632679, 3.14159265359, -1.57079632679], + expected: as_type![F: 1., 0., -1., 0.] + } +]); + +test_unary_impl!(test_asin, F, F::asin, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 0.5, 1., -0.5, -1.], + expected: as_type![F: 0., 0.52359877559, 1.57079632679, -0.52359877559, -1.57079632679] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 0.5, 1., -0.5], + expected: as_type![F: 0., 0.52359877559, 1.57079632679, -0.52359877559] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 0.5, 1., -0.5], + expected: as_type![F: 0., 0.52359877559, 1.57079632679, -0.52359877559] + } +]); + +test_unary_impl!(test_acos, F, F::acos, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 1., 0.5, 0., -0.5, -1.], + expected: as_type![F: 0., 1.04719755119, 1.57079632679, 2.09439510239, 3.14159265359] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 1., 0.5, 0., -0.5], + expected: as_type![F: 0., 1.04719755119, 1.57079632679, 2.09439510239] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 1., 0.5, 0., -0.5], + expected: as_type![F: 0., 1.04719755119, 1.57079632679, 2.09439510239] + } +]); + +test_unary_impl!(test_atan, F, F::atan, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1., -1., 1000., -1000.], + expected: as_type![F: 0., 0.78539816339, -0.78539816339, 1.56979632472, -1.56979632472] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1., -1., 1000.], + expected: as_type![F: 0., 0.78539816339, -0.78539816339, 1.56979632472] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1., -1., 1000.], + expected: as_type![F: 0., 0.78539816339, -0.78539816339, 1.56979632472] + } +]); + test_unary_impl!( test_magnitude, F, @@ -376,6 +481,11 @@ macro_rules! testgen_unary { }; } + add_test!(test_sin); + add_test!(test_cos); + add_test!(test_asin); + add_test!(test_acos); + add_test!(test_atan); add_test!(test_normalize); add_test!(test_magnitude); add_test!(test_abs); diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index 0e16e4ab2..a71ff4dfd 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -28,6 +28,47 @@ impl<'a> Visitor<'a> { let result = self.append_operation_with_result(operation); self.insert_variable(out, result); } + Arithmetic::ArcCos(_acos) => { + todo!("intr_acos does not exist") + /*let value = self.get_variable(acos.input); + let result = self.append_operation_with_result(llvm_ods::intr_acos( + self.context, + value, + self.location, + )); + self.insert_variable(out, result);*/ + } + Arithmetic::ArcSin(_asin) => { + todo!("intr_asin does not exist") + /*let value = self.get_variable(asin.input); + let result = self.append_operation_with_result(llvm_ods::intr_asin( + self.context, + value, + self.location, + )); + self.insert_variable(out, result);*/ + } + Arithmetic::ArcTan(_atan) => { + todo!("intr_atan does not exist") + /*let value = self.get_variable(acos.input); + let result = self.append_operation_with_result(llvm_ods::intr_atan( + self.context, + value, + self.location, + )); + self.insert_variable(out, result);*/ + } + Arithmetic::ArcTan2(_atan2) => { + todo!("intr_atan2 does not exist") + /*let (y, x) = self.get_binary_op_variable(atan2.lhs, atan2.rhs); + let result = self.append_operation_with_result(llvm_ods::intr_atan2( + self.context, + y, + x, + self.location, + )); + self.insert_variable(out, result);*/ + } Arithmetic::Ceil(ceil) => { let value = self.get_variable(ceil.input); let result = self.append_operation_with_result(llvm_ods::intr_ceil( From 99a5d96bc7501f2e4e1d09a1f8c47acffa180ac3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 4 Sep 2025 21:27:28 +0200 Subject: [PATCH 03/23] Add sinh, cosh, asinh, acosh and atanh --- .../cubecl-core/src/frontend/element/float.rs | 5 + .../src/frontend/element/float/typemap.rs | 5 + .../src/frontend/operation/unary.rs | 62 +++++++++- crates/cubecl-core/src/runtime_tests/unary.rs | 110 ++++++++++++++++++ crates/cubecl-cpp/src/shared/base.rs | 25 ++++ crates/cubecl-cpp/src/shared/instruction.rs | 10 ++ crates/cubecl-cpp/src/shared/unary.rs | 5 + .../compiler/visitor/operation/arithmetic.rs | 50 ++++++++ crates/cubecl-ir/src/arithmetic.rs | 10 ++ crates/cubecl-ir/src/processing.rs | 15 +++ crates/cubecl-opt/src/instructions.rs | 5 + crates/cubecl-opt/src/passes/constant_prop.rs | 5 + crates/cubecl-spirv/src/arithmetic.rs | 40 +++++++ crates/cubecl-spirv/src/extensions.rs | 25 ++++ .../cubecl-wgpu/src/compiler/wgsl/compiler.rs | 20 ++++ .../src/compiler/wgsl/instructions.rs | 40 +++++++ 16 files changed, 430 insertions(+), 2 deletions(-) diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs index 0036b0463..9a9361251 100644 --- a/crates/cubecl-core/src/frontend/element/float.rs +++ b/crates/cubecl-core/src/frontend/element/float.rs @@ -26,9 +26,14 @@ pub trait Float: + Cos + Sin + Tanh + + Sinh + + Cosh + ArcCos + ArcSin + ArcTan + + ArcSinh + + ArcCosh + + ArcTanh + ArcTan2 + Powf + Sqrt diff --git a/crates/cubecl-core/src/frontend/element/float/typemap.rs b/crates/cubecl-core/src/frontend/element/float/typemap.rs index 6b19efab8..7c3bd3a13 100644 --- a/crates/cubecl-core/src/frontend/element/float/typemap.rs +++ b/crates/cubecl-core/src/frontend/element/float/typemap.rs @@ -244,9 +244,14 @@ impl Log1p for ElemExpand {} impl Cos for ElemExpand {} impl Sin for ElemExpand {} impl Tanh for ElemExpand {} +impl Sinh for ElemExpand {} +impl Cosh for ElemExpand {} impl ArcCos for ElemExpand {} impl ArcSin for ElemExpand {} impl ArcTan for ElemExpand {} +impl ArcSinh for ElemExpand {} +impl ArcCosh for ElemExpand {} +impl ArcTanh for ElemExpand {} impl ArcTan2 for ElemExpand {} impl Powf for ElemExpand {} impl Sqrt for ElemExpand {} diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index 3c23fe719..c0e7adc9f 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -187,8 +187,30 @@ impl_unary_func!( f32, f64 ); -// TODO: Missing: SinH, ArcSinH, CosH, ArcCosH, ArcTanH, Tan -// TODO: Add function for converting between degree and radiants +impl_unary_func!( + Sinh, + sinh, + __expand_sinh, + Arithmetic::Sinh, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + Cosh, + cosh, + __expand_cosh, + Arithmetic::Cosh, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); // Open Questions: // - When to use metal safe / atomic stuff and when not // - When do I need to check for Bfloats and stuff? @@ -229,6 +251,42 @@ impl_unary_func!( f32, f64 ); +impl_unary_func!( + ArcSinh, + asinh, + __expand_asinh, + Arithmetic::ArcSinh, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + ArcCosh, + acosh, + __expand_acosh, + Arithmetic::ArcCosh, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + ArcTanh, + atanh, + __expand_atanh, + Arithmetic::ArcTanh, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); impl_unary_func!( Sqrt, sqrt, diff --git a/crates/cubecl-core/src/runtime_tests/unary.rs b/crates/cubecl-core/src/runtime_tests/unary.rs index 9aeb3aaf6..89f9cc986 100644 --- a/crates/cubecl-core/src/runtime_tests/unary.rs +++ b/crates/cubecl-core/src/runtime_tests/unary.rs @@ -273,6 +273,111 @@ test_unary_impl!(test_atan, F, F::atan, [ } ]); +test_unary_impl!(test_sinh, F, F::sinh, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1., -1., 2., -2.], + expected: as_type![F: 0., 1.1752011936, -1.1752011936, 3.6268604078, -3.6268604078] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1., -1., 2.], + expected: as_type![F: 0., 1.1752011936, -1.1752011936, 3.6268604078] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1., -1., 2.], + expected: as_type![F: 0., 1.1752011936, -1.1752011936, 3.6268604078] + } +]); + +test_unary_impl!(test_cosh, F, F::cosh, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1., -1., 2., -2.], + expected: as_type![F: 1., 1.5430806348, 1.5430806348, 3.7621956911, 3.7621956911] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1., -1., 2.], + expected: as_type![F: 1., 1.5430806348, 1.5430806348, 3.7621956911] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1., -1., 2.], + expected: as_type![F: 1., 1.5430806348, 1.5430806348, 3.7621956911] + } +]); + +test_unary_impl!(test_asinh, F, F::asinh, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1., -1., 2., -2.], + expected: as_type![F: 0., 0.88137358702, -0.88137358702, 1.44363547517, -1.44363547517] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1., -1., 2.], + expected: as_type![F: 0., 0.88137358702, -0.88137358702, 1.44363547517] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1., -1., 2.], + expected: as_type![F: 0., 0.88137358702, -0.88137358702, 1.44363547517] + } +]); + +test_unary_impl!(test_acosh, F, F::acosh, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 1., 2., 3., 10.], + expected: as_type![F: 0., 1.31695789692, 1.76274717404, 2.99322284612] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 1., 2., 3., 10.], + expected: as_type![F: 0., 1.31695789692, 1.76274717404, 2.99322284612] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 1., 2., 3., 10.], + expected: as_type![F: 0., 1.31695789692, 1.76274717404, 2.99322284612] + } +]); + +test_unary_impl!(test_atanh, F, F::atanh, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 0.5, -0.5, 0.9, -0.9], + expected: as_type![F: 0., 0.54930614433, -0.54930614433, 1.47221948958, -1.47221948958] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 0.5, -0.5, 0.9], + expected: as_type![F: 0., 0.54930614433, -0.54930614433, 1.47221948958] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 0.5, -0.5, 0.9], + expected: as_type![F: 0., 0.54930614433, -0.54930614433, 1.47221948958] + } +]); + test_unary_impl!( test_magnitude, F, @@ -483,9 +588,14 @@ macro_rules! testgen_unary { add_test!(test_sin); add_test!(test_cos); + add_test!(test_sinh); + add_test!(test_cosh); add_test!(test_asin); add_test!(test_acos); add_test!(test_atan); + add_test!(test_asinh); + add_test!(test_acosh); + add_test!(test_atanh); add_test!(test_normalize); add_test!(test_magnitude); add_test!(test_abs); diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index 5ae4b1129..bd4f09524 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -932,6 +932,16 @@ impl CppCompiler { D::register_instruction_extension(&mut self.extensions, &instruction); instructions.push(instruction) } + gpu::Arithmetic::Sinh(op) => { + let instruction = Instruction::Sinh(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::Cosh(op) => { + let instruction = Instruction::Cosh(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } gpu::Arithmetic::ArcCos(op) => { let instruction = Instruction::ArcCos(self.compile_unary(op, out)); D::register_instruction_extension(&mut self.extensions, &instruction); @@ -947,6 +957,21 @@ impl CppCompiler { D::register_instruction_extension(&mut self.extensions, &instruction); instructions.push(instruction) } + gpu::Arithmetic::ArcSinh(op) => { + let instruction = Instruction::ArcSinh(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::ArcCosh(op) => { + let instruction = Instruction::ArcCosh(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::ArcTanh(op) => { + let instruction = Instruction::ArcTanh(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } gpu::Arithmetic::ArcTan2(op) => { let instruction = Instruction::ArcTan2(self.compile_binary(op, out)); D::register_instruction_extension(&mut self.extensions, &instruction); diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index 550a1423c..df87077a3 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -163,9 +163,14 @@ pub enum Instruction { Cos(UnaryInstruction), Sin(UnaryInstruction), Tanh(UnaryInstruction), + Sinh(UnaryInstruction), + Cosh(UnaryInstruction), ArcCos(UnaryInstruction), ArcSin(UnaryInstruction), ArcTan(UnaryInstruction), + ArcSinh(UnaryInstruction), + ArcCosh(UnaryInstruction), + ArcTanh(UnaryInstruction), ArcTan2(BinaryInstruction), Powf(BinaryInstruction), Sqrt(UnaryInstruction), @@ -507,9 +512,14 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ Instruction::Cos(it) => Cos::format(f, &it.input, &it.out), Instruction::Sin(it) => Sin::format(f, &it.input, &it.out), Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out), + Instruction::Sinh(it) => Sinh::format(f, &it.input, &it.out), + Instruction::Cosh(it) => Cosh::format(f, &it.input, &it.out), Instruction::ArcCos(it) => ArcCos::format(f, &it.input, &it.out), Instruction::ArcSin(it) => ArcSin::format(f, &it.input, &it.out), Instruction::ArcTan(it) => ArcTan::format(f, &it.input, &it.out), + Instruction::ArcSinh(it) => ArcSinh::format(f, &it.input, &it.out), + Instruction::ArcCosh(it) => ArcCosh::format(f, &it.input, &it.out), + Instruction::ArcTanh(it) => ArcTanh::format(f, &it.input, &it.out), Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out), diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index c0976439e..4ff115a39 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -151,9 +151,14 @@ macro_rules! function { function!(Log, "log"); function!(Cos, "cos"); function!(Sin, "sin"); +function!(Sinh, "sinh"); +function!(Cosh, "cosh"); function!(ArcCos, "acos"); function!(ArcSin, "asin"); function!(ArcTan, "atan"); +function!(ArcSinh, "asinh"); +function!(ArcCosh, "acosh"); +function!(ArcTanh, "atanh"); function!(Sqrt, "sqrt"); function!(Exp, "exp"); function!(Ceil, "ceil"); diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index a71ff4dfd..1f87a0adf 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -28,6 +28,26 @@ impl<'a> Visitor<'a> { let result = self.append_operation_with_result(operation); self.insert_variable(out, result); } + Arithmetic::Sinh(_sinh) => { + todo!("intr_sinh does not exist") + /*let value = self.get_variable(sinh.input); + let result = self.append_operation_with_result(llvm_ods::intr_sinh( + self.context, + value, + self.location, + )); + self.insert_variable(out, result);*/ + } + Arithmetic::Cosh(_cosh) => { + todo!("intr_cosh does not exist") + /*let value = self.get_variable(cosh.input); + let result = self.append_operation_with_result(llvm_ods::intr_cosh( + self.context, + value, + self.location, + )); + self.insert_variable(out, result);*/ + } Arithmetic::ArcCos(_acos) => { todo!("intr_acos does not exist") /*let value = self.get_variable(acos.input); @@ -58,6 +78,36 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result);*/ } + Arithmetic::ArcSinh(_asinh) => { + todo!("intr_asinh does not exist") + /*let value = self.get_variable(asinh.input); + let result = self.append_operation_with_result(llvm_ods::intr_asinh( + self.context, + value, + self.location, + )); + self.insert_variable(out, result);*/ + } + Arithmetic::ArcCosh(_acosh) => { + todo!("intr_acosh does not exist") + /*let value = self.get_variable(acosh.input); + let result = self.append_operation_with_result(llvm_ods::intr_acosh( + self.context, + value, + self.location, + )); + self.insert_variable(out, result);*/ + } + Arithmetic::ArcTanh(_atanh) => { + todo!("intr_atanh does not exist") + /*let value = self.get_variable(atanh.input); + let result = self.append_operation_with_result(llvm_ods::intr_atanh( + self.context, + value, + self.location, + )); + self.insert_variable(out, result);*/ + } Arithmetic::ArcTan2(_atan2) => { todo!("intr_atan2 does not exist") /*let (y, x) = self.get_binary_op_variable(atan2.lhs, atan2.rhs); diff --git a/crates/cubecl-ir/src/arithmetic.rs b/crates/cubecl-ir/src/arithmetic.rs index e798a01ef..fdedb3a5c 100644 --- a/crates/cubecl-ir/src/arithmetic.rs +++ b/crates/cubecl-ir/src/arithmetic.rs @@ -23,9 +23,14 @@ pub enum Arithmetic { Cos(UnaryOperator), Sin(UnaryOperator), Tanh(UnaryOperator), + Sinh(UnaryOperator), + Cosh(UnaryOperator), ArcCos(UnaryOperator), ArcSin(UnaryOperator), ArcTan(UnaryOperator), + ArcSinh(UnaryOperator), + ArcCosh(UnaryOperator), + ArcTanh(UnaryOperator), ArcTan2(BinaryOperator), Powf(BinaryOperator), Sqrt(UnaryOperator), @@ -65,9 +70,14 @@ impl Display for Arithmetic { Arithmetic::Cos(op) => write!(f, "{}.cos()", op.input), Arithmetic::Sin(op) => write!(f, "{}.sin()", op.input), Arithmetic::Tanh(op) => write!(f, "{}.tanh()", op.input), + Arithmetic::Sinh(op) => write!(f, "{}.sinh()", op.input), + Arithmetic::Cosh(op) => write!(f, "{}.cosh()", op.input), Arithmetic::ArcCos(op) => write!(f, "{}.acos()", op.input), Arithmetic::ArcSin(op) => write!(f, "{}.asin()", op.input), Arithmetic::ArcTan(op) => write!(f, "{}.atan()", op.input), + Arithmetic::ArcSinh(op) => write!(f, "{}.asinh()", op.input), + Arithmetic::ArcCosh(op) => write!(f, "{}.acosh()", op.input), + Arithmetic::ArcTanh(op) => write!(f, "{}.atanh()", op.input), Arithmetic::ArcTan2(op) => write!(f, "{}.atan2({})", op.lhs, op.rhs), Arithmetic::Powf(op) => write!(f, "{}.pow({})", op.lhs, op.rhs), Arithmetic::Sqrt(op) => write!(f, "{}.sqrt()", op.input), diff --git a/crates/cubecl-ir/src/processing.rs b/crates/cubecl-ir/src/processing.rs index 0a2ad59fe..84f3cef0f 100644 --- a/crates/cubecl-ir/src/processing.rs +++ b/crates/cubecl-ir/src/processing.rs @@ -104,6 +104,12 @@ impl ScopeProcessing { Arithmetic::Tanh(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } + Arithmetic::Sinh(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::Cosh(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } Arithmetic::ArcCos(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } @@ -113,6 +119,15 @@ impl ScopeProcessing { Arithmetic::ArcTan(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } + Arithmetic::ArcSinh(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::ArcCosh(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::ArcTanh(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } Arithmetic::ArcTan2(op) => { sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap()); sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap()); diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index 56e5ab2bb..e38c9ebbf 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -90,9 +90,14 @@ impl Optimizer { | Arithmetic::Cos(unary_operator) | Arithmetic::Sin(unary_operator) | Arithmetic::Tanh(unary_operator) + | Arithmetic::Sinh(unary_operator) + | Arithmetic::Cosh(unary_operator) | Arithmetic::ArcCos(unary_operator) | Arithmetic::ArcSin(unary_operator) | Arithmetic::ArcTan(unary_operator) + | Arithmetic::ArcSinh(unary_operator) + | Arithmetic::ArcCosh(unary_operator) + | Arithmetic::ArcTanh(unary_operator) | Arithmetic::Sqrt(unary_operator) | Arithmetic::Round(unary_operator) | Arithmetic::Floor(unary_operator) diff --git a/crates/cubecl-opt/src/passes/constant_prop.rs b/crates/cubecl-opt/src/passes/constant_prop.rs index 25a87dd9d..612f56a51 100644 --- a/crates/cubecl-opt/src/passes/constant_prop.rs +++ b/crates/cubecl-opt/src/passes/constant_prop.rs @@ -369,9 +369,14 @@ fn try_const_eval_arithmetic(op: &mut Arithmetic) -> Option Arithmetic::Cos(op) => const_eval_float!(op.input; num::Float::cos), Arithmetic::Sin(op) => const_eval_float!(op.input; num::Float::sin), Arithmetic::Tanh(op) => const_eval_float!(op.input; num::Float::tanh), + Arithmetic::Sinh(op) => const_eval_float!(op.input; num::Float::sinh), + Arithmetic::Cosh(op) => const_eval_float!(op.input; num::Float::cosh), Arithmetic::ArcCos(op) => const_eval_float!(op.input; num::Float::acos), Arithmetic::ArcSin(op) => const_eval_float!(op.input; num::Float::asin), Arithmetic::ArcTan(op) => const_eval_float!(op.input; num::Float::atan), + Arithmetic::ArcSinh(op) => const_eval_float!(op.input; num::Float::asinh), + Arithmetic::ArcCosh(op) => const_eval_float!(op.input; num::Float::acosh), + Arithmetic::ArcTanh(op) => const_eval_float!(op.input; num::Float::atanh), Arithmetic::ArcTan2(op) => { use ConstantScalarValue::*; if let (Some(lhs), Some(rhs)) = (op.lhs.as_const(), op.rhs.as_const()) { diff --git a/crates/cubecl-spirv/src/arithmetic.rs b/crates/cubecl-spirv/src/arithmetic.rs index e20c20919..bafe08b5a 100644 --- a/crates/cubecl-spirv/src/arithmetic.rs +++ b/crates/cubecl-spirv/src/arithmetic.rs @@ -301,6 +301,22 @@ impl SpirvCompiler { } }) } + Arithmetic::Sinh(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::sinh(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::Cosh(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::cosh(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } Arithmetic::ArcCos(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { T::acos(b, ty, input, out); @@ -325,6 +341,30 @@ impl SpirvCompiler { } }) } + Arithmetic::ArcSinh(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::asinh(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::ArcCosh(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::acosh(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::ArcTanh(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::atanh(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } Arithmetic::ArcTan2(op) => { self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| { T::atan2(b, ty, lhs, rhs, out); diff --git a/crates/cubecl-spirv/src/extensions.rs b/crates/cubecl-spirv/src/extensions.rs index 3191b62a1..b8c52b572 100644 --- a/crates/cubecl-spirv/src/extensions.rs +++ b/crates/cubecl-spirv/src/extensions.rs @@ -13,9 +13,14 @@ pub trait TargetExtensions { fn sin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn cos(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn tanh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn sinh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn cosh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn asin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn acos(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn atan(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn asinh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn acosh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn atanh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn atan2(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word); fn pow(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word); fn exp(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); @@ -75,6 +80,14 @@ pub mod glcompute { b.tanh_id(ty, Some(out), input).unwrap(); } + fn sinh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.sinh_id(ty, Some(out), input).unwrap(); + } + + fn cosh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.cosh_id(ty, Some(out), input).unwrap(); + } + fn asin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { b.asin_id(ty, Some(out), input).unwrap(); } @@ -87,6 +100,18 @@ pub mod glcompute { b.atan_id(ty, Some(out), input).unwrap(); } + fn asinh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.asinh_id(ty, Some(out), input).unwrap(); + } + + fn acosh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.acosh_id(ty, Some(out), input).unwrap(); + } + + fn atanh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.atanh_id(ty, Some(out), input).unwrap(); + } + fn atan2(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word) { b.atan2_id(ty, Some(out), lhs, rhs).unwrap(); } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index fc165d7fd..30e6f40b0 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -752,6 +752,14 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(out), }), + cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos { input: self.compile_variable(op.input), out: self.compile_variable(out), @@ -764,6 +772,18 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(out), }), + cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index dd0a6bd6d..c42f7262b 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -121,6 +121,14 @@ pub enum Instruction { input: Variable, out: Variable, }, + Sinh { + input: Variable, + out: Variable, + }, + Cosh { + input: Variable, + out: Variable, + }, ArcCos { input: Variable, out: Variable, @@ -133,6 +141,18 @@ pub enum Instruction { input: Variable, out: Variable, }, + ArcSinh { + input: Variable, + out: Variable, + }, + ArcCosh { + input: Variable, + out: Variable, + }, + ArcTanh { + input: Variable, + out: Variable, + }, ArcTan2 { lhs: Variable, rhs: Variable, @@ -619,6 +639,14 @@ impl Display for Instruction { result } + Instruction::Sinh { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = sinh({input});") + } + Instruction::Cosh { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = cosh({input});") + } Instruction::ArcCos { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = acos({input});") @@ -631,6 +659,18 @@ impl Display for Instruction { let out = out.fmt_left(); writeln!(f, "{out} = atan({input});") } + Instruction::ArcSinh { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = asinh({input});") + } + Instruction::ArcCosh { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = acosh({input});") + } + Instruction::ArcTanh { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = atanh({input});") + } Instruction::ArcTan2 { lhs, rhs, out } => { let out = out.fmt_left(); writeln!(f, "{out} = atan2({lhs}, {rhs});") From 89b1642d058f8968114ca6fb71dfc001b5341db9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 4 Sep 2025 21:34:51 +0200 Subject: [PATCH 04/23] Add degrees and radians function --- .../cubecl-core/src/frontend/element/float.rs | 2 + .../src/frontend/element/float/typemap.rs | 2 + .../src/frontend/operation/unary.rs | 24 ++++++++++ crates/cubecl-core/src/runtime_tests/unary.rs | 44 +++++++++++++++++++ crates/cubecl-cpp/src/shared/base.rs | 10 +++++ crates/cubecl-cpp/src/shared/instruction.rs | 4 ++ crates/cubecl-cpp/src/shared/unary.rs | 2 + .../compiler/visitor/operation/arithmetic.rs | 20 +++++++++ crates/cubecl-ir/src/arithmetic.rs | 4 ++ crates/cubecl-ir/src/processing.rs | 6 +++ crates/cubecl-opt/src/instructions.rs | 2 + crates/cubecl-opt/src/passes/constant_prop.rs | 2 + crates/cubecl-spirv/src/arithmetic.rs | 16 +++++++ crates/cubecl-spirv/src/extensions.rs | 10 +++++ .../cubecl-wgpu/src/compiler/wgsl/compiler.rs | 8 ++++ .../src/compiler/wgsl/instructions.rs | 16 +++++++ 16 files changed, 172 insertions(+) diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs index 9a9361251..d01ae8e29 100644 --- a/crates/cubecl-core/src/frontend/element/float.rs +++ b/crates/cubecl-core/src/frontend/element/float.rs @@ -34,6 +34,8 @@ pub trait Float: + ArcSinh + ArcCosh + ArcTanh + + Degrees + + Radians + ArcTan2 + Powf + Sqrt diff --git a/crates/cubecl-core/src/frontend/element/float/typemap.rs b/crates/cubecl-core/src/frontend/element/float/typemap.rs index 7c3bd3a13..2c2067ae2 100644 --- a/crates/cubecl-core/src/frontend/element/float/typemap.rs +++ b/crates/cubecl-core/src/frontend/element/float/typemap.rs @@ -252,6 +252,8 @@ impl ArcTan for ElemExpand {} impl ArcSinh for ElemExpand {} impl ArcCosh for ElemExpand {} impl ArcTanh for ElemExpand {} +impl Degrees for ElemExpand {} +impl Radians for ElemExpand {} impl ArcTan2 for ElemExpand {} impl Powf for ElemExpand {} impl Sqrt for ElemExpand {} diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index c0e7adc9f..05cb6c665 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -287,6 +287,30 @@ impl_unary_func!( f32, f64 ); +impl_unary_func!( + Degrees, + degrees, + __expand_degrees, + Arithmetic::Degrees, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + Radians, + radians, + __expand_radians, + Arithmetic::Radians, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); impl_unary_func!( Sqrt, sqrt, diff --git a/crates/cubecl-core/src/runtime_tests/unary.rs b/crates/cubecl-core/src/runtime_tests/unary.rs index 89f9cc986..0c0ef4eb0 100644 --- a/crates/cubecl-core/src/runtime_tests/unary.rs +++ b/crates/cubecl-core/src/runtime_tests/unary.rs @@ -378,6 +378,48 @@ test_unary_impl!(test_atanh, F, F::atanh, [ } ]); +test_unary_impl!(test_degrees, F, F::degrees, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1.5707963268, 3.1415926536, -1.5707963268, -3.1415926536], + expected: as_type![F: 0., 90., 180., -90., -180.] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1.5707963268, 3.1415926536, -1.5707963268], + expected: as_type![F: 0., 90., 180., -90.] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1.5707963268, 3.1415926536, -1.5707963268], + expected: as_type![F: 0., 90., 180., -90.] + } +]); + +test_unary_impl!(test_radians, F, F::radians, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 90., 180., -90., -180.], + expected: as_type![F: 0., 1.5707963268, 3.1415926536, -1.5707963268, -3.1415926536] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 90., 180., -90.], + expected: as_type![F: 0., 1.5707963268, 3.1415926536, -1.5707963268] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 90., 180., -90.], + expected: as_type![F: 0., 1.5707963268, 3.1415926536, -1.5707963268] + } +]); + test_unary_impl!( test_magnitude, F, @@ -596,6 +638,8 @@ macro_rules! testgen_unary { add_test!(test_asinh); add_test!(test_acosh); add_test!(test_atanh); + add_test!(test_degrees); + add_test!(test_radians); add_test!(test_normalize); add_test!(test_magnitude); add_test!(test_abs); diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index bd4f09524..a3bdba49a 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -972,6 +972,16 @@ impl CppCompiler { D::register_instruction_extension(&mut self.extensions, &instruction); instructions.push(instruction) } + gpu::Arithmetic::Degrees(op) => { + let instruction = Instruction::Degrees(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::Radians(op) => { + let instruction = Instruction::Radians(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } gpu::Arithmetic::ArcTan2(op) => { let instruction = Instruction::ArcTan2(self.compile_binary(op, out)); D::register_instruction_extension(&mut self.extensions, &instruction); diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index df87077a3..9d4ad8ee8 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -171,6 +171,8 @@ pub enum Instruction { ArcSinh(UnaryInstruction), ArcCosh(UnaryInstruction), ArcTanh(UnaryInstruction), + Degrees(UnaryInstruction), + Radians(UnaryInstruction), ArcTan2(BinaryInstruction), Powf(BinaryInstruction), Sqrt(UnaryInstruction), @@ -520,6 +522,8 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ Instruction::ArcSinh(it) => ArcSinh::format(f, &it.input, &it.out), Instruction::ArcCosh(it) => ArcCosh::format(f, &it.input, &it.out), Instruction::ArcTanh(it) => ArcTanh::format(f, &it.input, &it.out), + Instruction::Degrees(it) => Degrees::format(f, &it.input, &it.out), + Instruction::Radians(it) => Radians::format(f, &it.input, &it.out), Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out), diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index 4ff115a39..e9bae51f8 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -159,6 +159,8 @@ function!(ArcTan, "atan"); function!(ArcSinh, "asinh"); function!(ArcCosh, "acosh"); function!(ArcTanh, "atanh"); +function!(Degrees, "degrees"); +function!(Radians, "radians"); function!(Sqrt, "sqrt"); function!(Exp, "exp"); function!(Ceil, "ceil"); diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index 1f87a0adf..0a7f88b5e 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -78,6 +78,26 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result);*/ } + Arithmetic::Degrees(_degrees) => { + todo!("intr_degrees does not exist") + /*let value = self.get_variable(degrees.input); + let result = self.append_operation_with_result(llvm_ods::intr_degrees( + self.context, + value, + self.location, + )); + self.insert_variable(out, result);*/ + } + Arithmetic::Radians(_radians) => { + todo!("intr_radians does not exist") + /*let value = self.get_variable(radians.input); + let result = self.append_operation_with_result(llvm_ods::intr_radians( + self.context, + value, + self.location, + )); + self.insert_variable(out, result);*/ + } Arithmetic::ArcSinh(_asinh) => { todo!("intr_asinh does not exist") /*let value = self.get_variable(asinh.input); diff --git a/crates/cubecl-ir/src/arithmetic.rs b/crates/cubecl-ir/src/arithmetic.rs index fdedb3a5c..2ffedf3fc 100644 --- a/crates/cubecl-ir/src/arithmetic.rs +++ b/crates/cubecl-ir/src/arithmetic.rs @@ -31,6 +31,8 @@ pub enum Arithmetic { ArcSinh(UnaryOperator), ArcCosh(UnaryOperator), ArcTanh(UnaryOperator), + Degrees(UnaryOperator), + Radians(UnaryOperator), ArcTan2(BinaryOperator), Powf(BinaryOperator), Sqrt(UnaryOperator), @@ -78,6 +80,8 @@ impl Display for Arithmetic { Arithmetic::ArcSinh(op) => write!(f, "{}.asinh()", op.input), Arithmetic::ArcCosh(op) => write!(f, "{}.acosh()", op.input), Arithmetic::ArcTanh(op) => write!(f, "{}.atanh()", op.input), + Arithmetic::Degrees(op) => write!(f, "{}.degrees()", op.input), + Arithmetic::Radians(op) => write!(f, "{}.radians()", op.input), Arithmetic::ArcTan2(op) => write!(f, "{}.atan2({})", op.lhs, op.rhs), Arithmetic::Powf(op) => write!(f, "{}.pow({})", op.lhs, op.rhs), Arithmetic::Sqrt(op) => write!(f, "{}.sqrt()", op.input), diff --git a/crates/cubecl-ir/src/processing.rs b/crates/cubecl-ir/src/processing.rs index 84f3cef0f..5eee2bae5 100644 --- a/crates/cubecl-ir/src/processing.rs +++ b/crates/cubecl-ir/src/processing.rs @@ -128,6 +128,12 @@ impl ScopeProcessing { Arithmetic::ArcTanh(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } + Arithmetic::Degrees(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::Radians(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } Arithmetic::ArcTan2(op) => { sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap()); sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap()); diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index e38c9ebbf..69fd4cf01 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -98,6 +98,8 @@ impl Optimizer { | Arithmetic::ArcSinh(unary_operator) | Arithmetic::ArcCosh(unary_operator) | Arithmetic::ArcTanh(unary_operator) + | Arithmetic::Degrees(unary_operator) + | Arithmetic::Radians(unary_operator) | Arithmetic::Sqrt(unary_operator) | Arithmetic::Round(unary_operator) | Arithmetic::Floor(unary_operator) diff --git a/crates/cubecl-opt/src/passes/constant_prop.rs b/crates/cubecl-opt/src/passes/constant_prop.rs index 612f56a51..527bad269 100644 --- a/crates/cubecl-opt/src/passes/constant_prop.rs +++ b/crates/cubecl-opt/src/passes/constant_prop.rs @@ -377,6 +377,8 @@ fn try_const_eval_arithmetic(op: &mut Arithmetic) -> Option Arithmetic::ArcSinh(op) => const_eval_float!(op.input; num::Float::asinh), Arithmetic::ArcCosh(op) => const_eval_float!(op.input; num::Float::acosh), Arithmetic::ArcTanh(op) => const_eval_float!(op.input; num::Float::atanh), + Arithmetic::Degrees(op) => const_eval_float!(op.input; num::Float::to_degrees), + Arithmetic::Radians(op) => const_eval_float!(op.input; num::Float::to_radians), Arithmetic::ArcTan2(op) => { use ConstantScalarValue::*; if let (Some(lhs), Some(rhs)) = (op.lhs.as_const(), op.rhs.as_const()) { diff --git a/crates/cubecl-spirv/src/arithmetic.rs b/crates/cubecl-spirv/src/arithmetic.rs index bafe08b5a..8b1f4513b 100644 --- a/crates/cubecl-spirv/src/arithmetic.rs +++ b/crates/cubecl-spirv/src/arithmetic.rs @@ -365,6 +365,22 @@ impl SpirvCompiler { } }) } + Arithmetic::Degrees(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::degrees(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::Radians(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::radians(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } Arithmetic::ArcTan2(op) => { self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| { T::atan2(b, ty, lhs, rhs, out); diff --git a/crates/cubecl-spirv/src/extensions.rs b/crates/cubecl-spirv/src/extensions.rs index b8c52b572..d696d42a7 100644 --- a/crates/cubecl-spirv/src/extensions.rs +++ b/crates/cubecl-spirv/src/extensions.rs @@ -21,6 +21,8 @@ pub trait TargetExtensions { fn asinh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn acosh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn atanh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn degrees(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn radians(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn atan2(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word); fn pow(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word); fn exp(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); @@ -112,6 +114,14 @@ pub mod glcompute { b.atanh_id(ty, Some(out), input).unwrap(); } + fn degrees(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.degrees_id(ty, Some(out), input).unwrap(); + } + + fn radians(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.radians_id(ty, Some(out), input).unwrap(); + } + fn atan2(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word) { b.atan2_id(ty, Some(out), lhs, rhs).unwrap(); } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 30e6f40b0..1c516a60b 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -784,6 +784,14 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(out), }), + cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index c42f7262b..0d5339b2f 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -153,6 +153,14 @@ pub enum Instruction { input: Variable, out: Variable, }, + Degrees { + input: Variable, + out: Variable, + }, + Radians { + input: Variable, + out: Variable, + }, ArcTan2 { lhs: Variable, rhs: Variable, @@ -671,6 +679,14 @@ impl Display for Instruction { let out = out.fmt_left(); writeln!(f, "{out} = atanh({input});") } + Instruction::Degrees { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = degrees({input});") + } + Instruction::Radians { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = radians({input});") + } Instruction::ArcTan2 { lhs, rhs, out } => { let out = out.fmt_left(); writeln!(f, "{out} = atan2({lhs}, {rhs});") From e6d04109bca0e9b52b6d7cede48373b1bd5670b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sat, 6 Sep 2025 00:35:32 +0200 Subject: [PATCH 05/23] Implement trigonometric functions in CPU backend Also try to handle half precision for some of the new trigonometric operations --- .../src/frontend/operation/unary.rs | 4 - crates/cubecl-cpp/src/shared/dialect.rs | 14 +++ crates/cubecl-cpp/src/shared/unary.rs | 50 ++++++-- .../compiler/visitor/operation/arithmetic.rs | 116 ++++++++---------- 4 files changed, 106 insertions(+), 78 deletions(-) diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index 05cb6c665..10a08af4d 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -211,10 +211,6 @@ impl_unary_func!( f32, f64 ); -// Open Questions: -// - When to use metal safe / atomic stuff and when not -// - When do I need to check for Bfloats and stuff? -// - How to add support for LLVM MIR (in cube-cpu) impl_unary_func!( ArcCos, acos, diff --git a/crates/cubecl-cpp/src/shared/dialect.rs b/crates/cubecl-cpp/src/shared/dialect.rs index b6e70857a..f122fd320 100644 --- a/crates/cubecl-cpp/src/shared/dialect.rs +++ b/crates/cubecl-cpp/src/shared/dialect.rs @@ -539,6 +539,20 @@ pub trait DialectInstructions { } } + fn compile_instruction_degrees_scalar>( + f: &mut std::fmt::Formatter<'_>, + input: T, + ) -> std::fmt::Result { + write!(f, "{input}*57.29577951308232") + } + + fn compile_instruction_radians_scalar>( + f: &mut std::fmt::Formatter<'_>, + input: T, + ) -> std::fmt::Result { + write!(f, "{input}*0.017453292519943295") + } + // unary fn compile_instruction_find_first_set>( f: &mut std::fmt::Formatter<'_>, diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index e9bae51f8..55e1957b5 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -151,16 +151,14 @@ macro_rules! function { function!(Log, "log"); function!(Cos, "cos"); function!(Sin, "sin"); -function!(Sinh, "sinh"); -function!(Cosh, "cosh"); -function!(ArcCos, "acos"); -function!(ArcSin, "asin"); -function!(ArcTan, "atan"); -function!(ArcSinh, "asinh"); -function!(ArcCosh, "acosh"); -function!(ArcTanh, "atanh"); -function!(Degrees, "degrees"); -function!(Radians, "radians"); +function!(Sinh, "sinh", false); +function!(Cosh, "cosh", false); +function!(ArcCos, "acos", false); +function!(ArcSin, "asin", false); +function!(ArcTan, "atan", false); +function!(ArcSinh, "asinh", false); +function!(ArcCosh, "acosh", false); +function!(ArcTanh, "atanh", false); function!(Sqrt, "sqrt"); function!(Exp, "exp"); function!(Ceil, "ceil"); @@ -202,6 +200,38 @@ impl Unary for Tanh { } } +pub struct Degrees; + +impl Unary for Degrees { + fn format_scalar>( + f: &mut std::fmt::Formatter<'_>, + input: Input, + _out_elem: Elem, + ) -> std::fmt::Result { + D::compile_instruction_degrees_scalar(f, input) + } + + fn can_optimize() -> bool { + false + } +} + +pub struct Radians; + +impl Unary for Radians { + fn format_scalar>( + f: &mut std::fmt::Formatter<'_>, + input: Input, + _out_elem: Elem, + ) -> std::fmt::Result { + D::compile_instruction_radians_scalar(f, input) + } + + fn can_optimize() -> bool { + false + } +} + pub fn zero_extend(input: impl Component) -> String { match input.elem() { Elem::I8 => format!("{}({}({input}))", Elem::::U32, Elem::::U8), diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index 0a7f88b5e..07282d0d4 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -3,7 +3,7 @@ use tracel_llvm::melior::{ dialect::{ arith::{self}, llvm, - ods::{llvm as llvm_ods, vector}, + ods::{llvm as llvm_ods, math as math_ods, vector}, }, ir::Attribute, }; @@ -28,116 +28,104 @@ impl<'a> Visitor<'a> { let result = self.append_operation_with_result(operation); self.insert_variable(out, result); } - Arithmetic::Sinh(_sinh) => { - todo!("intr_sinh does not exist") - /*let value = self.get_variable(sinh.input); + Arithmetic::Sinh(sinh) => { + let value = self.get_variable(sinh.input); let result = self.append_operation_with_result(llvm_ods::intr_sinh( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } - Arithmetic::Cosh(_cosh) => { - todo!("intr_cosh does not exist") - /*let value = self.get_variable(cosh.input); + Arithmetic::Cosh(cosh) => { + let value = self.get_variable(cosh.input); let result = self.append_operation_with_result(llvm_ods::intr_cosh( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } - Arithmetic::ArcCos(_acos) => { - todo!("intr_acos does not exist") - /*let value = self.get_variable(acos.input); - let result = self.append_operation_with_result(llvm_ods::intr_acos( + Arithmetic::ArcCos(acos) => { + let value = self.get_variable(acos.input); + let result = self.append_operation_with_result(math_ods::acos( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } - Arithmetic::ArcSin(_asin) => { - todo!("intr_asin does not exist") - /*let value = self.get_variable(asin.input); - let result = self.append_operation_with_result(llvm_ods::intr_asin( + Arithmetic::ArcSin(asin) => { + let value = self.get_variable(asin.input); + let result = self.append_operation_with_result(math_ods::asin( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } - Arithmetic::ArcTan(_atan) => { - todo!("intr_atan does not exist") - /*let value = self.get_variable(acos.input); - let result = self.append_operation_with_result(llvm_ods::intr_atan( + Arithmetic::ArcTan(atan) => { + let value = self.get_variable(atan.input); + let result = self.append_operation_with_result(math_ods::atan( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } - Arithmetic::Degrees(_degrees) => { - todo!("intr_degrees does not exist") - /*let value = self.get_variable(degrees.input); - let result = self.append_operation_with_result(llvm_ods::intr_degrees( - self.context, - value, - self.location, - )); - self.insert_variable(out, result);*/ + Arithmetic::Degrees(degrees) => { + let value = self.get_variable(degrees.input); + // 180 / pi + let f = self.create_float_constant_from_item(degrees.input.ty, 57.29577951308232); + let result = + self.append_operation_with_result(arith::mulf(value, f, self.location)); + self.insert_variable(out, result); } - Arithmetic::Radians(_radians) => { - todo!("intr_radians does not exist") - /*let value = self.get_variable(radians.input); - let result = self.append_operation_with_result(llvm_ods::intr_radians( - self.context, - value, - self.location, - )); - self.insert_variable(out, result);*/ + Arithmetic::Radians(radians) => { + let value = self.get_variable(radians.input); + // pi / 180 + let f = + self.create_float_constant_from_item(radians.input.ty, 0.017453292519943295); + let result = + self.append_operation_with_result(arith::mulf(value, f, self.location)); + self.insert_variable(out, result); } - Arithmetic::ArcSinh(_asinh) => { - todo!("intr_asinh does not exist") - /*let value = self.get_variable(asinh.input); - let result = self.append_operation_with_result(llvm_ods::intr_asinh( + Arithmetic::ArcSinh(asinh) => { + let value = self.get_variable(asinh.input); + let result = self.append_operation_with_result(math_ods::asinh( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } - Arithmetic::ArcCosh(_acosh) => { - todo!("intr_acosh does not exist") - /*let value = self.get_variable(acosh.input); - let result = self.append_operation_with_result(llvm_ods::intr_acosh( + Arithmetic::ArcCosh(acosh) => { + let value = self.get_variable(acosh.input); + let result = self.append_operation_with_result(math_ods::acosh( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } - Arithmetic::ArcTanh(_atanh) => { - todo!("intr_atanh does not exist") - /*let value = self.get_variable(atanh.input); - let result = self.append_operation_with_result(llvm_ods::intr_atanh( + Arithmetic::ArcTanh(atanh) => { + let value = self.get_variable(atanh.input); + let result = self.append_operation_with_result(math_ods::atanh( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } - Arithmetic::ArcTan2(_atan2) => { - todo!("intr_atan2 does not exist") - /*let (y, x) = self.get_binary_op_variable(atan2.lhs, atan2.rhs); - let result = self.append_operation_with_result(llvm_ods::intr_atan2( + Arithmetic::ArcTan2(atan2) => { + let (lhs, rhs) = self.get_binary_op_variable(atan2.lhs, atan2.rhs); + let result = self.append_operation_with_result(math_ods::atan_2( self.context, - y, - x, + lhs, + rhs, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } Arithmetic::Ceil(ceil) => { let value = self.get_variable(ceil.input); From 67357884d6dd23d4b84dd60424110dcbc88730d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sat, 6 Sep 2025 17:54:37 +0200 Subject: [PATCH 06/23] Add to_degrees and to_radians functions to cube-std --- crates/cubecl-cpp/src/shared/dialect.rs | 14 -------------- crates/cubecl-cpp/src/shared/unary.rs | 4 ++-- crates/cubecl-std/src/lib.rs | 11 +++++++++++ 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/crates/cubecl-cpp/src/shared/dialect.rs b/crates/cubecl-cpp/src/shared/dialect.rs index f122fd320..b6e70857a 100644 --- a/crates/cubecl-cpp/src/shared/dialect.rs +++ b/crates/cubecl-cpp/src/shared/dialect.rs @@ -539,20 +539,6 @@ pub trait DialectInstructions { } } - fn compile_instruction_degrees_scalar>( - f: &mut std::fmt::Formatter<'_>, - input: T, - ) -> std::fmt::Result { - write!(f, "{input}*57.29577951308232") - } - - fn compile_instruction_radians_scalar>( - f: &mut std::fmt::Formatter<'_>, - input: T, - ) -> std::fmt::Result { - write!(f, "{input}*0.017453292519943295") - } - // unary fn compile_instruction_find_first_set>( f: &mut std::fmt::Formatter<'_>, diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index 55e1957b5..a7e97ccb5 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -208,7 +208,7 @@ impl Unary for Degrees { input: Input, _out_elem: Elem, ) -> std::fmt::Result { - D::compile_instruction_degrees_scalar(f, input) + write!(f, "{input}*57.29577951308232f") } fn can_optimize() -> bool { @@ -224,7 +224,7 @@ impl Unary for Radians { input: Input, _out_elem: Elem, ) -> std::fmt::Result { - D::compile_instruction_radians_scalar(f, input) + write!(f, "{input}*0.017453292519943295f") } fn can_optimize() -> bool { diff --git a/crates/cubecl-std/src/lib.rs b/crates/cubecl-std/src/lib.rs index 8260c10b0..647dbcfbb 100644 --- a/crates/cubecl-std/src/lib.rs +++ b/crates/cubecl-std/src/lib.rs @@ -1,4 +1,5 @@ //! Cubecl standard library. +use core::f32; extern crate alloc; @@ -23,3 +24,13 @@ pub mod tests; pub fn div_ceil(a: u32, b: u32) -> u32 { (a + b - 1) / b } + +#[cube] +pub fn to_degrees(val: F) -> F { + val * F::new(180.0 / f32::consts::PI) +} + +#[cube] +pub fn to_radians(val: F) -> F { + val * F::new(f32::consts::PI / 180.0) +} From d07c2b2d907d609aaea7f87c0c7c9010758a0473 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sun, 7 Sep 2025 00:25:55 +0200 Subject: [PATCH 07/23] Register math to llvm transform of mlir --- crates/cubecl-cpu/src/compiler/module.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/cubecl-cpu/src/compiler/module.rs b/crates/cubecl-cpu/src/compiler/module.rs index 0a2f5cc09..db9e012d2 100644 --- a/crates/cubecl-cpu/src/compiler/module.rs +++ b/crates/cubecl-cpu/src/compiler/module.rs @@ -73,6 +73,7 @@ impl<'a> Module<'a> { pass_manager.add_pass(pass::conversion::create_vector_to_llvm()); pass_manager.add_pass(pass::conversion::create_arith_to_llvm()); pass_manager.add_pass(pass::conversion::create_func_to_llvm()); + pass_manager.add_pass(pass::conversion::create_math_to_llvm()); pass_manager.add_pass(pass::transform::create_inliner()); pass_manager.add_pass(pass::conversion::create_reconcile_unrealized_casts()); pass_manager.add_pass(pass::transform::create_sccp()); From 7aa191333215114eb24da0fda2e2ed15d6f496e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sun, 7 Sep 2025 12:03:30 +0200 Subject: [PATCH 08/23] Disable all ods_math dependant arithmetics for now --- crates/cubecl-cpu/src/compiler/module.rs | 2 +- .../compiler/visitor/operation/arithmetic.rs | 135 ++++++++++-------- 2 files changed, 79 insertions(+), 58 deletions(-) diff --git a/crates/cubecl-cpu/src/compiler/module.rs b/crates/cubecl-cpu/src/compiler/module.rs index db9e012d2..407004349 100644 --- a/crates/cubecl-cpu/src/compiler/module.rs +++ b/crates/cubecl-cpu/src/compiler/module.rs @@ -73,7 +73,7 @@ impl<'a> Module<'a> { pass_manager.add_pass(pass::conversion::create_vector_to_llvm()); pass_manager.add_pass(pass::conversion::create_arith_to_llvm()); pass_manager.add_pass(pass::conversion::create_func_to_llvm()); - pass_manager.add_pass(pass::conversion::create_math_to_llvm()); + // pass_manager.add_pass(pass::conversion::create_math_to_llvm()); pass_manager.add_pass(pass::transform::create_inliner()); pass_manager.add_pass(pass::conversion::create_reconcile_unrealized_casts()); pass_manager.add_pass(pass::transform::create_sccp()); diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index 07282d0d4..2065c4677 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -3,7 +3,7 @@ use tracel_llvm::melior::{ dialect::{ arith::{self}, llvm, - ods::{llvm as llvm_ods, math as math_ods, vector}, + ods::{llvm as llvm_ods, vector}, }, ir::Attribute, }; @@ -28,104 +28,90 @@ impl<'a> Visitor<'a> { let result = self.append_operation_with_result(operation); self.insert_variable(out, result); } - Arithmetic::Sinh(sinh) => { - let value = self.get_variable(sinh.input); - let result = self.append_operation_with_result(llvm_ods::intr_sinh( - self.context, - value, - self.location, - )); - self.insert_variable(out, result); - } - Arithmetic::Cosh(cosh) => { - let value = self.get_variable(cosh.input); - let result = self.append_operation_with_result(llvm_ods::intr_cosh( - self.context, - value, - self.location, - )); - self.insert_variable(out, result); - } - Arithmetic::ArcCos(acos) => { - let value = self.get_variable(acos.input); + Arithmetic::ArcCos(_acos) => { + todo!( + "Arc operations are only available through the ods::math module, which can not be properly loaded at the moment." + ); + /*let value = self.get_variable(acos.input); let result = self.append_operation_with_result(math_ods::acos( self.context, value, self.location, )); - self.insert_variable(out, result); + self.insert_variable(out, result);*/ } - Arithmetic::ArcSin(asin) => { - let value = self.get_variable(asin.input); + Arithmetic::ArcSin(_asin) => { + todo!( + "Arc operations are only available through the ods::math module, which can not be properly loaded at the moment." + ); + /*let value = self.get_variable(asin.input); let result = self.append_operation_with_result(math_ods::asin( self.context, value, self.location, )); - self.insert_variable(out, result); + self.insert_variable(out, result);*/ } - Arithmetic::ArcTan(atan) => { - let value = self.get_variable(atan.input); + Arithmetic::ArcTan(_atan) => { + todo!( + "Arc operations are only available through the ods::math module, which can not be properly loaded at the moment." + ); + /*let value = self.get_variable(atan.input); let result = self.append_operation_with_result(math_ods::atan( self.context, value, self.location, )); - self.insert_variable(out, result); - } - Arithmetic::Degrees(degrees) => { - let value = self.get_variable(degrees.input); - // 180 / pi - let f = self.create_float_constant_from_item(degrees.input.ty, 57.29577951308232); - let result = - self.append_operation_with_result(arith::mulf(value, f, self.location)); - self.insert_variable(out, result); + self.insert_variable(out, result);*/ } - Arithmetic::Radians(radians) => { - let value = self.get_variable(radians.input); - // pi / 180 - let f = - self.create_float_constant_from_item(radians.input.ty, 0.017453292519943295); - let result = - self.append_operation_with_result(arith::mulf(value, f, self.location)); - self.insert_variable(out, result); - } - Arithmetic::ArcSinh(asinh) => { - let value = self.get_variable(asinh.input); + Arithmetic::ArcSinh(_asinh) => { + todo!( + "Arc operations are only available through the ods::math module, which can not be properly loaded at the moment." + ); + /*let value = self.get_variable(asinh.input); let result = self.append_operation_with_result(math_ods::asinh( self.context, value, self.location, )); - self.insert_variable(out, result); + self.insert_variable(out, result);*/ } - Arithmetic::ArcCosh(acosh) => { - let value = self.get_variable(acosh.input); + Arithmetic::ArcCosh(_acosh) => { + todo!( + "Arc operations are only available through the ods::math module, which can not be properly loaded at the moment." + ); + /*let value = self.get_variable(acosh.input); let result = self.append_operation_with_result(math_ods::acosh( self.context, value, self.location, )); - self.insert_variable(out, result); + self.insert_variable(out, result);*/ } - Arithmetic::ArcTanh(atanh) => { - let value = self.get_variable(atanh.input); + Arithmetic::ArcTanh(_atanh) => { + todo!( + "Arc operations are only available through the ods::math module, which can not be properly loaded at the moment." + ); + /*let value = self.get_variable(atanh.input); let result = self.append_operation_with_result(math_ods::atanh( self.context, value, self.location, )); - self.insert_variable(out, result); + self.insert_variable(out, result);*/ } - Arithmetic::ArcTan2(atan2) => { - let (lhs, rhs) = self.get_binary_op_variable(atan2.lhs, atan2.rhs); + Arithmetic::ArcTan2(_atan2) => { + todo!( + "Arc operations are only available through the ods::math module, which can not be properly loaded at the moment." + ); + /*let (lhs, rhs) = self.get_binary_op_variable(atan2.lhs, atan2.rhs); let result = self.append_operation_with_result(math_ods::atan_2( self.context, lhs, rhs, self.location, )); - self.insert_variable(out, result); + self.insert_variable(out, result);*/ } Arithmetic::Ceil(ceil) => { let value = self.get_variable(ceil.input); @@ -193,6 +179,23 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result); } + Arithmetic::Cosh(cosh) => { + let value = self.get_variable(cosh.input); + let result = self.append_operation_with_result(llvm_ods::intr_cosh( + self.context, + value, + self.location, + )); + self.insert_variable(out, result); + } + Arithmetic::Degrees(degrees) => { + let value = self.get_variable(degrees.input); + // 180 / pi + let f = self.create_float_constant_from_item(degrees.input.ty, 57.29577951308232); + let result = + self.append_operation_with_result(arith::mulf(value, f, self.location)); + self.insert_variable(out, result); + } Arithmetic::Div(div) => { let (lhs, rhs) = self.get_binary_op_variable(div.lhs, div.rhs); let operation = if div.lhs.storage_type().is_signed_int() { @@ -425,6 +428,15 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result); } + Arithmetic::Radians(radians) => { + let value = self.get_variable(radians.input); + // pi / 180 + let f = + self.create_float_constant_from_item(radians.input.ty, 0.017453292519943295); + let result = + self.append_operation_with_result(arith::mulf(value, f, self.location)); + self.insert_variable(out, result); + } Arithmetic::Recip(recip) => { let value = self.get_variable(recip.input); let one = self.create_float_constant_from_item(recip.input.ty, 1.0); @@ -476,6 +488,15 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, output); } + Arithmetic::Sinh(sinh) => { + let value = self.get_variable(sinh.input); + let result = self.append_operation_with_result(llvm_ods::intr_sinh( + self.context, + value, + self.location, + )); + self.insert_variable(out, result); + } Arithmetic::Sqrt(sqrt) => { let input = self.get_variable(sqrt.input); let output = self.append_operation_with_result(llvm_ods::intr_sqrt( From 52429f7f291fe94f6a3143f8b419a6b082ccb055 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sun, 7 Sep 2025 12:17:05 +0200 Subject: [PATCH 09/23] Add dummy implementations instead of todo! to satisfy compilation of tests --- .../compiler/visitor/operation/arithmetic.rs | 77 ++++++++++++------- 1 file changed, 49 insertions(+), 28 deletions(-) diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index 2065c4677..afc91406e 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -28,10 +28,13 @@ impl<'a> Visitor<'a> { let result = self.append_operation_with_result(operation); self.insert_variable(out, result); } - Arithmetic::ArcCos(_acos) => { - todo!( - "Arc operations are only available through the ods::math module, which can not be properly loaded at the moment." - ); + Arithmetic::ArcCos(acos) => { + // Arc operations are only available through the ods::math module, + // which can not be properly loaded at the moment. + // Using dummy for now to satisfy compilation of other tests + let value = self.get_variable(acos.input); + let abs = self.get_absolute_val(acos.input.ty, value); + self.insert_variable(out, abs); /*let value = self.get_variable(acos.input); let result = self.append_operation_with_result(math_ods::acos( self.context, @@ -40,10 +43,13 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result);*/ } - Arithmetic::ArcSin(_asin) => { - todo!( - "Arc operations are only available through the ods::math module, which can not be properly loaded at the moment." - ); + Arithmetic::ArcSin(asin) => { + // Arc operations are only available through the ods::math module, + // which can not be properly loaded at the moment. + // Using dummy for now to satisfy compilation of other tests + let value = self.get_variable(asin.input); + let abs = self.get_absolute_val(asin.input.ty, value); + self.insert_variable(out, abs); /*let value = self.get_variable(asin.input); let result = self.append_operation_with_result(math_ods::asin( self.context, @@ -52,10 +58,13 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result);*/ } - Arithmetic::ArcTan(_atan) => { - todo!( - "Arc operations are only available through the ods::math module, which can not be properly loaded at the moment." - ); + Arithmetic::ArcTan(atan) => { + // Arc operations are only available through the ods::math module, + // which can not be properly loaded at the moment. + // Using dummy for now to satisfy compilation of other tests + let value = self.get_variable(atan.input); + let abs = self.get_absolute_val(atan.input.ty, value); + self.insert_variable(out, abs); /*let value = self.get_variable(atan.input); let result = self.append_operation_with_result(math_ods::atan( self.context, @@ -64,10 +73,13 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result);*/ } - Arithmetic::ArcSinh(_asinh) => { - todo!( - "Arc operations are only available through the ods::math module, which can not be properly loaded at the moment." - ); + Arithmetic::ArcSinh(asinh) => { + // Arc operations are only available through the ods::math module, + // which can not be properly loaded at the moment. + // Using dummy for now to satisfy compilation of other tests + let value = self.get_variable(asinh.input); + let abs = self.get_absolute_val(asinh.input.ty, value); + self.insert_variable(out, abs); /*let value = self.get_variable(asinh.input); let result = self.append_operation_with_result(math_ods::asinh( self.context, @@ -76,10 +88,13 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result);*/ } - Arithmetic::ArcCosh(_acosh) => { - todo!( - "Arc operations are only available through the ods::math module, which can not be properly loaded at the moment." - ); + Arithmetic::ArcCosh(acosh) => { + // Arc operations are only available through the ods::math module, + // which can not be properly loaded at the moment. + // Using dummy for now to satisfy compilation of other tests + let value = self.get_variable(acosh.input); + let abs = self.get_absolute_val(acosh.input.ty, value); + self.insert_variable(out, abs); /*let value = self.get_variable(acosh.input); let result = self.append_operation_with_result(math_ods::acosh( self.context, @@ -88,10 +103,13 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result);*/ } - Arithmetic::ArcTanh(_atanh) => { - todo!( - "Arc operations are only available through the ods::math module, which can not be properly loaded at the moment." - ); + Arithmetic::ArcTanh(atanh) => { + // Arc operations are only available through the ods::math module, + // which can not be properly loaded at the moment. + // Using dummy for now to satisfy compilation of other tests + let value = self.get_variable(atanh.input); + let abs = self.get_absolute_val(atanh.input.ty, value); + self.insert_variable(out, abs); /*let value = self.get_variable(atanh.input); let result = self.append_operation_with_result(math_ods::atanh( self.context, @@ -100,10 +118,13 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result);*/ } - Arithmetic::ArcTan2(_atan2) => { - todo!( - "Arc operations are only available through the ods::math module, which can not be properly loaded at the moment." - ); + Arithmetic::ArcTan2(atan2) => { + // Arc operations are only available through the ods::math module, + // which can not be properly loaded at the moment. + // Using dummy for now to satisfy compilation of other tests + let value = self.get_variable(atan2.lhs); + let abs = self.get_absolute_val(atan2.lhs.ty, value); + self.insert_variable(out, abs); /*let (lhs, rhs) = self.get_binary_op_variable(atan2.lhs, atan2.rhs); let result = self.append_operation_with_result(math_ods::atan_2( self.context, From 11cd568e3b6ecfaa6068089c9fe370bed4d31c20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Tue, 9 Sep 2025 19:59:26 +0200 Subject: [PATCH 10/23] Fix merge formatting --- crates/cubecl-cpp/src/shared/binary.rs | 3 +-- crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/crates/cubecl-cpp/src/shared/binary.rs b/crates/cubecl-cpp/src/shared/binary.rs index a0642bcb6..f7fdc3c3b 100644 --- a/crates/cubecl-cpp/src/shared/binary.rs +++ b/crates/cubecl-cpp/src/shared/binary.rs @@ -224,7 +224,6 @@ impl Binary for Powf { pub struct Powi; - impl Binary for Powi { // Powi doesn't support half and no half equivalent exists fn format_scalar( @@ -271,7 +270,7 @@ impl Binary for Powi { f.write_str("};\n") } } - + pub struct ArcTan2; impl Binary for ArcTan2 { diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index 53ce32c15..2f8bebfa7 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -474,7 +474,7 @@ impl<'a> Visitor<'a> { self.create_float_constant_from_item(radians.input.ty, 0.017453292519943295); let result = self.append_operation_with_result(arith::mulf(value, f, self.location)); - self.insert_variable(out, result); + self.insert_variable(out, result); } Arithmetic::Recip(recip) => { let value = self.get_variable(recip.input); From 7977e7f34e65401099cdf40dc82d8cd450f5a8e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Tue, 9 Sep 2025 20:09:03 +0200 Subject: [PATCH 11/23] Rename degrees and radians to to_degrees and to_radians to reflect rusts f32 operations --- crates/cubecl-core/src/frontend/operation/unary.rs | 8 ++++---- crates/cubecl-core/src/runtime_tests/unary.rs | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index 10a08af4d..ccac524ac 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -285,8 +285,8 @@ impl_unary_func!( ); impl_unary_func!( Degrees, - degrees, - __expand_degrees, + to_degrees, + __expand_to_degrees, Arithmetic::Degrees, f16, bf16, @@ -297,8 +297,8 @@ impl_unary_func!( ); impl_unary_func!( Radians, - radians, - __expand_radians, + to_radians, + __expand_to_radians, Arithmetic::Radians, f16, bf16, diff --git a/crates/cubecl-core/src/runtime_tests/unary.rs b/crates/cubecl-core/src/runtime_tests/unary.rs index 0c0ef4eb0..c4253ec7a 100644 --- a/crates/cubecl-core/src/runtime_tests/unary.rs +++ b/crates/cubecl-core/src/runtime_tests/unary.rs @@ -378,7 +378,7 @@ test_unary_impl!(test_atanh, F, F::atanh, [ } ]); -test_unary_impl!(test_degrees, F, F::degrees, [ +test_unary_impl!(test_degrees, F, F::to_degrees, [ { input_vectorization: 1, out_vectorization: 1, @@ -399,7 +399,7 @@ test_unary_impl!(test_degrees, F, F::degrees, [ } ]); -test_unary_impl!(test_radians, F, F::radians, [ +test_unary_impl!(test_radians, F, F::to_radians, [ { input_vectorization: 1, out_vectorization: 1, From 5f06df57473d89acfcfcdede71aac657f95d6935 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Wed, 10 Sep 2025 17:36:48 +0200 Subject: [PATCH 12/23] Add tan operation --- .../cubecl-core/src/frontend/element/float.rs | 1 + .../src/frontend/element/float/typemap.rs | 1 + .../src/frontend/operation/unary.rs | 12 ++++++ crates/cubecl-core/src/runtime_tests/unary.rs | 38 +++++++++++++++---- crates/cubecl-cpp/src/shared/base.rs | 3 ++ crates/cubecl-cpp/src/shared/instruction.rs | 2 + crates/cubecl-cpp/src/shared/unary.rs | 1 + .../compiler/visitor/operation/arithmetic.rs | 15 ++++++++ crates/cubecl-ir/src/arithmetic.rs | 2 + crates/cubecl-ir/src/processing.rs | 3 ++ crates/cubecl-opt/src/instructions.rs | 1 + crates/cubecl-opt/src/passes/constant_prop.rs | 1 + crates/cubecl-spirv/src/arithmetic.rs | 8 ++++ crates/cubecl-spirv/src/extensions.rs | 5 +++ .../cubecl-wgpu/src/compiler/wgsl/compiler.rs | 4 ++ .../src/compiler/wgsl/instructions.rs | 8 ++++ 16 files changed, 97 insertions(+), 8 deletions(-) diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs index 3a7954bb8..434783fec 100644 --- a/crates/cubecl-core/src/frontend/element/float.rs +++ b/crates/cubecl-core/src/frontend/element/float.rs @@ -25,6 +25,7 @@ pub trait Float: + Log1p + Cos + Sin + + Tan + Tanh + Sinh + Cosh diff --git a/crates/cubecl-core/src/frontend/element/float/typemap.rs b/crates/cubecl-core/src/frontend/element/float/typemap.rs index 3a39737ce..188dc1f9a 100644 --- a/crates/cubecl-core/src/frontend/element/float/typemap.rs +++ b/crates/cubecl-core/src/frontend/element/float/typemap.rs @@ -243,6 +243,7 @@ impl Log for ElemExpand {} impl Log1p for ElemExpand {} impl Cos for ElemExpand {} impl Sin for ElemExpand {} +impl Tan for ElemExpand {} impl Tanh for ElemExpand {} impl Sinh for ElemExpand {} impl Cosh for ElemExpand {} diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index ccac524ac..0779f4d64 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -175,6 +175,18 @@ impl_unary_func!( f32, f64 ); +impl_unary_func!( + Tan, + tan, + __expand_tan, + Arithmetic::Tan, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); impl_unary_func!( Tanh, tanh, diff --git a/crates/cubecl-core/src/runtime_tests/unary.rs b/crates/cubecl-core/src/runtime_tests/unary.rs index c4253ec7a..da546f506 100644 --- a/crates/cubecl-core/src/runtime_tests/unary.rs +++ b/crates/cubecl-core/src/runtime_tests/unary.rs @@ -1,3 +1,4 @@ +use std::f32::consts::PI; use std::fmt::Display; use crate::{self as cubecl, as_type}; @@ -210,6 +211,27 @@ test_unary_impl!(test_cos, F, F::cos, [ } ]); +test_unary_impl!(test_tan, F, F::tan, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 0.78539816339, 1.04719755119, -0.78539816339], + expected: as_type![F: 0., 1., 1.73205080757, -1.] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 0.78539816339, 1.04719755119, -0.78539816339], + expected: as_type![F: 0., 1., 1.73205080757, -1.] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 0.78539816339, 1.04719755119, -0.78539816339], + expected: as_type![F: 0., 1., 1.73205080757, -1.] + } +]); + test_unary_impl!(test_asin, F, F::asin, [ { input_vectorization: 1, @@ -382,19 +404,19 @@ test_unary_impl!(test_degrees, F, F::to_degrees, [ { input_vectorization: 1, out_vectorization: 1, - input: as_type![F: 0., 1.5707963268, 3.1415926536, -1.5707963268, -3.1415926536], - expected: as_type![F: 0., 90., 180., -90., -180.] + input: as_type![F: 0., PI / 2., PI, PI * 2., -PI / 2., -PI, -PI * 2.], + expected: as_type![F: 0., 90., 180., 360., -90., -180., -360.] }, { input_vectorization: 2, out_vectorization: 2, - input: as_type![F: 0., 1.5707963268, 3.1415926536, -1.5707963268], + input: as_type![F: 0., PI / 2., PI, -PI / 2.], expected: as_type![F: 0., 90., 180., -90.] }, { input_vectorization: 4, out_vectorization: 4, - input: as_type![F: 0., 1.5707963268, 3.1415926536, -1.5707963268], + input: as_type![F: 0., PI / 2., PI, -PI / 2.], expected: as_type![F: 0., 90., 180., -90.] } ]); @@ -403,20 +425,20 @@ test_unary_impl!(test_radians, F, F::to_radians, [ { input_vectorization: 1, out_vectorization: 1, - input: as_type![F: 0., 90., 180., -90., -180.], - expected: as_type![F: 0., 1.5707963268, 3.1415926536, -1.5707963268, -3.1415926536] + input: as_type![F: 0., 90., 180., 360., -90., -180., -360.], + expected: as_type![F: 0., PI / 2., PI, PI * 2., -PI / 2., -PI, -PI * 2.] }, { input_vectorization: 2, out_vectorization: 2, input: as_type![F: 0., 90., 180., -90.], - expected: as_type![F: 0., 1.5707963268, 3.1415926536, -1.5707963268] + expected: as_type![F: 0., PI / 2., PI, -PI / 2.] }, { input_vectorization: 4, out_vectorization: 4, input: as_type![F: 0., 90., 180., -90.], - expected: as_type![F: 0., 1.5707963268, 3.1415926536, -1.5707963268] + expected: as_type![F: 0., PI / 2., PI, -PI / 2.] } ]); diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index 1c904768d..af8e7558b 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -927,6 +927,9 @@ impl CppCompiler { gpu::Arithmetic::Sin(op) => { instructions.push(Instruction::Sin(self.compile_unary(op, out))) } + gpu::Arithmetic::Tan(op) => { + instructions.push(Instruction::Tan(self.compile_unary(op, out))) + } gpu::Arithmetic::Tanh(op) => { let instruction = Instruction::Tanh(self.compile_unary(op, out)); D::register_instruction_extension(&mut self.extensions, &instruction); diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index 863a69eb3..80071858e 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -162,6 +162,7 @@ pub enum Instruction { Log1p(UnaryInstruction), Cos(UnaryInstruction), Sin(UnaryInstruction), + Tan(UnaryInstruction), Tanh(UnaryInstruction), Sinh(UnaryInstruction), Cosh(UnaryInstruction), @@ -514,6 +515,7 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out), Instruction::Cos(it) => Cos::format(f, &it.input, &it.out), Instruction::Sin(it) => Sin::format(f, &it.input, &it.out), + Instruction::Tan(it) => Tan::format(f, &it.input, &it.out), Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out), Instruction::Sinh(it) => Sinh::format(f, &it.input, &it.out), Instruction::Cosh(it) => Cosh::format(f, &it.input, &it.out), diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index a7e97ccb5..a79fbc327 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -151,6 +151,7 @@ macro_rules! function { function!(Log, "log"); function!(Cos, "cos"); function!(Sin, "sin"); +function!(Tan, "tan"); function!(Sinh, "sinh", false); function!(Cosh, "cosh", false); function!(ArcCos, "acos", false); diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index 2f8bebfa7..24bbf40ac 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -555,6 +555,21 @@ impl<'a> Visitor<'a> { let result = self.append_operation_with_result(operation); self.insert_variable(out, result); } + Arithmetic::Tan(tan) => { + // Tan operations are only available through the ods::math module, + // which can not be properly loaded at the moment. + // Using dummy for now to satisfy compilation of other tests + let value = self.get_variable(tan.input); + let abs = self.get_absolute_val(tan.input.ty, value); + self.insert_variable(out, abs); + /*let value = self.get_variable(tan.input); + let result = self.append_operation_with_result(math_ods::tan( + self.context, + value, + self.location, + )); + self.insert_variable(out, result);*/ + } Arithmetic::Tanh(tanh) => { let input = self.get_variable(tanh.input); let output = self.append_operation_with_result(llvm_ods::intr_tanh( diff --git a/crates/cubecl-ir/src/arithmetic.rs b/crates/cubecl-ir/src/arithmetic.rs index b19bb661f..b27a946b1 100644 --- a/crates/cubecl-ir/src/arithmetic.rs +++ b/crates/cubecl-ir/src/arithmetic.rs @@ -22,6 +22,7 @@ pub enum Arithmetic { Log1p(UnaryOperator), Cos(UnaryOperator), Sin(UnaryOperator), + Tan(UnaryOperator), Tanh(UnaryOperator), Sinh(UnaryOperator), Cosh(UnaryOperator), @@ -72,6 +73,7 @@ impl Display for Arithmetic { Arithmetic::Log1p(op) => write!(f, "{}.log_1p()", op.input), Arithmetic::Cos(op) => write!(f, "{}.cos()", op.input), Arithmetic::Sin(op) => write!(f, "{}.sin()", op.input), + Arithmetic::Tan(op) => write!(f, "{}.tan()", op.input), Arithmetic::Tanh(op) => write!(f, "{}.tanh()", op.input), Arithmetic::Sinh(op) => write!(f, "{}.sinh()", op.input), Arithmetic::Cosh(op) => write!(f, "{}.cosh()", op.input), diff --git a/crates/cubecl-ir/src/processing.rs b/crates/cubecl-ir/src/processing.rs index 971483af6..be81a89da 100644 --- a/crates/cubecl-ir/src/processing.rs +++ b/crates/cubecl-ir/src/processing.rs @@ -101,6 +101,9 @@ impl ScopeProcessing { Arithmetic::Sin(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } + Arithmetic::Tan(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } Arithmetic::Tanh(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index 21ae02e7e..7614baac8 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -90,6 +90,7 @@ impl Optimizer { | Arithmetic::Log1p(unary_operator) | Arithmetic::Cos(unary_operator) | Arithmetic::Sin(unary_operator) + | Arithmetic::Tan(unary_operator) | Arithmetic::Tanh(unary_operator) | Arithmetic::Sinh(unary_operator) | Arithmetic::Cosh(unary_operator) diff --git a/crates/cubecl-opt/src/passes/constant_prop.rs b/crates/cubecl-opt/src/passes/constant_prop.rs index f4f6eaf88..e58a401fc 100644 --- a/crates/cubecl-opt/src/passes/constant_prop.rs +++ b/crates/cubecl-opt/src/passes/constant_prop.rs @@ -372,6 +372,7 @@ fn try_const_eval_arithmetic(op: &mut Arithmetic) -> Option Arithmetic::Log1p(op) => const_eval_float!(op.input; num::Float::ln_1p), Arithmetic::Cos(op) => const_eval_float!(op.input; num::Float::cos), Arithmetic::Sin(op) => const_eval_float!(op.input; num::Float::sin), + Arithmetic::Tan(op) => const_eval_float!(op.input; num::Float::tan), Arithmetic::Tanh(op) => const_eval_float!(op.input; num::Float::tanh), Arithmetic::Sinh(op) => const_eval_float!(op.input; num::Float::sinh), Arithmetic::Cosh(op) => const_eval_float!(op.input; num::Float::cosh), diff --git a/crates/cubecl-spirv/src/arithmetic.rs b/crates/cubecl-spirv/src/arithmetic.rs index 20e29100d..d1cefa402 100644 --- a/crates/cubecl-spirv/src/arithmetic.rs +++ b/crates/cubecl-spirv/src/arithmetic.rs @@ -293,6 +293,14 @@ impl SpirvCompiler { } }) } + Arithmetic::Tan(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::tan(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } Arithmetic::Tanh(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { T::tanh(b, ty, input, out); diff --git a/crates/cubecl-spirv/src/extensions.rs b/crates/cubecl-spirv/src/extensions.rs index d696d42a7..cb2f1fada 100644 --- a/crates/cubecl-spirv/src/extensions.rs +++ b/crates/cubecl-spirv/src/extensions.rs @@ -12,6 +12,7 @@ pub trait TargetExtensions { fn ceil(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn sin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn cos(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn tan(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn tanh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn sinh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn cosh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); @@ -78,6 +79,10 @@ pub mod glcompute { b.cos_id(ty, Some(out), input).unwrap(); } + fn tan(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.tan_id(ty, Some(out), input).unwrap(); + } + fn tanh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { b.tanh_id(ty, Some(out), input).unwrap(); } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 2268d27f9..94b4a3bfc 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -748,6 +748,10 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(out), }), + cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh { input: self.compile_variable(op.input), out: self.compile_variable(out), diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index 84c6572c8..b2b867087 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -117,6 +117,10 @@ pub enum Instruction { input: Variable, out: Variable, }, + Tan { + input: Variable, + out: Variable, + }, Tanh { input: Variable, out: Variable, @@ -634,6 +638,10 @@ impl Display for Instruction { let out = out.fmt_left(); writeln!(f, "{out} = sin({input});") } + Instruction::Tan { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = tan({input});") + } Instruction::Tanh { input, out } => { #[cfg(target_os = "macos")] let result = super::call_safe_tanh(f, input, out); From 5e1f84d9cafc575fd655f001b702cd7b17e0210b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sat, 13 Sep 2025 01:11:27 +0200 Subject: [PATCH 13/23] Make runtime tests for unary epsilon dependent Change the epsilon for to_degree() to 0.3, which checks out with the f16 maximum error for our valid tests. --- crates/cubecl-core/src/runtime_tests/unary.rs | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/crates/cubecl-core/src/runtime_tests/unary.rs b/crates/cubecl-core/src/runtime_tests/unary.rs index da546f506..63ec94e67 100644 --- a/crates/cubecl-core/src/runtime_tests/unary.rs +++ b/crates/cubecl-core/src/runtime_tests/unary.rs @@ -47,6 +47,24 @@ macro_rules! test_unary_impl { input: $input:expr, expected: $expected:expr }),*]) => { + test_unary_impl!($test_name, $float_type, $unary_func, [$({ + input_vectorization: $input_vectorization, + out_vectorization: $out_vectorization, + input: $input, + expected: $expected + }),*], 0.02); + }; + ( + $test_name:ident, + $float_type:ident, + $unary_func:expr, + [$({ + input_vectorization: $input_vectorization:expr, + out_vectorization: $out_vectorization:expr, + input: $input:expr, + expected: $expected:expr + }),*], + $epsilon:expr) => { pub fn $test_name(client: ComputeClient) { #[cube(launch_unchecked)] fn test_function<$float_type: Float>(input: &Array<$float_type>, output: &mut Array<$float_type>) { @@ -71,7 +89,7 @@ macro_rules! test_unary_impl { ) }; - assert_equals_approx::(&client, output_handle, $expected, $float_type::new(0.02)); + assert_equals_approx::(&client, output_handle, $expected, $float_type::new($epsilon)); } )* } @@ -419,7 +437,7 @@ test_unary_impl!(test_degrees, F, F::to_degrees, [ input: as_type![F: 0., PI / 2., PI, -PI / 2.], expected: as_type![F: 0., 90., 180., -90.] } -]); +], 0.3); test_unary_impl!(test_radians, F, F::to_radians, [ { From e55eb53ef6b571221de1a261695d7762197bddfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sat, 13 Sep 2025 01:14:27 +0200 Subject: [PATCH 14/23] Add trigonometry module Move to_degrees and to_radians there --- crates/cubecl-std/src/lib.rs | 15 +- crates/cubecl-std/src/tests/mod.rs | 2 + crates/cubecl-std/src/tests/trigonometry.rs | 632 ++++++++++++++++++++ crates/cubecl-std/src/trigonometry.rs | 322 ++++++++++ 4 files changed, 959 insertions(+), 12 deletions(-) create mode 100644 crates/cubecl-std/src/tests/trigonometry.rs create mode 100644 crates/cubecl-std/src/trigonometry.rs diff --git a/crates/cubecl-std/src/lib.rs b/crates/cubecl-std/src/lib.rs index 647dbcfbb..3abb487f3 100644 --- a/crates/cubecl-std/src/lib.rs +++ b/crates/cubecl-std/src/lib.rs @@ -1,6 +1,4 @@ //! Cubecl standard library. -use core::f32; - extern crate alloc; mod reinterpret_slice; @@ -8,6 +6,9 @@ pub use reinterpret_slice::*; mod fast_math; pub use fast_math::*; +mod trigonometry; +pub use trigonometry::*; + mod option; pub use option::*; @@ -24,13 +25,3 @@ pub mod tests; pub fn div_ceil(a: u32, b: u32) -> u32 { (a + b - 1) / b } - -#[cube] -pub fn to_degrees(val: F) -> F { - val * F::new(180.0 / f32::consts::PI) -} - -#[cube] -pub fn to_radians(val: F) -> F { - val * F::new(f32::consts::PI / 180.0) -} diff --git a/crates/cubecl-std/src/tests/mod.rs b/crates/cubecl-std/src/tests/mod.rs index cd9c92a32..bcd799585 100644 --- a/crates/cubecl-std/src/tests/mod.rs +++ b/crates/cubecl-std/src/tests/mod.rs @@ -1,5 +1,6 @@ pub mod reinterpret_slice; pub mod tensor; +pub mod trigonometry; #[macro_export] macro_rules! testgen { @@ -9,6 +10,7 @@ macro_rules! testgen { use half::{bf16, f16}; cubecl_std::testgen_reinterpret_slice!(); + cubecl_std::testgen_trigonometry!(); } }; } diff --git a/crates/cubecl-std/src/tests/trigonometry.rs b/crates/cubecl-std/src/tests/trigonometry.rs new file mode 100644 index 000000000..f5692202f --- /dev/null +++ b/crates/cubecl-std/src/tests/trigonometry.rs @@ -0,0 +1,632 @@ +use cubecl::prelude::*; +use cubecl_core as cubecl; +use std::f32::consts::{PI, TAU}; + +use crate::trigonometry::*; + +#[cube(launch_unchecked)] +fn kernel_to_degrees(input: &Array, output: &mut Array) { + if UNIT_POS < input.len() { + output[UNIT_POS] = to_degrees::(input[UNIT_POS]); + } +} + +pub fn test_to_degrees(client: ComputeClient) { + let input_data = vec![0.0, PI / 6.0, PI / 4.0, PI / 2.0, PI, TAU]; + let expected = vec![0.0, 30.0, 45.0, 90.0, 180.0, 360.0]; + + let input = client.create(f32::as_bytes(&input_data)); + let output = client.empty(input_data.len() * core::mem::size_of::()); + + unsafe { + kernel_to_degrees::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(input_data.len() as u32, 1, 1), + ArrayArg::from_raw_parts::(&input, input_data.len(), 1), + ArrayArg::from_raw_parts::(&output, input_data.len(), 1), + ); + } + + let actual = client.read_one(output); + let actual = f32::from_bytes(&actual); + + for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { + assert!( + (expected_val - actual_val).abs() < 1e-5, + "Test {} failed: expected {}, got {}", + i, + expected_val, + actual_val + ); + } +} + +#[cube(launch_unchecked)] +fn kernel_to_radians(input: &Array, output: &mut Array) { + if UNIT_POS < input.len() { + output[UNIT_POS] = to_radians::(input[UNIT_POS]); + } +} + +pub fn test_to_radians(client: ComputeClient) { + let input_data = vec![0.0, 30.0, 45.0, 90.0, 180.0, 360.0]; + let expected = vec![0.0, PI / 6.0, PI / 4.0, PI / 2.0, PI, TAU]; + + let input = client.create(f32::as_bytes(&input_data)); + let output = client.empty(input_data.len() * core::mem::size_of::()); + + unsafe { + kernel_to_radians::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(input_data.len() as u32, 1, 1), + ArrayArg::from_raw_parts::(&input, input_data.len(), 1), + ArrayArg::from_raw_parts::(&output, input_data.len(), 1), + ); + } + + let actual = client.read_one(output); + let actual = f32::from_bytes(&actual); + + for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { + assert!( + (expected_val - actual_val).abs() < 1e-5, + "Test {} failed: expected {}, got {}", + i, + expected_val, + actual_val + ); + } +} + +#[cube(launch_unchecked)] +fn kernel_sincos(input: &Array, sin_output: &mut Array, cos_output: &mut Array) { + if UNIT_POS < input.len() { + let (sin_val, cos_val) = sincos::(input[UNIT_POS]); + sin_output[UNIT_POS] = sin_val; + cos_output[UNIT_POS] = cos_val; + } +} + +pub fn test_sincos(client: ComputeClient) { + let input_data = vec![0.0, PI / 6.0, PI / 4.0, PI / 3.0, PI / 2.0, PI]; + + let input = client.create(f32::as_bytes(&input_data)); + let sin_output = client.empty(input_data.len() * core::mem::size_of::()); + let cos_output = client.empty(input_data.len() * core::mem::size_of::()); + + unsafe { + kernel_sincos::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(input_data.len() as u32, 1, 1), + ArrayArg::from_raw_parts::(&input, input_data.len(), 1), + ArrayArg::from_raw_parts::(&sin_output, input_data.len(), 1), + ArrayArg::from_raw_parts::(&cos_output, input_data.len(), 1), + ); + } + + let actual_sin = client.read_one(sin_output); + let actual_sin = f32::from_bytes(&actual_sin); + let actual_cos = client.read_one(cos_output); + let actual_cos = f32::from_bytes(&actual_cos); + + for (i, &angle) in input_data.iter().enumerate() { + let expected_sin = angle.sin(); + let expected_cos = angle.cos(); + + assert!( + (expected_sin - actual_sin[i]).abs() < 1e-6, + "Sin test {} failed: expected {}, got {}", + i, + expected_sin, + actual_sin[i] + ); + + assert!( + (expected_cos - actual_cos[i]).abs() < 1e-6, + "Cos test {} failed: expected {}, got {}", + i, + expected_cos, + actual_cos[i] + ); + } +} + +#[cube(launch_unchecked)] +fn kernel_normalize_angle(input: &Array, output: &mut Array) { + if UNIT_POS < input.len() { + output[UNIT_POS] = normalize_angle::(input[UNIT_POS]); + } +} + +pub fn test_normalize_angle(client: ComputeClient) { + let input_data = vec![ + 0.0, + PI, + TAU, + 3.0 * PI, + 4.0 * PI, + -PI, + -TAU, + -3.0 * PI, + PI + 0.5, + -PI + 0.5, + ]; + + let expected = vec![0.0, PI, 0.0, PI, 0.0, PI, 0.0, PI, PI + 0.5, PI + 0.5]; + + let input = client.create(f32::as_bytes(&input_data)); + let output = client.empty(input_data.len() * core::mem::size_of::()); + + unsafe { + kernel_normalize_angle::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(input_data.len() as u32, 1, 1), + ArrayArg::from_raw_parts::(&input, input_data.len(), 1), + ArrayArg::from_raw_parts::(&output, input_data.len(), 1), + ); + } + + let actual = client.read_one(output); + let actual = f32::from_bytes(&actual); + + for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { + assert!( + (expected_val - actual_val).abs() < 1e-5, + "Test {} failed: expected {}, got {}", + i, + expected_val, + actual_val + ); + } +} + +#[cube(launch_unchecked)] +fn kernel_normalize_angle_signed(input: &Array, output: &mut Array) { + if UNIT_POS < input.len() { + output[UNIT_POS] = normalize_angle_signed::(input[UNIT_POS]); + } +} + +pub fn test_normalize_angle_signed(client: ComputeClient) { + let input_data = vec![ + 0.0, + PI, + TAU, + // 3*PI can result in float errors -> add a small offset to the test + 3.0 * PI + 1e-5, + 4.0 * PI + 1e-5, + -PI, + -TAU, + -3.0 * PI + 1e-5, + PI + 0.5, + -PI + 0.5, + ]; + + let expected = vec![ + 0.0, + -PI, + 0.0, + -PI + 1e-5, + 0.0 + 1e-5, + -PI, + 0.0, + -PI + 1e-5, + -PI + 0.5, + -PI + 0.5, + ]; + + let input = client.create(f32::as_bytes(&input_data)); + let output = client.empty(input_data.len() * core::mem::size_of::()); + + unsafe { + kernel_normalize_angle_signed::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(input_data.len() as u32, 1, 1), + ArrayArg::from_raw_parts::(&input, input_data.len(), 1), + ArrayArg::from_raw_parts::(&output, input_data.len(), 1), + ); + } + + let actual = client.read_one(output); + let actual = f32::from_bytes(&actual); + + for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { + assert!( + (expected_val - actual_val).abs() < 1e-5, + "Test {} failed: expected {}, got {}", + i, + expected_val, + actual_val + ); + } +} + +#[cube(launch_unchecked)] +fn kernel_lerp_angle(from: &Array, to: &Array, t: &Array, output: &mut Array) { + if UNIT_POS < from.len() { + output[UNIT_POS] = lerp_angle::(from[UNIT_POS], to[UNIT_POS], t[UNIT_POS]); + } +} + +pub fn test_lerp_angle(client: ComputeClient) { + let from_data = vec![0.0, 0.1, PI - 0.1, 0.0]; + let to_data = vec![PI, TAU - 0.1, PI + 0.1, PI]; + let t_data = vec![0.5, 0.5, 0.5, 0.5]; + + let from = client.create(f32::as_bytes(&from_data)); + let to = client.create(f32::as_bytes(&to_data)); + let t = client.create(f32::as_bytes(&t_data)); + let output = client.empty(from_data.len() * core::mem::size_of::()); + + unsafe { + kernel_lerp_angle::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(from_data.len() as u32, 1, 1), + ArrayArg::from_raw_parts::(&from, from_data.len(), 1), + ArrayArg::from_raw_parts::(&to, to_data.len(), 1), + ArrayArg::from_raw_parts::(&t, t_data.len(), 1), + ArrayArg::from_raw_parts::(&output, from_data.len(), 1), + ); + } + + let actual = client.read_one(output); + let actual = f32::from_bytes(&actual); + + // Test case 0: 0 to π should give π/2 + assert!( + (actual[0] - PI / 2.0).abs() < 1e-5, + "Lerp angle test 0 failed" + ); + + // Test case 1: wraparound case - should take shortest path + assert!( + actual[1].abs() < 1e-5 || (actual[1] - TAU).abs() < 1e-5, + "Lerp angle test 1 failed: {}", + actual[1] + ); + + // Test case 2: small difference around π + assert!((actual[2] - PI).abs() < 1e-5, "Lerp angle test 2 failed"); + + // Test case 3: 0 to π should give π/2 + assert!( + (actual[3] - PI / 2.0).abs() < 1e-5, + "Lerp angle test 3 failed" + ); +} + +#[cube(launch_unchecked)] +fn kernel_angle_distance(from: &Array, to: &Array, output: &mut Array) { + if UNIT_POS < from.len() { + output[UNIT_POS] = angle_distance::(from[UNIT_POS], to[UNIT_POS]); + } +} + +pub fn test_angle_distance(client: ComputeClient) { + let from_data = vec![0.0, 0.1, PI, 0.0]; + let to_data = vec![PI, TAU - 0.1, 0.0, TAU - 0.1]; + let expected = vec![PI, -0.2, -PI, -0.1]; + + let from = client.create(f32::as_bytes(&from_data)); + let to = client.create(f32::as_bytes(&to_data)); + let output = client.empty(from_data.len() * core::mem::size_of::()); + + unsafe { + kernel_angle_distance::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(from_data.len() as u32, 1, 1), + ArrayArg::from_raw_parts::(&from, from_data.len(), 1), + ArrayArg::from_raw_parts::(&to, to_data.len(), 1), + ArrayArg::from_raw_parts::(&output, from_data.len(), 1), + ); + } + + let actual = client.read_one(output); + let actual = f32::from_bytes(&actual); + + for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { + assert!( + (expected_val - actual_val).abs() < 1e-5, + "Angle distance test {} failed: expected {}, got {}", + i, + expected_val, + actual_val + ); + } +} + +#[cube(launch_unchecked)] +fn kernel_vector_angle_2d( + x1: &Array, + y1: &Array, + x2: &Array, + y2: &Array, + output: &mut Array, +) { + if UNIT_POS < x1.len() { + output[UNIT_POS] = + vector_angle_2d::(x1[UNIT_POS], y1[UNIT_POS], x2[UNIT_POS], y2[UNIT_POS]); + } +} + +pub fn test_vector_angle_2d(client: ComputeClient) { + // Simplified test case + let x1_data = vec![1.0]; + let y1_data = vec![0.0]; + let x2_data = vec![0.0]; + let y2_data = vec![1.0]; + let expected = vec![PI / 2.0]; + + let x1 = client.create(f32::as_bytes(&x1_data)); + let y1 = client.create(f32::as_bytes(&y1_data)); + let x2 = client.create(f32::as_bytes(&x2_data)); + let y2 = client.create(f32::as_bytes(&y2_data)); + let output = client.empty(x1_data.len() * core::mem::size_of::()); + + unsafe { + kernel_vector_angle_2d::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(x1_data.len() as u32, 1, 1), + ArrayArg::from_raw_parts::(&x1, x1_data.len(), 1), + ArrayArg::from_raw_parts::(&y1, y1_data.len(), 1), + ArrayArg::from_raw_parts::(&x2, x2_data.len(), 1), + ArrayArg::from_raw_parts::(&y2, y2_data.len(), 1), + ArrayArg::from_raw_parts::(&output, x1_data.len(), 1), + ); + } + + let actual = client.read_one(output); + let actual = f32::from_bytes(&actual); + + for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { + assert!( + (expected_val - actual_val).abs() < 1e-5, + "Vector angle 2D test {} failed: expected {}, got {}", + i, + expected_val, + actual_val + ); + } +} + +#[cube(launch_unchecked)] +fn kernel_rotate_2d( + x: &Array, + y: &Array, + angle: &Array, + x_out: &mut Array, + y_out: &mut Array, +) { + if UNIT_POS < x.len() { + let (new_x, new_y) = rotate_2d::(x[UNIT_POS], y[UNIT_POS], angle[UNIT_POS]); + x_out[UNIT_POS] = new_x; + y_out[UNIT_POS] = new_y; + } +} + +pub fn test_rotate_2d(client: ComputeClient) { + let x_data = vec![1.0, 0.0, 1.0, 1.0]; + let y_data = vec![0.0, 1.0, 1.0, 0.0]; + let angle_data = vec![PI / 2.0, PI / 2.0, PI / 4.0, PI]; + + let expected_x = vec![0.0, -1.0, 0.0, -1.0]; + let expected_y = vec![1.0, 0.0, 1.414213562373095, 0.0]; + + let x = client.create(f32::as_bytes(&x_data)); + let y = client.create(f32::as_bytes(&y_data)); + let angle = client.create(f32::as_bytes(&angle_data)); + let x_out = client.empty(x_data.len() * core::mem::size_of::()); + let y_out = client.empty(y_data.len() * core::mem::size_of::()); + + unsafe { + kernel_rotate_2d::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(x_data.len() as u32, 1, 1), + ArrayArg::from_raw_parts::(&x, x_data.len(), 1), + ArrayArg::from_raw_parts::(&y, y_data.len(), 1), + ArrayArg::from_raw_parts::(&angle, angle_data.len(), 1), + ArrayArg::from_raw_parts::(&x_out, x_data.len(), 1), + ArrayArg::from_raw_parts::(&y_out, y_data.len(), 1), + ); + } + + let actual_x = client.read_one(x_out); + let actual_x = f32::from_bytes(&actual_x); + let actual_y = client.read_one(y_out); + let actual_y = f32::from_bytes(&actual_y); + + for i in 0..x_data.len() { + assert!( + (expected_x[i] - actual_x[i]).abs() < 1e-5, + "Rotate 2D X test {} failed: expected {}, got {}", + i, + expected_x[i], + actual_x[i] + ); + + assert!( + (expected_y[i] - actual_y[i]).abs() < 1e-5, + "Rotate 2D Y test {} failed: expected {}, got {}", + i, + expected_y[i], + actual_y[i] + ); + } +} + +#[cube(launch_unchecked)] +fn kernel_hypot(x: &Array, y: &Array, output: &mut Array) { + if UNIT_POS < x.len() { + output[UNIT_POS] = hypot::(x[UNIT_POS], y[UNIT_POS]); + } +} + +pub fn test_hypot(client: ComputeClient) { + let x_data = vec![3.0, 0.0, 1.0, 5.0, 0.0]; + let y_data = vec![4.0, 1.0, 1.0, 12.0, 0.0]; + let expected = vec![5.0, 1.0, 1.4142135623730951, 13.0, 0.0]; + + let x = client.create(f32::as_bytes(&x_data)); + let y = client.create(f32::as_bytes(&y_data)); + let output = client.empty(x_data.len() * core::mem::size_of::()); + + unsafe { + kernel_hypot::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(x_data.len() as u32, 1, 1), + ArrayArg::from_raw_parts::(&x, x_data.len(), 1), + ArrayArg::from_raw_parts::(&y, y_data.len(), 1), + ArrayArg::from_raw_parts::(&output, x_data.len(), 1), + ); + } + + let actual = client.read_one(output); + let actual = f32::from_bytes(&actual); + + for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { + assert!( + (expected_val - actual_val).abs() < 1e-5, + "Hypot test {} failed: expected {}, got {}", + i, + expected_val, + actual_val + ); + } +} + +#[cube(launch_unchecked)] +fn kernel_sinc(input: &Array, output: &mut Array) { + if UNIT_POS < input.len() { + output[UNIT_POS] = sinc::(input[UNIT_POS]); + } +} + +pub fn test_sinc(client: ComputeClient) { + let input_data = vec![0.0, 1.0, -1.0, 0.5, -0.5, 2.0]; + // Expected values for normalized sinc function: sin(πx)/(πx) + let expected = vec![ + 1.0, // sinc(0) = 1 + 0.0, // sinc(1) ≈ 0 (actually 3.8986e-17, but effectively 0) + 0.0, // sinc(-1) ≈ 0 + 0.6366197723675814, // sinc(0.5) = sin(π/2)/(π/2) = 1/(π/2) ≈ 0.6366 + 0.6366197723675814, // sinc(-0.5) = sinc(0.5) + 0.0, // sinc(2) ≈ 0 + ]; + + let input = client.create(f32::as_bytes(&input_data)); + let output = client.empty(input_data.len() * core::mem::size_of::()); + + unsafe { + kernel_sinc::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(input_data.len() as u32, 1, 1), + ArrayArg::from_raw_parts::(&input, input_data.len(), 1), + ArrayArg::from_raw_parts::(&output, input_data.len(), 1), + ); + } + + let actual = client.read_one(output); + let actual = f32::from_bytes(&actual); + + for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { + let tolerance = if i == 1 || i == 2 || i == 5 { + 1e-3 + } else { + 1e-5 + }; // More tolerance for near-zero values + assert!( + (expected_val - actual_val).abs() < tolerance, + "Sinc test {} failed: expected {}, got {}", + i, + expected_val, + actual_val + ); + } +} + +#[macro_export] +macro_rules! testgen_trigonometry { + () => { + mod trigonometry { + use super::*; + use $crate::tests::trigonometry::*; + + #[test] + fn test_to_degrees_conversion() { + let client = TestRuntime::client(&Default::default()); + test_to_degrees::(client); + } + + #[test] + fn test_to_radians_conversion() { + let client = TestRuntime::client(&Default::default()); + test_to_radians::(client); + } + + #[test] + fn test_sincos_computation() { + let client = TestRuntime::client(&Default::default()); + test_sincos::(client); + } + + #[test] + fn test_normalize_angle_positive() { + let client = TestRuntime::client(&Default::default()); + test_normalize_angle::(client); + } + + #[test] + fn test_normalize_angle_signed_range() { + let client = TestRuntime::client(&Default::default()); + test_normalize_angle_signed::(client); + } + + #[test] + fn test_lerp_angle_interpolation() { + let client = TestRuntime::client(&Default::default()); + test_lerp_angle::(client); + } + + #[test] + fn test_angle_distance_calculation() { + let client = TestRuntime::client(&Default::default()); + test_angle_distance::(client); + } + + #[test] + fn test_vector_angle_2d_computation() { + let client = TestRuntime::client(&Default::default()); + test_vector_angle_2d::(client); + } + + #[test] + fn test_rotate_2d_transformation() { + let client = TestRuntime::client(&Default::default()); + test_rotate_2d::(client); + } + + #[test] + fn test_hypot_computation() { + let client = TestRuntime::client(&Default::default()); + test_hypot::(client); + } + + #[test] + fn test_sinc_function() { + let client = TestRuntime::client(&Default::default()); + test_sinc::(client); + } + } + }; +} diff --git a/crates/cubecl-std/src/trigonometry.rs b/crates/cubecl-std/src/trigonometry.rs new file mode 100644 index 000000000..7e370dd19 --- /dev/null +++ b/crates/cubecl-std/src/trigonometry.rs @@ -0,0 +1,322 @@ +//! Trigonometric functions and utilities for CubeCL. +//! +//! This module provides basic trigonometric operations and angle conversion utilities +//! that can be used in all GPU kernels. + +use core::f32; +use cubecl::prelude::*; +use cubecl_core as cubecl; + +/// Converts an angle from radians to degrees. +/// +/// # Example +/// +/// ```rust,ignore +/// let radians = F::new(std::f32::consts::PI); +/// let degrees = to_degrees(radians); +/// assert!((degrees - F::new(180.0)).abs() < F::new(1e-6)); +/// ``` +#[cube] +pub fn to_degrees(val: F) -> F { + val * F::new(180.0 / f32::consts::PI) +} + +/// Converts an angle from degrees to radians. +/// +/// # Example +/// +/// ```rust,ignore +/// let degrees = F::new(180.0); +/// let radians = to_radians(degrees); +/// assert!((radians - F::new(std::f32::consts::PI)).abs() < F::new(1e-6)); +/// ``` +#[cube] +pub fn to_radians(val: F) -> F { + val * F::new(f32::consts::PI / 180.0) +} + +/// Computes both sine and cosine of an angle simultaneously. +/// +/// This can be more efficient than computing sin and cos separately +/// on some GPU architectures. +/// +/// # Arguments +/// +/// * `val` - The angle in radians +/// +/// # Returns +/// +/// A tuple containing (sine, cosine) of the input angle +/// +/// # Example +/// +/// ```rust,ignore +/// let angle = F::new(std::f32::consts::PI / 4.0); +/// let (sin_val, cos_val) = sincos(angle); +/// ``` +#[cube] +pub fn sincos(val: F) -> (F, F) { + (F::sin(val), F::cos(val)) +} + +/// Normalizes an angle to the range [0, 2π). +/// +/// # Arguments +/// +/// * `angle` - The angle in radians to normalize +/// +/// # Returns +/// +/// The angle normalized to the range [0, 2π) +/// +/// # Example +/// +/// ```rust,ignore +/// let angle = F::new(3.0 * std::f32::consts::PI); +/// let normalized = normalize_angle(angle); +/// assert!((normalized - F::new(std::f32::consts::PI)).abs() < F::new(1e-6)); +/// ``` +#[cube] +pub fn normalize_angle(angle: F) -> F { + let tau = F::new(f32::consts::TAU); + angle - F::floor(angle / tau) * tau +} + +/// Normalizes an angle to the range [-π, π). +/// +/// # Arguments +/// +/// * `angle` - The angle in radians to normalize +/// +/// # Returns +/// +/// The angle normalized to the range [-π, π) +/// +/// # Example +/// +/// ```rust,ignore +/// let angle = F::new(3.0 * std::f32::consts::PI); +/// let normalized = normalize_angle_signed(angle); +/// assert!((normalized - F::new(std::f32::consts::PI)).abs() < F::new(1e-6)); +/// ``` +#[cube] +pub fn normalize_angle_signed(angle: F) -> F { + let pi = F::new(f32::consts::PI); + let tau = F::new(f32::consts::TAU); + let normalized = angle - F::floor(angle / tau) * tau; + if normalized >= pi { + normalized - tau + } else { + normalized + } +} + +/// Linear interpolation between two angles, taking the shortest path. +/// +/// This function correctly handles the wraparound at 2π to ensure +/// the interpolation follows the shortest circular arc. +/// +/// # Arguments +/// +/// * `from` - The starting angle in radians +/// * `to` - The ending angle in radians +/// * `t` - The interpolation factor (0.0 = from, 1.0 = to) +/// +/// # Returns +/// +/// The interpolated angle +/// +/// # Example +/// +/// ```rust,ignore +/// let from = F::new(0.1); +/// let to = F::new(std::f32::consts::TAU - 0.1); +/// let mid = lerp_angle(from, to, F::new(0.5)); +/// assert!(mid.abs() < F::new(1e-6) || (mid - F::new(std::f32::consts::TAU)).abs() < F::new(1e-6)); +/// ``` +#[cube] +pub fn lerp_angle(from: F, to: F, t: F) -> F { + let pi = F::new(f32::consts::PI); + let tau = F::new(f32::consts::TAU); + + let diff = to - from; + let normalized_diff = if diff > pi { + diff - tau + } else if diff < -pi { + diff + tau + } else { + diff + }; + + normalize_angle::(from + normalized_diff * t) +} + +/// Computes the shortest angular distance between two angles. +/// +/// # Arguments +/// +/// * `from` - The first angle in radians +/// * `to` - The second angle in radians +/// +/// # Returns +/// +/// The shortest angular distance, positive if `to` is clockwise from `from` +/// +/// # Example +/// +/// ```rust,ignore +/// let angle1 = F::new(0.1); +/// let angle2 = F::new(std::f32::consts::TAU - 0.1); +/// let distance = angle_distance(angle1, angle2); +/// assert!((distance - F::new(-0.2)).abs() < F::new(1e-6)); +/// ``` +#[cube] +pub fn angle_distance(from: F, to: F) -> F { + let pi = F::new(f32::consts::PI); + let tau = F::new(f32::consts::TAU); + + let diff = to - from; + if diff > pi { + diff - tau + } else if diff < -pi { + diff + tau + } else { + diff + } +} + +/// Smoothstep interpolation for angles. +/// +/// Applies smoothstep interpolation (3t² - 2t³) between two angles, +/// taking the shortest circular path. +/// +/// # Arguments +/// +/// * `from` - The starting angle in radians +/// * `to` - The ending angle in radians +/// * `t` - The interpolation factor (0.0 = from, 1.0 = to) +/// +/// # Returns +/// +/// The smoothly interpolated angle +/// +/// # Example +/// +/// ```rust,ignore +/// let from = F::new(0.0); +/// let to = F::new(std::f32::consts::PI); +/// let smooth = smoothstep_angle(from, to, F::new(0.5)); +/// ``` +#[cube] +pub fn smoothstep_angle(from: F, to: F, t: F) -> F { + let smooth_t = t * t * (F::new(3.0) - F::new(2.0) * t); + lerp_angle::(from, to, smooth_t) +} + +/// Computes the angle between two 2D vectors. +/// +/// # Arguments +/// +/// * `x1`, `y1` - Components of the first vector +/// * `x2`, `y2` - Components of the second vector +/// +/// # Returns +/// +/// The angle between the vectors in radians +/// +/// # Example +/// +/// ```rust,ignore +/// let angle = vector_angle_2d(F::new(1.0), F::new(0.0), F::new(0.0), F::new(1.0)); +/// assert!((angle - F::new(std::f32::consts::PI / 2.0)).abs() < F::new(1e-6)); +/// ``` +#[cube] +pub fn vector_angle_2d(x1: F, y1: F, x2: F, y2: F) -> F { + let dot = x1 * x2 + y1 * y2; + let det = x1 * y2 - y1 * x2; + F::atan2(det, dot) +} + +/// Rotates a 2D point around the origin by the given angle. +/// +/// # Arguments +/// +/// * `x`, `y` - The point coordinates +/// * `angle` - The rotation angle in radians +/// +/// # Returns +/// +/// A tuple containing the rotated coordinates (x', y') +/// +/// # Example +/// +/// ```rust,ignore +/// let (x, y) = rotate_2d(F::new(1.0), F::new(0.0), F::new(std::f32::consts::PI / 2.0)); +/// assert!(x.abs() < F::new(1e-6) && (y - F::new(1.0)).abs() < F::new(1e-6)); +/// ``` +#[cube] +pub fn rotate_2d(x: F, y: F, angle: F) -> (F, F) { + let cos_a = F::cos(angle); + let sin_a = F::sin(angle); + (x * cos_a - y * sin_a, x * sin_a + y * cos_a) +} + +/// Computes the hypotenuse of a right triangle given the lengths of the other two sides. +/// +/// This function computes `sqrt(x² + y²)` in a numerically stable way that avoids +/// overflow and underflow issues. +/// +/// # Arguments +/// +/// * `x` - Length of one side +/// * `y` - Length of the other side +/// +/// # Returns +/// +/// The length of the hypotenuse +/// +/// # Example +/// +/// ```rust,ignore +/// let hyp = hypot(F::new(3.0), F::new(4.0)); +/// assert!((hyp - F::new(5.0)).abs() < F::new(1e-6)); +/// ``` +#[cube] +pub fn hypot(x: F, y: F) -> F { + F::sqrt(x * x + y * y) +} + +/// Computes the normalized sinc function. +/// +/// The sinc function is defined as: +/// - `sinc(x) = sin(πx) / (πx)` for x ≠ 0 +/// - `sinc(0) = 1` +/// +/// This is the normalized sinc function used in digital signal processing. +/// +/// # Arguments +/// +/// * `x` - The input value +/// +/// # Returns +/// +/// The sinc of the input +/// +/// # Example +/// +/// ```rust,ignore +/// let result = sinc(F::new(0.0)); +/// assert!((result - F::new(1.0)).abs() < F::new(1e-6)); +/// +/// let result = sinc(F::new(1.0)); +/// assert!(result.abs() < F::new(1e-6)); // sinc(1) ≈ 0 +/// ``` +#[cube] +pub fn sinc(x: F) -> F { + let pi_x = F::new(f32::consts::PI) * x; + if F::abs(x) < F::new(1e-8) { + F::new(1.0) + } else { + F::sin(pi_x) / pi_x + } +} From 93dda9b11f30935a8ab101cf61d485e3b8d03ab4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sat, 13 Sep 2025 12:19:02 +0200 Subject: [PATCH 15/23] Update spir-v calls to forked version --- crates/cubecl-spirv/src/extensions.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/crates/cubecl-spirv/src/extensions.rs b/crates/cubecl-spirv/src/extensions.rs index 4a5897327..c2093b0fd 100644 --- a/crates/cubecl-spirv/src/extensions.rs +++ b/crates/cubecl-spirv/src/extensions.rs @@ -78,7 +78,7 @@ pub mod glcompute { } fn tan(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { - b.tan_id(ty, Some(out), input).unwrap(); + b.cl_tan_id(ty, Some(out), input).unwrap(); } fn tanh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { @@ -86,47 +86,47 @@ pub mod glcompute { } fn sinh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { - b.sinh_id(ty, Some(out), input).unwrap(); + b.cl_sinh_id(ty, Some(out), input).unwrap(); } fn cosh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { - b.cosh_id(ty, Some(out), input).unwrap(); + b.cl_cosh_id(ty, Some(out), input).unwrap(); } fn asin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { - b.asin_id(ty, Some(out), input).unwrap(); + b.cl_asin_id(ty, Some(out), input).unwrap(); } fn acos(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { - b.acos_id(ty, Some(out), input).unwrap(); + b.cl_acos_id(ty, Some(out), input).unwrap(); } fn atan(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { - b.atan_id(ty, Some(out), input).unwrap(); + b.cl_atan_id(ty, Some(out), input).unwrap(); } fn asinh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { - b.asinh_id(ty, Some(out), input).unwrap(); + b.cl_asinh_id(ty, Some(out), input).unwrap(); } fn acosh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { - b.acosh_id(ty, Some(out), input).unwrap(); + b.cl_acosh_id(ty, Some(out), input).unwrap(); } fn atanh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { - b.atanh_id(ty, Some(out), input).unwrap(); + b.cl_atanh_id(ty, Some(out), input).unwrap(); } fn degrees(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { - b.degrees_id(ty, Some(out), input).unwrap(); + b.cl_degrees_id(ty, Some(out), input).unwrap(); } fn radians(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { - b.radians_id(ty, Some(out), input).unwrap(); + b.cl_radians_id(ty, Some(out), input).unwrap(); } fn atan2(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word) { - b.atan2_id(ty, Some(out), lhs, rhs).unwrap(); + b.cl_atan2_id(ty, Some(out), lhs, rhs).unwrap(); } fn pow(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word) { From 1319a30d70519b7ad01742ab34e803ae554886ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Tue, 7 Oct 2025 16:19:56 +0200 Subject: [PATCH 16/23] Remove unnecessary trig function in std --- crates/cubecl-std/src/tests/trigonometry.rs | 510 -------------------- crates/cubecl-std/src/trigonometry.rs | 261 ---------- 2 files changed, 771 deletions(-) diff --git a/crates/cubecl-std/src/tests/trigonometry.rs b/crates/cubecl-std/src/tests/trigonometry.rs index f5692202f..ab17421a5 100644 --- a/crates/cubecl-std/src/tests/trigonometry.rs +++ b/crates/cubecl-std/src/tests/trigonometry.rs @@ -80,389 +80,6 @@ pub fn test_to_radians(client: ComputeClient) } } -#[cube(launch_unchecked)] -fn kernel_sincos(input: &Array, sin_output: &mut Array, cos_output: &mut Array) { - if UNIT_POS < input.len() { - let (sin_val, cos_val) = sincos::(input[UNIT_POS]); - sin_output[UNIT_POS] = sin_val; - cos_output[UNIT_POS] = cos_val; - } -} - -pub fn test_sincos(client: ComputeClient) { - let input_data = vec![0.0, PI / 6.0, PI / 4.0, PI / 3.0, PI / 2.0, PI]; - - let input = client.create(f32::as_bytes(&input_data)); - let sin_output = client.empty(input_data.len() * core::mem::size_of::()); - let cos_output = client.empty(input_data.len() * core::mem::size_of::()); - - unsafe { - kernel_sincos::launch_unchecked::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(input_data.len() as u32, 1, 1), - ArrayArg::from_raw_parts::(&input, input_data.len(), 1), - ArrayArg::from_raw_parts::(&sin_output, input_data.len(), 1), - ArrayArg::from_raw_parts::(&cos_output, input_data.len(), 1), - ); - } - - let actual_sin = client.read_one(sin_output); - let actual_sin = f32::from_bytes(&actual_sin); - let actual_cos = client.read_one(cos_output); - let actual_cos = f32::from_bytes(&actual_cos); - - for (i, &angle) in input_data.iter().enumerate() { - let expected_sin = angle.sin(); - let expected_cos = angle.cos(); - - assert!( - (expected_sin - actual_sin[i]).abs() < 1e-6, - "Sin test {} failed: expected {}, got {}", - i, - expected_sin, - actual_sin[i] - ); - - assert!( - (expected_cos - actual_cos[i]).abs() < 1e-6, - "Cos test {} failed: expected {}, got {}", - i, - expected_cos, - actual_cos[i] - ); - } -} - -#[cube(launch_unchecked)] -fn kernel_normalize_angle(input: &Array, output: &mut Array) { - if UNIT_POS < input.len() { - output[UNIT_POS] = normalize_angle::(input[UNIT_POS]); - } -} - -pub fn test_normalize_angle(client: ComputeClient) { - let input_data = vec![ - 0.0, - PI, - TAU, - 3.0 * PI, - 4.0 * PI, - -PI, - -TAU, - -3.0 * PI, - PI + 0.5, - -PI + 0.5, - ]; - - let expected = vec![0.0, PI, 0.0, PI, 0.0, PI, 0.0, PI, PI + 0.5, PI + 0.5]; - - let input = client.create(f32::as_bytes(&input_data)); - let output = client.empty(input_data.len() * core::mem::size_of::()); - - unsafe { - kernel_normalize_angle::launch_unchecked::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(input_data.len() as u32, 1, 1), - ArrayArg::from_raw_parts::(&input, input_data.len(), 1), - ArrayArg::from_raw_parts::(&output, input_data.len(), 1), - ); - } - - let actual = client.read_one(output); - let actual = f32::from_bytes(&actual); - - for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { - assert!( - (expected_val - actual_val).abs() < 1e-5, - "Test {} failed: expected {}, got {}", - i, - expected_val, - actual_val - ); - } -} - -#[cube(launch_unchecked)] -fn kernel_normalize_angle_signed(input: &Array, output: &mut Array) { - if UNIT_POS < input.len() { - output[UNIT_POS] = normalize_angle_signed::(input[UNIT_POS]); - } -} - -pub fn test_normalize_angle_signed(client: ComputeClient) { - let input_data = vec![ - 0.0, - PI, - TAU, - // 3*PI can result in float errors -> add a small offset to the test - 3.0 * PI + 1e-5, - 4.0 * PI + 1e-5, - -PI, - -TAU, - -3.0 * PI + 1e-5, - PI + 0.5, - -PI + 0.5, - ]; - - let expected = vec![ - 0.0, - -PI, - 0.0, - -PI + 1e-5, - 0.0 + 1e-5, - -PI, - 0.0, - -PI + 1e-5, - -PI + 0.5, - -PI + 0.5, - ]; - - let input = client.create(f32::as_bytes(&input_data)); - let output = client.empty(input_data.len() * core::mem::size_of::()); - - unsafe { - kernel_normalize_angle_signed::launch_unchecked::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(input_data.len() as u32, 1, 1), - ArrayArg::from_raw_parts::(&input, input_data.len(), 1), - ArrayArg::from_raw_parts::(&output, input_data.len(), 1), - ); - } - - let actual = client.read_one(output); - let actual = f32::from_bytes(&actual); - - for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { - assert!( - (expected_val - actual_val).abs() < 1e-5, - "Test {} failed: expected {}, got {}", - i, - expected_val, - actual_val - ); - } -} - -#[cube(launch_unchecked)] -fn kernel_lerp_angle(from: &Array, to: &Array, t: &Array, output: &mut Array) { - if UNIT_POS < from.len() { - output[UNIT_POS] = lerp_angle::(from[UNIT_POS], to[UNIT_POS], t[UNIT_POS]); - } -} - -pub fn test_lerp_angle(client: ComputeClient) { - let from_data = vec![0.0, 0.1, PI - 0.1, 0.0]; - let to_data = vec![PI, TAU - 0.1, PI + 0.1, PI]; - let t_data = vec![0.5, 0.5, 0.5, 0.5]; - - let from = client.create(f32::as_bytes(&from_data)); - let to = client.create(f32::as_bytes(&to_data)); - let t = client.create(f32::as_bytes(&t_data)); - let output = client.empty(from_data.len() * core::mem::size_of::()); - - unsafe { - kernel_lerp_angle::launch_unchecked::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(from_data.len() as u32, 1, 1), - ArrayArg::from_raw_parts::(&from, from_data.len(), 1), - ArrayArg::from_raw_parts::(&to, to_data.len(), 1), - ArrayArg::from_raw_parts::(&t, t_data.len(), 1), - ArrayArg::from_raw_parts::(&output, from_data.len(), 1), - ); - } - - let actual = client.read_one(output); - let actual = f32::from_bytes(&actual); - - // Test case 0: 0 to π should give π/2 - assert!( - (actual[0] - PI / 2.0).abs() < 1e-5, - "Lerp angle test 0 failed" - ); - - // Test case 1: wraparound case - should take shortest path - assert!( - actual[1].abs() < 1e-5 || (actual[1] - TAU).abs() < 1e-5, - "Lerp angle test 1 failed: {}", - actual[1] - ); - - // Test case 2: small difference around π - assert!((actual[2] - PI).abs() < 1e-5, "Lerp angle test 2 failed"); - - // Test case 3: 0 to π should give π/2 - assert!( - (actual[3] - PI / 2.0).abs() < 1e-5, - "Lerp angle test 3 failed" - ); -} - -#[cube(launch_unchecked)] -fn kernel_angle_distance(from: &Array, to: &Array, output: &mut Array) { - if UNIT_POS < from.len() { - output[UNIT_POS] = angle_distance::(from[UNIT_POS], to[UNIT_POS]); - } -} - -pub fn test_angle_distance(client: ComputeClient) { - let from_data = vec![0.0, 0.1, PI, 0.0]; - let to_data = vec![PI, TAU - 0.1, 0.0, TAU - 0.1]; - let expected = vec![PI, -0.2, -PI, -0.1]; - - let from = client.create(f32::as_bytes(&from_data)); - let to = client.create(f32::as_bytes(&to_data)); - let output = client.empty(from_data.len() * core::mem::size_of::()); - - unsafe { - kernel_angle_distance::launch_unchecked::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(from_data.len() as u32, 1, 1), - ArrayArg::from_raw_parts::(&from, from_data.len(), 1), - ArrayArg::from_raw_parts::(&to, to_data.len(), 1), - ArrayArg::from_raw_parts::(&output, from_data.len(), 1), - ); - } - - let actual = client.read_one(output); - let actual = f32::from_bytes(&actual); - - for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { - assert!( - (expected_val - actual_val).abs() < 1e-5, - "Angle distance test {} failed: expected {}, got {}", - i, - expected_val, - actual_val - ); - } -} - -#[cube(launch_unchecked)] -fn kernel_vector_angle_2d( - x1: &Array, - y1: &Array, - x2: &Array, - y2: &Array, - output: &mut Array, -) { - if UNIT_POS < x1.len() { - output[UNIT_POS] = - vector_angle_2d::(x1[UNIT_POS], y1[UNIT_POS], x2[UNIT_POS], y2[UNIT_POS]); - } -} - -pub fn test_vector_angle_2d(client: ComputeClient) { - // Simplified test case - let x1_data = vec![1.0]; - let y1_data = vec![0.0]; - let x2_data = vec![0.0]; - let y2_data = vec![1.0]; - let expected = vec![PI / 2.0]; - - let x1 = client.create(f32::as_bytes(&x1_data)); - let y1 = client.create(f32::as_bytes(&y1_data)); - let x2 = client.create(f32::as_bytes(&x2_data)); - let y2 = client.create(f32::as_bytes(&y2_data)); - let output = client.empty(x1_data.len() * core::mem::size_of::()); - - unsafe { - kernel_vector_angle_2d::launch_unchecked::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(x1_data.len() as u32, 1, 1), - ArrayArg::from_raw_parts::(&x1, x1_data.len(), 1), - ArrayArg::from_raw_parts::(&y1, y1_data.len(), 1), - ArrayArg::from_raw_parts::(&x2, x2_data.len(), 1), - ArrayArg::from_raw_parts::(&y2, y2_data.len(), 1), - ArrayArg::from_raw_parts::(&output, x1_data.len(), 1), - ); - } - - let actual = client.read_one(output); - let actual = f32::from_bytes(&actual); - - for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { - assert!( - (expected_val - actual_val).abs() < 1e-5, - "Vector angle 2D test {} failed: expected {}, got {}", - i, - expected_val, - actual_val - ); - } -} - -#[cube(launch_unchecked)] -fn kernel_rotate_2d( - x: &Array, - y: &Array, - angle: &Array, - x_out: &mut Array, - y_out: &mut Array, -) { - if UNIT_POS < x.len() { - let (new_x, new_y) = rotate_2d::(x[UNIT_POS], y[UNIT_POS], angle[UNIT_POS]); - x_out[UNIT_POS] = new_x; - y_out[UNIT_POS] = new_y; - } -} - -pub fn test_rotate_2d(client: ComputeClient) { - let x_data = vec![1.0, 0.0, 1.0, 1.0]; - let y_data = vec![0.0, 1.0, 1.0, 0.0]; - let angle_data = vec![PI / 2.0, PI / 2.0, PI / 4.0, PI]; - - let expected_x = vec![0.0, -1.0, 0.0, -1.0]; - let expected_y = vec![1.0, 0.0, 1.414213562373095, 0.0]; - - let x = client.create(f32::as_bytes(&x_data)); - let y = client.create(f32::as_bytes(&y_data)); - let angle = client.create(f32::as_bytes(&angle_data)); - let x_out = client.empty(x_data.len() * core::mem::size_of::()); - let y_out = client.empty(y_data.len() * core::mem::size_of::()); - - unsafe { - kernel_rotate_2d::launch_unchecked::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(x_data.len() as u32, 1, 1), - ArrayArg::from_raw_parts::(&x, x_data.len(), 1), - ArrayArg::from_raw_parts::(&y, y_data.len(), 1), - ArrayArg::from_raw_parts::(&angle, angle_data.len(), 1), - ArrayArg::from_raw_parts::(&x_out, x_data.len(), 1), - ArrayArg::from_raw_parts::(&y_out, y_data.len(), 1), - ); - } - - let actual_x = client.read_one(x_out); - let actual_x = f32::from_bytes(&actual_x); - let actual_y = client.read_one(y_out); - let actual_y = f32::from_bytes(&actual_y); - - for i in 0..x_data.len() { - assert!( - (expected_x[i] - actual_x[i]).abs() < 1e-5, - "Rotate 2D X test {} failed: expected {}, got {}", - i, - expected_x[i], - actual_x[i] - ); - - assert!( - (expected_y[i] - actual_y[i]).abs() < 1e-5, - "Rotate 2D Y test {} failed: expected {}, got {}", - i, - expected_y[i], - actual_y[i] - ); - } -} - #[cube(launch_unchecked)] fn kernel_hypot(x: &Array, y: &Array, output: &mut Array) { if UNIT_POS < x.len() { @@ -503,130 +120,3 @@ pub fn test_hypot(client: ComputeClient) { ); } } - -#[cube(launch_unchecked)] -fn kernel_sinc(input: &Array, output: &mut Array) { - if UNIT_POS < input.len() { - output[UNIT_POS] = sinc::(input[UNIT_POS]); - } -} - -pub fn test_sinc(client: ComputeClient) { - let input_data = vec![0.0, 1.0, -1.0, 0.5, -0.5, 2.0]; - // Expected values for normalized sinc function: sin(πx)/(πx) - let expected = vec![ - 1.0, // sinc(0) = 1 - 0.0, // sinc(1) ≈ 0 (actually 3.8986e-17, but effectively 0) - 0.0, // sinc(-1) ≈ 0 - 0.6366197723675814, // sinc(0.5) = sin(π/2)/(π/2) = 1/(π/2) ≈ 0.6366 - 0.6366197723675814, // sinc(-0.5) = sinc(0.5) - 0.0, // sinc(2) ≈ 0 - ]; - - let input = client.create(f32::as_bytes(&input_data)); - let output = client.empty(input_data.len() * core::mem::size_of::()); - - unsafe { - kernel_sinc::launch_unchecked::( - &client, - CubeCount::Static(1, 1, 1), - CubeDim::new(input_data.len() as u32, 1, 1), - ArrayArg::from_raw_parts::(&input, input_data.len(), 1), - ArrayArg::from_raw_parts::(&output, input_data.len(), 1), - ); - } - - let actual = client.read_one(output); - let actual = f32::from_bytes(&actual); - - for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { - let tolerance = if i == 1 || i == 2 || i == 5 { - 1e-3 - } else { - 1e-5 - }; // More tolerance for near-zero values - assert!( - (expected_val - actual_val).abs() < tolerance, - "Sinc test {} failed: expected {}, got {}", - i, - expected_val, - actual_val - ); - } -} - -#[macro_export] -macro_rules! testgen_trigonometry { - () => { - mod trigonometry { - use super::*; - use $crate::tests::trigonometry::*; - - #[test] - fn test_to_degrees_conversion() { - let client = TestRuntime::client(&Default::default()); - test_to_degrees::(client); - } - - #[test] - fn test_to_radians_conversion() { - let client = TestRuntime::client(&Default::default()); - test_to_radians::(client); - } - - #[test] - fn test_sincos_computation() { - let client = TestRuntime::client(&Default::default()); - test_sincos::(client); - } - - #[test] - fn test_normalize_angle_positive() { - let client = TestRuntime::client(&Default::default()); - test_normalize_angle::(client); - } - - #[test] - fn test_normalize_angle_signed_range() { - let client = TestRuntime::client(&Default::default()); - test_normalize_angle_signed::(client); - } - - #[test] - fn test_lerp_angle_interpolation() { - let client = TestRuntime::client(&Default::default()); - test_lerp_angle::(client); - } - - #[test] - fn test_angle_distance_calculation() { - let client = TestRuntime::client(&Default::default()); - test_angle_distance::(client); - } - - #[test] - fn test_vector_angle_2d_computation() { - let client = TestRuntime::client(&Default::default()); - test_vector_angle_2d::(client); - } - - #[test] - fn test_rotate_2d_transformation() { - let client = TestRuntime::client(&Default::default()); - test_rotate_2d::(client); - } - - #[test] - fn test_hypot_computation() { - let client = TestRuntime::client(&Default::default()); - test_hypot::(client); - } - - #[test] - fn test_sinc_function() { - let client = TestRuntime::client(&Default::default()); - test_sinc::(client); - } - } - }; -} diff --git a/crates/cubecl-std/src/trigonometry.rs b/crates/cubecl-std/src/trigonometry.rs index 7e370dd19..355bc101b 100644 --- a/crates/cubecl-std/src/trigonometry.rs +++ b/crates/cubecl-std/src/trigonometry.rs @@ -35,232 +35,6 @@ pub fn to_radians(val: F) -> F { val * F::new(f32::consts::PI / 180.0) } -/// Computes both sine and cosine of an angle simultaneously. -/// -/// This can be more efficient than computing sin and cos separately -/// on some GPU architectures. -/// -/// # Arguments -/// -/// * `val` - The angle in radians -/// -/// # Returns -/// -/// A tuple containing (sine, cosine) of the input angle -/// -/// # Example -/// -/// ```rust,ignore -/// let angle = F::new(std::f32::consts::PI / 4.0); -/// let (sin_val, cos_val) = sincos(angle); -/// ``` -#[cube] -pub fn sincos(val: F) -> (F, F) { - (F::sin(val), F::cos(val)) -} - -/// Normalizes an angle to the range [0, 2π). -/// -/// # Arguments -/// -/// * `angle` - The angle in radians to normalize -/// -/// # Returns -/// -/// The angle normalized to the range [0, 2π) -/// -/// # Example -/// -/// ```rust,ignore -/// let angle = F::new(3.0 * std::f32::consts::PI); -/// let normalized = normalize_angle(angle); -/// assert!((normalized - F::new(std::f32::consts::PI)).abs() < F::new(1e-6)); -/// ``` -#[cube] -pub fn normalize_angle(angle: F) -> F { - let tau = F::new(f32::consts::TAU); - angle - F::floor(angle / tau) * tau -} - -/// Normalizes an angle to the range [-π, π). -/// -/// # Arguments -/// -/// * `angle` - The angle in radians to normalize -/// -/// # Returns -/// -/// The angle normalized to the range [-π, π) -/// -/// # Example -/// -/// ```rust,ignore -/// let angle = F::new(3.0 * std::f32::consts::PI); -/// let normalized = normalize_angle_signed(angle); -/// assert!((normalized - F::new(std::f32::consts::PI)).abs() < F::new(1e-6)); -/// ``` -#[cube] -pub fn normalize_angle_signed(angle: F) -> F { - let pi = F::new(f32::consts::PI); - let tau = F::new(f32::consts::TAU); - let normalized = angle - F::floor(angle / tau) * tau; - if normalized >= pi { - normalized - tau - } else { - normalized - } -} - -/// Linear interpolation between two angles, taking the shortest path. -/// -/// This function correctly handles the wraparound at 2π to ensure -/// the interpolation follows the shortest circular arc. -/// -/// # Arguments -/// -/// * `from` - The starting angle in radians -/// * `to` - The ending angle in radians -/// * `t` - The interpolation factor (0.0 = from, 1.0 = to) -/// -/// # Returns -/// -/// The interpolated angle -/// -/// # Example -/// -/// ```rust,ignore -/// let from = F::new(0.1); -/// let to = F::new(std::f32::consts::TAU - 0.1); -/// let mid = lerp_angle(from, to, F::new(0.5)); -/// assert!(mid.abs() < F::new(1e-6) || (mid - F::new(std::f32::consts::TAU)).abs() < F::new(1e-6)); -/// ``` -#[cube] -pub fn lerp_angle(from: F, to: F, t: F) -> F { - let pi = F::new(f32::consts::PI); - let tau = F::new(f32::consts::TAU); - - let diff = to - from; - let normalized_diff = if diff > pi { - diff - tau - } else if diff < -pi { - diff + tau - } else { - diff - }; - - normalize_angle::(from + normalized_diff * t) -} - -/// Computes the shortest angular distance between two angles. -/// -/// # Arguments -/// -/// * `from` - The first angle in radians -/// * `to` - The second angle in radians -/// -/// # Returns -/// -/// The shortest angular distance, positive if `to` is clockwise from `from` -/// -/// # Example -/// -/// ```rust,ignore -/// let angle1 = F::new(0.1); -/// let angle2 = F::new(std::f32::consts::TAU - 0.1); -/// let distance = angle_distance(angle1, angle2); -/// assert!((distance - F::new(-0.2)).abs() < F::new(1e-6)); -/// ``` -#[cube] -pub fn angle_distance(from: F, to: F) -> F { - let pi = F::new(f32::consts::PI); - let tau = F::new(f32::consts::TAU); - - let diff = to - from; - if diff > pi { - diff - tau - } else if diff < -pi { - diff + tau - } else { - diff - } -} - -/// Smoothstep interpolation for angles. -/// -/// Applies smoothstep interpolation (3t² - 2t³) between two angles, -/// taking the shortest circular path. -/// -/// # Arguments -/// -/// * `from` - The starting angle in radians -/// * `to` - The ending angle in radians -/// * `t` - The interpolation factor (0.0 = from, 1.0 = to) -/// -/// # Returns -/// -/// The smoothly interpolated angle -/// -/// # Example -/// -/// ```rust,ignore -/// let from = F::new(0.0); -/// let to = F::new(std::f32::consts::PI); -/// let smooth = smoothstep_angle(from, to, F::new(0.5)); -/// ``` -#[cube] -pub fn smoothstep_angle(from: F, to: F, t: F) -> F { - let smooth_t = t * t * (F::new(3.0) - F::new(2.0) * t); - lerp_angle::(from, to, smooth_t) -} - -/// Computes the angle between two 2D vectors. -/// -/// # Arguments -/// -/// * `x1`, `y1` - Components of the first vector -/// * `x2`, `y2` - Components of the second vector -/// -/// # Returns -/// -/// The angle between the vectors in radians -/// -/// # Example -/// -/// ```rust,ignore -/// let angle = vector_angle_2d(F::new(1.0), F::new(0.0), F::new(0.0), F::new(1.0)); -/// assert!((angle - F::new(std::f32::consts::PI / 2.0)).abs() < F::new(1e-6)); -/// ``` -#[cube] -pub fn vector_angle_2d(x1: F, y1: F, x2: F, y2: F) -> F { - let dot = x1 * x2 + y1 * y2; - let det = x1 * y2 - y1 * x2; - F::atan2(det, dot) -} - -/// Rotates a 2D point around the origin by the given angle. -/// -/// # Arguments -/// -/// * `x`, `y` - The point coordinates -/// * `angle` - The rotation angle in radians -/// -/// # Returns -/// -/// A tuple containing the rotated coordinates (x', y') -/// -/// # Example -/// -/// ```rust,ignore -/// let (x, y) = rotate_2d(F::new(1.0), F::new(0.0), F::new(std::f32::consts::PI / 2.0)); -/// assert!(x.abs() < F::new(1e-6) && (y - F::new(1.0)).abs() < F::new(1e-6)); -/// ``` -#[cube] -pub fn rotate_2d(x: F, y: F, angle: F) -> (F, F) { - let cos_a = F::cos(angle); - let sin_a = F::sin(angle); - (x * cos_a - y * sin_a, x * sin_a + y * cos_a) -} - /// Computes the hypotenuse of a right triangle given the lengths of the other two sides. /// /// This function computes `sqrt(x² + y²)` in a numerically stable way that avoids @@ -285,38 +59,3 @@ pub fn rotate_2d(x: F, y: F, angle: F) -> (F, F) { pub fn hypot(x: F, y: F) -> F { F::sqrt(x * x + y * y) } - -/// Computes the normalized sinc function. -/// -/// The sinc function is defined as: -/// - `sinc(x) = sin(πx) / (πx)` for x ≠ 0 -/// - `sinc(0) = 1` -/// -/// This is the normalized sinc function used in digital signal processing. -/// -/// # Arguments -/// -/// * `x` - The input value -/// -/// # Returns -/// -/// The sinc of the input -/// -/// # Example -/// -/// ```rust,ignore -/// let result = sinc(F::new(0.0)); -/// assert!((result - F::new(1.0)).abs() < F::new(1e-6)); -/// -/// let result = sinc(F::new(1.0)); -/// assert!(result.abs() < F::new(1e-6)); // sinc(1) ≈ 0 -/// ``` -#[cube] -pub fn sinc(x: F) -> F { - let pi_x = F::new(f32::consts::PI) * x; - if F::abs(x) < F::new(1e-8) { - F::new(1.0) - } else { - F::sin(pi_x) / pi_x - } -} From 0af20fcf32fa554611ff94407287a1016fadc4c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Tue, 7 Oct 2025 16:42:25 +0200 Subject: [PATCH 17/23] Fix for refactored launch and reenable std trig tests --- .../src/frontend/operation/binary.rs | 2 -- crates/cubecl-std/src/tests/trigonometry.rs | 28 +++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/crates/cubecl-core/src/frontend/operation/binary.rs b/crates/cubecl-core/src/frontend/operation/binary.rs index 9a628dc17..53be63601 100644 --- a/crates/cubecl-core/src/frontend/operation/binary.rs +++ b/crates/cubecl-core/src/frontend/operation/binary.rs @@ -255,8 +255,6 @@ impl_binary_func!( impl_binary_func!( ArcTan2, atan2, - __expand_atan2, - __expand_atan2_method, Arithmetic::ArcTan2, f16, bf16, diff --git a/crates/cubecl-std/src/tests/trigonometry.rs b/crates/cubecl-std/src/tests/trigonometry.rs index ab17421a5..97c0e3af1 100644 --- a/crates/cubecl-std/src/tests/trigonometry.rs +++ b/crates/cubecl-std/src/tests/trigonometry.rs @@ -120,3 +120,31 @@ pub fn test_hypot(client: ComputeClient) { ); } } + +#[macro_export] +macro_rules! testgen_trigonometry { + () => { + mod trigonometry { + use super::*; + use $crate::tests::trigonometry::*; + + #[test] + fn test_to_degrees_conversion() { + let client = TestRuntime::client(&Default::default()); + test_to_degrees::(client); + } + + #[test] + fn test_to_radians_conversion() { + let client = TestRuntime::client(&Default::default()); + test_to_radians::(client); + } + + #[test] + fn test_hypot_computation() { + let client = TestRuntime::client(&Default::default()); + test_hypot::(client); + } + } + }; +} From 4eff5a820f2bc3cf10f9c178b2aa922205686dd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Tue, 7 Oct 2025 16:53:20 +0200 Subject: [PATCH 18/23] remove dummy implementations for ods math arithmetics --- crates/cubecl-cpu/src/compiler/module.rs | 2 +- .../compiler/visitor/operation/arithmetic.rs | 68 +++---------------- 2 files changed, 11 insertions(+), 59 deletions(-) diff --git a/crates/cubecl-cpu/src/compiler/module.rs b/crates/cubecl-cpu/src/compiler/module.rs index af962a5b1..c07b9a3ad 100644 --- a/crates/cubecl-cpu/src/compiler/module.rs +++ b/crates/cubecl-cpu/src/compiler/module.rs @@ -73,7 +73,7 @@ impl<'a> Module<'a> { pass_manager.add_pass(pass::conversion::create_vector_to_llvm()); pass_manager.add_pass(pass::conversion::create_arith_to_llvm()); pass_manager.add_pass(pass::conversion::create_func_to_llvm()); - // pass_manager.add_pass(pass::conversion::create_math_to_llvm()); + pass_manager.add_pass(pass::conversion::create_math_to_llvm()); pass_manager.add_pass(pass::transform::create_inliner()); pass_manager.add_pass(pass::conversion::create_reconcile_unrealized_casts()); pass_manager.add_pass(pass::transform::create_sccp()); diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index a488192d3..5835c7809 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -3,7 +3,7 @@ use tracel_llvm::mlir_rs::{ dialect::{ arith::{self}, llvm, - ods::{llvm as llvm_ods, vector}, + ods::{llvm as llvm_ods, math as math_ods, vector}, }, ir::Attribute, }; @@ -29,110 +29,68 @@ impl<'a> Visitor<'a> { self.insert_variable(out, result); } Arithmetic::ArcCos(acos) => { - // Arc operations are only available through the ods::math module, - // which can not be properly loaded at the moment. - // Using dummy for now to satisfy compilation of other tests let value = self.get_variable(acos.input); - let abs = self.get_absolute_val(acos.input.ty, value); - self.insert_variable(out, abs); - /*let value = self.get_variable(acos.input); let result = self.append_operation_with_result(math_ods::acos( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } Arithmetic::ArcSin(asin) => { - // Arc operations are only available through the ods::math module, - // which can not be properly loaded at the moment. - // Using dummy for now to satisfy compilation of other tests let value = self.get_variable(asin.input); - let abs = self.get_absolute_val(asin.input.ty, value); - self.insert_variable(out, abs); - /*let value = self.get_variable(asin.input); let result = self.append_operation_with_result(math_ods::asin( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } Arithmetic::ArcTan(atan) => { - // Arc operations are only available through the ods::math module, - // which can not be properly loaded at the moment. - // Using dummy for now to satisfy compilation of other tests let value = self.get_variable(atan.input); - let abs = self.get_absolute_val(atan.input.ty, value); - self.insert_variable(out, abs); - /*let value = self.get_variable(atan.input); let result = self.append_operation_with_result(math_ods::atan( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } Arithmetic::ArcSinh(asinh) => { - // Arc operations are only available through the ods::math module, - // which can not be properly loaded at the moment. - // Using dummy for now to satisfy compilation of other tests let value = self.get_variable(asinh.input); - let abs = self.get_absolute_val(asinh.input.ty, value); - self.insert_variable(out, abs); - /*let value = self.get_variable(asinh.input); let result = self.append_operation_with_result(math_ods::asinh( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } Arithmetic::ArcCosh(acosh) => { - // Arc operations are only available through the ods::math module, - // which can not be properly loaded at the moment. - // Using dummy for now to satisfy compilation of other tests let value = self.get_variable(acosh.input); - let abs = self.get_absolute_val(acosh.input.ty, value); - self.insert_variable(out, abs); - /*let value = self.get_variable(acosh.input); let result = self.append_operation_with_result(math_ods::acosh( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } Arithmetic::ArcTanh(atanh) => { - // Arc operations are only available through the ods::math module, - // which can not be properly loaded at the moment. - // Using dummy for now to satisfy compilation of other tests let value = self.get_variable(atanh.input); - let abs = self.get_absolute_val(atanh.input.ty, value); - self.insert_variable(out, abs); - /*let value = self.get_variable(atanh.input); let result = self.append_operation_with_result(math_ods::atanh( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } Arithmetic::ArcTan2(atan2) => { - // Arc operations are only available through the ods::math module, - // which can not be properly loaded at the moment. - // Using dummy for now to satisfy compilation of other tests - let value = self.get_variable(atan2.lhs); - let abs = self.get_absolute_val(atan2.lhs.ty, value); - self.insert_variable(out, abs); - /*let (lhs, rhs) = self.get_binary_op_variable(atan2.lhs, atan2.rhs); + let (lhs, rhs) = self.get_binary_op_variable(atan2.lhs, atan2.rhs); let result = self.append_operation_with_result(math_ods::atan_2( self.context, lhs, rhs, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } Arithmetic::SaturatingAdd(_) => { unreachable!("Should be removed by preprocessor") @@ -559,19 +517,13 @@ impl<'a> Visitor<'a> { self.insert_variable(out, result); } Arithmetic::Tan(tan) => { - // Tan operations are only available through the ods::math module, - // which can not be properly loaded at the moment. - // Using dummy for now to satisfy compilation of other tests let value = self.get_variable(tan.input); - let abs = self.get_absolute_val(tan.input.ty, value); - self.insert_variable(out, abs); - /*let value = self.get_variable(tan.input); let result = self.append_operation_with_result(math_ods::tan( self.context, value, self.location, )); - self.insert_variable(out, result);*/ + self.insert_variable(out, result); } Arithmetic::SaturatingSub(_) => { unreachable!("Should be removed by preprocessor") From 8b9a348d4fe074b535c462a05e816750e87c9b43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 6 Nov 2025 13:59:12 +0100 Subject: [PATCH 19/23] Add rsqrt and add math_to_libm pass --- .../cubecl-core/src/frontend/element/float.rs | 1 + .../src/frontend/element/float/typemap.rs | 1 + .../src/frontend/operation/unary.rs | 12 ++++ crates/cubecl-core/src/runtime_tests/unary.rs | 67 +++++++++++++++++++ crates/cubecl-cpp/src/shared/base.rs | 3 + crates/cubecl-cpp/src/shared/instruction.rs | 2 + crates/cubecl-cpp/src/shared/unary.rs | 1 + crates/cubecl-cpu/src/compiler/module.rs | 1 + .../compiler/visitor/operation/arithmetic.rs | 9 +++ crates/cubecl-ir/src/arithmetic.rs | 2 + crates/cubecl-ir/src/processing.rs | 3 + crates/cubecl-opt/src/instructions.rs | 1 + crates/cubecl-opt/src/passes/constant_prop.rs | 11 +++ crates/cubecl-spirv/src/arithmetic.rs | 11 +++ .../cubecl-wgpu/src/compiler/wgsl/compiler.rs | 4 ++ .../src/compiler/wgsl/instructions.rs | 8 +++ 16 files changed, 137 insertions(+) diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs index 43d80984b..69b558619 100644 --- a/crates/cubecl-core/src/frontend/element/float.rs +++ b/crates/cubecl-core/src/frontend/element/float.rs @@ -41,6 +41,7 @@ pub trait Float: + Powf + Powi + Sqrt + + Rsqrt + Round + Floor + Ceil diff --git a/crates/cubecl-core/src/frontend/element/float/typemap.rs b/crates/cubecl-core/src/frontend/element/float/typemap.rs index 265c4422e..52a7b622b 100644 --- a/crates/cubecl-core/src/frontend/element/float/typemap.rs +++ b/crates/cubecl-core/src/frontend/element/float/typemap.rs @@ -259,6 +259,7 @@ impl ArcTan2 for ElemExpand {} impl Powf for ElemExpand {} impl Powi for ElemExpand {} impl Sqrt for ElemExpand {} +impl Rsqrt for ElemExpand {} impl Round for ElemExpand {} impl Floor for ElemExpand {} impl Ceil for ElemExpand {} diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index 221226643..82e482249 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -331,6 +331,18 @@ impl_unary_func!( f32, f64 ); +impl_unary_func!( + Rsqrt, + rsqrt, + __expand_rsqrt, + Arithmetic::Rsqrt, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); impl_unary_func!( Round, round, diff --git a/crates/cubecl-core/src/runtime_tests/unary.rs b/crates/cubecl-core/src/runtime_tests/unary.rs index 1c9d75df0..b22e26f70 100644 --- a/crates/cubecl-core/src/runtime_tests/unary.rs +++ b/crates/cubecl-core/src/runtime_tests/unary.rs @@ -401,6 +401,27 @@ test_unary_impl!(test_cosh, F, F::cosh, [ } ]); +test_unary_impl!(test_tanh, F, F::tanh, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1., -1., 2., -2.], + expected: as_type![F: 0., 0.7615941559, -0.7615941559, 0.9640275801, -0.9640275801] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1., -1., 2.], + expected: as_type![F: 0., 0.7615941559, -0.7615941559, 0.9640275801] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1., -1., 2.], + expected: as_type![F: 0., 0.7615941559, -0.7615941559, 0.9640275801] + } +]); + test_unary_impl!(test_asinh, F, F::asinh, [ { input_vectorization: 1, @@ -464,6 +485,48 @@ test_unary_impl!(test_atanh, F, F::atanh, [ } ]); +test_unary_impl!(test_sqrt, F, F::sqrt, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1., 4., 9., 16., 25.], + expected: as_type![F: 0., 1., 2., 3., 4., 5.] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1., 4., 9.], + expected: as_type![F: 0., 1., 2., 3.] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1., 4., 9.], + expected: as_type![F: 0., 1., 2., 3.] + } +]); + +test_unary_impl!(test_rsqrt, F, F::rsqrt, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 1., 4., 9., 16., 25.], + expected: as_type![F: 1., 0.5, 0.33333333333, 0.25, 0.2] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 1., 4., 9., 16.], + expected: as_type![F: 1., 0.5, 0.33333333333, 0.25] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 1., 4., 9., 16.], + expected: as_type![F: 1., 0.5, 0.33333333333, 0.25] + } +]); + test_unary_impl!(test_degrees, F, F::to_degrees, [ { input_vectorization: 1, @@ -770,8 +833,10 @@ macro_rules! testgen_unary { add_test!(test_sin); add_test!(test_cos); + add_test!(test_tan); add_test!(test_sinh); add_test!(test_cosh); + add_test!(test_tanh); add_test!(test_asin); add_test!(test_acos); add_test!(test_atan); @@ -782,6 +847,8 @@ macro_rules! testgen_unary { add_test!(test_radians); add_test!(test_normalize); add_test!(test_magnitude); + add_test!(test_sqrt); + add_test!(test_rsqrt); add_test!(test_abs); add_test!(test_is_nan); add_test!(test_is_inf); diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index 7b531c9c8..b46c68e36 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -1025,6 +1025,9 @@ impl CppCompiler { gpu::Arithmetic::Sqrt(op) => { instructions.push(Instruction::Sqrt(self.compile_unary(op, out))) } + gpu::Arithmetic::Rsqrt(op) => { + instructions.push(Instruction::Rsqrt(self.compile_unary(op, out))) + } gpu::Arithmetic::Erf(op) => { let instruction = Instruction::Erf(self.compile_unary(op, out)); D::register_instruction_extension(&mut self.extensions, &instruction); diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index 78bf5be77..d7a12d20c 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -180,6 +180,7 @@ pub enum Instruction { Powf(BinaryInstruction), Powi(BinaryInstruction), Sqrt(UnaryInstruction), + Rsqrt(UnaryInstruction), Min(BinaryInstruction), Max(BinaryInstruction), Not(UnaryInstruction), @@ -543,6 +544,7 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out), + Instruction::Rsqrt(it) => Rsqrt::format(f, &it.input, &it.out), Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Not(it) => Not::format(f, &it.input, &it.out), diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index e80d79ce0..6a55feeb9 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -161,6 +161,7 @@ function!(ArcSinh, "asinh", false); function!(ArcCosh, "acosh", false); function!(ArcTanh, "atanh", false); function!(Sqrt, "sqrt"); +function!(Rsqrt, "rsqrt"); function!(Exp, "exp"); function!(Ceil, "ceil"); function!(Floor, "floor"); diff --git a/crates/cubecl-cpu/src/compiler/module.rs b/crates/cubecl-cpu/src/compiler/module.rs index c07b9a3ad..e76f60c87 100644 --- a/crates/cubecl-cpu/src/compiler/module.rs +++ b/crates/cubecl-cpu/src/compiler/module.rs @@ -74,6 +74,7 @@ impl<'a> Module<'a> { pass_manager.add_pass(pass::conversion::create_arith_to_llvm()); pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager.add_pass(pass::conversion::create_math_to_llvm()); + pass_manager.add_pass(pass::conversion::create_math_to_libm()); pass_manager.add_pass(pass::transform::create_inliner()); pass_manager.add_pass(pass::conversion::create_reconcile_unrealized_casts()); pass_manager.add_pass(pass::transform::create_sccp()); diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index 5835c7809..423fac127 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -479,6 +479,15 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, output); } + Arithmetic::Rsqrt(rsqrt) => { + let input = self.get_variable(rsqrt.input); + let output = self.append_operation_with_result(math_ods::rsqrt( + self.context, + input, + self.location, + )); + self.insert_variable(out, output); + } Arithmetic::Sin(sin) => { let input = self.get_variable(sin.input); let output = self.append_operation_with_result(llvm_ods::intr_sin( diff --git a/crates/cubecl-ir/src/arithmetic.rs b/crates/cubecl-ir/src/arithmetic.rs index 9002f1d70..088f624dd 100644 --- a/crates/cubecl-ir/src/arithmetic.rs +++ b/crates/cubecl-ir/src/arithmetic.rs @@ -41,6 +41,7 @@ pub enum Arithmetic { Powf(BinaryOperator), Powi(BinaryOperator), Sqrt(UnaryOperator), + Rsqrt(UnaryOperator), Round(UnaryOperator), Floor(UnaryOperator), Ceil(UnaryOperator), @@ -94,6 +95,7 @@ impl Display for Arithmetic { Arithmetic::Powf(op) => write!(f, "{}.pow({})", op.lhs, op.rhs), Arithmetic::Powi(op) => write!(f, "{}.powi({})", op.lhs, op.rhs), Arithmetic::Sqrt(op) => write!(f, "{}.sqrt()", op.input), + Arithmetic::Rsqrt(op) => write!(f, "{}.rsqrt()", op.input), Arithmetic::Round(op) => write!(f, "{}.round()", op.input), Arithmetic::Floor(op) => write!(f, "{}.floor()", op.input), Arithmetic::Ceil(op) => write!(f, "{}.ceil()", op.input), diff --git a/crates/cubecl-ir/src/processing.rs b/crates/cubecl-ir/src/processing.rs index fdb3c2681..fa154c2a0 100644 --- a/crates/cubecl-ir/src/processing.rs +++ b/crates/cubecl-ir/src/processing.rs @@ -159,6 +159,9 @@ impl ScopeProcessing { Arithmetic::Sqrt(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } + Arithmetic::Rsqrt(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } Arithmetic::Round(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index 0ed147ea0..2e6d8f663 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -106,6 +106,7 @@ impl Optimizer { | Arithmetic::Degrees(unary_operator) | Arithmetic::Radians(unary_operator) | Arithmetic::Sqrt(unary_operator) + | Arithmetic::Rsqrt(unary_operator) | Arithmetic::Round(unary_operator) | Arithmetic::Floor(unary_operator) | Arithmetic::Ceil(unary_operator) diff --git a/crates/cubecl-opt/src/passes/constant_prop.rs b/crates/cubecl-opt/src/passes/constant_prop.rs index d22cfe29f..af135af35 100644 --- a/crates/cubecl-opt/src/passes/constant_prop.rs +++ b/crates/cubecl-opt/src/passes/constant_prop.rs @@ -443,6 +443,17 @@ fn try_const_eval_arithmetic(op: &mut Arithmetic) -> Option } } Arithmetic::Sqrt(op) => const_eval_float!(op.input; num::Float::sqrt), + Arithmetic::Rsqrt(op) => { + use ConstantScalarValue::*; + if let Some(input) = op.input.as_const() { + match input { + Float(input, kind) => Some(ConstantScalarValue::Float(1. / input.sqrt(), kind)), + _ => unreachable!(), + } + } else { + None + } + } Arithmetic::Round(op) => const_eval_float!(op.input; num::Float::round), Arithmetic::Floor(op) => const_eval_float!(op.input; num::Float::floor), Arithmetic::Ceil(op) => const_eval_float!(op.input; num::Float::ceil), diff --git a/crates/cubecl-spirv/src/arithmetic.rs b/crates/cubecl-spirv/src/arithmetic.rs index f96de0c07..42d18cdee 100644 --- a/crates/cubecl-spirv/src/arithmetic.rs +++ b/crates/cubecl-spirv/src/arithmetic.rs @@ -456,6 +456,17 @@ impl SpirvCompiler { } }) } + Arithmetic::Rsqrt(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + let sqrt = b.id(); + T::sqrt(b, ty, input, sqrt); + let one = out_ty.const_u32(b, 1); + b.f_div(ty, Some(out), one, sqrt).unwrap(); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } Arithmetic::Round(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { T::round(b, ty, input, out); diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index f3eebe272..492fc5e85 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -824,6 +824,10 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(out), }), + cube::Arithmetic::Rsqrt(op) => instructions.push(wgsl::Instruction::Rsqrt { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round { input: self.compile_variable(op.input), out: self.compile_variable(out), diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index 0f6837901..350e74702 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -179,6 +179,10 @@ pub enum Instruction { input: Variable, out: Variable, }, + Rsqrt { + input: Variable, + out: Variable, + }, Recip { input: Variable, out: Variable, @@ -648,6 +652,10 @@ impl Display for Instruction { let out = out.fmt_left(); writeln!(f, "{out} = sqrt({input});") } + Instruction::Rsqrt { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = rsqrt({input});") + } Instruction::Log1p { input, out } => { let out = out.fmt_left(); writeln!(f, "{out} = log({input} + 1.0);") From 0d51481edc76c8a259be5d0695f275d46211e15c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 6 Nov 2025 14:02:26 +0100 Subject: [PATCH 20/23] Fix formatting of string to satisfy lints --- crates/cubecl-cuda/src/compute/command.rs | 8 ++++---- crates/cubecl-cuda/src/compute/storage/cpu.rs | 3 +-- crates/cubecl-cuda/src/compute/storage/gpu.rs | 5 +---- crates/cubecl-hip/src/compute/command.rs | 2 +- crates/cubecl-hip/src/compute/storage/cpu.rs | 3 +-- crates/cubecl-hip/src/compute/storage/gpu.rs | 2 +- crates/cubecl-runtime/src/stream/scheduler.rs | 3 +-- 7 files changed, 10 insertions(+), 16 deletions(-) diff --git a/crates/cubecl-cuda/src/compute/command.rs b/crates/cubecl-cuda/src/compute/command.rs index f8f31c8c7..920f7ff98 100644 --- a/crates/cubecl-cuda/src/compute/command.rs +++ b/crates/cubecl-cuda/src/compute/command.rs @@ -429,12 +429,12 @@ pub(crate) unsafe fn write_to_gpu( unsafe { cuMemcpy2DAsync_v2(&cpy, stream) .result() - .map_err(|e| IoError::Unknown(format!("CUDA memcpy failed: {}", e)))?; + .map_err(|e| IoError::Unknown(format!("CUDA memcpy failed: {e}")))?; } } else { unsafe { cudarc::driver::result::memcpy_htod_async(dst_ptr, data, stream) - .map_err(|e| IoError::Unknown(format!("CUDA 2D memcpy failed: {}", e)))?; + .map_err(|e| IoError::Unknown(format!("CUDA 2D memcpy failed: {e}")))?; } }; @@ -454,7 +454,7 @@ pub(crate) unsafe fn write_to_cpu( if rank <= 1 { unsafe { cudarc::driver::result::memcpy_dtoh_async(bytes.deref_mut(), resource_ptr, stream) - .map_err(|e| IoError::Unknown(format!("CUDA memcpy failed: {}", e)))?; + .map_err(|e| IoError::Unknown(format!("CUDA memcpy failed: {e}")))?; } return Ok(()); } @@ -480,7 +480,7 @@ pub(crate) unsafe fn write_to_cpu( unsafe { cuMemcpy2DAsync_v2(&cpy, stream) .result() - .map_err(|e| IoError::Unknown(format!("CUDA 2D memcpy failed: {}", e)))?; + .map_err(|e| IoError::Unknown(format!("CUDA 2D memcpy failed: {e}")))?; } Ok(()) diff --git a/crates/cubecl-cuda/src/compute/storage/cpu.rs b/crates/cubecl-cuda/src/compute/storage/cpu.rs index aec974b15..7d834fc33 100644 --- a/crates/cubecl-cuda/src/compute/storage/cpu.rs +++ b/crates/cubecl-cuda/src/compute/storage/cpu.rs @@ -83,8 +83,7 @@ impl ComputeStorage for PinnedMemoryStorage { if result != cudarc::driver::sys::CUresult::CUDA_SUCCESS { return Err(IoError::Unknown(format!( - "cuMemAllocHost_v2 failed with error code: {:?}", - result + "cuMemAllocHost_v2 failed with error code: {result:?}" ))); } diff --git a/crates/cubecl-cuda/src/compute/storage/gpu.rs b/crates/cubecl-cuda/src/compute/storage/gpu.rs index 5378b2a3c..94ddae8f2 100644 --- a/crates/cubecl-cuda/src/compute/storage/gpu.rs +++ b/crates/cubecl-cuda/src/compute/storage/gpu.rs @@ -147,10 +147,7 @@ impl ComputeStorage for GpuStorage { return Err(IoError::BufferTooBig(size as usize)); } Err(other) => { - return Err(IoError::Unknown(format!( - "CUDA allocation error: {}", - other - ))); + return Err(IoError::Unknown(format!("CUDA allocation error: {other}"))); } }; diff --git a/crates/cubecl-hip/src/compute/command.rs b/crates/cubecl-hip/src/compute/command.rs index ac9df1b20..456039503 100644 --- a/crates/cubecl-hip/src/compute/command.rs +++ b/crates/cubecl-hip/src/compute/command.rs @@ -291,7 +291,7 @@ impl<'a> Command<'a> { ); if status != HIP_SUCCESS { - return Err(IoError::Unknown(format!("HIP memcpy failed: {}", status))); + return Err(IoError::Unknown(format!("HIP memcpy failed: {status}"))); } } return Ok(()); diff --git a/crates/cubecl-hip/src/compute/storage/cpu.rs b/crates/cubecl-hip/src/compute/storage/cpu.rs index dbb477ef2..04a93ca41 100644 --- a/crates/cubecl-hip/src/compute/storage/cpu.rs +++ b/crates/cubecl-hip/src/compute/storage/cpu.rs @@ -89,8 +89,7 @@ impl ComputeStorage for PinnedMemoryStorage { if result != HIP_SUCCESS { return Err(IoError::Unknown(format!( - "cuMemAllocHost_v2 failed with error code: {:?}", - result + "cuMemAllocHost_v2 failed with error code: {result:?}" ))); } diff --git a/crates/cubecl-hip/src/compute/storage/gpu.rs b/crates/cubecl-hip/src/compute/storage/gpu.rs index 8ff39c243..eeae0f30e 100644 --- a/crates/cubecl-hip/src/compute/storage/gpu.rs +++ b/crates/cubecl-hip/src/compute/storage/gpu.rs @@ -126,7 +126,7 @@ impl ComputeStorage for GpuStorage { match status { HIP_SUCCESS => {} other => { - return Err(IoError::Unknown(format!("HIP allocation error: {}", other))); + return Err(IoError::Unknown(format!("HIP allocation error: {other}"))); } } self.memory.insert(id, dptr); diff --git a/crates/cubecl-runtime/src/stream/scheduler.rs b/crates/cubecl-runtime/src/stream/scheduler.rs index e3073807e..a75ff7fdd 100644 --- a/crates/cubecl-runtime/src/stream/scheduler.rs +++ b/crates/cubecl-runtime/src/stream/scheduler.rs @@ -165,8 +165,7 @@ impl SchedulerMultiStream { |level| !matches!(level, StreamingLogLevel::Disabled), || { format!( - "Flushing streams {:?} before registering more tasks on {stream_id}", - to_flush + "Flushing streams {to_flush:?} before registering more tasks on {stream_id}" ) }, ); From 6ad1e5fe52b0b6652d40ddbe32bc5fa7afbb5fd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 6 Nov 2025 14:46:05 +0100 Subject: [PATCH 21/23] Merge branch 'main' into feature/arc-trigonomic-functions --- .github/workflows/ci.yml | 18 +- .github/workflows/publish-template.yml | 24 - .github/workflows/publish.yml | 240 ++- Cargo.toml | 18 +- crates/cubecl-attention/Cargo.toml | 13 +- crates/cubecl-attention/src/base.rs | 34 +- .../cubecl-attention/src/components/args.rs | 872 ++++++--- .../src/components/batch/base.rs | 7 +- .../src/components/batch/entry_point.rs | 36 +- .../src/components/batch/hypercube/base.rs | 2 - .../src/components/batch/mod.rs | 2 +- .../batch/{dummy => simple}/attention.rs | 12 +- .../batch/{dummy => simple}/config.rs | 12 +- .../src/components/batch/simple/mod.rs | 6 + .../batch/{dummy => simple}/setup.rs | 19 +- .../accelerated/attention.rs} | 168 +- .../components/fragment/accelerated/config.rs | 66 + .../dummy => fragment/accelerated}/mod.rs | 2 +- .../accelerated/setup.rs | 19 +- .../attention_matmul => fragment}/base.rs | 83 +- .../fragment/dummy_register/attention.rs | 495 +++++ .../fragment/dummy_register/config.rs | 116 ++ .../dummy_register/mod.rs | 4 +- .../dummy_register/setup.rs | 22 +- .../src/components/fragment/fragments.rs | 74 + .../attention_matmul => fragment}/mod.rs | 3 + .../fragment/unit_register/attention.rs | 345 ++++ .../fragment/unit_register/config.rs | 73 + .../unit_register}/mod.rs | 4 +- .../fragment/unit_register/setup.rs | 42 + .../src/components/global/base.rs | 26 +- .../src/components/global/dummy/mod.rs | 12 - .../src/components/global/dummy/read.rs | 226 --- .../src/components/global/dummy/writer.rs | 87 - .../src/components/global/mod.rs | 2 +- .../global/{dummy => simple}/attention.rs | 106 +- .../global/{dummy => simple}/config.rs | 18 +- .../src/components/global/simple/mod.rs | 10 + .../components/global/simple/reader/base.rs | 15 + .../components/global/simple/reader/key.rs | 109 ++ .../components/global/simple/reader/mask.rs | 147 ++ .../components/global/simple/reader/mod.rs | 11 + .../components/global/simple/reader/query.rs | 48 + .../components/global/simple/reader/value.rs | 103 + .../global/{dummy => simple}/setup.rs | 14 +- .../components/global/simple/writer/mod.rs | 26 + .../components/global/simple/writer/plane.rs | 91 + .../components/global/simple/writer/unit.rs | 83 + .../src/components/line_size.rs | 8 +- .../cubecl-attention/src/components/mask.rs | 94 - crates/cubecl-attention/src/components/mod.rs | 3 +- .../src/components/problem.rs | 4 +- .../src/components/selection.rs | 1 + .../cubecl-attention/src/components/spec.rs | 2 +- .../src/components/stage/base.rs | 106 +- .../src/components/stage/dummy/attention.rs | 252 --- .../components/stage/kv_reuse_attention.rs | 281 +++ .../src/components/stage/mod.rs | 7 +- .../src/components/stage/partitioner.rs | 16 + .../src/components/stage/plane/attention.rs | 35 + .../src/components/stage/plane/config.rs | 87 + .../components/stage/{dummy => plane}/mod.rs | 2 - .../stage/{dummy => plane}/setup.rs | 42 +- .../stage/{dummy => }/tile_partitions.rs | 291 +-- .../src/components/stage/unit/attention.rs | 35 + .../stage/{dummy => unit}/config.rs | 51 +- .../src/components/stage/unit/mod.rs | 7 + .../src/components/stage/unit/setup.rs | 115 ++ .../src/components/tile/base.rs | 214 ++- .../src/components/tile/dummy/attention.rs | 140 -- .../attention_matmul/accelerated/config.rs | 182 -- .../attention_matmul/dummy_register/config.rs | 171 -- .../attention_matmul/dummy_register/matmul.rs | 215 --- .../tile/dummy/fragment/accumulator.rs | 112 -- .../tile/dummy/fragment/key_value.rs | 100 - .../src/components/tile/dummy/fragment/mod.rs | 9 - .../components/tile/dummy/fragment/query.rs | 23 - .../components/tile/dummy/fragment/softmax.rs | 184 -- .../src/components/tile/dummy/mod.rs | 9 - .../src/components/tile/dummy/setup.rs | 36 - .../src/components/tile/mod.rs | 6 +- .../src/components/tile/row/mod.rs | 7 + .../src/components/tile/row/reduce/base.rs | 61 + .../tile/row/reduce/broadcast_reducer.rs | 91 + .../src/components/tile/row/reduce/mod.rs | 11 + .../tile/row/reduce/naive_reducer.rs | 60 + .../components/tile/row/reduce/reduce_op.rs | 58 + .../tile/row/reduce/unit_reducer.rs | 23 + .../src/components/tile/row/rowwise.rs | 202 ++ .../src/components/tile/row/state.rs | 38 + .../src/components/tile/rowwise.rs | 116 -- .../src/components/tile/tiles/accumulator.rs | 39 +- .../src/components/tile/tiles/key_value.rs | 108 ++ .../src/components/tile/tiles/mask.rs | 159 ++ .../src/components/tile/tiles/mod.rs | 6 + .../src/components/tile/tiles/query.rs | 27 + .../src/components/tile/tiles/softmax.rs | 71 +- .../cubecl-attention/src/kernels/algorithm.rs | 7 +- crates/cubecl-attention/src/kernels/dummy.rs | 48 +- crates/cubecl-attention/src/kernels/mod.rs | 1 + crates/cubecl-attention/src/kernels/unit.rs | 35 + crates/cubecl-attention/src/lib.rs | 7 +- .../src/tests/attention_test_launcher.rs | 42 +- .../cubecl-attention/src/tests/macros/mod.rs | 805 +------- .../src/tests/macros/suite.rs | 958 +++++++++ .../cubecl-attention/src/tests/test_utils.rs | 54 +- crates/cubecl-common/Cargo.toml | 10 +- crates/cubecl-common/src/device.rs | 566 +++++- crates/cubecl-common/src/lib.rs | 3 + crates/cubecl-common/src/quant/mod.rs | 2 + .../src => cubecl-common/src/quant}/scheme.rs | 26 +- crates/cubecl-convolution/Cargo.toml | 14 +- .../src/components/config.rs | 62 +- .../src/components/global/args.rs | 147 +- .../src/components/global/base.rs | 22 +- .../src/components/global/entry_point.rs | 66 +- .../src/components/global/layout/bias.rs | 34 + .../src/components/global/layout/im2col.rs | 71 +- .../src/components/global/layout/mod.rs | 2 + .../src/components/global/layout/spatial.rs | 49 +- .../src/components/global/layout/weight.rs | 68 +- .../src/components/global/layout/write.rs | 41 +- .../src/components/global/memory/tma.rs | 9 +- .../global/multi_stage/tma/config.rs | 2 +- .../global/multi_stage/tma/convolution.rs | 61 +- .../global/multi_stage/tma/launch.rs | 2 +- .../global/multi_stage/tma/setup.rs | 2 +- .../src/components/global/read/reader/bias.rs | 19 +- .../global/read/reader/im2col_tma.rs | 81 +- .../components/global/read/reader/layout.rs | 36 +- .../global/read/reader/weight_tma.rs | 4 +- .../global/single_stage/simple/convolution.rs | 53 +- .../global/single_stage/simple/launch.rs | 2 +- .../global/single_stage/simple/setup.rs | 2 +- .../global/single_stage/tma/convolution.rs | 52 +- .../global/single_stage/tma/launch.rs | 2 +- .../global/single_stage/tma/setup.rs | 2 +- .../cubecl-convolution/src/components/mod.rs | 4 +- .../src/components/problem.rs | 1 + .../src/components/selection.rs | 19 +- .../src/components/stage/reader.rs | 15 +- .../src/kernels/layered/algorithm/mod.rs | 6 +- .../layered/algorithm/multi_stage_tma.rs | 4 +- .../src/kernels/layered/algorithm/simple.rs | 4 +- .../kernels/layered/algorithm/simple_tma.rs | 6 +- .../kernels/layered/selector/select_kernel.rs | 18 +- crates/cubecl-convolution/src/launch.rs | 21 +- .../src/tests/convolution_test_launcher.rs | 19 +- .../src/tests/test_macros/mod.rs | 2 +- .../src/tests/test_macros/suite.rs | 6 +- .../src/tests/test_utils.rs | 6 +- crates/cubecl-core/Cargo.toml | 8 +- crates/cubecl-core/src/codegen/integrator.rs | 13 +- crates/cubecl-core/src/compute/launcher.rs | 4 +- crates/cubecl-core/src/frontend/branch.rs | 19 +- crates/cubecl-core/src/frontend/comment.rs | 10 - .../src/frontend/container/line/base.rs | 6 +- .../src/frontend/container/line/ops.rs | 4 +- .../src/frontend/container/sequence/base.rs | 11 +- .../src/frontend/container/sequence/launch.rs | 8 + .../src/frontend/container/shared_memory.rs | 4 +- .../src/frontend/container/tensor/base.rs | 2 +- .../src/frontend/container/tensor/launch.rs | 9 +- crates/cubecl-core/src/frontend/debug.rs | 11 + .../src/frontend/element/atomic.rs | 6 +- .../cubecl-core/src/frontend/element/base.rs | 5 + .../cubecl-core/src/frontend/element/bool.rs | 9 +- .../src/frontend/element/cube_elem.rs | 21 +- .../cubecl-core/src/frontend/element/float.rs | 16 +- .../src/frontend/element/float/fp4.rs | 18 +- .../src/frontend/element/float/fp6.rs | 10 +- .../src/frontend/element/float/fp8.rs | 23 +- .../src/frontend/element/float/relaxed.rs | 9 +- .../frontend/element/float/tensor_float.rs | 9 +- .../src/frontend/element/float/typemap.rs | 7 +- .../cubecl-core/src/frontend/element/int.rs | 9 +- .../src/frontend/element/int/typemap.rs | 8 +- .../cubecl-core/src/frontend/element/uint.rs | 9 +- crates/cubecl-core/src/frontend/mod.rs | 2 - .../src/frontend/operation/unary.rs | 20 +- crates/cubecl-core/src/frontend/options.rs | 49 +- crates/cubecl-core/src/frontend/plane.rs | 147 ++ crates/cubecl-core/src/id.rs | 5 +- .../cubecl-core/src/post_processing/unroll.rs | 9 +- crates/cubecl-core/src/prelude.rs | 2 +- crates/cubecl-core/src/runtime.rs | 26 +- .../cubecl-core/src/runtime_tests/assign.rs | 6 +- .../cubecl-core/src/runtime_tests/atomic.rs | 8 +- .../cubecl-core/src/runtime_tests/barrier.rs | 10 +- .../cubecl-core/src/runtime_tests/binary.rs | 17 +- .../cubecl-core/src/runtime_tests/branch.rs | 16 +- .../cubecl-core/src/runtime_tests/cluster.rs | 2 +- crates/cubecl-core/src/runtime_tests/cmma.rs | 42 +- .../src/runtime_tests/comparison.rs | 2 +- .../src/runtime_tests/const_match.rs | 2 +- .../src/runtime_tests/constants.rs | 2 +- crates/cubecl-core/src/runtime_tests/debug.rs | 6 +- .../src/runtime_tests/different_rank.rs | 6 +- crates/cubecl-core/src/runtime_tests/enums.rs | 6 +- crates/cubecl-core/src/runtime_tests/index.rs | 2 +- .../cubecl-core/src/runtime_tests/launch.rs | 8 +- crates/cubecl-core/src/runtime_tests/line.rs | 16 +- .../cubecl-core/src/runtime_tests/metadata.rs | 14 +- .../src/runtime_tests/minifloat.rs | 8 +- crates/cubecl-core/src/runtime_tests/mod.rs | 2 + .../cubecl-core/src/runtime_tests/numeric.rs | 44 + crates/cubecl-core/src/runtime_tests/plane.rs | 303 ++- .../src/runtime_tests/saturating.rs | 8 +- .../cubecl-core/src/runtime_tests/sequence.rs | 6 +- crates/cubecl-core/src/runtime_tests/slice.rs | 18 +- .../cubecl-core/src/runtime_tests/stream.rs | 4 +- .../src/runtime_tests/synchronization.rs | 6 +- .../cubecl-core/src/runtime_tests/tensor.rs | 2 +- .../src/runtime_tests/tensormap.rs | 17 +- .../cubecl-core/src/runtime_tests/topology.rs | 2 +- crates/cubecl-core/src/runtime_tests/unary.rs | 86 +- .../cubecl-core/src/runtime_tests/unroll.rs | 8 +- crates/cubecl-cpp/Cargo.toml | 8 +- crates/cubecl-cpp/src/metal/dialect.rs | 98 +- crates/cubecl-cpp/src/shared/base.rs | 203 +- crates/cubecl-cpp/src/shared/binary.rs | 28 + crates/cubecl-cpp/src/shared/instruction.rs | 62 +- crates/cubecl-cpp/src/shared/unary.rs | 14 +- crates/cubecl-cpp/src/shared/warp.rs | 81 + crates/cubecl-cpu/Cargo.toml | 26 +- crates/cubecl-cpu/src/compiler/mlir_data.rs | 10 +- .../compiler/visitor/operation/arithmetic.rs | 30 +- .../src/compiler/visitor/operation/mod.rs | 2 +- crates/cubecl-cpu/src/compute/scheduler.rs | 9 +- crates/cubecl-cpu/src/compute/server.rs | 26 +- crates/cubecl-cpu/src/lib.rs | 1 + crates/cubecl-cpu/src/runtime.rs | 129 +- crates/cubecl-cuda/Cargo.toml | 33 +- crates/cubecl-cuda/build.rs | 13 + crates/cubecl-cuda/src/compute/command.rs | 18 +- crates/cubecl-cuda/src/compute/context.rs | 4 - crates/cubecl-cuda/src/compute/server.rs | 217 ++- crates/cubecl-cuda/src/compute/stream.rs | 19 +- crates/cubecl-cuda/src/lib.rs | 1 + crates/cubecl-cuda/src/runtime.rs | 397 ++-- crates/cubecl-hip/Cargo.toml | 26 +- crates/cubecl-hip/src/compute/command.rs | 110 +- crates/cubecl-hip/src/compute/server.rs | 102 +- crates/cubecl-hip/src/compute/stream.rs | 19 +- crates/cubecl-hip/src/runtime.rs | 283 +-- crates/cubecl-ir/Cargo.toml | 7 +- crates/cubecl-ir/src/arithmetic.rs | 6 +- crates/cubecl-ir/src/lib.rs | 2 + crates/cubecl-ir/src/marker.rs | 75 + crates/cubecl-ir/src/operation.rs | 15 +- crates/cubecl-ir/src/plane.rs | 16 + crates/cubecl-ir/src/processing.rs | 7 +- crates/cubecl-ir/src/scope.rs | 14 +- crates/cubecl-ir/src/type_hash.rs | 2 + crates/cubecl-ir/src/variable.rs | 20 +- crates/cubecl-macros/Cargo.toml | 2 +- crates/cubecl-macros/src/generate/kernel.rs | 79 +- crates/cubecl-macros/src/generate/launch.rs | 23 +- crates/cubecl-macros/src/parse/branch.rs | 26 +- crates/cubecl-macros/src/parse/cube_impl.rs | 16 +- crates/cubecl-macros/src/parse/expression.rs | 2 +- crates/cubecl-macros/src/parse/helpers.rs | 29 +- crates/cubecl-macros/src/parse/kernel.rs | 67 +- crates/cubecl-matmul/Cargo.toml | 12 +- crates/cubecl-matmul/src/base.rs | 256 ++- .../src/components/batch/base.rs | 14 +- .../src/components/batch/entry_point.rs | 40 +- .../src/components/batch/layout.rs | 46 + .../cubecl-matmul/src/components/batch/mod.rs | 2 + .../batch/partitioned_matmul/matmul.rs | 23 +- .../partitioned_matmul/partition/matmul.rs | 105 +- .../batch/partitioned_matmul/setup.rs | 4 +- .../src/components/global/args.rs | 1708 +++-------------- .../src/components/global/base.rs | 28 +- .../src/components/global/memory/config.rs | 2 +- .../src/components/global/memory/iterator.rs | 5 + .../src/components/global/memory/layout.rs | 453 ++++- .../multi_stage/double_buffering/config.rs | 2 +- .../multi_stage/double_buffering/matmul.rs | 40 +- .../multi_stage/double_buffering/setup.rs | 2 +- .../global/multi_stage/ordered/config.rs | 2 +- .../global/multi_stage/ordered/matmul.rs | 42 +- .../global/multi_stage/ordered/setup.rs | 2 +- .../global/read/reader/sync_full_reader.rs | 15 +- .../global/read/reader/sync_partial_reader.rs | 24 +- .../global/read/reader/tma_reader.rs | 4 +- .../read/strategy/async_full_cooperative.rs | 6 +- .../global/read/strategy/async_full_cyclic.rs | 4 +- .../async_full_maximize_slice_length.rs | 6 +- .../async_full_maximize_unit_count.rs | 4 +- .../async_partial_maximize_slice_length.rs | 6 +- .../global/read/strategy/sync_full_cyclic.rs | 6 +- .../global/read/strategy/sync_full_ordered.rs | 6 +- .../global/read/strategy/sync_full_strided.rs | 6 +- .../read/strategy/sync_full_tilewise.rs | 6 +- .../read/strategy/sync_partial_cyclic.rs | 16 +- .../read/strategy/sync_partial_tilewise.rs | 6 +- .../global/single_stage/barrier/config.rs | 4 +- .../global/single_stage/barrier/matmul.rs | 49 +- .../global/single_stage/barrier/setup.rs | 2 +- .../global/single_stage/simple/config.rs | 2 +- .../global/single_stage/simple/matmul.rs | 49 +- .../global/single_stage/simple/setup.rs | 2 +- .../global/single_stage/tma/config.rs | 4 +- .../global/single_stage/tma/matmul.rs | 42 +- .../global/single_stage/tma/setup.rs | 2 +- .../src/components/global/write/plane.rs | 2 +- .../src/components/global/write/unit.rs | 33 +- .../cubecl-matmul/src/components/line_size.rs | 19 +- .../cubecl-matmul/src/components/problem.rs | 5 +- .../src/components/stage/base.rs | 2 +- .../stage/matmul/partition/fragments.rs | 12 +- .../stage/matmul/partition/matmul.rs | 53 +- .../stage/matmul/partitioned_matmul.rs | 12 +- .../stage/matmul/plane_partitioned/setup.rs | 2 +- .../stage/matmul/unit_partitioned/setup.rs | 2 +- .../src/components/stage/memory/layout.rs | 42 +- .../cubecl-matmul/src/components/tile/base.rs | 20 +- .../tile/{accelerated => cmma}/config.rs | 10 +- .../tile/{accelerated => cmma}/matmul.rs | 12 +- .../tile/{accelerated => cmma}/mod.rs | 0 .../tile/{accelerated => cmma}/reader.rs | 0 .../tile/{accelerated => cmma}/setup.rs | 39 +- .../tile/{accelerated => cmma}/writer.rs | 0 .../src/components/tile/mma/config.rs | 4 +- .../src/components/tile/mma/setup.rs | 42 +- .../cubecl-matmul/src/components/tile/mod.rs | 2 +- .../plane_vec_mat_inner_product/config.rs | 4 +- .../tile/plane_vec_mat_inner_product/setup.rs | 2 +- .../src/components/tile/register/config.rs | 36 +- .../src/components/tile/register/matmul.rs | 4 +- .../src/components/tile/register/setup.rs | 5 +- .../src/components/tile/register/writer.rs | 2 +- .../src/kernels/layered/algorithm/base.rs | 6 +- .../layered/algorithm/double_buffering.rs | 6 +- .../kernels/layered/algorithm/double_unit.rs | 7 +- .../algorithm/ordered_double_buffering.rs | 2 +- .../src/kernels/layered/algorithm/simple.rs | 23 +- .../layered/algorithm/simple_barrier.rs | 2 +- .../kernels/layered/algorithm/simple_tma.rs | 2 +- .../kernels/layered/algorithm/simple_unit.rs | 7 +- .../src/kernels/layered/algorithm/vecmat.rs | 6 +- .../cubecl-matmul/src/kernels/layered/base.rs | 300 ++- .../cubecl-matmul/src/kernels/layered/mod.rs | 2 +- .../src/kernels/layered/selector/plane.rs | 64 +- .../kernels/layered/selector/select_kernel.rs | 36 +- .../src/kernels/layered/selector/unit.rs | 157 +- crates/cubecl-matmul/src/kernels/naive.rs | 253 ++- .../macros/common/problem/problem_size.rs | 6 + .../layered/macros/plane_accelerated/mod.rs | 2 +- .../src/tests/layered/macros/tma/mod.rs | 2 +- .../src/tests/layered/matmul_test_launcher.rs | 79 +- .../src/tests/layered/tma_test_launcher.rs | 72 +- .../cubecl-matmul/src/tests/naive/macros.rs | 7 + crates/cubecl-matmul/src/tests/naive/tests.rs | 20 +- crates/cubecl-matmul/src/tests/naive/utils.rs | 10 +- crates/cubecl-matmul/src/tests/test_utils.rs | 14 +- crates/cubecl-matmul/src/tune_key.rs | 72 +- crates/cubecl-opt/Cargo.toml | 6 +- crates/cubecl-opt/src/analyses/liveness.rs | 6 +- crates/cubecl-opt/src/analyses/uniformity.rs | 9 + crates/cubecl-opt/src/control_flow.rs | 5 +- crates/cubecl-opt/src/gvn/analysis.rs | 21 +- crates/cubecl-opt/src/gvn/numbering.rs | 15 +- crates/cubecl-opt/src/instructions.rs | 11 +- crates/cubecl-opt/src/passes/constant_prop.rs | 17 +- crates/cubecl-quant/Cargo.toml | 8 +- crates/cubecl-quant/src/dequantize.rs | 10 +- crates/cubecl-quant/src/layout/scales.rs | 16 +- crates/cubecl-quant/src/lib.rs | 2 +- crates/cubecl-quant/src/quantize.rs | 8 +- crates/cubecl-random/Cargo.toml | 8 +- crates/cubecl-random/src/base.rs | 2 +- crates/cubecl-random/src/bernoulli.rs | 2 +- crates/cubecl-random/src/normal.rs | 2 +- crates/cubecl-random/src/uniform.rs | 2 +- crates/cubecl-reduce/Cargo.toml | 6 +- crates/cubecl-reduce/src/config.rs | 30 +- crates/cubecl-reduce/src/launch.rs | 2 +- crates/cubecl-reduce/src/lib.rs | 5 +- crates/cubecl-reduce/src/primitives.rs | 65 + crates/cubecl-reduce/src/shared_sum.rs | 5 +- crates/cubecl-reduce/src/strategy.rs | 8 +- crates/cubecl-reduce/src/test_shuffle.rs | 243 +++ crates/cubecl-runtime/Cargo.toml | 6 +- crates/cubecl-runtime/benches/dynamic.rs | 16 +- crates/cubecl-runtime/src/base.rs | 94 - crates/cubecl-runtime/src/channel/base.rs | 107 -- crates/cubecl-runtime/src/channel/cell.rs | 172 -- crates/cubecl-runtime/src/channel/mod.rs | 17 - crates/cubecl-runtime/src/channel/mpsc.rs | 408 ---- crates/cubecl-runtime/src/channel/mutex.rs | 157 -- crates/cubecl-runtime/src/client.rs | 348 ++-- crates/cubecl-runtime/src/config/base.rs | 5 + crates/cubecl-runtime/src/config/logger.rs | 35 +- crates/cubecl-runtime/src/config/memory.rs | 48 + crates/cubecl-runtime/src/config/mod.rs | 2 + crates/cubecl-runtime/src/lib.rs | 4 - crates/cubecl-runtime/src/logging/server.rs | 32 +- .../src/memory_management/base.rs | 48 +- .../src/memory_management/memory_manage.rs | 249 ++- .../src/memory_management/memory_pool/base.rs | 82 +- .../memory_pool/exclusive_pool.rs | 31 +- .../memory_management/memory_pool/index.rs | 65 - .../memory_pool/memory_page.rs | 684 +++++++ .../src/memory_management/memory_pool/mod.rs | 10 +- .../memory_pool/persistent_pool.rs | 204 ++ .../src/memory_management/memory_pool/ring.rs | 305 --- .../memory_pool/sliced_pool.rs | 336 +--- .../memory_pool/static_pool.rs | 85 - crates/cubecl-runtime/src/server.rs | 56 +- crates/cubecl-runtime/src/storage/base.rs | 8 +- crates/cubecl-runtime/src/stream/scheduler.rs | 26 +- crates/cubecl-runtime/src/tune/base.rs | 45 +- crates/cubecl-runtime/src/tune/local.rs | 9 +- .../cubecl-runtime/src/tune/tune_benchmark.rs | 13 +- crates/cubecl-runtime/src/tune/tuner.rs | 10 +- crates/cubecl-runtime/tests/dummy/compute.rs | 71 +- crates/cubecl-runtime/tests/dummy/server.rs | 42 +- .../tests/dummy/tune/autotune_operations.rs | 4 +- crates/cubecl-spirv/Cargo.toml | 10 +- crates/cubecl-spirv/src/arithmetic.rs | 133 +- crates/cubecl-spirv/src/atomic.rs | 10 +- crates/cubecl-spirv/src/bitwise.rs | 12 +- crates/cubecl-spirv/src/compiler.rs | 54 +- crates/cubecl-spirv/src/extensions.rs | 11 + crates/cubecl-spirv/src/instruction.rs | 59 +- crates/cubecl-spirv/src/subgroup.rs | 28 + crates/cubecl-spirv/src/target.rs | 6 +- crates/cubecl-std/Cargo.toml | 5 +- crates/cubecl-std/src/fast_math.rs | 2 +- crates/cubecl-std/src/lib.rs | 2 + crates/cubecl-std/src/quant/base.rs | 9 + crates/cubecl-std/src/quant/dequantize.rs | 100 + crates/cubecl-std/src/quant/mod.rs | 6 + crates/cubecl-std/src/quant/view.rs | 321 ++++ crates/cubecl-std/src/tensor/contiguous.rs | 219 ++- crates/cubecl-std/src/tensor/handle.rs | 4 +- crates/cubecl-std/src/tensor/identity.rs | 4 +- crates/cubecl-std/src/tensor/layout/as_dyn.rs | 98 + crates/cubecl-std/src/tensor/layout/linear.rs | 18 +- crates/cubecl-std/src/tensor/layout/mod.rs | 1 + .../cubecl-std/src/tensor/layout/permuted.rs | 28 +- .../cubecl-std/src/tensor/layout/strided.rs | 4 +- .../cubecl-std/src/tensor/layout/virtual.rs | 2 +- crates/cubecl-std/src/tensor/view/base.rs | 130 +- crates/cubecl-std/src/tensor/view/launch.rs | 258 ++- .../tensor/view/operations/virtual_tensor.rs | 8 +- crates/cubecl-std/src/tests/mod.rs | 1 + .../cubecl-std/src/tests/reinterpret_slice.rs | 14 +- crates/cubecl-std/src/tests/trigonometry.rs | 6 +- crates/cubecl-std/src/tests/view/mod.rs | 1 + crates/cubecl-std/src/tests/view/quantized.rs | 214 +++ crates/cubecl-wgpu/Cargo.toml | 26 +- .../cubecl-wgpu/src/compiler/wgsl/compiler.rs | 39 +- .../src/compiler/wgsl/extension.rs | 12 +- .../src/compiler/wgsl/instructions.rs | 14 +- .../cubecl-wgpu/src/compiler/wgsl/subgroup.rs | 36 + crates/cubecl-wgpu/src/compute/mem_manager.rs | 15 +- crates/cubecl-wgpu/src/compute/schedule.rs | 9 + crates/cubecl-wgpu/src/compute/server.rs | 13 +- crates/cubecl-wgpu/src/compute/stream.rs | 7 +- crates/cubecl-wgpu/src/lib.rs | 3 + crates/cubecl-wgpu/src/runtime.rs | 70 +- crates/cubecl/Cargo.toml | 22 +- crates/cubecl/benches/conv2d.rs | 2 +- crates/cubecl/benches/matmul.rs | 25 +- crates/cubecl/benches/memcpy_async.rs | 4 +- crates/cubecl/benches/unary.rs | 2 +- cubecl-book/src/SUMMARY.md | 4 +- .../src/advanced-usage/math_optimizations.md | 112 ++ .../src/algorithms/quantized_matmul.md | 189 -- cubecl-book/src/algorithms/summary.md | 10 - cubecl-book/src/core-features/features.md | 5 + .../src/getting-started/installation.md | 27 +- .../src/getting-started/src/bin/v2-gpu.rs | 2 +- .../src/getting-started/src/bin/v3-gpu.rs | 2 +- .../src/getting-started/src/bin/v4-gpu.rs | 2 +- .../src/getting-started/src/bin/v5-gpu.rs | 2 +- .../src/getting-started/src/bin/v6-gpu.rs | 2 +- .../src/getting-started/src/bin/v7-gpu.rs | 2 +- .../src/getting-started/src/gpu_tensor.rs | 6 +- examples/device_sharing/Cargo.toml | 2 +- examples/fusing/Cargo.toml | 2 +- examples/gelu/Cargo.toml | 2 +- examples/normalization/Cargo.toml | 2 +- examples/sum_things/Cargo.toml | 2 +- examples/sum_things/src/lib.rs | 8 +- xtask/Cargo.toml | 2 +- xtask/src/commands/profile.rs | 2 +- 490 files changed, 16028 insertions(+), 10432 deletions(-) delete mode 100644 .github/workflows/publish-template.yml rename crates/cubecl-attention/src/components/batch/{dummy => simple}/attention.rs (74%) rename crates/cubecl-attention/src/components/batch/{dummy => simple}/config.rs (77%) create mode 100644 crates/cubecl-attention/src/components/batch/simple/mod.rs rename crates/cubecl-attention/src/components/batch/{dummy => simple}/setup.rs (77%) rename crates/cubecl-attention/src/components/{tile/dummy/attention_matmul/accelerated/matmul.rs => fragment/accelerated/attention.rs} (58%) create mode 100644 crates/cubecl-attention/src/components/fragment/accelerated/config.rs rename crates/cubecl-attention/src/components/{batch/dummy => fragment/accelerated}/mod.rs (59%) rename crates/cubecl-attention/src/components/{tile/dummy/attention_matmul => fragment}/accelerated/setup.rs (63%) rename crates/cubecl-attention/src/components/{tile/dummy/attention_matmul => fragment}/base.rs (62%) create mode 100644 crates/cubecl-attention/src/components/fragment/dummy_register/attention.rs create mode 100644 crates/cubecl-attention/src/components/fragment/dummy_register/config.rs rename crates/cubecl-attention/src/components/{tile/dummy/attention_matmul => fragment}/dummy_register/mod.rs (53%) rename crates/cubecl-attention/src/components/{tile/dummy/attention_matmul => fragment}/dummy_register/setup.rs (61%) create mode 100644 crates/cubecl-attention/src/components/fragment/fragments.rs rename crates/cubecl-attention/src/components/{tile/dummy/attention_matmul => fragment}/mod.rs (55%) create mode 100644 crates/cubecl-attention/src/components/fragment/unit_register/attention.rs create mode 100644 crates/cubecl-attention/src/components/fragment/unit_register/config.rs rename crates/cubecl-attention/src/components/{tile/dummy/attention_matmul/accelerated => fragment/unit_register}/mod.rs (53%) create mode 100644 crates/cubecl-attention/src/components/fragment/unit_register/setup.rs delete mode 100644 crates/cubecl-attention/src/components/global/dummy/mod.rs delete mode 100644 crates/cubecl-attention/src/components/global/dummy/read.rs delete mode 100644 crates/cubecl-attention/src/components/global/dummy/writer.rs rename crates/cubecl-attention/src/components/global/{dummy => simple}/attention.rs (51%) rename crates/cubecl-attention/src/components/global/{dummy => simple}/config.rs (83%) create mode 100644 crates/cubecl-attention/src/components/global/simple/mod.rs create mode 100644 crates/cubecl-attention/src/components/global/simple/reader/base.rs create mode 100644 crates/cubecl-attention/src/components/global/simple/reader/key.rs create mode 100644 crates/cubecl-attention/src/components/global/simple/reader/mask.rs create mode 100644 crates/cubecl-attention/src/components/global/simple/reader/mod.rs create mode 100644 crates/cubecl-attention/src/components/global/simple/reader/query.rs create mode 100644 crates/cubecl-attention/src/components/global/simple/reader/value.rs rename crates/cubecl-attention/src/components/global/{dummy => simple}/setup.rs (67%) create mode 100644 crates/cubecl-attention/src/components/global/simple/writer/mod.rs create mode 100644 crates/cubecl-attention/src/components/global/simple/writer/plane.rs create mode 100644 crates/cubecl-attention/src/components/global/simple/writer/unit.rs delete mode 100644 crates/cubecl-attention/src/components/mask.rs delete mode 100644 crates/cubecl-attention/src/components/stage/dummy/attention.rs create mode 100644 crates/cubecl-attention/src/components/stage/kv_reuse_attention.rs create mode 100644 crates/cubecl-attention/src/components/stage/partitioner.rs create mode 100644 crates/cubecl-attention/src/components/stage/plane/attention.rs create mode 100644 crates/cubecl-attention/src/components/stage/plane/config.rs rename crates/cubecl-attention/src/components/stage/{dummy => plane}/mod.rs (66%) rename crates/cubecl-attention/src/components/stage/{dummy => plane}/setup.rs (74%) rename crates/cubecl-attention/src/components/stage/{dummy => }/tile_partitions.rs (56%) create mode 100644 crates/cubecl-attention/src/components/stage/unit/attention.rs rename crates/cubecl-attention/src/components/stage/{dummy => unit}/config.rs (55%) create mode 100644 crates/cubecl-attention/src/components/stage/unit/mod.rs create mode 100644 crates/cubecl-attention/src/components/stage/unit/setup.rs delete mode 100644 crates/cubecl-attention/src/components/tile/dummy/attention.rs delete mode 100644 crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/config.rs delete mode 100644 crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/config.rs delete mode 100644 crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/matmul.rs delete mode 100644 crates/cubecl-attention/src/components/tile/dummy/fragment/accumulator.rs delete mode 100644 crates/cubecl-attention/src/components/tile/dummy/fragment/key_value.rs delete mode 100644 crates/cubecl-attention/src/components/tile/dummy/fragment/mod.rs delete mode 100644 crates/cubecl-attention/src/components/tile/dummy/fragment/query.rs delete mode 100644 crates/cubecl-attention/src/components/tile/dummy/fragment/softmax.rs delete mode 100644 crates/cubecl-attention/src/components/tile/dummy/mod.rs delete mode 100644 crates/cubecl-attention/src/components/tile/dummy/setup.rs create mode 100644 crates/cubecl-attention/src/components/tile/row/mod.rs create mode 100644 crates/cubecl-attention/src/components/tile/row/reduce/base.rs create mode 100644 crates/cubecl-attention/src/components/tile/row/reduce/broadcast_reducer.rs create mode 100644 crates/cubecl-attention/src/components/tile/row/reduce/mod.rs create mode 100644 crates/cubecl-attention/src/components/tile/row/reduce/naive_reducer.rs create mode 100644 crates/cubecl-attention/src/components/tile/row/reduce/reduce_op.rs create mode 100644 crates/cubecl-attention/src/components/tile/row/reduce/unit_reducer.rs create mode 100644 crates/cubecl-attention/src/components/tile/row/rowwise.rs create mode 100644 crates/cubecl-attention/src/components/tile/row/state.rs delete mode 100644 crates/cubecl-attention/src/components/tile/rowwise.rs create mode 100644 crates/cubecl-attention/src/components/tile/tiles/key_value.rs create mode 100644 crates/cubecl-attention/src/components/tile/tiles/mask.rs create mode 100644 crates/cubecl-attention/src/components/tile/tiles/query.rs create mode 100644 crates/cubecl-attention/src/kernels/unit.rs create mode 100644 crates/cubecl-attention/src/tests/macros/suite.rs create mode 100644 crates/cubecl-common/src/quant/mod.rs rename crates/{cubecl-quant/src => cubecl-common/src/quant}/scheme.rs (89%) create mode 100644 crates/cubecl-convolution/src/components/global/layout/bias.rs delete mode 100644 crates/cubecl-core/src/frontend/comment.rs create mode 100644 crates/cubecl-core/src/runtime_tests/numeric.rs create mode 100644 crates/cubecl-cuda/build.rs create mode 100644 crates/cubecl-ir/src/marker.rs create mode 100644 crates/cubecl-matmul/src/components/batch/layout.rs rename crates/cubecl-matmul/src/components/tile/{accelerated => cmma}/config.rs (94%) rename crates/cubecl-matmul/src/components/tile/{accelerated => cmma}/matmul.rs (89%) rename crates/cubecl-matmul/src/components/tile/{accelerated => cmma}/mod.rs (100%) rename crates/cubecl-matmul/src/components/tile/{accelerated => cmma}/reader.rs (100%) rename crates/cubecl-matmul/src/components/tile/{accelerated => cmma}/setup.rs (53%) rename crates/cubecl-matmul/src/components/tile/{accelerated => cmma}/writer.rs (100%) create mode 100644 crates/cubecl-reduce/src/test_shuffle.rs delete mode 100644 crates/cubecl-runtime/src/base.rs delete mode 100644 crates/cubecl-runtime/src/channel/base.rs delete mode 100644 crates/cubecl-runtime/src/channel/cell.rs delete mode 100644 crates/cubecl-runtime/src/channel/mod.rs delete mode 100644 crates/cubecl-runtime/src/channel/mpsc.rs delete mode 100644 crates/cubecl-runtime/src/channel/mutex.rs create mode 100644 crates/cubecl-runtime/src/config/memory.rs delete mode 100644 crates/cubecl-runtime/src/memory_management/memory_pool/index.rs create mode 100644 crates/cubecl-runtime/src/memory_management/memory_pool/memory_page.rs create mode 100644 crates/cubecl-runtime/src/memory_management/memory_pool/persistent_pool.rs delete mode 100644 crates/cubecl-runtime/src/memory_management/memory_pool/ring.rs delete mode 100644 crates/cubecl-runtime/src/memory_management/memory_pool/static_pool.rs create mode 100644 crates/cubecl-std/src/quant/base.rs create mode 100644 crates/cubecl-std/src/quant/dequantize.rs create mode 100644 crates/cubecl-std/src/quant/mod.rs create mode 100644 crates/cubecl-std/src/quant/view.rs create mode 100644 crates/cubecl-std/src/tensor/layout/as_dyn.rs create mode 100644 crates/cubecl-std/src/tests/view/mod.rs create mode 100644 crates/cubecl-std/src/tests/view/quantized.rs create mode 100644 cubecl-book/src/advanced-usage/math_optimizations.md delete mode 100644 cubecl-book/src/algorithms/quantized_matmul.md delete mode 100644 cubecl-book/src/algorithms/summary.md diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 86f92485f..12594f524 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -61,7 +61,7 @@ jobs: toolchain: stable steps: - name: Setup Rust - uses: tracel-ai/github-actions/setup-rust@v4 + uses: tracel-ai/github-actions/setup-rust@v5 with: rust-toolchain: ${{ matrix.toolchain }} cache-key: ${{ matrix.rust }}-linux @@ -81,7 +81,7 @@ jobs: run: cargo xtask check lint # -------------------------------------------------------------------------------- - name: Typos - uses: tracel-ai/github-actions/check-typos@v4 + uses: tracel-ai/github-actions/check-typos@v5 documentation: runs-on: ubuntu-22.04 @@ -94,7 +94,7 @@ jobs: toolchain: stable steps: - name: Setup Rust - uses: tracel-ai/github-actions/setup-rust@v4 + uses: tracel-ai/github-actions/setup-rust@v5 with: rust-toolchain: ${{ matrix.toolchain }} cache-key: ${{ matrix.rust }}-linux @@ -114,7 +114,7 @@ jobs: '@keep-alive:false', '@machine-type:n2-standard-16', '@os:linux', - '@zone:northamerica-northeast1-b' + '@zones:northamerica-northeast1-b' ] needs: [prepare-checks, code-quality] strategy: @@ -127,13 +127,13 @@ jobs: toolchain: ${{ needs.prepare-checks.outputs.rust-prev-version }} steps: - name: Setup Rust - uses: tracel-ai/github-actions/setup-rust@v4 + uses: tracel-ai/github-actions/setup-rust@v5 with: rust-toolchain: ${{ matrix.toolchain }} cache-key: ${{ matrix.rust }}-linux # -------------------------------------------------------------------------------- - name: Setup Linux runner - uses: tracel-ai/github-actions/setup-linux@v4 + uses: tracel-ai/github-actions/setup-linux@v5 with: apt-packages: "build-essential" vulkan-sdk-version: ${{ env.VULKAN_SDK_VERSION }} @@ -157,14 +157,14 @@ jobs: # toolchain: stable # steps: # - name: Setup Rust - # uses: tracel-ai/github-actions/setup-rust@v4 + # uses: tracel-ai/github-actions/setup-rust@v5 # with: # rust-toolchain: ${{ matrix.toolchain }} # cache-key: ${{ matrix.rust }}-windows # # -------------------------------------------------------------------------------- # - name: Setup Windows runner # if: env.DISABLE_WGPU != '1' - # uses: tracel-ai/github-actions/setup-windows@v4 + # uses: tracel-ai/github-actions/setup-windows@v5 # with: # dxc-release: ${{ env.DXC_RELEASE }} # dxc-filename: ${{ env.DXC_FILENAME }} @@ -186,7 +186,7 @@ jobs: # toolchain: stable # steps: # - name: Setup Rust - # uses: tracel-ai/github-actions/setup-rust@v4 + # uses: tracel-ai/github-actions/setup-rust@v5 # with: # rust-toolchain: ${{ matrix.toolchain }} # cache-key: ${{ matrix.rust }}-macos diff --git a/.github/workflows/publish-template.yml b/.github/workflows/publish-template.yml deleted file mode 100644 index 9e10ea98a..000000000 --- a/.github/workflows/publish-template.yml +++ /dev/null @@ -1,24 +0,0 @@ -on: - workflow_call: - inputs: - crate: - required: true - type: string - secrets: - CRATES_IO_API_TOKEN: - required: true - -jobs: - publish-crate: - runs-on: ubuntu-latest - steps: - - name: checkout - uses: actions/checkout@v3 - - - name: install rust - uses: dtolnay/rust-toolchain@stable - - - name: publish to crates.io - run: cargo xtask publish ${{ inputs.crate }} - env: - CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 1edff3da9..30ac8974e 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -4,90 +4,125 @@ on: push: tags: - "v*" + workflow_dispatch: + inputs: + dry-run-only: + description: "Run xtask publish in dry-run mode (no publish)" + type: boolean + required: false + default: false jobs: + check-version: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: tracel-ai/github-actions/check-version@v5 + with: + tag: ${{ github.ref_name }} + cargo_toml_path: Cargo.toml + publish-cubecl-common: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main + needs: + - check-version + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 with: crate: cubecl-common - secrets: inherit + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-macros-internal: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main + needs: + - check-version + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 with: crate: cubecl-macros-internal - secrets: inherit + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-ir: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main - with: - crate: cubecl-ir needs: - publish-cubecl-common - publish-cubecl-macros-internal - secrets: inherit + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-ir + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-std: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main - with: - crate: cubecl-std needs: - publish-cubecl-core - publish-cubecl-runtime - secrets: inherit + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-std + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-runtime: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main - with: - crate: cubecl-runtime needs: - publish-cubecl-ir - publish-cubecl-common - secrets: inherit + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-runtime + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-macros: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main - with: - crate: cubecl-macros needs: - publish-cubecl-common - secrets: inherit + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-macros + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-core: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main - with: - crate: cubecl-core needs: - publish-cubecl-ir - publish-cubecl-runtime - publish-cubecl-macros - secrets: inherit + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-core + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-random: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main - with: - crate: cubecl-random needs: - publish-cubecl-runtime - publish-cubecl-std - publish-cubecl-core - publish-cubecl-common - secrets: inherit + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-random + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-reduce: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main - with: - crate: cubecl-reduce needs: - publish-cubecl-std - publish-cubecl-runtime - publish-cubecl-core - secrets: inherit + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-reduce + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-matmul: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main - with: - crate: cubecl-matmul needs: - publish-cubecl-runtime - publish-cubecl-std @@ -95,48 +130,71 @@ jobs: - publish-cubecl-core - publish-cubecl-common - publish-cubecl-random - secrets: inherit + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-matmul + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-convolution: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main + needs: + - publish-cubecl-runtime + - publish-cubecl-std + - publish-cubecl-reduce + - publish-cubecl-core + - publish-cubecl-common + - publish-cubecl-random + - publish-cubecl-matmul + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 with: crate: cubecl-convolution + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} + + publish-cubecl-attention: needs: - publish-cubecl-runtime - publish-cubecl-std - - publish-cubecl-reduce - publish-cubecl-core - publish-cubecl-common - publish-cubecl-random - publish-cubecl-matmul - secrets: inherit + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-attention + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-opt: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main - with: - crate: cubecl-opt needs: - publish-cubecl-ir - publish-cubecl-common - publish-cubecl-core - secrets: inherit + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-opt + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-spirv: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main - with: - crate: cubecl-spirv needs: - publish-cubecl-opt - publish-cubecl-common - publish-cubecl-core - publish-cubecl-random - publish-cubecl-runtime - secrets: inherit + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-spirv + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-wgpu: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main - with: - crate: cubecl-wgpu needs: - publish-cubecl-std - publish-cubecl-spirv @@ -146,22 +204,28 @@ jobs: - publish-cubecl-matmul - publish-cubecl-convolution - publish-cubecl-reduce - secrets: inherit + - publish-cubecl-attention + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-wgpu + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-cpp: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main - with: - crate: cubecl-cpp needs: - publish-cubecl-common - publish-cubecl-runtime - publish-cubecl-core - secrets: inherit + - publish-cubecl-opt + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-cpp + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-cuda: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main - with: - crate: cubecl-cuda needs: - publish-cubecl-std - publish-cubecl-cpp @@ -171,12 +235,15 @@ jobs: - publish-cubecl-matmul - publish-cubecl-convolution - publish-cubecl-reduce - secrets: inherit + - publish-cubecl-attention + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-cuda + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} publish-cubecl-hip: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main - with: - crate: cubecl-hip needs: - publish-cubecl-std - publish-cubecl-cpp @@ -186,12 +253,46 @@ jobs: - publish-cubecl-matmul - publish-cubecl-convolution - publish-cubecl-reduce - secrets: inherit + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-hip + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} - publish-cubecl: - uses: tracel-ai/cubecl/.github/workflows/publish-template.yml@main + publish-cubecl-cpu: + needs: + - publish-cubecl-common + - publish-cubecl-std + - publish-cubecl-core + - publish-cubecl-runtime + - publish-cubecl-reduce + - publish-cubecl-opt + - publish-cubecl-matmul + - publish-cubecl-convolution + - publish-cubecl-reduce + - publish-cubecl-random + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 with: - crate: cubecl + crate: cubecl-cpu + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} + + publish-cubecl-quant: + needs: + - publish-cubecl-core + - publish-cubecl-common + - publish-cubecl-runtime + - publish-cubecl-std + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl-quant + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} + + publish-cubecl: needs: - publish-cubecl-core - publish-cubecl-cuda @@ -199,4 +300,9 @@ jobs: - publish-cubecl-matmul - publish-cubecl-convolution - publish-cubecl-reduce - secrets: inherit + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v5 + with: + crate: cubecl + dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }} + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} diff --git a/Cargo.toml b/Cargo.toml index 9ee2ef8bc..10fe9c227 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ edition = "2024" license = "MIT OR Apache-2.0" readme = "README.md" rust-version = "1.88" -version = "0.7.0" +version = "0.9.0" [workspace.dependencies] bitflags = { version = "2.9.1", features = ["serde"] } @@ -73,7 +73,7 @@ num-traits = { version = "0.2.19", default-features = false, features = [ cfg-if = "1.0.0" darling = "0.21.0" -enumset = "1.1.10" +enumset = { version = "1.1.10", default-features = false } ident_case = "1" paste = "1" proc-macro2 = "1" @@ -85,7 +85,7 @@ tracy-client = { version = "0.18.0" } ### For xtask crate ### strum = { version = "0.27.1", features = ["derive"] } -tracel-xtask = { version = "=2.1.8" } +tracel-xtask = { version = "=2.1.11" } portable-atomic = { version = "1.11", default-features = false, features = [ "serde", @@ -103,20 +103,18 @@ tracel-llvm = { version = "20.1.4-5", features = ["mlir-helpers"] } # tracel-llvm = { git = "https://github.com/tracel-ai/tracel-llvm.git", branch = "fix/linux", package = "tracel-llvm", features = ["mlir-helpers"] } # tracel-llvm = { path = "../tracel-llvm/crates/tracel-llvm", features = ["mlir-helpers"] } -# CubeCL-CUDA -cudarc = { version = "0.17.2", features = [ +cudarc = { version = "0.17.7", features = [ "std", "driver", - "cuda-version-from-build-system", - "dynamic-loading", -], default-features = false } + "nvrtc", +], default-features = false } # CubeCL-CUDA # CubeCL-SPIR-V -rspirv = { git = "https://github.com/tracel-ai/tracel-rspirv.git", rev = "9b4037a2e14fe4138c8d4a8cfcac40c6577b5549", package = "tracel-rspirv" } +rspirv = { package = "tracel-rspirv", version = "0.12.0" } # CubeCL-WGPU ash = "0.38" -tracel-ash = { git = "https://github.com/tracel-ai/tracel-ash.git", rev = "cedd894bbdb03f0635900e7c0d317020fdd38263" } +tracel-ash = "0.38.0" # Build deps cfg_aliases = "0.2.1" diff --git a/crates/cubecl-attention/Cargo.toml b/crates/cubecl-attention/Cargo.toml index c0fa26ed7..d7cfe28d7 100644 --- a/crates/cubecl-attention/Cargo.toml +++ b/crates/cubecl-attention/Cargo.toml @@ -19,13 +19,12 @@ attention_tests = [] [dependencies] bytemuck = { workspace = true } -cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false } -cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false } -cubecl-std = { path = "../cubecl-std", version = "0.7.0", default-features = false } -cubecl-matmul = { path = "../cubecl-matmul", version = "0.7.0", default-features = false } -cubecl-reduce = { path = "../cubecl-reduce", version = "0.7.0", default-features = false } -cubecl-random = { path = "../cubecl-random", version = "0.7.0", default-features = false } +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false } +cubecl-core = { path = "../cubecl-core", version = "0.9.0", default-features = false } +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false } +cubecl-std = { path = "../cubecl-std", version = "0.9.0", default-features = false } +cubecl-matmul = { path = "../cubecl-matmul", version = "0.9.0", default-features = false } +cubecl-random = { path = "../cubecl-random", version = "0.9.0", default-features = false } half = { workspace = true, features = ["bytemuck"] } pretty_assertions = { workspace = true, optional = true } serde = { workspace = true } diff --git a/crates/cubecl-attention/src/base.rs b/crates/cubecl-attention/src/base.rs index 90a5a3b2c..bf7f8e7b5 100644 --- a/crates/cubecl-attention/src/base.rs +++ b/crates/cubecl-attention/src/base.rs @@ -9,12 +9,11 @@ use crate::{ AttentionTilingScheme, AvailableLineSizes, args::TensorInputsLaunch, attention_types::*, batch::HypercubeSelection, }, - kernels::{Algorithm, dummy::DummyAlgorithm}, + kernels::{Algorithm, dummy::DummyRegisterAlgorithm}, }; use crate::components::batch::BatchAttentionConfig; use crate::components::batch::BatchAttentionFamily; -use cubecl_core::frontend::CubePrimitive; pub enum Strategy { /// Temporary implementation @@ -24,10 +23,11 @@ pub enum Strategy { #[allow(clippy::result_large_err)] pub fn launch( strategy: &Strategy, - client: &ComputeClient, + client: &ComputeClient, query: TensorHandle>, key: TensorHandle>, value: TensorHandle>, + mask: Option>>, out: TensorHandle>, ) -> Result<(), AttentionSetupError> { launch_ref::( @@ -36,6 +36,7 @@ pub fn launch( &query.as_ref(), &key.as_ref(), &value.as_ref(), + &mask.as_ref().map(|m| m.as_ref()), &out.as_ref(), ) } @@ -43,30 +44,32 @@ pub fn launch( #[allow(clippy::result_large_err)] pub fn launch_ref( strategy: &Strategy, - client: &ComputeClient, + client: &ComputeClient, query: &TensorHandleRef, key: &TensorHandleRef, value: &TensorHandleRef, + mask: &Option>, out: &TensorHandleRef, ) -> Result<(), AttentionSetupError> { match strategy { - Strategy::Tmp => launch_tmp::(client, query, key, value, out), + Strategy::Tmp => launch_tmp::(client, query, key, value, mask, out), } } pub fn launch_tmp( - client: &ComputeClient, + client: &ComputeClient, query: &TensorHandleRef, key: &TensorHandleRef, value: &TensorHandleRef, + mask: &Option>, out: &TensorHandleRef, ) -> Result<(), AttentionSetupError> { let line_sizes = AvailableLineSizes::from_elem_types::( - &QG::::as_type_native_unchecked(), - &MSK::::as_type_native_unchecked(), - &OG::::as_type_native_unchecked(), + query.elem_size, + size_of::>(), + out.elem_size, ); - let line_sizes = DummyAlgorithm::filter_line_sizes(line_sizes) + let line_sizes = DummyRegisterAlgorithm::filter_line_sizes(line_sizes) .filter_with_tensor(AttentionIdent::Query, query.strides, query.shape) .filter_with_tensor(AttentionIdent::Key, key.strides, key.shape) .filter_with_tensor(AttentionIdent::Value, value.strides, value.shape) @@ -81,7 +84,8 @@ pub fn launch_tmp( num_heads: query.shape[2], head_dim: query.shape[3], val_dim: value.shape[3], - masked: false, + masked: mask.is_some(), + causal: false, }; let tile_size = AttentionTileSize { @@ -105,16 +109,17 @@ pub fn launch_tmp( }, plane_dim: 32, reuse_key_value: false, + two_rows_in_array_tile: false, }; - let config = DummyAlgorithm::setup::(client, &problem, &selection, &line_sizes)?; + let config = DummyRegisterAlgorithm::setup::(client, &problem, &selection, &line_sizes)?; let cube_count_plan = config .hypercube_config() .cube_count_plan(&problem, &selection); unsafe { - ::BatchAttention::launch_unchecked::( + ::BatchAttention::launch_unchecked::( client, config.cube_dim(), cube_count_plan.resolve(), @@ -122,6 +127,9 @@ pub fn launch_tmp( query.as_tensor_arg(line_sizes.query), key.as_tensor_arg(line_sizes.key), value.as_tensor_arg(line_sizes.value), + mask.as_ref() + .map(|it| it.as_tensor_arg(line_sizes.out)) + .into(), ), out.as_tensor_arg(line_sizes.out), cube_count_plan.as_args(), diff --git a/crates/cubecl-attention/src/components/args.rs b/crates/cubecl-attention/src/components/args.rs index 822018608..982ef2792 100644 --- a/crates/cubecl-attention/src/components/args.rs +++ b/crates/cubecl-attention/src/components/args.rs @@ -1,7 +1,7 @@ use cubecl::prelude::*; use cubecl_core::{self as cubecl}; use cubecl_std::{ - CubeOption, CubeOptionExpand, + CubeOption, CubeOptionArgs, CubeOptionExpand, tensor::r#virtual::{VirtualTensorOperations, VirtualTensorOperationsExpand}, }; @@ -16,7 +16,7 @@ pub trait ConcreteInputsFactory: LaunchArg { query: &'a TensorHandleRef<'a, R>, key: &'a TensorHandleRef<'a, R>, value: &'a TensorHandleRef<'a, R>, - // mask: &'a TensorHandleRef<'a, R>, + mask: &'a Option>, selection: &AttentionSelection, problem: &AttentionProblem, line_sizes: &AttentionLineSizes, @@ -38,206 +38,278 @@ pub trait ConcreteOutputFactory: LaunchArg { /// Arguments for the attention algorithm. pub trait AttentionArgs: Send + Sync + 'static + Clone { /// Type used for the input. - type Input: LaunchArg + CubeType; + type Input: LaunchArg + CubeType; /// Type used for the output. type Output: LaunchArg + CubeType; /// Inner state that is used to create [tensor inputs](TensorInput) and /// [tensor outputs](TensorOutput) . - type State: CubeType; + type State: CubeType; /// Init the state. - fn init_state( - input: &Self::Input, + fn init_state( + input: &Self::Input, output: &mut Self::Output, - ) -> Self::State; + ) -> Self::State; + + /// Whether the mask argument is present. Returns `CubeOption` to allow matching at + /// comptime + fn has_mask( + state: &Self::State, + ) -> CubeOption<()>; /// Read the line of the query tensor using the state at the given coordinate. - fn read_query( - state: &Self::State, + fn read_query( + state: &Self::State, coordinate: u32, ) -> Line; /// Read the line of the key tensor using the state at the given coordinate. - fn read_key( - state: &Self::State, + fn read_key( + state: &Self::State, coordinate: u32, ) -> Line; /// Read the line of the value tensor using the state at the given coordinate. - fn read_value( - state: &Self::State, + fn read_value( + state: &Self::State, coordinate: u32, ) -> Line; + /// Read the line of the mask tensor using the state at the given coordinate. + fn read_mask( + state: &Self::State, + coordinate: u32, + ) -> Line; /// Read the line of the query tensor using the state at the given coordinate. - fn read_window_query( - state: &Self::State, + fn read_window_query( + state: &Self::State, start: u32, end: u32, ) -> Slice>; - /// Read the line of the key tensor using the state at the given coordinate. - fn read_window_key( - state: &Self::State, + fn read_window_key( + state: &Self::State, start: u32, end: u32, ) -> Slice>; - /// Read the line of the value tensor using the state at the given coordinate. - fn read_window_value( - state: &Self::State, + fn read_window_value( + state: &Self::State, start: u32, end: u32, ) -> Slice>; + /// Read the line of the mask tensor using the state at the given coordinate. + fn read_window_mask( + state: &Self::State, + start: u32, + end: u32, + ) -> Slice>; /// Reinterpret query as tensor map - fn as_tensor_map_query( - state: &Self::State, + fn as_tensor_map_query( + state: &Self::State, ) -> CubeOption>; - /// Reinterpret key as tensor map - fn as_tensor_map_key( - state: &Self::State, + fn as_tensor_map_key( + state: &Self::State, ) -> CubeOption>; - /// Reinterpret value as tensor map - fn as_tensor_map_value( - state: &Self::State, + fn as_tensor_map_value( + state: &Self::State, ) -> CubeOption>; + /// Reinterpret mask as tensor map + fn as_tensor_map_mask( + state: &Self::State, + ) -> CubeOption>; /// Write the line to the output at the given coordinate using the state. - fn write_out( - state: &mut Self::State, + fn write_out( + state: &mut Self::State, coordinate: u32, value: Line, ); /// Get the rank of the query tensor using the state. - fn rank_query(state: &Self::State) -> u32; + fn rank_query( + state: &Self::State, + ) -> u32; /// Get the rank of the key tensor using the state. - fn rank_key(state: &Self::State) -> u32; + fn rank_key( + state: &Self::State, + ) -> u32; /// Get the rank of the value tensor using the state. - fn rank_value(state: &Self::State) -> u32; + fn rank_value( + state: &Self::State, + ) -> u32; + /// Get the rank of the mask tensor using the state. + fn rank_mask( + state: &Self::State, + ) -> u32; /// Get the rank of the out tensor using the state. - fn rank_out(state: &Self::State) -> u32; + fn rank_out( + state: &Self::State, + ) -> u32; /// Get the length of the query tensor using the state. - fn len_query(state: &Self::State) -> u32; + fn len_query( + state: &Self::State, + ) -> u32; /// Get the length of the key tensor using the state. - fn len_key(state: &Self::State) -> u32; + fn len_key( + state: &Self::State, + ) -> u32; /// Get the length of the value tensor using the state. - fn len_value(state: &Self::State) -> u32; + fn len_value( + state: &Self::State, + ) -> u32; + /// Get the length of the mask tensor using the state. + fn len_mask( + state: &Self::State, + ) -> u32; /// Get the length of the out tensor using the state. - fn len_out(state: &Self::State) -> u32; + fn len_out( + state: &Self::State, + ) -> u32; /// Get the buffer length of the query tensor using the state. - fn buffer_len_query( - state: &Self::State, + fn buffer_len_query( + state: &Self::State, ) -> u32; /// Get the buffer length of the key tensor using the state. - fn buffer_len_key( - state: &Self::State, + fn buffer_len_key( + state: &Self::State, ) -> u32; /// Get the buffer length of the value tensor using the state. - fn buffer_len_value( - state: &Self::State, + fn buffer_len_value( + state: &Self::State, + ) -> u32; + /// Get the buffer length of the mask tensor using the state. + fn buffer_len_mask( + state: &Self::State, ) -> u32; /// Get the buffer length of the out tensor using the state. - fn buffer_len_out( - state: &Self::State, + fn buffer_len_out( + state: &Self::State, ) -> u32; /// Get the shape of the query tensor using the state. - fn shape_query( - state: &Self::State, + fn shape_query( + state: &Self::State, axis: u32, ) -> u32; /// Get the shape of the key tensor using the state. - fn shape_key( - state: &Self::State, + fn shape_key( + state: &Self::State, axis: u32, ) -> u32; /// Get the shape of the value tensor using the state. - fn shape_value( - state: &Self::State, + fn shape_value( + state: &Self::State, + axis: u32, + ) -> u32; + /// Get the shape of the mask tensor using the state. + fn shape_mask( + state: &Self::State, axis: u32, ) -> u32; /// Get the shape of the out tensor using the state. - fn shape_out( - state: &Self::State, + fn shape_out( + state: &Self::State, axis: u32, ) -> u32; /// Get the stride of the query tensor using the state. - fn stride_query( - state: &Self::State, + fn stride_query( + state: &Self::State, axis: u32, ) -> u32; /// Get the stride of the key tensor using the state. - fn stride_key( - state: &Self::State, + fn stride_key( + state: &Self::State, axis: u32, ) -> u32; /// Get the stride of the value tensor using the state. - fn stride_value( - state: &Self::State, + fn stride_value( + state: &Self::State, + axis: u32, + ) -> u32; + /// Get the stride of the mask tensor using the state. + fn stride_mask( + state: &Self::State, axis: u32, ) -> u32; /// Get the stride of the out tensor using the state. - fn stride_out( - state: &Self::State, + fn stride_out( + state: &Self::State, axis: u32, ) -> u32; - fn line_size_query( - state: &Self::State, + /// Get the line size of the query tensor using the state. + fn line_size_query( + state: &Self::State, + ) -> comptime_type!(u32); + /// Get the line size of the key tensor using the state. + fn line_size_key( + state: &Self::State, ) -> comptime_type!(u32); - fn line_size_key( - state: &Self::State, + /// Get the line size of the value tensor using the state. + fn line_size_value( + state: &Self::State, ) -> comptime_type!(u32); - fn line_size_value( - state: &Self::State, + /// Get the line size of the mask tensor using the state. + fn line_size_mask( + state: &Self::State, ) -> comptime_type!(u32); - fn line_size_out( - state: &Self::State, + /// Get the line size of the out tensor using the state. + fn line_size_out( + state: &Self::State, ) -> comptime_type!(u32); } /// Tensor input representation. /// /// You can use the tensor input as if it was a pointer to the actually tensor. -pub struct TensorQuery { - state: *const GA::State, +pub struct TensorQuery { + state: *const GA::State, } -pub struct TensorKey { - state: *const GA::State, +pub struct TensorKey { + state: *const GA::State, } -pub struct TensorValue { - state: *const GA::State, +pub struct TensorValue { + state: *const GA::State, } -impl VirtualTensorOperations - for TensorQuery +pub struct TensorMask { + state: *const GA::State, +} + +impl + VirtualTensorOperations for TensorQuery { } -impl VirtualTensorOperations - for TensorKey +impl + VirtualTensorOperations for TensorKey { } -impl VirtualTensorOperations - for TensorValue +impl + VirtualTensorOperations for TensorValue { } -impl VirtualTensorOperations - for TensorOutput +impl + VirtualTensorOperations for TensorMask { } -impl VirtualTensorOperationsExpand - for TensorOutputExpand +impl + VirtualTensorOperations for TensorOutput +{ +} + +impl + VirtualTensorOperationsExpand for TensorOutputExpand { fn __expand_read_method( &self, @@ -298,12 +370,12 @@ impl VirtualTensorOpe } } -impl Lined - for TensorOutput +impl Lined + for TensorOutput { } -impl LinedExpand - for TensorOutputExpand +impl LinedExpand + for TensorOutputExpand { fn line_size(&self) -> u32 { let mut scope = Scope::root(false); @@ -311,8 +383,8 @@ impl LinedExpand } } -impl VirtualTensorOperationsExpand - for TensorQueryExpand +impl + VirtualTensorOperationsExpand for TensorQueryExpand { fn __expand_read_method( &self, @@ -372,12 +444,12 @@ impl VirtualTensorOpe } } -impl Lined - for TensorQuery +impl Lined + for TensorQuery { } -impl LinedExpand - for TensorQueryExpand +impl LinedExpand + for TensorQueryExpand { fn line_size(&self) -> u32 { let mut scope = Scope::root(false); @@ -385,8 +457,8 @@ impl LinedExpand } } -impl VirtualTensorOperationsExpand - for TensorKeyExpand +impl + VirtualTensorOperationsExpand for TensorKeyExpand { fn __expand_read_method( &self, @@ -446,12 +518,12 @@ impl VirtualTensorOpe } } -impl Lined - for TensorKey +impl Lined + for TensorKey { } -impl LinedExpand - for TensorKeyExpand +impl LinedExpand + for TensorKeyExpand { fn line_size(&self) -> u32 { let mut scope = Scope::root(false); @@ -459,8 +531,8 @@ impl LinedExpand } } -impl VirtualTensorOperationsExpand - for TensorValueExpand +impl + VirtualTensorOperationsExpand for TensorValueExpand { fn __expand_read_method( &self, @@ -520,12 +592,12 @@ impl VirtualTensorOpe } } -impl Lined - for TensorValue +impl Lined + for TensorValue { } -impl LinedExpand - for TensorValueExpand +impl LinedExpand + for TensorValueExpand { fn line_size(&self) -> u32 { let mut scope = Scope::root(false); @@ -533,40 +605,124 @@ impl LinedExpand } } +impl + VirtualTensorOperationsExpand for TensorMaskExpand +{ + fn __expand_read_method( + &self, + scope: &mut Scope, + index: ExpandElementTyped, + ) -> ExpandElementTyped> { + TensorMaskExpand::__expand_read_method(self.clone(), scope, index) + } + fn __expand_read_window_method( + &self, + context: &mut Scope, + start: ExpandElementTyped, + end: ExpandElementTyped, + ) -> SliceExpand, ReadOnly> { + TensorMaskExpand::__expand_read_window_method(self.clone(), context, start, end) + } + + fn __expand_write_method( + &self, + _scope: &mut Scope, + _index: ExpandElementTyped, + _value: ExpandElementTyped>, + ) { + panic!("Can't write to input tensor"); + } + + fn __expand_shape_method( + &self, + scope: &mut Scope, + axis: ExpandElementTyped, + ) -> ExpandElementTyped { + TensorMaskExpand::__expand_shape_method(self.clone(), scope, axis) + } + + fn __expand_stride_method( + &self, + scope: &mut Scope, + axis: ExpandElementTyped, + ) -> ExpandElementTyped { + TensorMaskExpand::__expand_stride_method(self.clone(), scope, axis) + } + + fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped { + TensorMaskExpand::__expand_rank_method(self.clone(), scope) + } + + fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped { + TensorMaskExpand::__expand_len_method(self.clone(), scope) + } + + fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped { + TensorMaskExpand::__expand_buffer_len_method(self.clone(), scope) + } + + fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand> { + TensorMaskExpand::__expand_as_tensor_map_method(self.clone(), scope) + } +} + +impl Lined + for TensorMask +{ +} +impl LinedExpand + for TensorMaskExpand +{ + fn line_size(&self) -> u32 { + let mut scope = Scope::root(false); + TensorMaskExpand::__expand_line_size_method(self.clone(), &mut scope) + } +} + /// Tensor output representation. /// /// You can use the tensor output as if it was a pointer to the actually tensor. /// /// # Warning +/// # Warning /// /// There is no mutability guarantee. -pub struct TensorOutput { - state: *mut GA::State, +pub struct TensorOutput { + state: *mut GA::State, } /// Expand type for [tensor input](TensorInput). -pub struct TensorQueryExpand { - state: as CubeType>::ExpandType, +pub struct TensorQueryExpand +{ + state: as CubeType>::ExpandType, } -pub struct TensorKeyExpand { - state: as CubeType>::ExpandType, +pub struct TensorKeyExpand { + state: as CubeType>::ExpandType, } -pub struct TensorValueExpand { - state: as CubeType>::ExpandType, +pub struct TensorValueExpand +{ + state: as CubeType>::ExpandType, +} + +pub struct TensorMaskExpand { + state: as CubeType>::ExpandType, } /// Expand type for [tensor output](TensorOutput). -pub struct TensorOutputExpand { - state: as CubeType>::ExpandType, +pub struct TensorOutputExpand +{ + state: as CubeType>::ExpandType, } #[cube] -impl TensorQuery { +impl + TensorQuery +{ /// Create a [tensor input](TensorInput) from the state and the [ident](TensorInputIdent). - pub fn new(state: &MA::State) -> TensorQuery { - TensorQuery:: { state } + pub fn new(state: &MA::State) -> TensorQuery { + TensorQuery:: { state } } //// Read the tensor at the given coordinate. @@ -617,10 +773,12 @@ impl TensorQuery TensorKey { +impl + TensorKey +{ /// Create a [tensor input](TensorInput) from the state and the [ident](TensorInputIdent). - pub fn new(state: &MA::State) -> TensorKey { - TensorKey:: { state } + pub fn new(state: &MA::State) -> TensorKey { + TensorKey:: { state } } //// Read the tensor at the given coordinate. @@ -671,10 +829,12 @@ impl TensorKey TensorValue { +impl + TensorValue +{ /// Create a [tensor input](TensorInput) from the state and the [ident](TensorInputIdent). - pub fn new(state: &MA::State) -> TensorValue { - TensorValue:: { state } + pub fn new(state: &MA::State) -> TensorValue { + TensorValue:: { state } } //// Read the tensor at the given coordinate. @@ -725,10 +885,68 @@ impl TensorValue TensorOutput { +impl + TensorMask +{ + /// Create a [tensor input](TensorInput) from the state and the [ident](TensorInputIdent). + pub fn new(state: &MA::State) -> TensorMask { + TensorMask:: { state } + } + + //// Read the tensor at the given coordinate. + pub fn read_window(&self, start: u32, end: u32) -> Slice> { + unsafe { MA::read_window_mask(&(*self.state), start, end) } + } + + /// Read the tensor at the given coordinate. + pub fn read(&self, coordinate: u32) -> Line { + unsafe { MA::read_mask(&(*self.state), coordinate) } + } + + /// Get the shape of the tensor at the given axis. + pub fn shape(&self, axis: u32) -> u32 { + unsafe { MA::shape_mask(&(*self.state), axis) } + } + + /// Get the stride of the tensor at the given axis. + pub fn stride(&self, axis: u32) -> u32 { + unsafe { MA::stride_mask(&(*self.state), axis) } + } + + /// Get the rank of the tensor. + pub fn rank(&self) -> u32 { + unsafe { MA::rank_mask(&(*self.state)) } + } + + /// Get the length of the tensor. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> u32 { + unsafe { MA::len_mask(&(*self.state)) } + } + + /// Get the buffer length of the tensor. + pub fn buffer_len(&self) -> u32 { + unsafe { MA::buffer_len_mask(&(*self.state)) } + } + + /// Get the buffer length of the tensor. + pub fn as_tensor_map(&self) -> CubeOption> { + unsafe { MA::as_tensor_map_mask(&(*self.state)) } + } + + /// Get the line size of the tensor. + pub fn line_size(&self) -> comptime_type!(u32) { + unsafe { MA::line_size_mask(&(*self.state)) } + } +} + +#[cube] +impl + TensorOutput +{ /// Create a [tensor output](TensorOutput) from the state. - pub fn new(state: &mut GA::State) -> TensorOutput { - TensorOutput:: { state } + pub fn new(state: &mut GA::State) -> TensorOutput { + TensorOutput:: { state } } /// Write the value to tensor at the given coordinate. @@ -776,18 +994,19 @@ pub struct TensorArgs; #[derive(CubeLaunch, CubeType)] /// Input representation for [TensorArgs] implementing [AttentionArgs]. -pub struct TensorInputs { +pub struct TensorInputs { pub query: Tensor>, pub key: Tensor>, pub value: Tensor>, - // pub mask: CubeOption>>, + pub mask: CubeOption>>, } -impl ConcreteInputsFactory for TensorInputs { +impl ConcreteInputsFactory for TensorInputs { fn create<'a, R: Runtime>( query: &'a TensorHandleRef<'a, R>, key: &'a TensorHandleRef<'a, R>, value: &'a TensorHandleRef<'a, R>, + mask: &'a Option>, _selection: &AttentionSelection, _problem: &AttentionProblem, line_sizes: &AttentionLineSizes, @@ -796,7 +1015,10 @@ impl ConcreteInputsFactory for TensorInputs CubeOptionArgs::Some(mask.as_tensor_arg(line_sizes.mask)), + None => CubeOptionArgs::None, + }, ) } } @@ -813,234 +1035,328 @@ impl ConcreteOutputFactory for Tensor> { } #[derive(CubeType)] -pub struct AttentionState { +pub struct AttentionState { pub query: *const Tensor>, pub key: *const Tensor>, pub value: *const Tensor>, + pub mask: CubeOption<*const Tensor>>, pub output: *mut Tensor>, } #[cube] impl AttentionArgs for TensorArgs { - type Input = TensorInputs; + type Input = TensorInputs; type Output = Tensor>; - type State = AttentionState; + type State = AttentionState; - fn init_state( - input: &Self::Input, + fn init_state( + input: &Self::Input, output: &mut Self::Output, - ) -> Self::State { - AttentionState:: { + ) -> Self::State { + let mask = match &input.mask { + CubeOption::None => CubeOption::new_None(), + CubeOption::Some(mask) => { + let ptr: *const Tensor> = mask; + CubeOption::new_Some(ptr) + } + }; + + AttentionState:: { query: &input.query, key: &input.key, value: &input.value, + mask, output, } } - fn read_query( - state: &Self::State, + fn has_mask( + state: &Self::State, + ) -> CubeOption<()> { + match state.mask { + CubeOption::None => CubeOption::new_None(), + CubeOption::Some(_) => CubeOption::new_Some(()), + } + } + + fn read_query( + state: &Self::State, coordinate: u32, ) -> Line { unsafe { (*state.query)[coordinate] } } - fn read_key( - state: &Self::State, + fn read_key( + state: &Self::State, coordinate: u32, ) -> Line { unsafe { (*state.key)[coordinate] } } - fn read_value( - state: &Self::State, + fn read_value( + state: &Self::State, coordinate: u32, ) -> Line { unsafe { (*state.value)[coordinate] } } - fn read_window_query( - state: &Self::State, + fn read_mask( + state: &Self::State, + coordinate: u32, + ) -> Line { + unsafe { (*state.mask.unwrap())[coordinate] } + } + + fn read_window_query( + state: &Self::State, start: u32, end: u32, ) -> Slice> { unsafe { (*state.query).slice(start, end) } } - fn read_window_key( - state: &Self::State, + fn read_window_key( + state: &Self::State, start: u32, end: u32, ) -> Slice> { unsafe { (*state.key).slice(start, end) } } - fn read_window_value( - state: &Self::State, + fn read_window_value( + state: &Self::State, start: u32, end: u32, ) -> Slice> { unsafe { (*state.value).slice(start, end) } } - fn as_tensor_map_query( - _state: &Self::State, + fn read_window_mask( + state: &Self::State, + start: u32, + end: u32, + ) -> Slice> { + unsafe { (*state.mask.unwrap()).slice(start, end) } + } + + fn as_tensor_map_query( + _state: &Self::State, ) -> CubeOption> { CubeOption::new_None() } - fn as_tensor_map_key( - _state: &Self::State, + fn as_tensor_map_key( + _state: &Self::State, ) -> CubeOption> { CubeOption::new_None() } - fn as_tensor_map_value( - _state: &Self::State, + fn as_tensor_map_value( + _state: &Self::State, ) -> CubeOption> { CubeOption::new_None() } - fn shape_query( - state: &Self::State, + fn as_tensor_map_mask( + _state: &Self::State, + ) -> CubeOption> { + CubeOption::new_None() + } + + fn shape_query( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.query).shape(dim) } } - fn shape_key( - state: &Self::State, + fn shape_key( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.key).shape(dim) } } - fn shape_value( - state: &Self::State, + fn shape_value( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.value).shape(dim) } } - fn shape_out( - state: &Self::State, + fn shape_mask( + state: &Self::State, + dim: u32, + ) -> u32 { + unsafe { (*state.mask.unwrap()).shape(dim) } + } + + fn shape_out( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.output).shape(dim) } } - fn stride_query( - state: &Self::State, + fn stride_query( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.query).stride(dim) } } - fn stride_key( - state: &Self::State, + fn stride_key( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.key).stride(dim) } } - fn stride_value( - state: &Self::State, + fn stride_value( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.value).stride(dim) } } - fn stride_out( - state: &Self::State, + fn stride_mask( + state: &Self::State, + dim: u32, + ) -> u32 { + unsafe { (*state.mask.unwrap()).stride(dim) } + } + + fn stride_out( + state: &Self::State, dim: u32, ) -> u32 { unsafe { (*state.output).stride(dim) } } - fn write_out( - state: &mut Self::State, + fn write_out( + state: &mut Self::State, coordinate: u32, value: Line, ) { unsafe { (*state.output)[coordinate] = value } } - fn rank_query(state: &Self::State) -> u32 { + fn rank_query( + state: &Self::State, + ) -> u32 { unsafe { (*state.query).rank() } } - fn rank_key(state: &Self::State) -> u32 { + fn rank_key( + state: &Self::State, + ) -> u32 { unsafe { (*state.key).rank() } } - fn rank_value(state: &Self::State) -> u32 { + fn rank_value( + state: &Self::State, + ) -> u32 { unsafe { (*state.value).rank() } } - fn rank_out(state: &Self::State) -> u32 { + fn rank_mask( + state: &Self::State, + ) -> u32 { + unsafe { (*state.mask.unwrap()).rank() } + } + + fn rank_out( + state: &Self::State, + ) -> u32 { unsafe { (*state.output).rank() } } - fn len_query(state: &Self::State) -> u32 { + fn len_query( + state: &Self::State, + ) -> u32 { unsafe { (*state.query).len() } } - fn len_key(state: &Self::State) -> u32 { + fn len_key( + state: &Self::State, + ) -> u32 { unsafe { (*state.key).len() } } - fn len_value(state: &Self::State) -> u32 { + fn len_value( + state: &Self::State, + ) -> u32 { unsafe { (*state.value).len() } } - fn len_out(state: &Self::State) -> u32 { + fn len_mask( + state: &Self::State, + ) -> u32 { + unsafe { (*state.mask.unwrap()).len() } + } + + fn len_out( + state: &Self::State, + ) -> u32 { unsafe { (*state.output).len() } } - fn buffer_len_query( - state: &Self::State, + fn buffer_len_query( + state: &Self::State, ) -> u32 { unsafe { (*state.query).buffer_len() } } - fn buffer_len_key( - state: &Self::State, + fn buffer_len_key( + state: &Self::State, ) -> u32 { unsafe { (*state.key).buffer_len() } } - fn buffer_len_value( - state: &Self::State, + fn buffer_len_value( + state: &Self::State, ) -> u32 { unsafe { (*state.value).buffer_len() } } - fn buffer_len_out( - state: &Self::State, + fn buffer_len_mask( + state: &Self::State, + ) -> u32 { + unsafe { (*state.mask.unwrap()).buffer_len() } + } + + fn buffer_len_out( + state: &Self::State, ) -> u32 { unsafe { (*state.output).buffer_len() } } - fn line_size_query( - state: &Self::State, + fn line_size_query( + state: &Self::State, ) -> comptime_type!(u32) { unsafe { (*state.query).line_size() } } - fn line_size_key( - state: &Self::State, + fn line_size_key( + state: &Self::State, ) -> comptime_type!(u32) { unsafe { (*state.key).line_size() } } - fn line_size_value( - state: &Self::State, + fn line_size_value( + state: &Self::State, ) -> comptime_type!(u32) { unsafe { (*state.value).line_size() } } - fn line_size_out( - state: &Self::State, + fn line_size_mask( + state: &Self::State, + ) -> comptime_type!(u32) { + unsafe { (*state.mask.unwrap()).line_size() } + } + + fn line_size_out( + state: &Self::State, ) -> comptime_type!(u32) { unsafe { (*state.output).line_size() } } @@ -1049,14 +1365,14 @@ impl AttentionArgs for TensorArgs { mod __query { use super::*; - impl CubeType - for TensorQuery + impl CubeType + for TensorQuery { - type ExpandType = TensorQueryExpand; + type ExpandType = TensorQueryExpand; } - impl Clone - for TensorQueryExpand + impl Clone + for TensorQueryExpand { fn clone(&self) -> Self { Self { @@ -1065,30 +1381,30 @@ mod __query { } } - impl IntoMut - for TensorQueryExpand + impl IntoMut + for TensorQueryExpand { fn into_mut(mut self, scope: &mut Scope) -> Self { self.state = self.state.into_mut(scope); self } } - impl CubeDebug - for TensorQueryExpand + impl CubeDebug + for TensorQueryExpand { fn set_debug_name(&self, scope: &mut Scope, name: &'static str) { self.state.set_debug_name(scope, name); } } - impl Clone - for TensorQuery + impl Clone + for TensorQuery { fn clone(&self) -> Self { *self } } - impl Copy - for TensorQuery + impl Copy + for TensorQuery { } } @@ -1096,14 +1412,14 @@ mod __query { mod __key { use super::*; - impl CubeType - for TensorKey + impl CubeType + for TensorKey { - type ExpandType = TensorKeyExpand; + type ExpandType = TensorKeyExpand; } - impl Clone - for TensorKeyExpand + impl Clone + for TensorKeyExpand { fn clone(&self) -> Self { Self { @@ -1112,42 +1428,92 @@ mod __key { } } - impl IntoMut - for TensorKeyExpand + impl IntoMut + for TensorKeyExpand { fn into_mut(mut self, scope: &mut Scope) -> Self { self.state = self.state.into_mut(scope); self } } - impl CubeDebug - for TensorKeyExpand + impl CubeDebug + for TensorKeyExpand { fn set_debug_name(&self, scope: &mut Scope, name: &'static str) { self.state.set_debug_name(scope, name); } } - impl Clone - for TensorKey + impl Clone + for TensorKey { fn clone(&self) -> Self { *self } } - impl Copy for TensorKey {} + impl Copy + for TensorKey + { + } } mod __value { use super::*; - impl CubeType - for TensorValue + impl CubeType + for TensorValue + { + type ExpandType = TensorValueExpand; + } + + impl Clone + for TensorValueExpand + { + fn clone(&self) -> Self { + Self { + state: self.state.clone(), + } + } + } + + impl IntoMut + for TensorValueExpand + { + fn into_mut(mut self, scope: &mut Scope) -> Self { + self.state = self.state.into_mut(scope); + self + } + } + impl CubeDebug + for TensorValueExpand + { + fn set_debug_name(&self, scope: &mut Scope, name: &'static str) { + self.state.set_debug_name(scope, name); + } + } + impl Clone + for TensorValue + { + fn clone(&self) -> Self { + *self + } + } + impl Copy + for TensorValue + { + } +} + +mod __mask { + use super::*; + + impl CubeType + for TensorMask { - type ExpandType = TensorValueExpand; + type ExpandType = TensorMaskExpand; } - impl Clone - for TensorValueExpand + impl Clone + for TensorMaskExpand { fn clone(&self) -> Self { Self { @@ -1156,30 +1522,30 @@ mod __value { } } - impl IntoMut - for TensorValueExpand + impl IntoMut + for TensorMaskExpand { fn into_mut(mut self, scope: &mut Scope) -> Self { self.state = self.state.into_mut(scope); self } } - impl CubeDebug - for TensorValueExpand + impl CubeDebug + for TensorMaskExpand { fn set_debug_name(&self, scope: &mut Scope, name: &'static str) { self.state.set_debug_name(scope, name); } } - impl Clone - for TensorValue + impl Clone + for TensorMask { fn clone(&self) -> Self { *self } } - impl Copy - for TensorValue + impl Copy + for TensorMask { } } @@ -1187,22 +1553,22 @@ mod __value { mod __output { use super::*; - impl CubeType - for TensorOutput + impl CubeType + for TensorOutput { - type ExpandType = TensorOutputExpand; + type ExpandType = TensorOutputExpand; } - impl Clone - for TensorOutput + impl Clone + for TensorOutput { fn clone(&self) -> Self { *self } } - impl Clone - for TensorOutputExpand + impl Clone + for TensorOutputExpand { fn clone(&self) -> Self { Self { @@ -1211,8 +1577,8 @@ mod __output { } } - impl IntoMut - for TensorOutputExpand + impl IntoMut + for TensorOutputExpand { fn into_mut(mut self, scope: &mut Scope) -> Self { self.state = self.state.into_mut(scope); @@ -1220,16 +1586,16 @@ mod __output { } } - impl CubeDebug - for TensorOutputExpand + impl CubeDebug + for TensorOutputExpand { fn set_debug_name(&self, scope: &mut Scope, name: &'static str) { self.state.set_debug_name(scope, name); } } - impl Copy - for TensorOutput + impl Copy + for TensorOutput { } } diff --git a/crates/cubecl-attention/src/components/batch/base.rs b/crates/cubecl-attention/src/components/batch/base.rs index bba0d34e5..ce12b4d5a 100644 --- a/crates/cubecl-attention/src/components/batch/base.rs +++ b/crates/cubecl-attention/src/components/batch/base.rs @@ -1,6 +1,6 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use cubecl_std::tensor::r#virtual::VirtualTensor; +use cubecl_std::{CubeOption, tensor::r#virtual::VirtualTensor}; use crate::components::{ AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, @@ -25,7 +25,7 @@ pub trait BatchAttentionFamily: Send + Sync + 'static { /// /// Out-of-bounds can happen unsafe fn launch_unchecked<'a, MS: AttentionSpec, R: Runtime>( - client: &ComputeClient<::Server, ::Channel>, + client: &ComputeClient<::Server>, cube_dim: CubeDim, cube_count: CubeCount, input: InputRuntimeArg<'a, MS, R>, @@ -38,7 +38,7 @@ pub trait BatchAttentionFamily: Send + Sync + 'static { /// /// This function may return an error if the configuration cannot be supported on the current runtime. fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &AttentionProblem, selection: &AttentionSelection, line_sizes: &AttentionLineSizes, @@ -61,6 +61,7 @@ pub trait BatchAttention: 'static + Send + Sync { query: VirtualTensor>, key: VirtualTensor>, value: VirtualTensor>, + mask: CubeOption>>, out: VirtualTensor, ReadWrite>, cube_count_args: CubeCountInput, #[comptime] config: Self::Config, diff --git a/crates/cubecl-attention/src/components/batch/entry_point.rs b/crates/cubecl-attention/src/components/batch/entry_point.rs index 10c1e8e6c..8a7a74d57 100644 --- a/crates/cubecl-attention/src/components/batch/entry_point.rs +++ b/crates/cubecl-attention/src/components/batch/entry_point.rs @@ -1,5 +1,6 @@ use crate::components::args::AttentionArgs; use crate::components::args::TensorKey; +use crate::components::args::TensorMask; use crate::components::args::TensorOutput; use crate::components::args::TensorQuery; use crate::components::args::TensorValue; @@ -9,8 +10,9 @@ use crate::components::batch::base::BatchAttention; use cubecl_core as cubecl; use cubecl_core::prelude::*; use cubecl_std::tensor::r#virtual::VirtualTensor; +use cubecl_std::{CubeOption, CubeOptionExpand}; -type Input = ::Input; +type Input = ::Input; type Output = ::Output; #[cube(launch_unchecked)] @@ -31,27 +33,41 @@ pub(crate) fn attention< OS: Float, BMMF: BatchAttentionFamily, >( - inputs: &Input, + inputs: &Input, output: &mut Output, cube_count_args: CubeCountInput, #[comptime] config: BMMF::Config, ) { let mut state = Args::init_state(inputs, output); - let query = TensorQuery::::new(&state); - let key = TensorKey::::new(&state); - let value = TensorValue::::new(&state); - let mut out = TensorOutput::::new(&mut state); + let query = TensorQuery::::new(&state); + let query = VirtualTensor::::new::>(&query); - let query = VirtualTensor::::new::>(&query); - let key = VirtualTensor::::new::>(&key); - let value = VirtualTensor::::new::>(&value); - let out = VirtualTensor::::new::>(&mut out); + let key = TensorKey::::new(&state); + let key = VirtualTensor::::new::>(&key); + + let value = TensorValue::::new(&state); + let value = VirtualTensor::::new::>(&value); + + let has_mask = Args::has_mask(&state); + let mask: CubeOption> = match has_mask { + CubeOption::Some(_) => { + let mask = TensorMask::::new(&state); + let mask = VirtualTensor::::new::>(&mask); + CubeOption::new_Some(mask) + } + CubeOption::None => CubeOption::new_None(), + }; + + let mut out = TensorOutput::::new(&mut state); + let out = + VirtualTensor::::new::>(&mut out); BMMF::Attention::<(QG, QT, KG, KS, VG, VS, KVT, SM, ACC, MSK, OG, OS)>::execute( query, key, value, + mask, out, cube_count_args, config, diff --git a/crates/cubecl-attention/src/components/batch/hypercube/base.rs b/crates/cubecl-attention/src/components/batch/hypercube/base.rs index dad528387..c6f647af5 100644 --- a/crates/cubecl-attention/src/components/batch/hypercube/base.rs +++ b/crates/cubecl-attention/src/components/batch/hypercube/base.rs @@ -53,8 +53,6 @@ impl CubeCountPlan { #[derive(CubeType, CubeLaunch)] /// CubeCountPlan stripped of non-essential runtime information -/// -/// This enum is given as runtime input to the matmul pub enum CubeCountInput { Tmp { dummy: u32 }, } diff --git a/crates/cubecl-attention/src/components/batch/mod.rs b/crates/cubecl-attention/src/components/batch/mod.rs index d2d6cf864..02c175c2a 100644 --- a/crates/cubecl-attention/src/components/batch/mod.rs +++ b/crates/cubecl-attention/src/components/batch/mod.rs @@ -1,4 +1,4 @@ -pub mod dummy; +pub mod simple; mod base; mod entry_point; diff --git a/crates/cubecl-attention/src/components/batch/dummy/attention.rs b/crates/cubecl-attention/src/components/batch/simple/attention.rs similarity index 74% rename from crates/cubecl-attention/src/components/batch/dummy/attention.rs rename to crates/cubecl-attention/src/components/batch/simple/attention.rs index 7a4e480ee..cb1f1ec65 100644 --- a/crates/cubecl-attention/src/components/batch/dummy/attention.rs +++ b/crates/cubecl-attention/src/components/batch/simple/attention.rs @@ -1,31 +1,32 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use cubecl_std::tensor::r#virtual::VirtualTensor; +use cubecl_std::{CubeOption, tensor::r#virtual::VirtualTensor}; use std::marker::PhantomData; use crate::components::{ AttentionPrecision, attention_types::*, batch::{ - BatchAttention, BatchAttentionConfig, CubeCountInput, dummy::config::DummyBatchConfig, + BatchAttention, BatchAttentionConfig, CubeCountInput, simple::config::SimpleBatchConfig, }, global::{GlobalAttention, GlobalAttentionConfig as _}, }; -pub struct DummyBatchAttention> { +pub struct SimpleBatchAttention> { _phantom: PhantomData<(AP, GA)>, } #[cube] impl, AP: AttentionPrecision> BatchAttention - for DummyBatchAttention + for SimpleBatchAttention { - type Config = DummyBatchConfig; + type Config = SimpleBatchConfig; fn execute( query: VirtualTensor>, key: VirtualTensor>, value: VirtualTensor>, + mask: CubeOption>>, out: VirtualTensor, ReadWrite>, _cube_count_args: CubeCountInput, #[comptime] config: Self::Config, @@ -46,6 +47,7 @@ impl, AP: AttentionPrecision> BatchAttention GA::init_query_reader(q_offset, query, global_config), GA::init_key_reader(key, global_config), GA::init_value_reader(value, global_config), + GA::init_mask_reader(q_offset, mask, seq_kv, global_config), GA::init_writer(q_offset, out, global_config), seq_q, seq_kv, diff --git a/crates/cubecl-attention/src/components/batch/dummy/config.rs b/crates/cubecl-attention/src/components/batch/simple/config.rs similarity index 77% rename from crates/cubecl-attention/src/components/batch/dummy/config.rs rename to crates/cubecl-attention/src/components/batch/simple/config.rs index 47741bdec..418c035ba 100644 --- a/crates/cubecl-attention/src/components/batch/dummy/config.rs +++ b/crates/cubecl-attention/src/components/batch/simple/config.rs @@ -7,13 +7,13 @@ use crate::components::{ }; #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] -pub struct DummyBatchConfig { +pub struct SimpleBatchConfig { global_config: G, hypercube_config: HypercubeConfig, - seq_k: u32, + seq_kv: u32, } -impl BatchAttentionConfig for DummyBatchConfig { +impl BatchAttentionConfig for SimpleBatchConfig { type GlobalConfig = G; fn global_config(&self) -> Self::GlobalConfig { @@ -29,12 +29,12 @@ impl BatchAttentionConfig for DummyBatchConfig { } } -impl DummyBatchConfig { - pub fn new(global_config: G, hypercube_config: HypercubeConfig, seq_k: u32) -> Self { +impl SimpleBatchConfig { + pub fn new(global_config: G, hypercube_config: HypercubeConfig, seq_kv: u32) -> Self { Self { global_config, hypercube_config, - seq_k, + seq_kv, } } diff --git a/crates/cubecl-attention/src/components/batch/simple/mod.rs b/crates/cubecl-attention/src/components/batch/simple/mod.rs new file mode 100644 index 000000000..299efbf2a --- /dev/null +++ b/crates/cubecl-attention/src/components/batch/simple/mod.rs @@ -0,0 +1,6 @@ +mod attention; +mod config; +mod setup; + +pub use attention::*; +pub use setup::SimpleBatchAttentionFamily; diff --git a/crates/cubecl-attention/src/components/batch/dummy/setup.rs b/crates/cubecl-attention/src/components/batch/simple/setup.rs similarity index 77% rename from crates/cubecl-attention/src/components/batch/dummy/setup.rs rename to crates/cubecl-attention/src/components/batch/simple/setup.rs index 220d31e94..35c342d4a 100644 --- a/crates/cubecl-attention/src/components/batch/dummy/setup.rs +++ b/crates/cubecl-attention/src/components/batch/simple/setup.rs @@ -7,29 +7,29 @@ use crate::components::{ attention_types::*, batch::{ BatchAttentionFamily, - dummy::{DummyBatchAttention, config::DummyBatchConfig}, entry_point::attention, + simple::{SimpleBatchAttention, config::SimpleBatchConfig}, }, global::GlobalAttentionFamily, }; -pub struct DummyBatchAttentionFamily { +pub struct SimpleBatchAttentionFamily { _phantom: PhantomData, } -impl BatchAttentionFamily for DummyBatchAttentionFamily { - type Attention = DummyBatchAttention>; - type Config = DummyBatchConfig; +impl BatchAttentionFamily for SimpleBatchAttentionFamily { + type Attention = SimpleBatchAttention>; + type Config = SimpleBatchConfig; fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &AttentionProblem, selection: &AttentionSelection, line_sizes: &AttentionLineSizes, ) -> Result { let global_config = GA::setup::(client, problem, selection, line_sizes)?; - DummyBatchConfig::new( + SimpleBatchConfig::new( global_config, selection .hypercube_selection @@ -44,10 +44,7 @@ impl BatchAttentionFamily for DummyBatchAttentionFami AS: crate::components::AttentionSpec, R: cubecl_core::Runtime, >( - client: &cubecl_core::prelude::ComputeClient< - ::Server, - ::Channel, - >, + client: &cubecl_core::prelude::ComputeClient<::Server>, cube_dim: cubecl_core::CubeDim, cube_count: cubecl_core::CubeCount, input: crate::components::InputRuntimeArg<'a, AS, R>, diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/matmul.rs b/crates/cubecl-attention/src/components/fragment/accelerated/attention.rs similarity index 58% rename from crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/matmul.rs rename to crates/cubecl-attention/src/components/fragment/accelerated/attention.rs index 35a965b02..8c77c61d8 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/matmul.rs +++ b/crates/cubecl-attention/src/components/fragment/accelerated/attention.rs @@ -4,19 +4,75 @@ use cubecl_matmul::components::tile::StridedTile; use crate::components::AttentionPrecision; use crate::components::attention_types::*; -use crate::components::tile::dummy::accelerated::AcceleratedAttentionMatmulConfig; -use crate::components::tile::dummy::{AttentionMatmul, AttentionMatmulConfig as _}; +use crate::components::fragment::accelerated::AcceleratedFragmentAttentionConfig; +use crate::components::fragment::{FragmentAttention, FragmentAttentionConfig as _}; +use crate::components::fragment::{FragmentLayout, FragmentLayoutExpand}; +use crate::components::fragment::{FragmentMask, FragmentMaskExpand}; +use crate::components::fragment::{FragmentOps, FragmentOpsExpand}; +use crate::components::tile::RowWise; +use cubecl_std::tensor::layout::Coords2d; /// Performs two matmuls with fragment reuse for key/value and score/prob -pub struct AcceleratedAttentionMatmul; +pub struct AcceleratedFragmentAttention; + +#[derive(CubeType)] +pub struct TODO; #[cube] -impl AttentionMatmul for AcceleratedAttentionMatmul { - type Config = AcceleratedAttentionMatmulConfig; +impl FragmentLayout for TODO { + fn absolute_pos(&self, _local_pos: Coords2d) -> Coords2d { + todo!() + } + fn num_units_per_row(&self) -> comptime_type!(u32) { + todo!() + } +} + +#[cube] +impl FragmentOps for cmma::Matrix { + type Layout = TODO; + + fn rowwise_max(&self) -> RowWise { + todo!() + } + + fn rowwise_sum(&self) -> RowWise { + todo!() + } + + fn rowwise_scale(&mut self, _val: &RowWise) { + todo!() + } + + fn scale_and_mask(_this: &mut Self, _scale: E, _mask: &M) { + todo!() + } + + fn exp_diff(&mut self, _val: &RowWise) { + todo!() + } + + fn layout(&self) -> Self::Layout { + todo!() + } +} + +#[cube] +impl FragmentMask for cmma::Matrix { + fn should_mask(&self, _local_pos: Coords2d) -> bool { + todo!() + } +} + +#[cube] +impl FragmentAttention for AcceleratedFragmentAttention { + type Config = AcceleratedFragmentAttentionConfig; type Query = cmma::Matrix>; type KeyValue = cmma::Matrix>; + type Mask = cmma::Matrix>; type Softmax = cmma::Matrix>; type Accumulator = cmma::Matrix>; + type FragmentLayout = TODO; fn score_matmul( lhs: &Self::Query, @@ -36,42 +92,26 @@ impl AttentionMatmul for AcceleratedAttentionMatmul cmma::execute::, KVT, ACC, ACC>(lhs, rhs, out, out); } - fn allocate_fill_query( - tile: &StridedTile, - #[comptime] config: Self::Config, - ) -> Self::Query { - let (slice, stride) = tile.as_unlined(); + fn allocate_query(#[comptime] config: Self::Config) -> Self::Query { let size = config.attention_tile_size().to_score_matmul_tile_size(); - if config.cast_query() { - let query = unsafe { - cmma::Matrix::>::uninitialized( - cmma::MatrixIdent::A, - size.m(), - size.n(), - size.k(), - cmma::MatrixLayout::RowMajor, - ) - }; - - cmma::load(&query, &slice, stride); - query - } else { - let tmp = unsafe { - cmma::Matrix::::uninitialized( - cmma::MatrixIdent::A, - size.m(), - size.n(), - size.k(), - cmma::MatrixLayout::RowMajor, - ) - }; - - cmma::load(&tmp, &slice, stride); - cmma::cast::>(&tmp) + unsafe { + cmma::Matrix::>::uninitialized( + cmma::MatrixIdent::A, + size.m(), + size.n(), + size.k(), + cmma::MatrixLayout::RowMajor, + ) } } + fn fill_query(tile: &StridedTile, fragment: &mut Self::Query) { + let (slice, stride) = tile.as_unlined(); + + cmma::load(fragment, &slice, stride); + } + fn allocate_key(#[comptime] config: Self::Config) -> Self::KeyValue { let size = config.attention_tile_size(); unsafe { @@ -107,7 +147,20 @@ impl AttentionMatmul for AcceleratedAttentionMatmul cmma::MatrixIdent::B, // m not relevant because it's a B size.seq_q, - // n and k match key, but we are guaranteed that value takes the same space (albeit not the same shape) + // n and k match key, and we assume value takes the same space + size.seq_kv, + size.head_dim, + cmma::MatrixLayout::RowMajor, + ) + } + } + + fn allocate_mask(#[comptime] config: Self::Config) -> Self::Mask { + let size = config.attention_tile_size(); + unsafe { + cmma::Matrix::>::uninitialized( + cmma::MatrixIdent::Accumulator, + size.seq_q, size.seq_kv, size.head_dim, cmma::MatrixLayout::RowMajor, @@ -124,6 +177,14 @@ impl AttentionMatmul for AcceleratedAttentionMatmul cmma::load(rhs, &slice, stride); } + fn fill_mask( + _tile: &StridedTile, + _mask: &mut Self::Mask, + #[comptime] _config: Self::Config, + ) { + todo!() + } + fn allocate_softmax(#[comptime] config: Self::Config) -> Self::Softmax { let size = config.attention_tile_size(); unsafe { @@ -131,7 +192,7 @@ impl AttentionMatmul for AcceleratedAttentionMatmul cmma::MatrixIdent::Accumulator, size.seq_q, size.seq_kv, - size.head_dim, // k, because we take accumulator point of view + size.head_dim, // k, because we take score matmul acc point of view cmma::MatrixLayout::RowMajor, ) } @@ -154,7 +215,7 @@ impl AttentionMatmul for AcceleratedAttentionMatmul } } - fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] _config: Self::Config) { + fn zero_accumulator(acc: &mut Self::Accumulator) { cmma::fill(acc, ACC::::from_int(0)); } @@ -172,32 +233,7 @@ impl AttentionMatmul for AcceleratedAttentionMatmul ); } - fn tmp_fill_accumulator( - tile: &StridedTile>, - acc: &mut Self::Accumulator, - #[comptime] _config: Self::Config, - ) { - let (slice, stride) = tile.as_unlined(); - cmma::load_with_layout(acc, &slice, stride, cmma::MatrixLayout::RowMajor); - } - fn tmp_fill_prob( - tile: &StridedTile>, - prob: &mut Self::Softmax, - #[comptime] _config: Self::Config, - ) { - let (slice, stride) = tile.as_unlined(); - cmma::load_with_layout(prob, &slice, stride, cmma::MatrixLayout::RowMajor); - } - fn tmp_write_softmax( - softmax: &Self::Softmax, - slice: &mut SliceMut>>, - #[comptime] config: Self::Config, - ) { - cmma::store( - slice, - softmax, - config.attention_tile_size().seq_kv, - cmma::MatrixLayout::RowMajor, - ); + fn softmax_layout(#[comptime] _config: Self::Config) -> Self::FragmentLayout { + todo!() } } diff --git a/crates/cubecl-attention/src/components/fragment/accelerated/config.rs b/crates/cubecl-attention/src/components/fragment/accelerated/config.rs new file mode 100644 index 000000000..dd1bec04c --- /dev/null +++ b/crates/cubecl-attention/src/components/fragment/accelerated/config.rs @@ -0,0 +1,66 @@ +use std::fmt::Debug; +use std::hash::Hash; + +use crate::components::fragment::FragmentAttentionConfig; +use crate::components::{AttentionPrecision, AttentionSetupError, AttentionTileSize}; + +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +pub struct AcceleratedFragmentAttentionConfig { + plane_dim: u32, + num_planes: u32, + attention_tile_size: AttentionTileSize, + query_stage_line_size: u32, + key_value_stage_line_size: u32, + check_bounds: bool, +} + +impl FragmentAttentionConfig for AcceleratedFragmentAttentionConfig { + fn plane_dim(&self) -> u32 { + self.plane_dim + } + + fn num_planes(&self) -> u32 { + self.num_planes + } + + fn attention_tile_size(&self) -> AttentionTileSize { + self.attention_tile_size + } + + fn num_rows_per_unit(&self) -> u32 { + todo!() + } + + fn causal_mask(&self) -> bool { + todo!() + } + + fn materialized_mask(&self) -> bool { + todo!() + } +} + +impl AcceleratedFragmentAttentionConfig { + pub fn new( + plane_dim: u32, + attention_tile_size: AttentionTileSize, + query_stage_line_size: u32, + key_value_stage_line_size: u32, + check_bounds: bool, + num_planes: u32, + ) -> Result { + Self { + plane_dim, + num_planes, + attention_tile_size, + query_stage_line_size, + key_value_stage_line_size, + check_bounds, + } + .validate() + } + + pub fn validate(self) -> Result { + Ok(self) + } +} diff --git a/crates/cubecl-attention/src/components/batch/dummy/mod.rs b/crates/cubecl-attention/src/components/fragment/accelerated/mod.rs similarity index 59% rename from crates/cubecl-attention/src/components/batch/dummy/mod.rs rename to crates/cubecl-attention/src/components/fragment/accelerated/mod.rs index 8da144650..967f0ae95 100644 --- a/crates/cubecl-attention/src/components/batch/dummy/mod.rs +++ b/crates/cubecl-attention/src/components/fragment/accelerated/mod.rs @@ -3,4 +3,4 @@ mod config; mod setup; pub use attention::*; -pub use setup::DummyBatchAttentionFamily; +pub use config::*; diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/setup.rs b/crates/cubecl-attention/src/components/fragment/accelerated/setup.rs similarity index 63% rename from crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/setup.rs rename to crates/cubecl-attention/src/components/fragment/accelerated/setup.rs index c132cccc4..add58a20a 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/setup.rs +++ b/crates/cubecl-attention/src/components/fragment/accelerated/setup.rs @@ -2,16 +2,16 @@ use cubecl_matmul::components::ComputeResources; use crate::components::{ AttentionPrecision, AttentionSetupError, InvalidConfigError, - tile::dummy::{ - AttentionMatmulFamily, - accelerated::{AcceleratedAttentionMatmul, AcceleratedAttentionMatmulConfig}, + fragment::{ + FragmentAttentionFamily, + accelerated::{AcceleratedFragmentAttention, AcceleratedFragmentAttentionConfig}, }, }; -impl AttentionMatmulFamily for AcceleratedAttentionMatmul { - type Matmul = AcceleratedAttentionMatmul; +impl FragmentAttentionFamily for AcceleratedFragmentAttention { + type FragmentAttention = AcceleratedFragmentAttention; - type Config = AcceleratedAttentionMatmulConfig; + type Config = AcceleratedFragmentAttentionConfig; fn requires_accelerator() -> bool { true @@ -22,18 +22,19 @@ impl AttentionMatmulFamily for AcceleratedAttentionMatmul { } fn setup( - _client: &cubecl_core::prelude::ComputeClient, + _client: &cubecl_core::prelude::ComputeClient, problem: &crate::components::AttentionProblem, selection: &crate::components::AttentionSelection, line_sizes: &crate::components::AttentionLineSizes, + num_planes: u32, ) -> Result { - AcceleratedAttentionMatmulConfig::new::( + AcceleratedFragmentAttentionConfig::new::( selection.plane_dim, selection.tiling_scheme.tile_size, - 1, line_sizes.query as u32, line_sizes.key as u32, !(problem.seq_kv as u32).is_multiple_of(selection.tiling_scheme.tile_size.seq_kv), + num_planes, ) } } diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/base.rs b/crates/cubecl-attention/src/components/fragment/base.rs similarity index 62% rename from crates/cubecl-attention/src/components/tile/dummy/attention_matmul/base.rs rename to crates/cubecl-attention/src/components/fragment/base.rs index 92a6d8a51..79b614f7d 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/base.rs +++ b/crates/cubecl-attention/src/components/fragment/base.rs @@ -4,20 +4,25 @@ use cubecl_matmul::components::ComputeResources; use cubecl_matmul::components::tile::StridedTile; use crate::components::attention_types::*; +use crate::components::fragment::{FragmentLayout, FragmentMask, FragmentOps}; use crate::components::{ - AttentionIdent, AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, + AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, AttentionSetupError, AttentionTileSize, AvailableLineSizes, InvalidConfigError, }; use std::fmt::Debug; use std::hash::Hash; #[cube] -pub trait AttentionMatmul: Send + Sync + 'static { - type Config: AttentionMatmulConfig; +pub trait FragmentAttention: Send + Sync + 'static { + type Config: FragmentAttentionConfig; type Query: CubeType; type KeyValue: CubeType; - type Softmax: CubeType; - type Accumulator: CubeType; + type Mask: FragmentMask; + type Softmax: FragmentOps, Layout = Self::FragmentLayout>; + type Accumulator: FragmentOps, Layout = Self::FragmentLayout>; + type FragmentLayout: FragmentLayout; + + fn softmax_layout(#[comptime] config: Self::Config) -> Self::FragmentLayout; fn score_matmul( lhs: &Self::Query, @@ -33,77 +38,56 @@ pub trait AttentionMatmul: Send + Sync + 'static { #[comptime] config: Self::Config, ); - fn allocate_fill_query( - tile: &StridedTile, - #[comptime] config: Self::Config, - ) -> Self::Query; + fn allocate_query(#[comptime] config: Self::Config) -> Self::Query; + fn allocate_mask(#[comptime] config: Self::Config) -> Self::Mask; fn allocate_key(#[comptime] config: Self::Config) -> Self::KeyValue; fn allocate_value(#[comptime] config: Self::Config) -> Self::KeyValue; fn allocate_key_value(#[comptime] config: Self::Config) -> Self::KeyValue; - fn fill_key_value( + fn allocate_softmax(#[comptime] config: Self::Config) -> Self::Softmax; + fn allocate_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator; + + fn fill_query(tile: &StridedTile, fragment: &mut Self::Query); + fn fill_key_value( tile: &StridedTile, - rhs: &mut Self::KeyValue, + fragment: &mut Self::KeyValue, + #[comptime] config: Self::Config, + ); + fn fill_mask( + tile: &StridedTile, + fragment: &mut Self::Mask, #[comptime] config: Self::Config, ); - fn allocate_softmax(#[comptime] config: Self::Config) -> Self::Softmax; fn zero_softmax(softmax: &mut Self::Softmax, #[comptime] config: Self::Config); + fn zero_accumulator(acc: &mut Self::Accumulator); - fn allocate_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator; - fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] config: Self::Config); - - fn write_results( + fn write_results( out: &Self::Accumulator, slice: &mut SliceMut>, #[comptime] config: Self::Config, ); - - // These methods should be deletable when we have proper control over fragments - fn tmp_fill_accumulator( - tile: &StridedTile>, - acc: &mut Self::Accumulator, - #[comptime] config: Self::Config, - ); - fn tmp_fill_prob( - tile: &StridedTile>, - prob: &mut Self::Softmax, - #[comptime] config: Self::Config, - ); - fn tmp_write_softmax( - softmax: &Self::Softmax, - slice: &mut SliceMut>>, - #[comptime] config: Self::Config, - ); } /// Configuration for the Tile Attention level -pub trait AttentionMatmulConfig: +pub trait FragmentAttentionConfig: Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static { fn plane_dim(&self) -> u32; - - // TODO try to remove this fn num_planes(&self) -> u32; - fn stage_line_size(&self, ident: AttentionIdent) -> u32; fn attention_tile_size(&self) -> AttentionTileSize; - // If AP::EI != FP::Q - fn cast_query(&self) -> bool; - - fn num_units_per_row(&self, ident: AttentionIdent) -> u32; - fn num_cols_per_unit(&self, ident: AttentionIdent) -> u32; - fn num_rows_per_unit(&self, ident: AttentionIdent) -> u32; - - fn check_bounds(&self) -> bool; + fn num_rows_per_unit(&self) -> u32; + fn causal_mask(&self) -> bool; + fn materialized_mask(&self) -> bool; } -pub trait AttentionMatmulFamily: Send + Sync + 'static { +pub trait FragmentAttentionFamily: Send + Sync + 'static { /// The specific [TileMatmul] implementation associated with this family. - type Matmul: AttentionMatmul; + type FragmentAttention: FragmentAttention; /// The configuration type associated with this matmul family. - type Config: AttentionMatmulConfig; + type Config: FragmentAttentionConfig; /// Returns whether this tile matmul requires specialized hardware accelerators (e.g., tensor cores). fn requires_accelerator() -> bool; @@ -115,10 +99,11 @@ pub trait AttentionMatmulFamily: Send + Sync + 'static { /// /// This function may return an error if the configuration cannot be supported on the current runtime. fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &AttentionProblem, selection: &AttentionSelection, line_sizes: &AttentionLineSizes, + num_planes: u32, ) -> Result; /// Filters out line sizes that are incompatible with this matmul family. diff --git a/crates/cubecl-attention/src/components/fragment/dummy_register/attention.rs b/crates/cubecl-attention/src/components/fragment/dummy_register/attention.rs new file mode 100644 index 000000000..201f9ff27 --- /dev/null +++ b/crates/cubecl-attention/src/components/fragment/dummy_register/attention.rs @@ -0,0 +1,495 @@ +use std::cmp::max; + +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_matmul::components::tile::StridedTile; +use cubecl_std::tensor::layout::Coords2d; + +use crate::components::AttentionPrecision; +use crate::components::attention_types::*; +use crate::components::fragment::{FragmentMask, FragmentMaskExpand}; +use crate::components::tile::{RowVal, RowWise}; + +use crate::components::fragment::dummy_register::DummyRegisterAttentionMatmulConfig; +use crate::components::fragment::{FragmentAttention, FragmentAttentionConfig as _}; +use crate::components::fragment::{FragmentLayout, FragmentLayoutExpand}; +use crate::components::fragment::{FragmentOps, FragmentOpsExpand}; + +pub struct DummyRegisterFragmentAttention; + +#[derive(CubeType)] +/// Mimics fragment behaviour, but execution is not efficient +/// Assumes: +/// - unit_size * plane_dim = total_size (not dim wise but in total count) +pub struct ArrayTile { + array: Array, + layout: ArrayTileLayout, +} + +#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)] +pub enum InnerLayout { + /// Each unit has all its elements contiguous inside the same row + /// + /// 0, 0, 1, 1, 2, 2, 3, 3, + /// 4, 4, 5, 5, 6, 6, 7, 7, + /// 8, 8, 9, 9, 10, 10, 11, 11, + /// 12, 12, 13, 13, 14, 14, 15, 15, + /// 16, 16, 17, 17, 18, 18, 19, 19, + /// 20, 20, 21, 21, 22, 22, 23, 23, + /// 24, 24, 25, 25, 26, 26, 27, 27, + /// 28, 28, 29, 29, 30, 30, 31, 31, + Contiguous, + /// Each unit spreads its elements along two rows + /// + /// 0, 1, 2, 3, 4, 5, 6, 7, + /// 8, 9, 10, 11, 12, 13, 14, 15, + /// 16, 17, 18, 19, 20, 21, 22, 23, + /// 24, 25, 26, 27, 28, 29, 30, 31, + /// 0, 1, 2, 3, 4, 5, 6, 7, + /// 8, 9, 10, 11, 12, 13, 14, 15, + /// 16, 17, 18, 19, 20, 21, 22, 23, + /// 24, 25, 26, 27, 28, 29, 30, 31, + SplitRows, +} + +#[cube] +impl ArrayTile { + pub fn new(layout: ArrayTileLayout) -> ArrayTile { + let array = Array::::new(comptime!(layout.unit_size.0 * layout.unit_size.1)); + ArrayTile:: { array, layout } + } + + pub fn zero(&mut self) { + for i in 0..self.layout.unit_size.0 * self.layout.unit_size.1 { + self.array[i] = E::from_int(0); + } + } +} + +#[derive(CubeType, Copy, Clone)] +pub struct ArrayTileLayout { + #[cube(comptime)] + total_size: Coords2d, + #[cube(comptime)] + unit_size: Coords2d, + #[cube(comptime)] + num_units_per_row: u32, + #[cube(comptime)] + plane_dim: u32, +} + +#[cube] +impl ArrayTileLayout { + pub fn new( + #[comptime] total_size: Coords2d, + #[comptime] plane_dim: u32, + #[comptime] inner_layout: InnerLayout, + ) -> ArrayTileLayout { + let total_elements = total_size.0 * total_size.1; + let elements_per_unit = total_elements.div_ceil(plane_dim); + + let (num_rows_per_unit, num_cols_per_unit) = match inner_layout { + InnerLayout::Contiguous => (1u32, elements_per_unit), + InnerLayout::SplitRows => (2u32, elements_per_unit / 2u32), + }; + let unit_size = (num_rows_per_unit, num_cols_per_unit); + + let num_units_per_row = comptime!(total_size.1 / unit_size.1); + + ArrayTileLayout { + total_size, + unit_size, + num_units_per_row, + plane_dim, + } + } +} + +#[cube] +impl FragmentLayout for ArrayTileLayout { + fn absolute_pos(&self, local_pos: Coords2d) -> Coords2d { + let abs_row_index = { + let row_0 = UNIT_POS_X / self.num_units_per_row; + let row_jump = comptime!(self.plane_dim / self.num_units_per_row); + + local_pos.0 * row_jump + row_0 + }; + + let abs_col_index = self.unit_size.1 * (UNIT_POS_X % self.num_units_per_row) + local_pos.1; + + (abs_row_index, abs_col_index) + } + + fn num_units_per_row(&self) -> comptime_type!(u32) { + comptime!(self.total_size.1 / self.unit_size.1) + } +} + +#[cube] +impl FragmentOps for ArrayTile { + type Layout = ArrayTileLayout; + + fn rowwise_max(&self) -> RowWise { + let mut vals = Sequence::new(); + + #[unroll] + for r in 0..self.layout.unit_size.0 { + let row_offset = r * self.layout.unit_size.1; + let mut val = E::min_value(); + + #[unroll] + for c in 0..self.layout.unit_size.1 { + let index = row_offset + c; + val = Max::max(val, self.array[index]); + } + + vals.push(RowVal:: { val }); + } + + RowWise:: { + num_rows: self.layout.unit_size.0, + vals, + } + } + + fn rowwise_sum(&self) -> RowWise { + let mut vals = Sequence::new(); + + #[unroll] + for r in 0..self.layout.unit_size.0 { + let row_offset = r * self.layout.unit_size.1; + let mut val = E::from_int(0); + + #[unroll] + for c in 0..self.layout.unit_size.1 { + let index = row_offset + c; + val += self.array[index]; + } + + vals.push(RowVal:: { val }); + } + + RowWise:: { + num_rows: self.layout.unit_size.0, + vals, + } + } + + fn rowwise_scale(&mut self, scale: &RowWise) { + #[unroll] + for r in 0..self.layout.unit_size.0 { + let row_offset = r * self.layout.unit_size.1; + #[unroll] + for c in 0..self.layout.unit_size.1 { + let index = row_offset + c; + self.array[index] = self.array[index] * scale.index(r); + } + } + } + + fn scale_and_mask(this: &mut Self, scale: E, mask: &M) { + #[unroll] + for r in 0..this.layout.unit_size.0 { + let row_offset = r * this.layout.unit_size.1; + #[unroll] + for c in 0..this.layout.unit_size.1 { + let index = row_offset + c; + this.array[index] = this.array[index] * scale + + E::cast_from(mask.should_mask((r, c).runtime())) * E::min_value(); + } + } + } + + fn exp_diff(&mut self, val: &RowWise) { + #[unroll] + for r in 0..self.layout.unit_size.0 { + let row_offset = r * self.layout.unit_size.1; + #[unroll] + for c in 0..self.layout.unit_size.1 { + let index = row_offset + c; + self.array[index] = Exp::exp(self.array[index] - val.index(r)); + } + } + } + + fn layout(&self) -> Self::Layout { + self.layout + } +} + +#[cube] +impl FragmentMask for ArrayTile { + fn should_mask(&self, local_pos: Coords2d) -> bool { + bool::cast_from(self.array[local_pos.0 * self.layout.unit_size.1 + local_pos.1]) + } +} + +#[cube] +fn array_tile_to_tmp_smem( + array_tile: &ArrayTile, + #[comptime] num_planes: u32, +) -> SliceMut { + let tile_size = comptime!(array_tile.layout.total_size.0 * array_tile.layout.total_size.1); + let mut tmp_smem = SharedMemory::::new(comptime!(num_planes * tile_size)); + + let start = UNIT_POS_Y * tile_size; + let end = start + tile_size; + let mut tmp_smem_slice = tmp_smem.slice_mut(start, end); + + if UNIT_POS_X == 0 { + for i in 0..tile_size { + tmp_smem_slice[i] = E::from_int(0); + } + } + sync_cube(); + + for r in 0..array_tile.layout.unit_size.0 { + for c in 0..array_tile.layout.unit_size.1 { + let (row, col) = array_tile.layout.absolute_pos((r, c)); + let index = row * array_tile.layout.total_size.1 + col; + tmp_smem_slice[index] = array_tile.array[r * array_tile.layout.unit_size.1 + c]; + } + } + + tmp_smem_slice +} + +#[cube] +fn tmp_smem_to_array_tile(tmp_smem_slice: &SliceMut, array_tile: &mut ArrayTile) { + for r in 0..array_tile.layout.unit_size.0 { + for c in 0..array_tile.layout.unit_size.1 { + let (row, col) = array_tile.layout.absolute_pos((r, c)); + let index = row * array_tile.layout.total_size.1 + col; + array_tile.array[r * array_tile.layout.unit_size.1 + c] = tmp_smem_slice[index]; + } + } +} + +#[cube] +fn strided_tile_to_array_tile( + strided_tile: &StridedTile, + array_tile: &mut ArrayTile, +) { + for r in 0..array_tile.layout.unit_size.0 { + for c in 0..array_tile.layout.unit_size.1 { + let (row, col) = array_tile.layout.absolute_pos((r, c)); + array_tile.array[r * array_tile.layout.unit_size.1 + c] = + E2::cast_from(strided_tile.get_line(row, col)) + } + } +} + +#[cube] +fn array_tile_to_slice( + array_tile: &ArrayTile, + slice: &mut SliceMut>, +) { + for r in 0..array_tile.layout.unit_size.0 { + for c in 0..array_tile.layout.unit_size.1 { + let (row, col) = array_tile.layout.absolute_pos((r, c)); + let index = row * array_tile.layout.total_size.1 + col; + slice[index] = Line::cast_from(array_tile.array[r * array_tile.layout.unit_size.1 + c]); + } + } +} + +#[cube] +impl FragmentAttention for DummyRegisterFragmentAttention { + type Config = DummyRegisterAttentionMatmulConfig; + + type Query = ArrayTile>; + type KeyValue = ArrayTile>; + type Mask = ArrayTile>; + type Softmax = ArrayTile>; + type Accumulator = ArrayTile>; + type FragmentLayout = ArrayTileLayout; + + fn softmax_layout(#[comptime] config: Self::Config) -> ArrayTileLayout { + ArrayTileLayout::new( + ( + config.attention_tile_size().seq_q, + config.attention_tile_size().seq_kv, + ), + config.plane_dim(), + config.inner_layout(), + ) + } + + fn score_matmul( + lhs: &Self::Query, + rhs: &Self::KeyValue, + out: &mut Self::Softmax, + #[comptime] config: Self::Config, + ) { + let tmp_lhs_smem_slice = array_tile_to_tmp_smem::>(lhs, config.num_planes()); + let tmp_rhs_smem_slice = array_tile_to_tmp_smem::>(rhs, config.num_planes()); + let mut tmp_out_smem_slice = array_tile_to_tmp_smem::>(out, config.num_planes()); + sync_cube(); + + if UNIT_POS_X == 0 { + let (m, n, k) = comptime! {let (m, n, k): (u32, u32, u32) = config.attention_tile_size().to_score_matmul_tile_size().into(); (m, n, k)}; + + for i in 0..m { + for j in 0..n { + let mut sum = SM::::from_int(0); + for ki in 0..k { + let lhs_val = tmp_lhs_smem_slice[i * k + ki]; + let rhs_val = tmp_rhs_smem_slice[ki * n + j]; + sum += SM::::cast_from(lhs_val) * SM::::cast_from(rhs_val); + } + tmp_out_smem_slice[i * n + j] = tmp_out_smem_slice[i * n + j] + sum; + } + } + } + + sync_cube(); + tmp_smem_to_array_tile(&tmp_out_smem_slice, out); + sync_cube(); + } + + fn value_matmul( + lhs: &Self::Softmax, + rhs: &Self::KeyValue, + out: &mut Self::Accumulator, + #[comptime] config: Self::Config, + ) { + sync_cube(); + let tmp_lhs_smem_slice = array_tile_to_tmp_smem::>(lhs, config.num_planes()); + let tmp_rhs_smem_slice = array_tile_to_tmp_smem::>(rhs, config.num_planes()); + let mut tmp_out_smem_slice = array_tile_to_tmp_smem::>(out, config.num_planes()); + sync_cube(); + + if UNIT_POS_X == 0 { + let (m, n, k) = comptime! {let (m, n, k): (u32, u32, u32) = config.attention_tile_size().to_value_matmul_tile_size().into(); (m, n, k)}; + + for i in 0..m { + for j in 0..n { + let mut sum = ACC::::from_int(0); + for ki in 0..k { + let lhs_val = tmp_lhs_smem_slice[i * k + ki]; + let rhs_val = tmp_rhs_smem_slice[ki * n + j]; + sum += ACC::::cast_from(lhs_val) * ACC::::cast_from(rhs_val); + } + tmp_out_smem_slice[i * n + j] = tmp_out_smem_slice[i * n + j] + sum; + } + } + } + + sync_cube(); + tmp_smem_to_array_tile(&tmp_out_smem_slice, out); + sync_cube(); + } + + fn allocate_key_value(#[comptime] config: Self::Config) -> Self::KeyValue { + ArrayTile::new(ArrayTileLayout::new( + ( + comptime!(max( + config.attention_tile_size().head_dim, + config.attention_tile_size().seq_kv, + )), + comptime!(max( + config.attention_tile_size().seq_kv, + config.attention_tile_size().val_dim, + )), + ), + config.plane_dim(), + config.inner_layout(), + )) + } + + fn allocate_key(#[comptime] config: Self::Config) -> Self::KeyValue { + ArrayTile::new(ArrayTileLayout::new( + ( + config.attention_tile_size().head_dim, + config.attention_tile_size().seq_kv, + ), + config.plane_dim(), + config.inner_layout(), + )) + } + + fn allocate_value(#[comptime] config: Self::Config) -> Self::KeyValue { + ArrayTile::new(ArrayTileLayout::new( + ( + config.attention_tile_size().seq_kv, + config.attention_tile_size().val_dim, + ), + config.plane_dim(), + config.inner_layout(), + )) + } + + fn allocate_mask(#[comptime] config: Self::Config) -> Self::Mask { + ArrayTile::new(>::softmax_layout(config)) + } + + fn allocate_softmax(#[comptime] config: Self::Config) -> Self::Softmax { + ArrayTile::new(>::softmax_layout(config)) + } + + fn allocate_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator { + ArrayTile::new(ArrayTileLayout::new( + ( + config.attention_tile_size().seq_q, + config.attention_tile_size().val_dim, + ), + config.plane_dim(), + config.inner_layout(), + )) + } + + fn allocate_query(#[comptime] config: Self::Config) -> Self::Query { + let seq_q = config.attention_tile_size().seq_q; + let head_dim = config.attention_tile_size().head_dim; + + ArrayTile::new(ArrayTileLayout::new( + (seq_q, head_dim), + config.plane_dim(), + config.inner_layout(), + )) + } + + fn fill_query(tile: &StridedTile, fragment: &mut Self::Query) { + strided_tile_to_array_tile(tile, fragment); + + sync_cube(); + } + + fn fill_key_value( + tile: &StridedTile, + rhs: &mut Self::KeyValue, + #[comptime] _config: Self::Config, + ) { + strided_tile_to_array_tile(tile, rhs); + + sync_cube(); + } + + fn fill_mask( + tile: &StridedTile, + mask: &mut Self::Mask, + #[comptime] _config: Self::Config, + ) { + strided_tile_to_array_tile(tile, mask); + + sync_cube(); + } + + fn zero_softmax(softmax: &mut Self::Softmax, #[comptime] _config: Self::Config) { + softmax.zero(); + sync_cube(); + } + + fn zero_accumulator(acc: &mut Self::Accumulator) { + acc.zero(); + sync_cube(); + } + + fn write_results( + out: &Self::Accumulator, + slice: &mut SliceMut>, + #[comptime] _config: Self::Config, + ) { + array_tile_to_slice(out, slice); + } +} diff --git a/crates/cubecl-attention/src/components/fragment/dummy_register/config.rs b/crates/cubecl-attention/src/components/fragment/dummy_register/config.rs new file mode 100644 index 000000000..a1a77ee2a --- /dev/null +++ b/crates/cubecl-attention/src/components/fragment/dummy_register/config.rs @@ -0,0 +1,116 @@ +use std::fmt::Debug; +use std::hash::Hash; + +use crate::components::fragment::FragmentAttentionConfig; +use crate::components::fragment::dummy_register::InnerLayout; +use crate::components::{AttentionPrecision, AttentionSetupError, AttentionTileSize}; + +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +pub struct DummyRegisterAttentionMatmulConfig { + plane_dim: u32, + attention_tile_size: AttentionTileSize, + num_planes: u32, + query_stage_line_size: u32, + key_value_stage_line_size: u32, + check_bounds: bool, + inner_layout: InnerLayout, + causal_mask: bool, + materialized_mask: bool, +} + +impl FragmentAttentionConfig for DummyRegisterAttentionMatmulConfig { + fn plane_dim(&self) -> u32 { + self.plane_dim + } + + fn num_planes(&self) -> u32 { + self.num_planes + } + + fn attention_tile_size(&self) -> AttentionTileSize { + self.attention_tile_size + } + + fn num_rows_per_unit(&self) -> u32 { + match self.inner_layout { + InnerLayout::Contiguous => 1u32, + InnerLayout::SplitRows => 2u32, + } + } + + fn causal_mask(&self) -> bool { + self.causal_mask + } + + fn materialized_mask(&self) -> bool { + self.materialized_mask + } +} + +impl DummyRegisterAttentionMatmulConfig { + #[allow(clippy::too_many_arguments)] + pub fn new( + plane_dim: u32, + attention_tile_size: AttentionTileSize, + num_planes: u32, + query_stage_line_size: u32, + key_value_stage_line_size: u32, + check_bounds: bool, + two_rows_in_array_tile: bool, + causal_mask: bool, + materialized_mask: bool, + ) -> Result { + Self { + plane_dim, + attention_tile_size, + num_planes, + query_stage_line_size, + key_value_stage_line_size, + check_bounds, + inner_layout: if two_rows_in_array_tile { + InnerLayout::SplitRows + } else { + InnerLayout::Contiguous + }, + causal_mask, + materialized_mask, + } + .validate() + } + + pub fn validate(self) -> Result { + let softmax_num_rows = self.attention_tile_size.seq_q; + let softmax_num_cols = self.attention_tile_size.seq_kv; + let softmax_total = softmax_num_rows * softmax_num_cols; + + if softmax_total % self.plane_dim != 0 { + return Err(AttentionSetupError::InvalidConfig(Box::new( + "Softmax size should be divisible by plane dim", + ))); + } + + if self.inner_layout == InnerLayout::Contiguous && softmax_num_rows > self.plane_dim { + return Err(AttentionSetupError::InvalidConfig(Box::new( + "More than one row per unit not supported with this inner layout", + ))); + } + + if self.inner_layout == InnerLayout::SplitRows && softmax_total % (2 * self.plane_dim) != 0 + { + return Err(AttentionSetupError::InvalidConfig(Box::new( + "With split rows, units must have two elements each", + ))); + } + + if self.attention_tile_size.head_dim < self.attention_tile_size.val_dim { + return Err(AttentionSetupError::InvalidConfig(Box::new( + "Can't have tile head_dim < tile val dim (not sure why)", + ))); + } + Ok(self) + } + + pub fn inner_layout(&self) -> InnerLayout { + self.inner_layout + } +} diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/mod.rs b/crates/cubecl-attention/src/components/fragment/dummy_register/mod.rs similarity index 53% rename from crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/mod.rs rename to crates/cubecl-attention/src/components/fragment/dummy_register/mod.rs index 44a0b24f7..967f0ae95 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/mod.rs +++ b/crates/cubecl-attention/src/components/fragment/dummy_register/mod.rs @@ -1,6 +1,6 @@ +mod attention; mod config; -mod matmul; mod setup; +pub use attention::*; pub use config::*; -pub use matmul::*; diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/setup.rs b/crates/cubecl-attention/src/components/fragment/dummy_register/setup.rs similarity index 61% rename from crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/setup.rs rename to crates/cubecl-attention/src/components/fragment/dummy_register/setup.rs index d5b166770..0ea32a0e1 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/setup.rs +++ b/crates/cubecl-attention/src/components/fragment/dummy_register/setup.rs @@ -1,22 +1,20 @@ use cubecl_core::client::ComputeClient; use cubecl_matmul::components::ComputeResources; +use crate::components::fragment::dummy_register::DummyRegisterAttentionMatmulConfig; +use crate::components::fragment::dummy_register::DummyRegisterFragmentAttention; use crate::components::{ AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, - AttentionSetupError, InvalidConfigError, - tile::dummy::{ - AttentionMatmulFamily, - dummy_register::{DummyRegisterAttentionMatmul, DummyRegisterAttentionMatmulConfig}, - }, + AttentionSetupError, InvalidConfigError, fragment::FragmentAttentionFamily, }; -impl AttentionMatmulFamily for DummyRegisterAttentionMatmul { - type Matmul = DummyRegisterAttentionMatmul; +impl FragmentAttentionFamily for DummyRegisterFragmentAttention { + type FragmentAttention = DummyRegisterFragmentAttention; type Config = DummyRegisterAttentionMatmulConfig; fn requires_accelerator() -> bool { - true + false } fn computation_resources() -> Result { @@ -24,18 +22,22 @@ impl AttentionMatmulFamily for DummyRegisterAttentionMatmul { } fn setup( - _client: &ComputeClient, + _client: &ComputeClient, problem: &AttentionProblem, selection: &AttentionSelection, line_sizes: &AttentionLineSizes, + num_planes: u32, ) -> Result { DummyRegisterAttentionMatmulConfig::new::( selection.plane_dim, selection.tiling_scheme.tile_size, - 1, + num_planes, line_sizes.query as u32, line_sizes.key as u32, !(problem.seq_kv as u32).is_multiple_of(selection.tiling_scheme.tile_size.seq_kv), + selection.two_rows_in_array_tile, + problem.causal, + problem.masked, ) } } diff --git a/crates/cubecl-attention/src/components/fragment/fragments.rs b/crates/cubecl-attention/src/components/fragment/fragments.rs new file mode 100644 index 000000000..124d08a08 --- /dev/null +++ b/crates/cubecl-attention/src/components/fragment/fragments.rs @@ -0,0 +1,74 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::components::tile::RowWise; +use cubecl_std::tensor::layout::Coords2d; + +#[cube] +/// Describes how a fragment is fragmented across units +/// The layout is independant of the data and data types +pub trait FragmentLayout: CubeType { + /// Maps the (row, col) of the registers of a single unit to the position within the whole tile + /// + /// Example: for simplicity, if we had a 4 units warp for a 4x4 tile divided as such: + /// 0, 0, 1, 1, + /// 2, 2, 3, 3, + /// 0, 0, 1, 1, + /// 2, 2, 3, 3, + /// Then we would have: + /// unit_0: absolute_pos((0, 0)) == (0, 0) + /// unit_0: absolute_pos((0, 1)) == (0, 1) + /// unit_0: absolute_pos((1, 0)) == (2, 0) + /// unit_0: absolute_pos((1, 1)) == (2, 1) + /// ... + /// unit_3: absolute_pos((0, 0)) == (1, 2) + /// unit_3: absolute_pos((0, 1)) == (1, 3) + /// unit_3: absolute_pos((1, 0)) == (3, 2) + /// unit_3: absolute_pos((1, 1)) == (3, 3) + fn absolute_pos(&self, local_pos: Coords2d) -> Coords2d; + + /// Gives how many units participate in the same row + /// + /// Example: for simplicity, if we had a 4 units warp for a 4x4 tile divided as such: + /// 0, 0, 1, 1, + /// 2, 2, 3, 3, + /// 0, 0, 1, 1, + /// 2, 2, 3, 3, + /// Then it would output 2, because each row is spread across two different units (0 and 1, or 2 and 3) + /// Layouts with varying num_units_per_row are not supported + fn num_units_per_row(&self) -> comptime_type!(u32); +} + +#[cube] +/// Operations on a fragment, having a specific fragment layout +pub trait FragmentOps { + /// How the fragment is fragmented across units + type Layout: FragmentLayout; + + /// Get the layout of the fragment + fn layout(&self) -> Self::Layout; + + /// Return the maximum of each row + /// Units only output values for rows they participate in + fn rowwise_max(&self) -> RowWise; + + /// Return the sum of each row + /// Units only output values for rows they participate in + fn rowwise_sum(&self) -> RowWise; + + /// Scale each element in a row by a value for this row + fn rowwise_scale(&mut self, val: &RowWise); + + /// Scale every element by a constant factor, and masks values identified by the mask + fn scale_and_mask(this: &mut Self, scale: E, mask: &M); + + /// Changes each value x_ij for e^(x_ij - m_i) for every row + fn exp_diff(&mut self, m: &RowWise); +} + +#[cube] +/// Describes which elements of a fragment should be masked +pub trait FragmentMask: CubeType { + /// Returns `true` if the element at `local_pos` should be masked + fn should_mask(&self, local_pos: Coords2d) -> bool; +} diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/mod.rs b/crates/cubecl-attention/src/components/fragment/mod.rs similarity index 55% rename from crates/cubecl-attention/src/components/tile/dummy/attention_matmul/mod.rs rename to crates/cubecl-attention/src/components/fragment/mod.rs index 974f6651d..6c294f7fe 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/mod.rs +++ b/crates/cubecl-attention/src/components/fragment/mod.rs @@ -1,6 +1,9 @@ pub mod accelerated; pub mod dummy_register; +pub mod unit_register; mod base; +mod fragments; pub use base::*; +pub use fragments::*; diff --git a/crates/cubecl-attention/src/components/fragment/unit_register/attention.rs b/crates/cubecl-attention/src/components/fragment/unit_register/attention.rs new file mode 100644 index 000000000..0a7ac4980 --- /dev/null +++ b/crates/cubecl-attention/src/components/fragment/unit_register/attention.rs @@ -0,0 +1,345 @@ +use std::cmp::max; + +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_matmul::components::tile::StridedTile; +use cubecl_std::tensor::layout::Coords2d; + +use crate::components::AttentionPrecision; +use crate::components::attention_types::*; +use crate::components::fragment::FragmentAttentionConfig; +use crate::components::fragment::unit_register::UnitRegisterFragmentAttentionConfig; +use crate::components::fragment::{FragmentMask, FragmentMaskExpand}; +use crate::components::tile::RowVal; +use crate::components::tile::RowWise; + +use crate::components::fragment::FragmentAttention; +use crate::components::fragment::{FragmentLayout, FragmentLayoutExpand}; +use crate::components::fragment::{FragmentOps, FragmentOpsExpand}; + +pub struct UnitRegisterFragmentAttention; + +#[derive(CubeType)] +pub struct UnitTile { + data: Array, + layout: UnitTileLayout, +} + +#[derive(CubeType, Copy, Clone)] +pub struct UnitTileLayout { + #[cube(comptime)] + num_rows: u32, + #[cube(comptime)] + num_cols: u32, +} + +#[cube] +impl UnitTile { + pub fn new(layout: UnitTileLayout) -> UnitTile { + let data = Array::::new(comptime!(layout.num_rows * layout.num_cols)); + UnitTile:: { data, layout } + } + + pub fn zero(&mut self) { + for i in 0..self.layout.num_rows * self.layout.num_cols { + self.data[i] = E::from_int(0); + } + } + + pub fn get(&self, i: u32, j: u32) -> E { + self.data[i * self.layout.num_cols + j] + } + + pub fn accumulate(&mut self, i: u32, j: u32, val: E) { + self.data[i * self.layout.num_cols + j] += val; + } +} + +#[cube] +impl UnitTileLayout { + pub fn new(#[comptime] num_rows: u32, #[comptime] num_cols: u32) -> UnitTileLayout { + UnitTileLayout { num_rows, num_cols } + } +} + +#[cube] +impl FragmentLayout for UnitTileLayout { + fn absolute_pos(&self, local_pos: Coords2d) -> Coords2d { + local_pos + } + + fn num_units_per_row(&self) -> comptime_type!(u32) { + 1u32 + } +} + +#[cube] +impl FragmentOps for UnitTile { + type Layout = UnitTileLayout; + + fn rowwise_max(&self) -> RowWise { + let mut vals = Sequence::new(); + + #[unroll] + for r in 0..self.layout.num_rows { + let row_offset = r * self.layout.num_cols; + let mut val = E::min_value(); + + #[unroll] + for c in 0..self.layout.num_cols { + let index = row_offset + c; + val = Max::max(val, self.data[index]); + } + + vals.push(RowVal:: { val }); + } + + RowWise:: { + num_rows: self.layout.num_rows, + vals, + } + } + + fn rowwise_sum(&self) -> RowWise { + let mut vals = Sequence::new(); + + #[unroll] + for r in 0..self.layout.num_rows { + let row_offset = r * self.layout.num_cols; + let mut val = E::from_int(0); + + #[unroll] + for c in 0..self.layout.num_cols { + let index = row_offset + c; + val += self.data[index]; + } + + vals.push(RowVal:: { val }); + } + + RowWise:: { + num_rows: self.layout.num_rows, + vals, + } + } + + fn rowwise_scale(&mut self, scale: &RowWise) { + #[unroll] + for r in 0..self.layout.num_rows { + let row_offset = r * self.layout.num_cols; + #[unroll] + for c in 0..self.layout.num_cols { + let index = row_offset + c; + self.data[index] = self.data[index] * scale.index(r); + } + } + } + + fn scale_and_mask(this: &mut Self, scale: E, mask: &M) { + #[unroll] + for r in 0..this.layout.num_rows { + let row_offset = r * this.layout.num_cols; + #[unroll] + for c in 0..this.layout.num_cols { + let index = row_offset + c; + this.data[index] = this.data[index] * scale + + E::cast_from(mask.should_mask((r, c).runtime())) * E::min_value(); + } + } + } + + fn exp_diff(&mut self, val: &RowWise) { + #[unroll] + for r in 0..self.layout.num_rows { + let row_offset = r * self.layout.num_cols; + #[unroll] + for c in 0..self.layout.num_cols { + let index = row_offset + c; + self.data[index] = Exp::exp(self.data[index] - val.index(r)); + } + } + } + + fn layout(&self) -> Self::Layout { + self.layout + } +} + +#[cube] +impl FragmentMask for UnitTile { + fn should_mask(&self, local_pos: Coords2d) -> bool { + bool::cast_from(self.data[local_pos.0 * self.layout.num_cols + local_pos.1]) + } +} + +#[cube] +impl FragmentAttention for UnitRegisterFragmentAttention { + type Config = UnitRegisterFragmentAttentionConfig; + + type Query = UnitTile>; + type KeyValue = UnitTile>; + type Mask = UnitTile>; + type Softmax = UnitTile>; + type Accumulator = UnitTile>; + type FragmentLayout = UnitTileLayout; + + fn softmax_layout(#[comptime] config: Self::Config) -> Self::FragmentLayout { + UnitTileLayout { + num_rows: config.attention_tile_size().seq_q, + num_cols: config.attention_tile_size().seq_kv, + } + } + + fn score_matmul( + lhs: &Self::Query, + rhs: &Self::KeyValue, + out: &mut Self::Softmax, + #[comptime] config: Self::Config, + ) { + let (m, n, k) = comptime! {let (m, n, k): (u32, u32, u32) = config.attention_tile_size().to_score_matmul_tile_size().into(); (m, n, k)}; + unit_inner_matmul(lhs, rhs, out, m, n, k); + } + + fn value_matmul( + lhs: &Self::Softmax, + rhs: &Self::KeyValue, + out: &mut Self::Accumulator, + #[comptime] config: Self::Config, + ) { + let (m, n, k) = comptime! {let (m, n, k): (u32, u32, u32) = config.attention_tile_size().to_value_matmul_tile_size().into(); (m, n, k)}; + unit_inner_matmul(lhs, rhs, out, m, n, k); + } + + fn allocate_key_value(#[comptime] config: Self::Config) -> Self::KeyValue { + UnitTile::new(UnitTileLayout::new( + comptime!(max( + config.attention_tile_size().head_dim, + config.attention_tile_size().seq_kv, + )), + comptime!(max( + config.attention_tile_size().seq_kv, + config.attention_tile_size().val_dim, + )), + )) + } + + fn allocate_key(#[comptime] config: Self::Config) -> Self::KeyValue { + UnitTile::new(UnitTileLayout::new( + config.attention_tile_size().head_dim, + config.attention_tile_size().seq_kv, + )) + } + + fn allocate_value(#[comptime] config: Self::Config) -> Self::KeyValue { + UnitTile::new(UnitTileLayout::new( + config.attention_tile_size().seq_kv, + config.attention_tile_size().val_dim, + )) + } + + fn allocate_mask(#[comptime] config: Self::Config) -> Self::Mask { + UnitTile::new(>::softmax_layout(config)) + } + + fn allocate_softmax(#[comptime] config: Self::Config) -> Self::Softmax { + UnitTile::new(>::softmax_layout(config)) + } + + fn allocate_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator { + UnitTile::new(UnitTileLayout::new( + config.attention_tile_size().seq_q, + config.attention_tile_size().val_dim, + )) + } + + fn allocate_query(#[comptime] config: Self::Config) -> Self::Query { + UnitTile::new(UnitTileLayout::new( + config.attention_tile_size().seq_q, + config.attention_tile_size().head_dim, + )) + } + + fn fill_query(tile: &StridedTile, fragment: &mut Self::Query) { + strided_tile_to_array_tile(tile, fragment); + } + + fn fill_key_value( + tile: &StridedTile, + fragment: &mut Self::KeyValue, + #[comptime] _config: Self::Config, + ) { + strided_tile_to_array_tile(tile, fragment); + } + + fn fill_mask( + tile: &StridedTile, + fragment: &mut Self::Mask, + #[comptime] _config: Self::Config, + ) { + strided_tile_to_array_tile(tile, fragment); + } + + fn zero_softmax(softmax: &mut Self::Softmax, #[comptime] _config: Self::Config) { + softmax.zero(); + } + + fn zero_accumulator(acc: &mut Self::Accumulator) { + acc.zero(); + } + + fn write_results( + out: &Self::Accumulator, + slice: &mut SliceMut>, + #[comptime] _config: Self::Config, + ) { + array_tile_to_slice(out, slice) + } +} + +#[cube] +fn strided_tile_to_array_tile( + strided_tile: &StridedTile, + unit_tile: &mut UnitTile, +) { + for row in 0..unit_tile.layout.num_rows { + for col in 0..unit_tile.layout.num_cols { + unit_tile.data[row * unit_tile.layout.num_cols + col] = + E2::cast_from(strided_tile.get_line(row, col)) + } + } +} + +#[cube] +fn array_tile_to_slice( + unit_tile: &UnitTile, + slice: &mut SliceMut>, +) { + for row in 0..unit_tile.layout.num_rows { + for col in 0..unit_tile.layout.num_cols { + let index = row * unit_tile.layout.num_cols + col; + slice[index] = Line::cast_from(unit_tile.data[index]); + } + } +} + +#[cube] +fn unit_inner_matmul( + lhs: &UnitTile, + rhs: &UnitTile, + out: &mut UnitTile, + #[comptime] m: u32, + #[comptime] n: u32, + #[comptime] k: u32, +) { + for m_ in 0..m { + for n_ in 0..n { + let mut sum = Acc::from_int(0); + for k_ in 0..k { + let lhs_val = lhs.get(m_, k_); + let rhs_val = rhs.get(k_, n_); + sum += Acc::cast_from(lhs_val) * Acc::cast_from(rhs_val); + } + out.accumulate(m_, n_, sum); + } + } +} diff --git a/crates/cubecl-attention/src/components/fragment/unit_register/config.rs b/crates/cubecl-attention/src/components/fragment/unit_register/config.rs new file mode 100644 index 000000000..b1319c211 --- /dev/null +++ b/crates/cubecl-attention/src/components/fragment/unit_register/config.rs @@ -0,0 +1,73 @@ +use std::fmt::Debug; +use std::hash::Hash; + +use crate::components::fragment::FragmentAttentionConfig; +use crate::components::{AttentionPrecision, AttentionSetupError, AttentionTileSize}; + +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +pub struct UnitRegisterFragmentAttentionConfig { + plane_dim: u32, + num_planes: u32, + attention_tile_size: AttentionTileSize, + query_stage_line_size: u32, + key_value_stage_line_size: u32, + check_bounds: bool, + causal_mask: bool, + materialized_mask: bool, +} + +impl FragmentAttentionConfig for UnitRegisterFragmentAttentionConfig { + fn plane_dim(&self) -> u32 { + self.plane_dim + } + + fn num_planes(&self) -> u32 { + self.num_planes + } + + fn attention_tile_size(&self) -> AttentionTileSize { + self.attention_tile_size + } + + fn num_rows_per_unit(&self) -> u32 { + self.attention_tile_size.seq_q + } + + fn causal_mask(&self) -> bool { + self.causal_mask + } + + fn materialized_mask(&self) -> bool { + self.materialized_mask + } +} + +impl UnitRegisterFragmentAttentionConfig { + #[allow(clippy::too_many_arguments)] + pub fn new( + plane_dim: u32, + attention_tile_size: AttentionTileSize, + query_stage_line_size: u32, + key_value_stage_line_size: u32, + check_bounds: bool, + num_planes: u32, + causal_mask: bool, + materialized_mask: bool, + ) -> Result { + Self { + plane_dim, + num_planes, + attention_tile_size, + query_stage_line_size, + key_value_stage_line_size, + check_bounds, + causal_mask, + materialized_mask, + } + .validate() + } + + pub fn validate(self) -> Result { + Ok(self) + } +} diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/mod.rs b/crates/cubecl-attention/src/components/fragment/unit_register/mod.rs similarity index 53% rename from crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/mod.rs rename to crates/cubecl-attention/src/components/fragment/unit_register/mod.rs index 44a0b24f7..967f0ae95 100644 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/mod.rs +++ b/crates/cubecl-attention/src/components/fragment/unit_register/mod.rs @@ -1,6 +1,6 @@ +mod attention; mod config; -mod matmul; mod setup; +pub use attention::*; pub use config::*; -pub use matmul::*; diff --git a/crates/cubecl-attention/src/components/fragment/unit_register/setup.rs b/crates/cubecl-attention/src/components/fragment/unit_register/setup.rs new file mode 100644 index 000000000..cff9a3fa4 --- /dev/null +++ b/crates/cubecl-attention/src/components/fragment/unit_register/setup.rs @@ -0,0 +1,42 @@ +use cubecl_core::client::ComputeClient; +use cubecl_matmul::components::ComputeResources; + +use crate::components::fragment::unit_register::UnitRegisterFragmentAttention; +use crate::components::fragment::unit_register::UnitRegisterFragmentAttentionConfig; +use crate::components::{ + AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, + AttentionSetupError, InvalidConfigError, fragment::FragmentAttentionFamily, +}; + +impl FragmentAttentionFamily for UnitRegisterFragmentAttention { + type FragmentAttention = UnitRegisterFragmentAttention; + + type Config = UnitRegisterFragmentAttentionConfig; + + fn requires_accelerator() -> bool { + false + } + + fn computation_resources() -> Result { + Ok(ComputeResources::Units(1)) + } + + fn setup( + _client: &ComputeClient, + problem: &AttentionProblem, + selection: &AttentionSelection, + line_sizes: &AttentionLineSizes, + num_planes: u32, + ) -> Result { + UnitRegisterFragmentAttentionConfig::new::( + selection.plane_dim, + selection.tiling_scheme.tile_size, + line_sizes.query as u32, + line_sizes.key as u32, + !(problem.seq_kv as u32).is_multiple_of(selection.tiling_scheme.tile_size.seq_kv), + num_planes, + problem.causal, + problem.masked, + ) + } +} diff --git a/crates/cubecl-attention/src/components/global/base.rs b/crates/cubecl-attention/src/components/global/base.rs index b9131b4d2..6395859a4 100644 --- a/crates/cubecl-attention/src/components/global/base.rs +++ b/crates/cubecl-attention/src/components/global/base.rs @@ -1,12 +1,14 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; + +use crate::components::global::simple::{AttentionReader, AttentionWriter}; use cubecl_matmul::components::{global::memory::GlobalMemoryConfig, stage::StageMemoryConfig}; -use cubecl_std::tensor::r#virtual::VirtualTensor; +use cubecl_std::{CubeOption, tensor::r#virtual::VirtualTensor}; use crate::components::{ AttentionIdent, AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, AttentionSetupError, AttentionTilingScheme, AvailableLineSizes, attention_types::*, - global::dummy::QueryReader, stage::StageAttentionConfig, + global::simple::QueryReader, stage::StageAttentionConfig, }; use std::{fmt::Debug, hash::Hash}; @@ -22,7 +24,7 @@ pub trait GlobalAttentionFamily: Send + Sync + 'static { /// /// This function may return an error if the configuration cannot be supported on the current runtime. fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &AttentionProblem, selection: &AttentionSelection, line_sizes: &AttentionLineSizes, @@ -39,12 +41,14 @@ pub trait GlobalAttentionFamily: Send + Sync + 'static { #[cube] pub trait GlobalAttention: 'static + Send + Sync { /// Writes to Out at the same offset it loaded Query - type Writer: CubeType; + type Writer: AttentionWriter, OG>; /// Loads to SMEM transposed - type KeyReader: CubeType; + type KeyReader: AttentionReader, Self::Config>; + /// Loads to SMEM as is + type ValueReader: AttentionReader, Self::Config>; /// Loads to SMEM as is - type ValueReader: CubeType; + type MaskReader: CubeType; /// The configuration type associated with this Attention. type Config: GlobalAttentionConfig; @@ -53,6 +57,7 @@ pub trait GlobalAttention: 'static + Send + Sync { query_reader: QueryReader, key_reader: Self::KeyReader, value_reader: Self::ValueReader, + mask_reader: Self::MaskReader, writer: Self::Writer, seq_q: u32, seq_kv: u32, @@ -75,6 +80,13 @@ pub trait GlobalAttention: 'static + Send + Sync { #[comptime] config: Self::Config, ) -> Self::ValueReader; + fn init_mask_reader( + q_offset: u32, + mask: CubeOption>>, + seq_kv_shape: u32, + #[comptime] config: Self::Config, + ) -> Self::MaskReader; + fn init_writer( q_offset: u32, out: VirtualTensor, ReadWrite>, @@ -97,4 +109,6 @@ pub trait GlobalAttentionConfig: fn global_memory_config(&self, ident: AttentionIdent) -> GlobalMemoryConfig; fn tiling_scheme(&self) -> AttentionTilingScheme; + + fn causal_mask(&self) -> bool; } diff --git a/crates/cubecl-attention/src/components/global/dummy/mod.rs b/crates/cubecl-attention/src/components/global/dummy/mod.rs deleted file mode 100644 index 5514f8659..000000000 --- a/crates/cubecl-attention/src/components/global/dummy/mod.rs +++ /dev/null @@ -1,12 +0,0 @@ -mod attention; -mod config; -mod read; -mod setup; -mod writer; - -pub use attention::*; -pub use read::*; -pub use setup::DummyGlobalAttentionFamily; - -// tmp -pub use config::DummyGlobalConfig; diff --git a/crates/cubecl-attention/src/components/global/dummy/read.rs b/crates/cubecl-attention/src/components/global/dummy/read.rs deleted file mode 100644 index b459f83cc..000000000 --- a/crates/cubecl-attention/src/components/global/dummy/read.rs +++ /dev/null @@ -1,226 +0,0 @@ -use crate::components::attention_types::*; -use cubecl_core as cubecl; -use cubecl_core::prelude::*; -use cubecl_matmul::components::global::{ - memory::{GlobalIterator, ViewDirection}, - read::tiled::TiledLayout, -}; -use cubecl_matmul::components::stage::StridedStage; -use cubecl_matmul::components::tile::StridedTile; -use cubecl_matmul::components::{MatrixLayout, StageIdent}; -use cubecl_std::tensor::{View, layout::Coords2d}; -use std::marker::PhantomData; - -use crate::components::global::base::GlobalAttentionConfig; -use crate::components::stage::StageAttentionConfig; -use crate::components::tile::AttentionTilingLayout; -use crate::components::{AttentionIdent, AttentionPrecision}; - -#[derive(CubeType)] -pub struct QueryReader { - query: View>, Coords2d>, -} - -#[derive(CubeType)] -pub struct DummyKeyReader { - global_iter: GlobalIterator>>, - stage_memory: StridedStage, AttentionTilingLayout>, - - #[cube(comptime)] - _phantom: PhantomData, -} - -#[derive(CubeType)] -pub struct DummyValueReader { - global_iter: GlobalIterator>>, - stage_memory: StridedStage, AttentionTilingLayout>, - - #[cube(comptime)] - _phantom: PhantomData, -} - -#[cube] -impl QueryReader { - pub fn new(q_offset: u32, query: View>, Coords2d>) -> Self { - let query = query.slice((q_offset, 0), query.shape()); - - QueryReader:: { query } - } - - pub fn get_tile( - &self, - tile: Coords2d, - #[comptime] config: S, - ) -> StridedTile> { - let (row_in_partition, col) = tile; - let attention_tile_size = config.tiling_scheme().tile_size; - - let row = row_in_partition + UNIT_POS_Y * config.tiling_scheme().partition_size.seq_q; - - StridedTile::>::new_strided( - self.query - .slice( - ( - row * attention_tile_size.seq_q, - col * attention_tile_size.head_dim, - ), - (attention_tile_size.seq_q, attention_tile_size.head_dim).runtime(), - ) - .to_linear_slice(), - config.tiling_scheme().elements_in_partition_head_dim(), - MatrixLayout::RowMajor, - ) - } -} - -#[cube] -impl DummyKeyReader { - pub fn new(key: View>, Coords2d>, step: u32, #[comptime] config: G) -> Self { - let global_iter = GlobalIterator::new(key, step, ViewDirection::Row, false); - let stage_memory = StridedStage::new(StageIdent::Rhs, config.score_stage_memory_config()); - - DummyKeyReader:: { - global_iter, - stage_memory, - _phantom: PhantomData, - } - } - - pub fn stage(&self) -> StridedStage, AttentionTilingLayout> { - self.stage_memory - } - - pub fn read_transposed(&mut self, #[comptime] config: G) { - // TODO this reader is bad - if UNIT_POS_Y == 0 { - let memory_config = config.global_memory_config(AttentionIdent::Key); - - let mut slice = self.stage_memory.as_slice_mut(1u32); - - let tile_rows_load = memory_config.elements_in_tile_row; - let tile_cols_load = memory_config.elements_in_tile_col; - let partition_rows_load = memory_config.elements_in_stage_row / tile_rows_load; - let partition_cols_load = memory_config.elements_in_stage_col / tile_cols_load; - - let units_per_tile_row = comptime!(config.plane_dim() / tile_rows_load); - let tile_cols_per_unit = comptime!(div_ceil(tile_cols_load, units_per_tile_row)); - - let row_load_in_tile = UNIT_POS_X / units_per_tile_row; - let col_load_in_tile_start = (UNIT_POS_X % units_per_tile_row) * tile_cols_per_unit; - - // Assumes row tiling order - let num_elements_per_tile = tile_rows_load * tile_cols_load; - let tile_row_stride_store = partition_rows_load * num_elements_per_tile; - let tile_col_stride_store = num_elements_per_tile; - - let layout = TiledLayout::new(memory_config); - let view = self.global_iter.view().view(layout); - - #[unroll] - for tile_row_load in 0..partition_rows_load { - #[unroll] - for tile_col_load in 0..partition_cols_load { - if row_load_in_tile < tile_rows_load { - #[unroll] - for i in 0..tile_cols_per_unit { - let col_load = col_load_in_tile_start + i; - - if col_load < tile_cols_load { - let tile_row_store = tile_col_load; - let tile_col_store = tile_row_load; - let tile_row_store_offset = tile_row_store * tile_row_stride_store; - let tile_col_store_offset = tile_col_store * tile_col_stride_store; - let store_offset = tile_row_store_offset + tile_col_store_offset; - - let index_load = row_load_in_tile * tile_cols_load + col_load; - let index_store = col_load * tile_rows_load + row_load_in_tile; - - slice[index_store + store_offset] = Line::cast_from( - view.read_checked(((tile_row_load, tile_col_load), index_load)), - ); - } - } - } - } - } - } - } - - pub fn advance_view(&mut self) { - self.global_iter.advance(); - } -} - -#[cube] -impl DummyValueReader { - pub fn new(value: View>, Coords2d>, step: u32, #[comptime] config: G) -> Self { - let global_iter = GlobalIterator::new(value, step, ViewDirection::Row, false); - let stage_memory = StridedStage::new(StageIdent::Rhs, config.value_stage_memory_config()); - - DummyValueReader:: { - global_iter, - stage_memory, - _phantom: PhantomData, - } - } - - pub fn stage(&self) -> StridedStage, AttentionTilingLayout> { - self.stage_memory - } - - pub fn read(&mut self, #[comptime] config: G) { - if UNIT_POS_Y == 0 { - // TODO this reader is bad, it's not coalesced - let memory_config = config.global_memory_config(AttentionIdent::Value); - let mut slice = self.stage_memory.as_slice_mut(1u32); - - let tile_rows = memory_config.elements_in_tile_row; - let tile_cols = memory_config.elements_in_tile_col; - let partition_rows = memory_config.elements_in_stage_row / tile_rows; - let partition_cols = memory_config.elements_in_stage_col / tile_cols; - - let units_per_tile_row = comptime!(config.plane_dim() / tile_rows); - let tile_cols_per_unit = comptime!(div_ceil(tile_cols, units_per_tile_row)); - - let row_in_tile = UNIT_POS_X / units_per_tile_row; - let col_in_tile_start = (UNIT_POS_X % units_per_tile_row) * tile_cols_per_unit; - - // Assumes row tiling order - let num_elements_per_tile = tile_rows * tile_cols; - let tile_row_stride = partition_cols * num_elements_per_tile; - let tile_col_stride = num_elements_per_tile; - - let layout = TiledLayout::new(memory_config); - let view = self.global_iter.view().view(layout); - - #[unroll] - for tile_row in 0..partition_rows { - #[unroll] - for tile_col in 0..partition_cols { - if row_in_tile < tile_rows { - #[unroll] - for i in 0..tile_cols_per_unit { - let col = col_in_tile_start + i; - - if col < tile_cols { - let tile_row_offset = tile_row * tile_row_stride; - let tile_col_offset = tile_col * tile_col_stride; - let offset = tile_row_offset + tile_col_offset; - - let index = row_in_tile * tile_cols + col; - - slice[index + offset] = Line::cast_from( - view.read_checked(((tile_row, tile_col), index)), - ); - } - } - } - } - } - } - } - - pub fn advance_view(&mut self) { - self.global_iter.advance(); - } -} diff --git a/crates/cubecl-attention/src/components/global/dummy/writer.rs b/crates/cubecl-attention/src/components/global/dummy/writer.rs deleted file mode 100644 index f82a3ee02..000000000 --- a/crates/cubecl-attention/src/components/global/dummy/writer.rs +++ /dev/null @@ -1,87 +0,0 @@ -use cubecl::prelude::*; -use cubecl_core::{self as cubecl}; -use cubecl_matmul::components::{ - MatrixLayout, MatrixPrecision, - global::{ - PartitionedStage, WriteEvent, WriteEventExpand, WriteEventListener, - memory::GlobalMemoryConfig, - plane_write, - read::tiled::{TiledCoords, TiledLayout}, - }, - stage::StageMemoryConfig, -}; -use cubecl_std::tensor::{View, layout::Coords2d}; - -use crate::components::stage::StageAttentionConfig; - -#[derive(CubeType)] -pub struct DummyWriter { - global: View, TiledCoords, ReadWrite>, - stage: PartitionedStage, - #[cube(comptime)] - plane_dim: u32, - #[cube(comptime)] - config: GlobalMemoryConfig, -} - -#[cube] -impl DummyWriter { - pub fn new( - global: View, Coords2d, ReadWrite>, - #[comptime] global_config: GlobalMemoryConfig, - #[comptime] stage_config: S, - ) -> Self { - let stage_mem_config = comptime! { - let tile_rows = stage_config.tiling_scheme().elements_in_partition_val_dim(); - let tile_cols = stage_config.tiling_scheme().elements_in_partition_seq_q(); - let planes = stage_config.num_planes(); - - StageMemoryConfig { - num_main_flow_planes: planes, - elements_in_tile_row: tile_rows, - elements_in_tile_col: tile_cols, - tiles_in_stage_row: 1, - tiles_in_stage_col: planes, - stage_line_size: 1, - matrix_layout: MatrixLayout::RowMajor, - num_stages: 1, - } - }; - - let stage = PartitionedStage::new((0u32, UNIT_POS_Y), stage_mem_config); - - DummyWriter:: { - global: global.view_mut(TiledLayout::new(global_config)), - stage, - plane_dim: stage_config.plane_dim(), - config: global_config, - } - } - - fn write(&mut self, tile_pos: Coords2d) { - plane_write::( - &mut self.global, - &self.stage.unit_tile, - tile_pos, - comptime![self.plane_dim], - comptime![self.config], - ) - } - - pub fn stage(&mut self) -> PartitionedStage { - self.stage - } -} - -#[cube] -impl WriteEventListener for DummyWriter { - fn on_event(this: &mut Self, event: WriteEvent) { - #[allow(clippy::single_match)] - match event { - WriteEvent::TileStored { tile } => { - this.write(tile); - } - _ => {} - } - } -} diff --git a/crates/cubecl-attention/src/components/global/mod.rs b/crates/cubecl-attention/src/components/global/mod.rs index 9c316f052..3f277f2f2 100644 --- a/crates/cubecl-attention/src/components/global/mod.rs +++ b/crates/cubecl-attention/src/components/global/mod.rs @@ -1,4 +1,4 @@ -pub mod dummy; +pub mod simple; mod base; mod layout; diff --git a/crates/cubecl-attention/src/components/global/dummy/attention.rs b/crates/cubecl-attention/src/components/global/simple/attention.rs similarity index 51% rename from crates/cubecl-attention/src/components/global/dummy/attention.rs rename to crates/cubecl-attention/src/components/global/simple/attention.rs index e82f996f9..254c21852 100644 --- a/crates/cubecl-attention/src/components/global/dummy/attention.rs +++ b/crates/cubecl-attention/src/components/global/simple/attention.rs @@ -1,26 +1,28 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use cubecl_matmul::components::{global::PartitionedStage, stage::StridedStage}; +use cubecl_matmul::components::global::PartitionedStage; +use cubecl_matmul::components::stage::StridedStage; use cubecl_std::tensor::r#virtual::VirtualTensor; +use cubecl_std::{CubeOption, CubeOptionExpand}; use std::marker::PhantomData; -use crate::components::GlobalMask; use crate::components::attention_types::*; use crate::components::global::base::GlobalAttentionConfig; -use crate::components::global::dummy::writer::DummyWriter; +use crate::components::global::simple::reader::{AttentionReader, AttentionReaderExpand}; +use crate::components::global::simple::{AttentionWriter, AttentionWriterExpand, MaskReader}; use crate::components::global::{ AttentionGlobalLayout, - dummy::{DummyKeyReader, DummyValueReader}, + simple::{DummyKeyReader, DummyValueReader}, }; -use crate::components::stage::StageAttention; +use crate::components::stage::{AttentionPartitioner, StageAttention}; use crate::components::tile::AttentionTilingLayout; -use crate::components::{AttentionIdent, global::dummy::QueryReader}; +use crate::components::{AttentionIdent, global::simple::QueryReader}; use crate::components::{ AttentionPrecision, - global::{GlobalAttention, dummy::config::DummyGlobalConfig}, + global::{GlobalAttention, simple::config::SimpleGlobalConfig}, }; -pub struct DummyGlobalAttention> { +pub struct SimpleGlobalAttention> { _phantom: PhantomData<(AP, SA)>, } @@ -33,65 +35,81 @@ impl< OutStage = PartitionedStage>, >, AP: AttentionPrecision, -> GlobalAttention for DummyGlobalAttention +> GlobalAttention for SimpleGlobalAttention { type KeyReader = DummyKeyReader; type ValueReader = DummyValueReader; + type MaskReader = MaskReader; - type Writer = DummyWriter<(OG, OS)>; + type Writer = ::Writer, OG>; - type Config = DummyGlobalConfig; + type Config = SimpleGlobalConfig; fn execute( query_reader: QueryReader, mut key_reader: Self::KeyReader, mut value_reader: Self::ValueReader, + mut mask_reader: Self::MaskReader, mut writer: Self::Writer, seq_q: u32, seq_kv: u32, #[comptime] config: Self::Config, ) { - let key_stage = key_reader.stage(); - let value_stage = value_reader.stage(); + let mut key_stage = key_reader.init_stage(config); + let mut value_stage = value_reader.init_stage(config); - let mut stage_state = SA::init_state(config.stage_config()); + let mut query_registers = SA::init_query(config.stage_config()); + let mut key_value_registers = SA::init_key_value(config.stage_config()); + let mut mask_registers = + SA::init_mask(CubeOption::new_Some((seq_q, seq_kv)), config.stage_config()); + let mut softmax_registers = SA::init_softmax(config.stage_config()); + let mut accumulator_registers = SA::init_accumulator(config.stage_config()); - let (query, mut key_value, mut softmax, mut accumulator) = - SA::init_partitions(query_reader, config.stage_config()); + let mut stage_state = SA::init_state(config.stage_config()); let seq_kv_stage = config.tiling_scheme().elements_in_partition_seq_kv(); let num_stage_iterations = seq_kv.div_ceil(seq_kv_stage); - let mask = GlobalMask::new(seq_q, seq_kv, config.tiling_scheme()); - for i in 0..num_stage_iterations { - key_reader.read_transposed(config); - value_reader.read(config); + SA::read_query(&query_reader, &mut query_registers, config.stage_config()); + + for _ in 0..num_stage_iterations { + key_reader.read_global(&mut key_stage, config); + value_reader.read_global(&mut value_stage, config); + + SA::read_mask(&mask_reader, &mut mask_registers, config.stage_config()); + sync_cube(); SA::execute( + &query_registers, &key_stage, &value_stage, - &query, - &mut key_value, - &mut softmax, - mask.to_stage(CUBE_POS, i), - &mut accumulator, + &mut key_value_registers, + &mask_registers, + &mut softmax_registers, + &mut accumulator_registers, &mut stage_state, config.stage_config(), ); sync_cube(); + key_reader.advance_view(); value_reader.advance_view(); + mask_reader.advance_view(); } - SA::rescale(&mut accumulator, stage_state, config.stage_config()); + SA::rescale( + &mut accumulator_registers, + stage_state, + config.stage_config(), + ); let mut out_stage = writer.stage(); SA::write::( - &accumulator, + &accumulator_registers, &mut out_stage, &mut writer, config.stage_config(), @@ -119,7 +137,7 @@ impl< let step = reduction_step::(config); let layout = AttentionGlobalLayout::new(&key, 0, config.global_memory_config(AttentionIdent::Key)); - DummyKeyReader::new(key.view(layout), step, config) + DummyKeyReader::new(key.view(layout), step) } fn init_value_reader( @@ -132,7 +150,37 @@ impl< 0, config.global_memory_config(AttentionIdent::Value), ); - DummyValueReader::new(value.view(layout), step, config) + DummyValueReader::new(value.view(layout), step) + } + + fn init_mask_reader( + q_offset: u32, + mask: CubeOption>>, + seq_kv_shape: u32, + #[comptime] config: Self::Config, + ) -> Self::MaskReader { + let step = reduction_step::(config); + let inner_q_offset = ::seq_q_index() + * config.tiling_scheme().elements_in_partition_seq_q(); + + match mask { + CubeOption::Some(mask) => { + let layout = AttentionGlobalLayout::new( + &mask, + 0, + config.global_memory_config(AttentionIdent::Value), + ); + + MaskReader::new_materialized( + q_offset, + inner_q_offset, + mask.view(layout), + step, + seq_kv_shape, + ) + } + CubeOption::None => MaskReader::new_logical(q_offset + inner_q_offset, step), + } } fn init_writer( diff --git a/crates/cubecl-attention/src/components/global/dummy/config.rs b/crates/cubecl-attention/src/components/global/simple/config.rs similarity index 83% rename from crates/cubecl-attention/src/components/global/dummy/config.rs rename to crates/cubecl-attention/src/components/global/simple/config.rs index 6fecc75c7..bc3bc6b5f 100644 --- a/crates/cubecl-attention/src/components/global/dummy/config.rs +++ b/crates/cubecl-attention/src/components/global/simple/config.rs @@ -9,12 +9,13 @@ use crate::components::{ }; #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] -pub struct DummyGlobalConfig { +pub struct SimpleGlobalConfig { stage_config: S, num_planes: u32, + causal_mask: bool, } -impl GlobalAttentionConfig for DummyGlobalConfig { +impl GlobalAttentionConfig for SimpleGlobalConfig { type StageConfig = S; fn score_stage_memory_config(&self) -> StageMemoryConfig { @@ -66,13 +67,22 @@ impl GlobalAttentionConfig for DummyGlobalConfig { fn tiling_scheme(&self) -> AttentionTilingScheme { self.stage_config.tiling_scheme() } + + fn causal_mask(&self) -> bool { + self.causal_mask + } } -impl DummyGlobalConfig { - pub fn new(stage_config: S, num_planes: u32) -> Result { +impl SimpleGlobalConfig { + pub fn new( + stage_config: S, + num_planes: u32, + causal_mask: bool, + ) -> Result { Self { stage_config, num_planes, + causal_mask, } .validate() } diff --git a/crates/cubecl-attention/src/components/global/simple/mod.rs b/crates/cubecl-attention/src/components/global/simple/mod.rs new file mode 100644 index 000000000..cf6ded21d --- /dev/null +++ b/crates/cubecl-attention/src/components/global/simple/mod.rs @@ -0,0 +1,10 @@ +mod attention; +mod config; +mod reader; +mod setup; +mod writer; + +pub use attention::*; +pub use reader::*; +pub use setup::SimpleGlobalAttentionFamily; +pub use writer::*; diff --git a/crates/cubecl-attention/src/components/global/simple/reader/base.rs b/crates/cubecl-attention/src/components/global/simple/reader/base.rs new file mode 100644 index 000000000..4981a876c --- /dev/null +++ b/crates/cubecl-attention/src/components/global/simple/reader/base.rs @@ -0,0 +1,15 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::components::global::GlobalAttentionConfig; + +#[cube] +pub trait AttentionReader { + type Stage: CubeType; + + fn init_stage(&mut self, #[comptime] config: G) -> Self::Stage; + + fn read_global(&mut self, stage: &mut Self::Stage, #[comptime] config: G); + + fn advance_view(&mut self); +} diff --git a/crates/cubecl-attention/src/components/global/simple/reader/key.rs b/crates/cubecl-attention/src/components/global/simple/reader/key.rs new file mode 100644 index 000000000..24c0157c2 --- /dev/null +++ b/crates/cubecl-attention/src/components/global/simple/reader/key.rs @@ -0,0 +1,109 @@ +use crate::components::attention_types::*; +use crate::components::global::simple::reader::{AttentionReader, AttentionReaderExpand}; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_matmul::components::StageIdent; +use cubecl_matmul::components::global::{ + memory::{GlobalIterator, ViewDirection}, + read::tiled::TiledLayout, +}; +use cubecl_matmul::components::stage::StridedStage; +use cubecl_std::tensor::{View, layout::Coords2d}; +use std::marker::PhantomData; + +use crate::components::global::base::GlobalAttentionConfig; +use crate::components::tile::AttentionTilingLayout; +use crate::components::{AttentionIdent, AttentionPrecision}; + +#[derive(CubeType)] +pub struct DummyKeyReader { + global_iter: GlobalIterator>>, + + #[cube(comptime)] + _phantom: PhantomData, +} + +#[cube] +impl DummyKeyReader { + pub fn new(key: View>, Coords2d>, step: u32) -> Self { + let global_iter = GlobalIterator::new(key, step, ViewDirection::Row, false); + + DummyKeyReader:: { + global_iter, + _phantom: PhantomData, + } + } +} + +#[cube] +impl AttentionReader, G> + for DummyKeyReader +{ + type Stage = StridedStage, AttentionTilingLayout>; + + fn init_stage(&mut self, #[comptime] config: G) -> Self::Stage { + StridedStage::new(StageIdent::Rhs, config.score_stage_memory_config()) + } + + fn read_global(&mut self, stage: &mut Self::Stage, #[comptime] config: G) { + // TODO this reader is bad + if UNIT_POS_Y == 0 { + let memory_config = config.global_memory_config(AttentionIdent::Key); + + let mut slice = stage.as_slice_mut(1u32); + + let tile_rows_load = memory_config.elements_in_tile_row; + let tile_cols_load = memory_config.elements_in_tile_col; + let partition_rows_load = memory_config.elements_in_stage_row / tile_rows_load; + let partition_cols_load = memory_config.elements_in_stage_col / tile_cols_load; + + let units_per_tile_row = comptime!(config.plane_dim() / tile_rows_load); + let tile_cols_per_unit = comptime!(div_ceil(tile_cols_load, units_per_tile_row)); + + let row_load_in_tile = UNIT_POS_X / units_per_tile_row; + let col_load_in_tile_start = (UNIT_POS_X % units_per_tile_row) * tile_cols_per_unit; + + // Assumes row tiling order + let num_elements_per_tile = tile_rows_load * tile_cols_load; + let tile_row_stride_store = partition_rows_load * num_elements_per_tile; + let tile_col_stride_store = num_elements_per_tile; + + let layout = TiledLayout::new(memory_config); + let view = self.global_iter.view().view(layout); + + #[unroll] + for tile_row_load in 0..partition_rows_load { + #[unroll] + for tile_col_load in 0..partition_cols_load { + if row_load_in_tile < tile_rows_load { + #[unroll] + for i in 0..tile_cols_per_unit { + let col_load = col_load_in_tile_start + i; + + if col_load < tile_cols_load { + let tile_row_store = tile_col_load; + let tile_col_store = tile_row_load; + let tile_row_store_offset = tile_row_store * tile_row_stride_store; + let tile_col_store_offset = tile_col_store * tile_col_stride_store; + let store_offset = tile_row_store_offset + tile_col_store_offset; + + let index_load = row_load_in_tile * tile_cols_load + col_load; + let index_store = col_load * tile_rows_load + row_load_in_tile; + + slice[index_store + store_offset] = + Line::cast_from(view.read_checked(( + (tile_row_load, tile_col_load).runtime(), + index_load, + ))); + } + } + } + } + } + } + } + + fn advance_view(&mut self) { + self.global_iter.advance(); + } +} diff --git a/crates/cubecl-attention/src/components/global/simple/reader/mask.rs b/crates/cubecl-attention/src/components/global/simple/reader/mask.rs new file mode 100644 index 000000000..5a772242c --- /dev/null +++ b/crates/cubecl-attention/src/components/global/simple/reader/mask.rs @@ -0,0 +1,147 @@ +use crate::components::attention_types::*; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_matmul::components::MatrixLayout; +use cubecl_matmul::components::global::memory::{GlobalIterator, ViewDirection}; +use cubecl_matmul::components::tile::StridedTile; +use cubecl_std::tensor::{View, layout::Coords2d}; + +use crate::components::AttentionPrecision; +use crate::components::stage::{AttentionPartitioner, StageAttentionConfig}; +use cubecl_std::CubeOption; + +#[derive(CubeType)] +pub struct LogicalIterator { + row: u32, + col: RuntimeCell, + step_col: u32, +} + +#[cube] +impl LogicalIterator { + fn init(q_offset: u32, step_col: u32) -> LogicalIterator { + LogicalIterator { + row: q_offset, + col: RuntimeCell::new(0), + step_col, + } + } + + fn read(&self) -> Coords2d { + (self.row, self.col.read()) + } + + fn advance(&mut self) { + self.col.store(self.col.read() + self.step_col); + } +} + +#[derive(CubeType)] +pub struct MaterializedMaskReader { + global_iter: GlobalIterator>, + logical_iter: LogicalIterator, + // TODO not sure if mandatory, but i need for the stride when reading in global memory + seq_kv_shape: u32, +} + +#[derive(CubeType)] +pub enum MaskReader { + Materialized(MaterializedMaskReader>), + Logical(LogicalIterator), +} + +#[cube] +impl MaskReader { + pub fn new_logical(inner_q_offset: u32, step: u32) -> Self { + MaskReader::::new_Logical(LogicalIterator::init(inner_q_offset, step)) + } + + pub fn new_materialized( + outer_q_offset: u32, + inner_q_offset: u32, + mask: View>, Coords2d>, + step: u32, + seq_kv_shape: u32, + ) -> Self { + let mask = mask.slice((outer_q_offset, 0), mask.shape()); + let global_iter = GlobalIterator::new(mask, step, ViewDirection::Col, false); + + MaskReader::::new_Materialized(MaterializedMaskReader::new( + global_iter, + LogicalIterator::init(inner_q_offset, step), + seq_kv_shape, + )) + } + + pub fn read( + &self, + #[comptime] pos_in_partition: Coords2d, + #[comptime] config: S, + ) -> (Coords2d, CubeOption>>) { + match self { + MaskReader::Materialized(materialized_mask_reader) => { + materialized_mask_reader.read::(pos_in_partition, config) + } + MaskReader::Logical(logical_iterator) => { + (logical_iterator.read(), CubeOption::new_None()) + } + } + } + + pub fn advance_view(&mut self) { + match self { + MaskReader::Logical(logical_iter) => logical_iter.advance(), + MaskReader::Materialized(materialized_mask_reader) => { + materialized_mask_reader.advance() + } + } + } +} + +#[cube] +impl MaterializedMaskReader { + fn new( + global_iter: GlobalIterator>, + logical_iter: LogicalIterator, + seq_kv_shape: u32, + ) -> Self { + MaterializedMaskReader:: { + global_iter, + logical_iter, + seq_kv_shape, + } + } + + fn read( + &self, + #[comptime] pos_in_partition: Coords2d, + #[comptime] config: S, + ) -> (Coords2d, CubeOption>) { + let (row_in_partition, col) = pos_in_partition; + let attention_tile_size = config.tiling_scheme().tile_size; + + let row = row_in_partition + P::seq_q_index() * config.tiling_scheme().partition_size.seq_q; + + let tile = StridedTile::::new_strided( + self.global_iter + .view() + .slice( + ( + row * attention_tile_size.seq_q, + col.runtime() * attention_tile_size.seq_kv, + ), + (attention_tile_size.seq_q, attention_tile_size.seq_kv).runtime(), + ) + .to_linear_slice(), + self.seq_kv_shape, + MatrixLayout::RowMajor, + ); + + (self.logical_iter.read(), CubeOption::new_Some(tile)) + } + + fn advance(&mut self) { + self.global_iter.advance(); + self.logical_iter.advance() + } +} diff --git a/crates/cubecl-attention/src/components/global/simple/reader/mod.rs b/crates/cubecl-attention/src/components/global/simple/reader/mod.rs new file mode 100644 index 000000000..577a6f2e5 --- /dev/null +++ b/crates/cubecl-attention/src/components/global/simple/reader/mod.rs @@ -0,0 +1,11 @@ +mod base; +mod key; +mod mask; +mod query; +mod value; + +pub use base::*; +pub use key::*; +pub use mask::*; +pub use query::*; +pub use value::*; diff --git a/crates/cubecl-attention/src/components/global/simple/reader/query.rs b/crates/cubecl-attention/src/components/global/simple/reader/query.rs new file mode 100644 index 000000000..0d5a30ebc --- /dev/null +++ b/crates/cubecl-attention/src/components/global/simple/reader/query.rs @@ -0,0 +1,48 @@ +use crate::components::attention_types::*; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_matmul::components::MatrixLayout; +use cubecl_matmul::components::tile::StridedTile; +use cubecl_std::tensor::{View, layout::Coords2d}; + +use crate::components::AttentionPrecision; +use crate::components::stage::{AttentionPartitioner, StageAttentionConfig}; + +#[derive(CubeType)] +pub struct QueryReader { + query: View>, Coords2d>, +} + +#[cube] +impl QueryReader { + pub fn new(q_offset: u32, query: View>, Coords2d>) -> Self { + let query = query.slice((q_offset, 0), query.shape()); + + QueryReader:: { query } + } + + pub fn get_tile( + &self, + tile: Coords2d, + #[comptime] config: S, + ) -> StridedTile> { + let (row_in_partition, col) = tile; + let attention_tile_size = config.tiling_scheme().tile_size; + + let row = row_in_partition + P::seq_q_index() * config.tiling_scheme().partition_size.seq_q; + + StridedTile::>::new_strided( + self.query + .slice( + ( + row * attention_tile_size.seq_q, + col * attention_tile_size.head_dim, + ), + (attention_tile_size.seq_q, attention_tile_size.head_dim).runtime(), + ) + .to_linear_slice(), + config.tiling_scheme().elements_in_partition_head_dim(), + MatrixLayout::RowMajor, + ) + } +} diff --git a/crates/cubecl-attention/src/components/global/simple/reader/value.rs b/crates/cubecl-attention/src/components/global/simple/reader/value.rs new file mode 100644 index 000000000..5c63fd862 --- /dev/null +++ b/crates/cubecl-attention/src/components/global/simple/reader/value.rs @@ -0,0 +1,103 @@ +use crate::components::attention_types::*; +use crate::components::global::simple::reader::{AttentionReader, AttentionReaderExpand}; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_matmul::components::StageIdent; +use cubecl_matmul::components::global::{ + memory::{GlobalIterator, ViewDirection}, + read::tiled::TiledLayout, +}; +use cubecl_matmul::components::stage::StridedStage; +use cubecl_std::tensor::{View, layout::Coords2d}; +use std::marker::PhantomData; + +use crate::components::global::base::GlobalAttentionConfig; +use crate::components::tile::AttentionTilingLayout; +use crate::components::{AttentionIdent, AttentionPrecision}; + +#[derive(CubeType)] +pub struct DummyValueReader { + global_iter: GlobalIterator>>, + + #[cube(comptime)] + _phantom: PhantomData, +} + +#[cube] +impl DummyValueReader { + pub fn new(value: View>, Coords2d>, step: u32) -> Self { + let global_iter = GlobalIterator::new(value, step, ViewDirection::Row, false); + + DummyValueReader:: { + global_iter, + _phantom: PhantomData, + } + } +} + +#[cube] +impl AttentionReader, G> + for DummyValueReader +{ + type Stage = StridedStage, AttentionTilingLayout>; + + fn init_stage(&mut self, #[comptime] config: G) -> Self::Stage { + StridedStage::new(StageIdent::Rhs, config.value_stage_memory_config()) + } + + fn read_global(&mut self, stage: &mut Self::Stage, #[comptime] config: G) { + if UNIT_POS_Y == 0 { + // TODO this reader is bad, it's not coalesced + let memory_config = config.global_memory_config(AttentionIdent::Value); + let mut slice = stage.as_slice_mut(1u32); + + let tile_rows = memory_config.elements_in_tile_row; + let tile_cols = memory_config.elements_in_tile_col; + let partition_rows = memory_config.elements_in_stage_row / tile_rows; + let partition_cols = memory_config.elements_in_stage_col / tile_cols; + + let units_per_tile_row = comptime!(config.plane_dim() / tile_rows); + let tile_cols_per_unit = comptime!(div_ceil(tile_cols, units_per_tile_row)); + + let row_in_tile = UNIT_POS_X / units_per_tile_row; + let col_in_tile_start = (UNIT_POS_X % units_per_tile_row) * tile_cols_per_unit; + + // Assumes row tiling order + let num_elements_per_tile = tile_rows * tile_cols; + let tile_row_stride = partition_cols * num_elements_per_tile; + let tile_col_stride = num_elements_per_tile; + + let layout = TiledLayout::new(memory_config); + let view = self.global_iter.view().view(layout); + + #[unroll] + for tile_row in 0..partition_rows { + #[unroll] + for tile_col in 0..partition_cols { + if row_in_tile < tile_rows { + #[unroll] + for i in 0..tile_cols_per_unit { + let col = col_in_tile_start + i; + + if col < tile_cols { + let tile_row_offset = tile_row * tile_row_stride; + let tile_col_offset = tile_col * tile_col_stride; + let offset = tile_row_offset + tile_col_offset; + + let index = row_in_tile * tile_cols + col; + + slice[index + offset] = Line::cast_from( + view.read_checked(((tile_row, tile_col).runtime(), index)), + ); + } + } + } + } + } + } + } + + fn advance_view(&mut self) { + self.global_iter.advance(); + } +} diff --git a/crates/cubecl-attention/src/components/global/dummy/setup.rs b/crates/cubecl-attention/src/components/global/simple/setup.rs similarity index 67% rename from crates/cubecl-attention/src/components/global/dummy/setup.rs rename to crates/cubecl-attention/src/components/global/simple/setup.rs index f590b1163..e10d6a08f 100644 --- a/crates/cubecl-attention/src/components/global/dummy/setup.rs +++ b/crates/cubecl-attention/src/components/global/simple/setup.rs @@ -8,12 +8,12 @@ use crate::components::{ AttentionSetupError, global::{ GlobalAttentionFamily, - dummy::{DummyGlobalAttention, config::DummyGlobalConfig}, + simple::{SimpleGlobalAttention, config::SimpleGlobalConfig}, }, stage::{StageAttentionConfig as _, StageAttentionFamily}, }; -pub struct DummyGlobalAttentionFamily { +pub struct SimpleGlobalAttentionFamily { _phantom: PhantomData, } @@ -23,20 +23,20 @@ impl< ValueStage = StridedStageFamily, OutStage = PartitionedStageFamily, >, -> GlobalAttentionFamily for DummyGlobalAttentionFamily +> GlobalAttentionFamily for SimpleGlobalAttentionFamily { - type Attention = DummyGlobalAttention>; + type Attention = SimpleGlobalAttention>; - type Config = DummyGlobalConfig; + type Config = SimpleGlobalConfig; fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &AttentionProblem, selection: &AttentionSelection, line_sizes: &AttentionLineSizes, ) -> Result { let stage_config = SA::setup::(client, problem, selection, line_sizes)?; - DummyGlobalConfig::new(stage_config, stage_config.num_planes()) + SimpleGlobalConfig::new(stage_config, stage_config.num_planes(), problem.causal) } } diff --git a/crates/cubecl-attention/src/components/global/simple/writer/mod.rs b/crates/cubecl-attention/src/components/global/simple/writer/mod.rs new file mode 100644 index 000000000..aa5b8e049 --- /dev/null +++ b/crates/cubecl-attention/src/components/global/simple/writer/mod.rs @@ -0,0 +1,26 @@ +use cubecl::prelude::*; +use cubecl_core::{self as cubecl}; + +use cubecl_matmul::components::global::{ + PartitionedStage, WriteEventListener, memory::GlobalMemoryConfig, +}; + +mod plane; +mod unit; + +use cubecl_std::tensor::{View, layout::Coords2d}; +pub use plane::*; +pub use unit::*; + +use crate::components::stage::StageAttentionConfig; + +#[cube] +pub trait AttentionWriter: WriteEventListener { + fn new( + global: View, Coords2d, ReadWrite>, + #[comptime] global_config: GlobalMemoryConfig, + #[comptime] stage_config: S, + ) -> Self; + + fn stage(&mut self) -> PartitionedStage; +} diff --git a/crates/cubecl-attention/src/components/global/simple/writer/plane.rs b/crates/cubecl-attention/src/components/global/simple/writer/plane.rs new file mode 100644 index 000000000..c467d3ccf --- /dev/null +++ b/crates/cubecl-attention/src/components/global/simple/writer/plane.rs @@ -0,0 +1,91 @@ +use cubecl::prelude::*; +use cubecl_core::{self as cubecl}; +use cubecl_matmul::components::{ + MatrixLayout, + global::{ + PartitionedStage, WriteEvent, WriteEventExpand, WriteEventListener, + memory::GlobalMemoryConfig, + plane_write, + read::tiled::{TiledCoords, TiledLayout}, + }, + stage::StageMemoryConfig, +}; +use cubecl_std::tensor::{View, layout::Coords2d}; + +use crate::components::{ + global::simple::{AttentionWriter, AttentionWriterExpand}, + stage::{AttentionPartitioner, StageAttentionConfig, plane::PlanePartitioner}, +}; + +#[derive(CubeType)] +pub struct PlaneAttentionWriter { + global: View, TiledCoords, ReadWrite>, + stage: PartitionedStage, + + #[cube(comptime)] + plane_dim: u32, + #[cube(comptime)] + config: GlobalMemoryConfig, +} + +#[cube] +impl PlaneAttentionWriter {} + +#[cube] +impl WriteEventListener for PlaneAttentionWriter { + fn on_event(this: &mut Self, event: WriteEvent) { + #[allow(clippy::single_match)] + match event { + WriteEvent::TileStored { tile } => plane_write::( + &mut this.global, + &this.stage.unit_tile, + tile, + comptime![this.plane_dim], + comptime![this.config], + ), + _ => {} + } + } +} + +#[cube] +impl AttentionWriter for PlaneAttentionWriter { + fn new( + global: View, Coords2d, ReadWrite>, + #[comptime] global_config: GlobalMemoryConfig, + #[comptime] stage_config: S, + ) -> Self { + let stage_mem_config = comptime! { + let elements_in_tile_row = stage_config.tiling_scheme().elements_in_partition_seq_q(); + let elements_in_tile_col= stage_config.tiling_scheme().elements_in_partition_val_dim(); + let planes = stage_config.num_planes(); + + StageMemoryConfig { + num_main_flow_planes: planes, + elements_in_tile_row, + elements_in_tile_col, + // Each plane has its slot in row direction + tiles_in_stage_row: planes, + // Each plane needs only one slot + tiles_in_stage_col: 1, + stage_line_size: 1, + matrix_layout: MatrixLayout::RowMajor, + num_stages: 1, + } + }; + + let stage = + PartitionedStage::new((PlanePartitioner::seq_q_index(), 0u32), stage_mem_config); + + PlaneAttentionWriter:: { + global: global.view_mut(TiledLayout::new(global_config)), + stage, + plane_dim: stage_config.plane_dim(), + config: global_config, + } + } + + fn stage(&mut self) -> PartitionedStage { + self.stage + } +} diff --git a/crates/cubecl-attention/src/components/global/simple/writer/unit.rs b/crates/cubecl-attention/src/components/global/simple/writer/unit.rs new file mode 100644 index 000000000..2fdb32347 --- /dev/null +++ b/crates/cubecl-attention/src/components/global/simple/writer/unit.rs @@ -0,0 +1,83 @@ +use cubecl::prelude::*; +use cubecl_core::{self as cubecl}; +use cubecl_matmul::components::{ + MatrixLayout, + global::{ + PartitionedStage, WriteEvent, WriteEventExpand, WriteEventListener, + memory::GlobalMemoryConfig, + read::tiled::{TiledCoords, TiledLayout}, + unit_write, + }, + stage::StageMemoryConfig, +}; +use cubecl_std::tensor::{View, layout::Coords2d}; + +use crate::components::{ + global::simple::{AttentionWriter, AttentionWriterExpand}, + stage::{AttentionPartitioner, StageAttentionConfig, unit::UnitPartitioner}, +}; + +#[derive(CubeType)] +pub struct UnitAttentionWriter { + global: View, TiledCoords, ReadWrite>, + stage: PartitionedStage, + + #[cube(comptime)] + config: GlobalMemoryConfig, +} + +#[cube] +impl WriteEventListener for UnitAttentionWriter { + fn on_event(this: &mut Self, event: WriteEvent) { + #[allow(clippy::single_match)] + match event { + WriteEvent::TileStored { tile } => unit_write::( + &mut this.global, + &this.stage.unit_tile, + tile, + comptime![this.config], + ), + _ => {} + } + } +} + +#[cube] +impl AttentionWriter for UnitAttentionWriter { + fn new( + global: View, Coords2d, ReadWrite>, + #[comptime] global_config: GlobalMemoryConfig, + #[comptime] stage_config: S, + ) -> Self { + let stage_mem_config = comptime! { + let elements_in_tile_row= stage_config.tiling_scheme().elements_in_partition_seq_q(); + let elements_in_tile_col= stage_config.tiling_scheme().elements_in_partition_val_dim(); + let planes = stage_config.num_planes(); + + StageMemoryConfig { + num_main_flow_planes: planes, + elements_in_tile_row, + elements_in_tile_col, + // Each unit has its slot in row direction + tiles_in_stage_row: planes, + // Each unit needs only one slot + tiles_in_stage_col: 1, + stage_line_size: 1, + matrix_layout: MatrixLayout::RowMajor, + num_stages: 1, + } + }; + + let stage = PartitionedStage::new((UnitPartitioner::seq_q_index(), 0u32), stage_mem_config); + + UnitAttentionWriter:: { + global: global.view_mut(TiledLayout::new(global_config)), + stage, + config: global_config, + } + } + + fn stage(&mut self) -> PartitionedStage { + self.stage + } +} diff --git a/crates/cubecl-attention/src/components/line_size.rs b/crates/cubecl-attention/src/components/line_size.rs index 290f7451c..419d957d5 100644 --- a/crates/cubecl-attention/src/components/line_size.rs +++ b/crates/cubecl-attention/src/components/line_size.rs @@ -1,6 +1,6 @@ use std::fmt::Debug; -use cubecl_core::{LineSizeError, Runtime, ir::StorageType, tensor_line_size_parallel}; +use cubecl_core::{LineSizeError, Runtime, tensor_line_size_parallel}; use crate::components::{AttentionIdent, AttentionSetupError}; @@ -29,11 +29,7 @@ pub struct AvailableLineSizes { } impl AvailableLineSizes { - pub fn from_elem_types( - elem_in: &StorageType, - elem_mask: &StorageType, - elem_out: &StorageType, - ) -> Self { + pub fn from_elem_types(elem_in: usize, elem_mask: usize, elem_out: usize) -> Self { let in_available: Vec = R::io_optimized_line_sizes_unchecked(elem_in).collect(); let mask_available: Vec = R::io_optimized_line_sizes_unchecked(elem_mask).collect(); let out_available = R::io_optimized_line_sizes_unchecked(elem_out).collect(); diff --git a/crates/cubecl-attention/src/components/mask.rs b/crates/cubecl-attention/src/components/mask.rs deleted file mode 100644 index 6fe83a74c..000000000 --- a/crates/cubecl-attention/src/components/mask.rs +++ /dev/null @@ -1,94 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -use crate::components::AttentionTilingScheme; - -#[derive(CubeType, Copy, Clone)] -pub struct GlobalMask { - q_bound: u32, - kv_bound: u32, - #[cube(comptime)] - tiling_scheme: AttentionTilingScheme, -} - -#[derive(CubeType, Copy, Clone)] -pub struct StageMask { - q_bound: u32, - kv_bound: u32, - #[cube(comptime)] - tiling_scheme: AttentionTilingScheme, -} - -#[derive(CubeType, Copy, Clone)] -pub struct PartitionMask { - q_bound: u32, - kv_bound: u32, - #[cube(comptime)] - tiling_scheme: AttentionTilingScheme, -} - -#[derive(CubeType, Copy, Clone)] -pub struct TileMask { - q_bound: u32, - kv_bound: u32, -} - -#[cube] -impl GlobalMask { - pub fn new( - q_bound: u32, - kv_bound: u32, - #[comptime] tiling_scheme: AttentionTilingScheme, - ) -> GlobalMask { - GlobalMask { - q_bound, - kv_bound, - tiling_scheme, - } - } - - pub fn to_stage(&self, row: u32, col: u32) -> StageMask { - let q_factor = comptime!(self.tiling_scheme.elements_in_stage_seq_q()); - let kv_factor = comptime!(self.tiling_scheme.elements_in_stage_seq_kv()); - - StageMask { - q_bound: self.q_bound.saturating_sub(row * q_factor), - kv_bound: self.kv_bound.saturating_sub(col * kv_factor), - tiling_scheme: self.tiling_scheme, - } - } -} - -#[cube] -impl StageMask { - pub fn to_partition(&self, row: u32) -> PartitionMask { - let q_factor = comptime!(self.tiling_scheme.elements_in_partition_seq_q()); - - PartitionMask { - q_bound: self.q_bound.saturating_sub(row * q_factor), - kv_bound: self.kv_bound, - tiling_scheme: self.tiling_scheme, - } - } -} - -#[cube] -impl PartitionMask { - pub fn to_tile(self, row: u32, col: u32) -> TileMask { - let q_factor = comptime!(self.tiling_scheme.elements_in_tile_seq_q()); - let kv_factor = comptime!(self.tiling_scheme.elements_in_tile_seq_kv()); - - TileMask { - q_bound: self.q_bound.saturating_sub(row * q_factor), - kv_bound: self.kv_bound.saturating_sub(col * kv_factor), - } - } -} - -#[cube] -impl TileMask { - pub fn apply(&self, row: u32, col: u32) -> Line { - let should_mask = Line::::cast_from(row >= self.q_bound || col >= self.kv_bound); - should_mask * Line::cast_from(-999999) - } -} diff --git a/crates/cubecl-attention/src/components/mod.rs b/crates/cubecl-attention/src/components/mod.rs index 8a6d26741..0222f7cf3 100644 --- a/crates/cubecl-attention/src/components/mod.rs +++ b/crates/cubecl-attention/src/components/mod.rs @@ -1,5 +1,6 @@ pub mod args; pub mod batch; +pub mod fragment; pub mod global; pub mod stage; pub mod tile; @@ -7,7 +8,6 @@ pub mod tile; mod error; mod ident; mod line_size; -mod mask; mod problem; mod selection; mod spec; @@ -16,7 +16,6 @@ mod tiling_scheme; pub use error::*; pub use ident::*; pub use line_size::*; -pub use mask::*; pub use problem::*; pub use selection::*; pub use spec::*; diff --git a/crates/cubecl-attention/src/components/problem.rs b/crates/cubecl-attention/src/components/problem.rs index d5e66f488..7f457085f 100644 --- a/crates/cubecl-attention/src/components/problem.rs +++ b/crates/cubecl-attention/src/components/problem.rs @@ -16,6 +16,8 @@ pub struct AttentionProblem { /// Usually equal to `head_dim`, but may differ in some variants pub val_dim: usize, - /// Whether a mask is applied (shape is always [batch, seq_q, heads, seq_k]) + /// Whether a mask is supplied (shape is always [batch, seq_q, heads, seq_kv]) pub masked: bool, + /// Whether there is a causal mask + pub causal: bool, } diff --git a/crates/cubecl-attention/src/components/selection.rs b/crates/cubecl-attention/src/components/selection.rs index dd2725008..a26a61562 100644 --- a/crates/cubecl-attention/src/components/selection.rs +++ b/crates/cubecl-attention/src/components/selection.rs @@ -8,4 +8,5 @@ pub struct AttentionSelection { pub plane_dim: u32, pub reuse_key_value: bool, + pub two_rows_in_array_tile: bool, } diff --git a/crates/cubecl-attention/src/components/spec.rs b/crates/cubecl-attention/src/components/spec.rs index d599c9717..53469a290 100644 --- a/crates/cubecl-attention/src/components/spec.rs +++ b/crates/cubecl-attention/src/components/spec.rs @@ -205,7 +205,7 @@ impl< } /// Input argument -pub type InputArg = as AttentionArgs>::Input, KG, VG>; +pub type InputArg = as AttentionArgs>::Input, KG, VG, MSK>; /// Output argument pub type OutputArg = as AttentionArgs>::Output>; diff --git a/crates/cubecl-attention/src/components/stage/base.rs b/crates/cubecl-attention/src/components/stage/base.rs index 7b2cd5741..5cda3d5e2 100644 --- a/crates/cubecl-attention/src/components/stage/base.rs +++ b/crates/cubecl-attention/src/components/stage/base.rs @@ -1,21 +1,23 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; use cubecl_matmul::components::{ + MatrixLayout, StageIdent, TilingScheme, global::{WriteEventListener, WriteTiling}, - stage::StageFamily, + stage::{StageFamily, StageMemoryConfig}, }; use std::{fmt::Debug, hash::Hash}; -use crate::components::attention_types::*; -use crate::components::stage::dummy::AttentionStageMemoryConfig; -use crate::components::{AttentionIdent, StageMask}; +use crate::components::tile::RunningState; use crate::components::{ AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, - AttentionSetupError, AvailableLineSizes, - global::GlobalAttentionConfig, - tile::{AttentionTilingLayout, dummy::AttentionMatmulConfig}, + AttentionSetupError, AvailableLineSizes, global::GlobalAttentionConfig, + tile::AttentionTilingLayout, }; -use crate::components::{AttentionTilingScheme, global::dummy::QueryReader}; +use crate::components::{AttentionTilingScheme, global::simple::QueryReader}; +use crate::components::{attention_types::*, fragment::FragmentAttentionConfig}; +use crate::components::{global::simple::MaskReader, stage::AttentionPartitioner}; +use cubecl_std::CubeOption; +use cubecl_std::tensor::layout::Coords2d; /// A family of [TileAttention] implementations that operate with any [precision](AttentionPrecision). pub trait StageAttentionFamily: Send + Sync + 'static { @@ -39,7 +41,7 @@ pub trait StageAttentionFamily: Send + Sync + 'static { /// /// This function may return an error if the configuration cannot be supported on the current runtime. fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &AttentionProblem, selection: &AttentionSelection, line_sizes: &AttentionLineSizes, @@ -61,49 +63,59 @@ pub trait StageAttention: 'static + Send + Sync { /// The configuration type associated with this Attention. type Config: StageAttentionConfig; + type Partitioner: AttentionPartitioner; - type State: CubeType; + type QueryRegisters: CubeType; + type KeyValueRegisters: CubeType; + type SoftmaxRegisters: CubeType; + type AccumulatorRegisters: CubeType; + type MaskRegisters: CubeType; - type QueryPartition: CubeType; - type KeyValuePartition: CubeType; - type SoftmaxPartition: CubeType; - type AccumulatorPartition: CubeType; - - fn init_state(#[comptime] config: Self::Config) -> Self::State; + fn init_state(#[comptime] config: Self::Config) -> Sequence>>; fn execute( - key_reader: &Self::KeyStage, - value_reader: &Self::ValueStage, - query: &Self::QueryPartition, - key_value: &mut Self::KeyValuePartition, - score: &mut Self::SoftmaxPartition, - mask: StageMask, - accumulator: &mut Self::AccumulatorPartition, - prev_state: &mut Self::State, + query: &Self::QueryRegisters, + key_stage: &Self::KeyStage, + value_stage: &Self::ValueStage, + key_value: &mut Self::KeyValueRegisters, + mask_partition: &Self::MaskRegisters, + score: &mut Self::SoftmaxRegisters, + accumulator: &mut Self::AccumulatorRegisters, + prev_state: &mut Sequence>>, #[comptime] config: Self::Config, ); fn rescale( - acc: &mut Self::AccumulatorPartition, - state: Self::State, + acc: &mut Self::AccumulatorRegisters, + state: Sequence>>, #[comptime] config: Self::Config, ); fn write( - acc: &Self::AccumulatorPartition, + acc: &Self::AccumulatorRegisters, stage: &mut Self::OutStage, writer: &mut W, #[comptime] tile_config: Self::Config, ); - fn init_partitions( - query_loader: QueryReader, + fn init_query(#[comptime] config: Self::Config) -> Self::QueryRegisters; + fn init_key_value(#[comptime] config: Self::Config) -> Self::KeyValueRegisters; + fn init_mask( + out_of_bounds: CubeOption, + #[comptime] config: Self::Config, + ) -> Self::MaskRegisters; + fn init_softmax(#[comptime] config: Self::Config) -> Self::SoftmaxRegisters; + fn init_accumulator(#[comptime] config: Self::Config) -> Self::AccumulatorRegisters; + + fn read_query( + reader: &QueryReader, + registers: &mut Self::QueryRegisters, + #[comptime] config: Self::Config, + ); + fn read_mask( + reader: &MaskReader, + registers: &mut Self::MaskRegisters, #[comptime] config: Self::Config, - ) -> ( - Self::QueryPartition, - Self::KeyValuePartition, - Self::SoftmaxPartition, - Self::AccumulatorPartition, ); } @@ -111,17 +123,37 @@ pub trait StageAttention: 'static + Send + Sync { pub trait StageAttentionConfig: Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static { - type AttentionMatmulConfig: AttentionMatmulConfig; + type FragmentAttentionConfig: FragmentAttentionConfig; fn plane_dim(&self) -> u32; fn num_planes(&self) -> u32; - fn tile_config(&self) -> Self::AttentionMatmulConfig; + fn tile_config(&self) -> Self::FragmentAttentionConfig; fn score_stage_memory_config(&self) -> AttentionStageMemoryConfig; fn value_stage_memory_config(&self) -> AttentionStageMemoryConfig; fn tiling_scheme(&self) -> AttentionTilingScheme; fn reuse_key_value(&self) -> bool; - fn num_rows_per_unit(&self, ident: AttentionIdent) -> u32; + fn num_rows_per_unit(&self) -> u32; +} + +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +pub struct AttentionStageMemoryConfig { + pub matmul_tiling_scheme: TilingScheme, +} + +impl AttentionStageMemoryConfig { + pub fn into_matmul_config(&self, ident: StageIdent) -> StageMemoryConfig { + StageMemoryConfig { + num_main_flow_planes: 1, + elements_in_tile_row: self.matmul_tiling_scheme.elements_in_tile_row(ident), + elements_in_tile_col: self.matmul_tiling_scheme.elements_in_tile_col(ident), + tiles_in_stage_row: self.matmul_tiling_scheme.tiles_in_stage_row(ident), + tiles_in_stage_col: self.matmul_tiling_scheme.tiles_in_stage_col(ident), + stage_line_size: 1, + matrix_layout: MatrixLayout::RowMajor, + num_stages: 1, + } + } } diff --git a/crates/cubecl-attention/src/components/stage/dummy/attention.rs b/crates/cubecl-attention/src/components/stage/dummy/attention.rs deleted file mode 100644 index 26f5139fe..000000000 --- a/crates/cubecl-attention/src/components/stage/dummy/attention.rs +++ /dev/null @@ -1,252 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; -use cubecl_matmul::components::{ - global::{WriteEvent, WriteEventListener}, - stage::Stage, - tile::io::Strided, -}; -use std::marker::PhantomData; - -use crate::components::StageMask; -use crate::components::attention_types::*; -use crate::components::global::dummy::QueryReader; -use crate::components::stage::dummy::SoftmaxPartition; -use crate::components::stage::dummy::StageState; -use crate::components::stage::dummy::{Accumulators, DummyStageConfig, KeyValues, Queries}; -use crate::components::stage::{StageAttention, StageAttentionConfig}; -use crate::components::tile::RowWise; -use crate::components::tile::TileAttention; -use crate::components::{AttentionPrecision, global::GlobalAttentionConfig}; - -pub struct DummyStageAttention> { - _phantom: PhantomData<(AP, SK, SV, SO, TA)>, -} - -#[cube] -impl< - AP: AttentionPrecision, - SK: Stage, ReadOnly, TileKind = Strided>, - SV: Stage, ReadOnly, TileKind = Strided>, - SO: Stage, ReadWrite, TileKind = Strided>, - TA: TileAttention, -> StageAttention for DummyStageAttention -{ - type Config = DummyStageConfig; - - type KeyStage = SK; - type ValueStage = SV; - type OutStage = SO; - - type State = StageState; - type QueryPartition = Queries; - type KeyValuePartition = KeyValues; - type SoftmaxPartition = SoftmaxPartition; - type AccumulatorPartition = Accumulators; - - fn execute( - key_reader: &Self::KeyStage, - value_reader: &Self::ValueStage, - query_partition: &Self::QueryPartition, - key_value_partition: &mut Self::KeyValuePartition, - softmax_partition: &mut Self::SoftmaxPartition, - mask: StageMask, - accumulator_partition: &mut Self::AccumulatorPartition, - state: &mut Self::State, - #[comptime] config: Self::Config, - ) { - let partition_mask = mask.to_partition(UNIT_POS_Y); - - let p = config.tiling_scheme().partition_size; - - let mut kv = comptime![0u32]; - - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.seq_kv { - let mut hd = comptime![0u32]; - - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.head_dim { - let key_tile = SK::tile(key_reader, (hd, kv).runtime()); - - TA::fill_key( - &key_tile, - key_value_partition.get_key_at_mut(hd, kv, config), - config.tile_config(), - ); - - comptime![hd += 1]; - } - - let mut q = comptime![0u32]; - let mut scales = Sequence::>>::new(); - - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.seq_q { - let softmax_tile = softmax_partition.get_at_mut(q, kv, config); - TA::zero_softmax(softmax_tile, config.tile_config()); - - let mut hd = comptime![0u32]; - - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.head_dim { - let query_tile = query_partition.get_at(q, hd, config); - let key_tile = key_value_partition.get_key_at(hd, kv, config); - - TA::accumulate_score(query_tile, key_tile, softmax_tile, config.tile_config()); - - comptime![hd += 1]; - } - - let state_q = state.get_at_mut(q); - - let accumulator_scale = TA::softmax( - softmax_tile, - partition_mask.to_tile(q, kv), - state_q, - config.tiling_scheme().elements_in_partition_head_dim(), - ); - - scales.push(accumulator_scale); - - comptime![q += 1]; - } - - let mut vd = comptime![0u32]; - - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.val_dim { - let value_tile = SV::tile(value_reader, (kv, vd).runtime()); - - TA::fill_value( - &value_tile, - key_value_partition.get_value_at_mut(kv, vd, config), - config.tile_config(), - ); - - comptime![vd += 1]; - } - - let mut q = comptime![0u32]; - - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.seq_q { - let mut vd = comptime![0u32]; - let softmax_tile = softmax_partition.get_at(q, kv, config); - - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.val_dim { - TA::accumulate_value( - softmax_tile, - key_value_partition.get_value_at(kv, vd, config), - accumulator_partition.get_at_mut(q, vd, config), - scales.index(q), - config.tile_config(), - ); - - comptime![vd += 1]; - } - - comptime![q += 1]; - } - - comptime![kv += 1]; - } - } - - fn rescale( - acc: &mut Self::AccumulatorPartition, - state: Self::State, - #[comptime] config: Self::Config, - ) { - let p = config.tiling_scheme().partition_size; - - let mut q = comptime!(0u32); - - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.seq_q { - let mut vd = comptime!(0u32); - - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.val_dim { - TA::rescale( - Self::AccumulatorPartition::get_at_mut(acc, q, vd, config), - state.get_at(q), - config.tile_config(), - ); - - comptime![vd += 1]; - } - - comptime![q += 1]; - } - } - - fn init_state(#[comptime] config: Self::Config) -> Self::State { - StageState::::init::(config) - } - - fn write( - acc: &Self::AccumulatorPartition, - stage: &mut Self::OutStage, - writer: &mut W, - #[comptime] stage_config: Self::Config, - ) { - let p = stage_config.tiling_scheme().partition_size; - let mut q = comptime!(0u32); - - W::on_event(writer, WriteEvent::new_Begin()); - - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.seq_q { - let mut kv = comptime!(0u32); - - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..p.val_dim { - let tile_pos = (q + UNIT_POS_Y * p.seq_q, kv.runtime()); - let mut tile = Self::OutStage::tile(stage, tile_pos); - - TA::write_results( - &mut tile, - Self::AccumulatorPartition::get_at(acc, q, kv, stage_config), - stage_config.tile_config(), - ); - - W::on_event(writer, WriteEvent::new_TileStored(tile_pos)); - - comptime![kv += 1]; - } - - comptime![q += 1]; - } - - W::on_event(writer, WriteEvent::new_Finish()); - } - - fn init_partitions( - query_loader: QueryReader, - #[comptime] config: Self::Config, - ) -> ( - Self::QueryPartition, - Self::KeyValuePartition, - Self::SoftmaxPartition, - Self::AccumulatorPartition, - ) { - ( - Self::QueryPartition::new(query_loader, config), - Self::KeyValuePartition::new(config), - Self::SoftmaxPartition::new(config), - Self::AccumulatorPartition::new(config), - ) - } -} diff --git a/crates/cubecl-attention/src/components/stage/kv_reuse_attention.rs b/crates/cubecl-attention/src/components/stage/kv_reuse_attention.rs new file mode 100644 index 000000000..b8d1eadf0 --- /dev/null +++ b/crates/cubecl-attention/src/components/stage/kv_reuse_attention.rs @@ -0,0 +1,281 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_matmul::components::{ + global::{WriteEvent, WriteEventListener}, + stage::Stage, + tile::io::Strided, +}; +use std::marker::PhantomData; + +use crate::components::global::simple::QueryReader; +use crate::components::stage::StageAttentionConfig; +use crate::components::tile::RowWise; +use crate::components::tile::RunningState; +use crate::components::tile::TileAttention; +use crate::components::{AttentionPrecision, global::GlobalAttentionConfig}; +use crate::components::{attention_types::*, stage::StageAttention}; +use crate::components::{ + fragment::FragmentAttention, + stage::tile_partitions::{ + AccumulatorPartition, KeyValues, MaskPartition, QueryPartition, SoftmaxPartition, + }, +}; +use crate::components::{global::simple::MaskReader, stage::partitioner::AttentionPartitioner}; +use cubecl_std::CubeOption; +use cubecl_std::tensor::layout::Coords2d; + +#[derive(CubeType)] +pub struct KVReuseStageAttention< + AP: AttentionPrecision, + SK, + SV, + SO, + FA: FragmentAttention, + P: AttentionPartitioner, + S: StageAttentionConfig, +> { + #[cube(comptime)] + _phantom: PhantomData<(AP, SK, SV, SO, FA, P, S)>, +} + +#[cube] +impl< + AP: AttentionPrecision, + SK: Stage, ReadOnly, TileKind = Strided>, + SV: Stage, ReadOnly, TileKind = Strided>, + SO: Stage, ReadWrite, TileKind = Strided>, + FA: FragmentAttention, + P: AttentionPartitioner, + S: StageAttentionConfig, +> StageAttention for KVReuseStageAttention +{ + type KeyStage = SK; + type ValueStage = SV; + type OutStage = SO; + + type Config = S; + type Partitioner = P; + + type QueryRegisters = QueryPartition; + type KeyValueRegisters = KeyValues; + type SoftmaxRegisters = SoftmaxPartition; + type AccumulatorRegisters = AccumulatorPartition; + type MaskRegisters = MaskPartition; + + fn execute( + query_partition: &QueryPartition, + key_stage: &SK, + value_stage: &SV, + key_value_partition: &mut KeyValues, + mask_partition: &MaskPartition, + softmax_partition: &mut SoftmaxPartition, + accumulator_partition: &mut AccumulatorPartition, + state: &mut Sequence>>, + #[comptime] config: S, + ) { + let p = config.tiling_scheme().partition_size; + + let mut max_placeholder = + TileAttention::::init_max_placeholder(config.num_rows_per_unit()); + let mut sum_placeholder = + TileAttention::::init_sum_placeholder(config.num_rows_per_unit()); + + #[unroll] + for kv in 0..p.seq_kv { + #[unroll] + for hd in 0..p.head_dim { + let key_tile = SK::tile(key_stage, (hd, kv).runtime()); + + TileAttention::fill_key( + &key_tile, + key_value_partition.get_key_at_mut(hd, kv, config), + config.tile_config(), + ); + } + + let mut scales = Sequence::>>::new(); + + #[unroll] + for q in 0..p.seq_q { + let softmax_tile = softmax_partition.get_at_mut(q, kv, config); + TileAttention::zero_softmax(softmax_tile, config.tile_config()); + + let mask_tile = mask_partition.get_at(q, kv, config.tiling_scheme()); + + #[unroll] + for hd in 0..p.head_dim { + let query_tile = query_partition.get_at(q, hd, config); + let key_tile = key_value_partition.get_key_at(hd, kv, config); + + TileAttention::accumulate_score( + query_tile, + key_tile, + softmax_tile, + config.tile_config(), + ); + } + + let state_q = state.index_mut(q); + + scales.push(TileAttention::softmax::( + softmax_tile, + mask_tile, + state_q, + &mut max_placeholder, + &mut sum_placeholder, + config.tiling_scheme().elements_in_partition_head_dim(), + config.tile_config(), + )); + } + + #[unroll] + for vd in 0..p.val_dim { + let value_tile = SV::tile(value_stage, (kv, vd).runtime()); + + TileAttention::fill_value( + &value_tile, + key_value_partition.get_value_at_mut(kv, vd, config), + config.tile_config(), + ); + } + + #[unroll] + for q in 0..p.seq_q { + let softmax_tile = softmax_partition.get_at(q, kv, config); + + #[unroll] + for vd in 0..p.val_dim { + TileAttention::accumulate_value( + softmax_tile, + key_value_partition.get_value_at(kv, vd, config), + accumulator_partition.get_at_mut(q, vd, config), + scales.index(q), + config.tile_config(), + ); + } + } + } + } + + fn rescale( + acc: &mut AccumulatorPartition, + state: Sequence>>, + #[comptime] config: S, + ) { + let p = config.tiling_scheme().partition_size; + + #[unroll] + for q in 0..p.seq_q { + #[unroll] + for vd in 0..p.val_dim { + TileAttention::::rescale( + AccumulatorPartition::::get_at_mut(acc, q, vd, config), + state.index(q), + ); + } + } + } + + fn init_state(#[comptime] config: S) -> Sequence>> { + let p = config.tiling_scheme().partition_size; + let mut sequence = Sequence::new(); + + #[unroll] + for _ in 0..comptime!(p.seq_q) { + sequence.push(TileAttention::::init_state(config.tile_config())); + } + + sequence + } + + fn write( + acc: &AccumulatorPartition, + stage: &mut SO, + writer: &mut W, + #[comptime] stage_config: S, + ) { + let p = stage_config.tiling_scheme().partition_size; + + W::on_event(writer, WriteEvent::new_Begin()); + + #[unroll] + for q in 0..p.seq_q { + #[unroll] + for vd in 0..p.val_dim { + let tile_pos = (q + P::seq_q_index() * p.seq_q, vd.runtime()); + let mut tile = SO::tile(stage, tile_pos); + + TileAttention::::write_results( + &mut tile, + AccumulatorPartition::::get_at(acc, q, vd, stage_config), + stage_config.tile_config(), + ); + + W::on_event(writer, WriteEvent::new_TileStored(tile_pos)); + } + } + + W::on_event(writer, WriteEvent::new_Finish()); + } + + fn init_query(#[comptime] config: S) -> QueryPartition { + QueryPartition::::new(config) + } + + fn init_key_value(#[comptime] config: S) -> KeyValues { + KeyValues::::new(config) + } + + fn init_softmax(#[comptime] config: S) -> SoftmaxPartition { + SoftmaxPartition::::new(config) + } + + fn init_accumulator(#[comptime] config: S) -> AccumulatorPartition { + AccumulatorPartition::::new(config) + } + + fn init_mask( + out_of_bounds: CubeOption, + #[comptime] config: S, + ) -> MaskPartition { + MaskPartition::::new(out_of_bounds, config) + } + + fn read_query( + reader: &QueryReader, + registers: &mut QueryPartition, + #[comptime] config: S, + ) { + let p = config.tiling_scheme().partition_size; + + #[unroll] + for q in 0..p.seq_q { + #[unroll] + for hd in 0..p.head_dim { + let tile_to_write = registers.get_at_mut(q, hd, config); + let tile_read = reader.get_tile::((q, hd).runtime(), config); + + tile_to_write.update(&tile_read); + } + } + } + + fn read_mask( + reader: &MaskReader, + registers: &mut MaskPartition, + #[comptime] config: S, + ) { + let p = config.tiling_scheme().partition_size; + + #[unroll] + for q in 0..p.seq_q { + #[unroll] + for kv in 0..p.seq_kv { + let mask_tile = registers.get_at_mut(q, kv, config.tiling_scheme()); + + let (new_origin, tile) = reader.read::((q, kv), config); + mask_tile.update(new_origin, tile); + } + } + } +} diff --git a/crates/cubecl-attention/src/components/stage/mod.rs b/crates/cubecl-attention/src/components/stage/mod.rs index 38256d83a..6f700eaf0 100644 --- a/crates/cubecl-attention/src/components/stage/mod.rs +++ b/crates/cubecl-attention/src/components/stage/mod.rs @@ -1,5 +1,10 @@ -pub mod dummy; +pub mod plane; +pub mod unit; mod base; +mod kv_reuse_attention; +mod partitioner; +mod tile_partitions; pub use base::*; +pub use partitioner::*; diff --git a/crates/cubecl-attention/src/components/stage/partitioner.rs b/crates/cubecl-attention/src/components/stage/partitioner.rs new file mode 100644 index 000000000..5d6ea0b93 --- /dev/null +++ b/crates/cubecl-attention/src/components/stage/partitioner.rs @@ -0,0 +1,16 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_std::tensor::layout::Coords1d; + +use crate::components::global::simple::AttentionWriter; +use crate::components::tile::Reducer; + +#[cube] +/// Defines how the stage is partitioned among compute primitives (e.g., units or planes). +/// Controls global writeback and and compute indexing. +pub trait AttentionPartitioner: Send + Sync + 'static { + type Reducer: Reducer; + type Writer: AttentionWriter; + + fn seq_q_index() -> Coords1d; +} diff --git a/crates/cubecl-attention/src/components/stage/plane/attention.rs b/crates/cubecl-attention/src/components/stage/plane/attention.rs new file mode 100644 index 000000000..66326bed7 --- /dev/null +++ b/crates/cubecl-attention/src/components/stage/plane/attention.rs @@ -0,0 +1,35 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_std::tensor::layout::Coords1d; + +use crate::components::{ + fragment::FragmentAttention, + global::simple::PlaneAttentionWriter, + stage::{ + kv_reuse_attention::KVReuseStageAttention, partitioner::AttentionPartitioner, + plane::PlaneKVReuseStageConfig, + }, + tile::BroadcastReducer, +}; + +pub type PlaneKVReuseStageAttention = KVReuseStageAttention< + AP, + SK, + SV, + SO, + FA, + PlanePartitioner, + PlaneKVReuseStageConfig<>::Config>, +>; + +pub struct PlanePartitioner {} + +#[cube] +impl AttentionPartitioner for PlanePartitioner { + type Reducer = BroadcastReducer; + type Writer = PlaneAttentionWriter; + + fn seq_q_index() -> Coords1d { + UNIT_POS_Y + } +} diff --git a/crates/cubecl-attention/src/components/stage/plane/config.rs b/crates/cubecl-attention/src/components/stage/plane/config.rs new file mode 100644 index 000000000..c369ef2ac --- /dev/null +++ b/crates/cubecl-attention/src/components/stage/plane/config.rs @@ -0,0 +1,87 @@ +use crate::components::{ + AttentionSetupError, AttentionTilingScheme, + fragment::FragmentAttentionConfig, + stage::{AttentionStageMemoryConfig, StageAttentionConfig}, +}; + +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +pub struct PlaneKVReuseStageConfig { + fragment_config: FC, + score_stage_memory_config: AttentionStageMemoryConfig, + value_stage_memory_config: AttentionStageMemoryConfig, + tiling_scheme: AttentionTilingScheme, + reuse_key_value: bool, + num_planes: u32, +} + +impl StageAttentionConfig for PlaneKVReuseStageConfig { + type FragmentAttentionConfig = FC; + + fn plane_dim(&self) -> u32 { + self.fragment_config.plane_dim() + } + + fn num_planes(&self) -> u32 { + self.num_planes + } + + fn tile_config(&self) -> Self::FragmentAttentionConfig { + self.fragment_config + } + + fn score_stage_memory_config(&self) -> AttentionStageMemoryConfig { + self.score_stage_memory_config + } + + fn value_stage_memory_config(&self) -> AttentionStageMemoryConfig { + self.value_stage_memory_config + } + + fn tiling_scheme(&self) -> AttentionTilingScheme { + self.tiling_scheme + } + + fn reuse_key_value(&self) -> bool { + self.reuse_key_value + } + + fn num_rows_per_unit(&self) -> u32 { + self.fragment_config.num_rows_per_unit() + } +} + +impl PlaneKVReuseStageConfig { + pub fn new( + fragment_config: FC, + score_stage_memory_config: AttentionStageMemoryConfig, + value_stage_memory_config: AttentionStageMemoryConfig, + tiling_scheme: AttentionTilingScheme, + reuse_key_value: bool, + num_planes: u32, + ) -> Result { + Self { + fragment_config, + score_stage_memory_config, + value_stage_memory_config, + tiling_scheme, + reuse_key_value, + num_planes, + } + .validate() + } + + pub fn validate(self) -> Result { + if self.reuse_key_value + && (self.tiling_scheme.tile_size.head_dim != self.tiling_scheme.tile_size.val_dim + || self.tiling_scheme.partition_size.head_dim + != self.tiling_scheme.partition_size.val_dim) + { + return Err(AttentionSetupError::InvalidConfig(Box::new( + "When reusing key/value, head_dim must equal val_dim in both tile_size and partition_size." + .to_string(), + ))); + } + + Ok(self) + } +} diff --git a/crates/cubecl-attention/src/components/stage/dummy/mod.rs b/crates/cubecl-attention/src/components/stage/plane/mod.rs similarity index 66% rename from crates/cubecl-attention/src/components/stage/dummy/mod.rs rename to crates/cubecl-attention/src/components/stage/plane/mod.rs index ab488aabd..cce5602e3 100644 --- a/crates/cubecl-attention/src/components/stage/dummy/mod.rs +++ b/crates/cubecl-attention/src/components/stage/plane/mod.rs @@ -1,9 +1,7 @@ mod attention; mod config; mod setup; -mod tile_partitions; pub use attention::*; pub use config::*; pub use setup::*; -pub use tile_partitions::*; diff --git a/crates/cubecl-attention/src/components/stage/dummy/setup.rs b/crates/cubecl-attention/src/components/stage/plane/setup.rs similarity index 74% rename from crates/cubecl-attention/src/components/stage/dummy/setup.rs rename to crates/cubecl-attention/src/components/stage/plane/setup.rs index 52154a153..40991b2d2 100644 --- a/crates/cubecl-attention/src/components/stage/dummy/setup.rs +++ b/crates/cubecl-attention/src/components/stage/plane/setup.rs @@ -1,6 +1,13 @@ use std::marker::PhantomData; -use crate::components::attention_types::*; +use crate::components::{ + attention_types::*, + fragment::FragmentAttentionFamily, + stage::{ + AttentionStageMemoryConfig, + plane::{PlaneKVReuseStageAttention, config::PlaneKVReuseStageConfig}, + }, +}; use cubecl_core::{client::ComputeClient, prelude::ReadWrite}; use cubecl_matmul::components::{ GlobalPartitionSize, TilingScheme, stage::StageFamily, tile::io::Strided, @@ -8,56 +15,51 @@ use cubecl_matmul::components::{ use crate::components::{ AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, - AttentionSetupError, - stage::{ - StageAttentionFamily, - dummy::{AttentionStageMemoryConfig, DummyStageAttention, DummyStageConfig}, - }, - tile::{AttentionTilingLayout, TileAttentionFamily}, + AttentionSetupError, stage::StageAttentionFamily, tile::AttentionTilingLayout, }; -pub struct DummyStageAttentionFamily< - TA: TileAttentionFamily, +pub struct PlaneKVReuseStageAttentionFamily< + FA: FragmentAttentionFamily, SK: StageFamily, SV: StageFamily, SO: StageFamily, > { - _phantom: PhantomData<(TA, SK, SV, SO)>, + _phantom: PhantomData<(FA, SK, SV, SO)>, } impl< - TA: TileAttentionFamily, + FA: FragmentAttentionFamily, SK: StageFamily, SV: StageFamily, SO: StageFamily, -> StageAttentionFamily for DummyStageAttentionFamily +> StageAttentionFamily for PlaneKVReuseStageAttentionFamily { - type Attention = DummyStageAttention< + type Attention = PlaneKVReuseStageAttention< AP, SK::Stage, AttentionTilingLayout>, SV::Stage, AttentionTilingLayout>, SO::Stage, AttentionTilingLayout>, - TA::Attention, + FA::FragmentAttention, >; type KeyStage = SK; type ValueStage = SV; type OutStage = SO; - type Config = DummyStageConfig; + type Config = PlaneKVReuseStageConfig; fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &AttentionProblem, selection: &AttentionSelection, line_sizes: &AttentionLineSizes, ) -> Result { - let tile_config = TA::setup::(client, problem, selection, line_sizes)?; - let num_planes = selection.tiling_scheme.stage_size.seq_q - * TA::computation_resources()?.num_planes(selection.plane_dim)?; + * FA::computation_resources()?.num_planes(selection.plane_dim)?; + + let tile_config = FA::setup::(client, problem, selection, line_sizes, num_planes)?; - DummyStageConfig::new( + PlaneKVReuseStageConfig::new( tile_config, score_attention_stage_memory_config(selection), value_attention_stage_memory_config(selection), diff --git a/crates/cubecl-attention/src/components/stage/dummy/tile_partitions.rs b/crates/cubecl-attention/src/components/stage/tile_partitions.rs similarity index 56% rename from crates/cubecl-attention/src/components/stage/dummy/tile_partitions.rs rename to crates/cubecl-attention/src/components/stage/tile_partitions.rs index e90eb75c0..50aa6518c 100644 --- a/crates/cubecl-attention/src/components/stage/dummy/tile_partitions.rs +++ b/crates/cubecl-attention/src/components/stage/tile_partitions.rs @@ -4,19 +4,24 @@ use std::marker::PhantomData; use cubecl::prelude::*; use cubecl_core as cubecl; -use crate::components::AttentionIdent; -use crate::components::attention_types::*; -use crate::components::global::dummy::QueryReader; -use crate::components::tile::RunningState; +use crate::components::AttentionTilingScheme; +use crate::components::fragment::FragmentAttention; +use crate::components::tile::AccumulatorTile; +use crate::components::tile::KeyValueTile; +use crate::components::tile::MaskTile; +use crate::components::tile::QueryTile; +use crate::components::tile::SoftmaxTile; use crate::components::{AttentionPrecision, stage::StageAttentionConfig, tile::TileAttention}; +use cubecl_std::CubeOption; +use cubecl_std::tensor::layout::Coords2d; #[derive(CubeType)] -pub struct Accumulators< +pub struct QueryPartition< AP: AttentionPrecision, - TA: TileAttention, - S: StageAttentionConfig, + FA: FragmentAttention, + S: StageAttentionConfig, > { - sequence: Sequence, + sequence: Sequence>, #[cube(comptime)] _phantom: PhantomData, } @@ -24,88 +29,20 @@ pub struct Accumulators< #[cube] impl< AP: AttentionPrecision, - TA: TileAttention, - S: StageAttentionConfig, -> Accumulators + FA: FragmentAttention, + S: StageAttentionConfig, +> QueryPartition { - pub fn new(#[comptime] config: S) -> Accumulators { + pub fn new(#[comptime] config: S) -> QueryPartition { let p = config.tiling_scheme().partition_size; let mut sequence = Sequence::new(); #[unroll] - for _ in 0..comptime!(p.seq_q * p.val_dim) { - sequence.push(TA::init_accumulator(config.tile_config())); + for _ in 0..comptime!(p.seq_q * p.head_dim) { + sequence.push(TileAttention::::init_query(config.tile_config())); } - Accumulators:: { - sequence, - _phantom: PhantomData, - } - } - - pub fn get_at( - &self, - #[comptime] i: u32, - #[comptime] j: u32, - #[comptime] config: S, - ) -> &TA::AccumulatorTile { - let p = config.tiling_scheme().partition_size; - self.sequence.index(comptime!(i * p.val_dim + j)) - } - - pub fn get_at_mut( - &mut self, - #[comptime] i: u32, - #[comptime] j: u32, - #[comptime] config: S, - ) -> &mut TA::AccumulatorTile { - let p = config.tiling_scheme().partition_size; - self.sequence.index_mut(comptime!(i * p.val_dim + j)) - } -} - -#[derive(CubeType)] -pub struct Queries< - AP: AttentionPrecision, - TA: TileAttention, - S: StageAttentionConfig, -> { - sequence: Sequence, - #[cube(comptime)] - _phantom: PhantomData, -} - -#[cube] -impl< - AP: AttentionPrecision, - TA: TileAttention, - S: StageAttentionConfig, -> Queries -{ - pub fn new(query_loader: QueryReader, #[comptime] config: S) -> Queries { - let p = config.tiling_scheme().partition_size; - let mut sequence = Sequence::new(); - - let mut q = comptime!(0u32); - - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..comptime!(p.seq_q) { - let mut hd = comptime!(0u32); - - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..comptime!(p.head_dim) { - let tile = query_loader.get_tile::((q, hd).runtime(), config); - sequence.push(TA::init_query(&tile, config.tile_config())); - - comptime![hd += 1]; - } - - comptime![q += 1]; - } - - Queries:: { + QueryPartition:: { sequence, _phantom: PhantomData, } @@ -116,7 +53,7 @@ impl< #[comptime] q: u32, #[comptime] hd: u32, #[comptime] config: S, - ) -> &TA::QueryTile { + ) -> &QueryTile { let p = config.tiling_scheme().partition_size; self.sequence.index(comptime!(q * p.head_dim + hd)) } @@ -126,7 +63,7 @@ impl< #[comptime] q: u32, #[comptime] hd: u32, #[comptime] config: S, - ) -> &mut TA::QueryTile { + ) -> &mut QueryTile { let p = config.tiling_scheme().partition_size; self.sequence.index_mut(comptime!(q * p.head_dim + hd)) } @@ -135,20 +72,20 @@ impl< #[derive(CubeType)] pub enum KeyValues< AP: AttentionPrecision, - TA: TileAttention, - S: StageAttentionConfig, + FA: FragmentAttention, + S: StageAttentionConfig, > { - Reuse(KeyValueSequence), - Separate(KeyValueSequence, KeyValueSequence), + Reuse(KeyValueSequence), + Separate(KeyValueSequence, KeyValueSequence), } #[derive(CubeType)] pub struct KeyValueSequence< AP: AttentionPrecision, - TA: TileAttention, - S: StageAttentionConfig, + FA: FragmentAttention, + S: StageAttentionConfig, > { - sequence: Sequence, + sequence: Sequence>, #[cube(comptime)] _phantom: PhantomData, } @@ -156,21 +93,21 @@ pub struct KeyValueSequence< #[cube] impl< AP: AttentionPrecision, - TA: TileAttention, - S: StageAttentionConfig, -> KeyValues + FA: FragmentAttention, + S: StageAttentionConfig, +> KeyValues { - pub fn new(#[comptime] config: S) -> KeyValues { + pub fn new(#[comptime] config: S) -> KeyValues { if config.reuse_key_value() { let p = config.tiling_scheme().partition_size; let mut sequence = Sequence::new(); #[unroll] for _ in 0..comptime!(p.seq_kv * max(p.head_dim, p.val_dim)) { - sequence.push(TA::init_key_value(config.tile_config())); + sequence.push(KeyValueTile::new_key_value(config.tile_config())); } - KeyValues::::new_Reuse(KeyValueSequence:: { + KeyValues::::new_Reuse(KeyValueSequence:: { sequence, _phantom: PhantomData, }) @@ -181,19 +118,19 @@ impl< #[unroll] for _ in 0..comptime!(p.head_dim * p.seq_kv) { - keys.push(TA::init_key(config.tile_config())); + keys.push(KeyValueTile::new_key(config.tile_config())); } #[unroll] for _ in 0..comptime!(p.seq_kv * p.val_dim) { - values.push(TA::init_value(config.tile_config())); + values.push(KeyValueTile::new_value(config.tile_config())); } - KeyValues::::new_Separate( - KeyValueSequence:: { + KeyValues::::new_Separate( + KeyValueSequence:: { sequence: keys, _phantom: PhantomData, }, - KeyValueSequence:: { + KeyValueSequence:: { sequence: values, _phantom: PhantomData, }, @@ -206,7 +143,7 @@ impl< #[comptime] hd: u32, #[comptime] kv: u32, #[comptime] config: S, - ) -> &TA::KeyValueTile { + ) -> &KeyValueTile { let index = hd * config.tiling_scheme().partition_size.seq_kv + kv; match self { KeyValues::Reuse(key_values) => key_values.sequence.index(index), @@ -219,7 +156,7 @@ impl< #[comptime] hd: u32, #[comptime] kv: u32, #[comptime] config: S, - ) -> &mut TA::KeyValueTile { + ) -> &mut KeyValueTile { let index = hd * config.tiling_scheme().partition_size.seq_kv + kv; match self { KeyValues::Reuse(key_values) => key_values.sequence.index_mut(index), @@ -232,7 +169,7 @@ impl< #[comptime] kv: u32, #[comptime] vd: u32, #[comptime] config: S, - ) -> &TA::KeyValueTile { + ) -> &KeyValueTile { let index = kv * config.tiling_scheme().partition_size.val_dim + vd; match self { KeyValues::Reuse(key_values) => key_values.sequence.index(index), @@ -245,7 +182,7 @@ impl< #[comptime] kv: u32, #[comptime] vd: u32, #[comptime] config: S, - ) -> &mut TA::KeyValueTile { + ) -> &mut KeyValueTile { let index = kv * config.tiling_scheme().partition_size.val_dim + vd; match self { KeyValues::Reuse(key_values) => key_values.sequence.index_mut(index), @@ -257,10 +194,10 @@ impl< #[derive(CubeType)] pub struct SoftmaxPartition< AP: AttentionPrecision, - TA: TileAttention, - S: StageAttentionConfig, + FA: FragmentAttention, + S: StageAttentionConfig, > { - sequence: Sequence, + sequence: Sequence>, #[cube(comptime)] _phantom: PhantomData, } @@ -268,20 +205,20 @@ pub struct SoftmaxPartition< #[cube] impl< AP: AttentionPrecision, - TA: TileAttention, - S: StageAttentionConfig, -> SoftmaxPartition + FA: FragmentAttention, + S: StageAttentionConfig, +> SoftmaxPartition { - pub fn new(#[comptime] config: S) -> SoftmaxPartition { + pub fn new(#[comptime] config: S) -> SoftmaxPartition { let p = config.tiling_scheme().partition_size; let mut sequence = Sequence::new(); #[unroll] for _ in 0..comptime!(p.seq_q * p.seq_kv) { - sequence.push(TA::init_softmax(config.tile_config())); + sequence.push(SoftmaxTile::new(config.tile_config())); } - SoftmaxPartition:: { + SoftmaxPartition:: { sequence, _phantom: PhantomData, } @@ -292,7 +229,7 @@ impl< #[comptime] q: u32, #[comptime] kv: u32, #[comptime] config: S, - ) -> &TA::SoftmaxTile { + ) -> &SoftmaxTile { let index = q * config.tiling_scheme().partition_size.seq_kv + kv; self.sequence.index(index) } @@ -302,38 +239,130 @@ impl< #[comptime] q: u32, #[comptime] kv: u32, #[comptime] config: S, - ) -> &mut TA::SoftmaxTile { + ) -> &mut SoftmaxTile { let index = q * config.tiling_scheme().partition_size.seq_kv + kv; self.sequence.index_mut(index) } } #[derive(CubeType)] -pub struct StageState { - sequence: Sequence>>, +pub struct MaskPartition< + AP: AttentionPrecision, + FA: FragmentAttention, + S: StageAttentionConfig, +> { + sequence: Sequence>, + #[cube(comptime)] + _phantom: PhantomData, } #[cube] -impl StageState { - pub fn init(#[comptime] config: S) -> StageState { +impl< + AP: AttentionPrecision, + FA: FragmentAttention, + S: StageAttentionConfig, +> MaskPartition +{ + pub fn new( + out_of_bounds: CubeOption, + #[comptime] config: S, + ) -> MaskPartition { let p = config.tiling_scheme().partition_size; let mut sequence = Sequence::new(); + let mut q = comptime![0]; + #[unroll] - for _ in 0..comptime!(p.seq_q) { - sequence.push(RunningState::>::init( - config.num_rows_per_unit(AttentionIdent::Softmax), - )); + for _ in 0..p.seq_q { + let mut kv = comptime![0]; + + #[unroll] + for _ in 0..p.seq_kv { + sequence.push(MaskTile::new(out_of_bounds, (q, kv), config.tile_config())); + + comptime![kv += 1]; + } + + comptime![q += 1]; + } + + MaskPartition:: { + sequence, + _phantom: PhantomData, + } + } + + pub fn get_at( + &self, + #[comptime] q: u32, + #[comptime] kv: u32, + #[comptime] tiling_scheme: AttentionTilingScheme, + ) -> &MaskTile { + let p = tiling_scheme.partition_size; + self.sequence.index(comptime!(q * p.seq_kv + kv)) + } + + pub fn get_at_mut( + &mut self, + #[comptime] q: u32, + #[comptime] kv: u32, + #[comptime] tiling_scheme: AttentionTilingScheme, + ) -> &mut MaskTile { + let p = tiling_scheme.partition_size; + self.sequence.index_mut(comptime!(q * p.seq_kv + kv)) + } +} + +#[derive(CubeType)] +pub struct AccumulatorPartition< + AP: AttentionPrecision, + FA: FragmentAttention, + S: StageAttentionConfig, +> { + sequence: Sequence>, + #[cube(comptime)] + _phantom: PhantomData, +} + +#[cube] +impl< + AP: AttentionPrecision, + FA: FragmentAttention, + S: StageAttentionConfig, +> AccumulatorPartition +{ + pub fn new(#[comptime] config: S) -> AccumulatorPartition { + let p = config.tiling_scheme().partition_size; + let mut sequence = Sequence::new(); + + #[unroll] + for _ in 0..comptime!(p.seq_q * p.val_dim) { + sequence.push(AccumulatorTile::new(config.tile_config())); } - StageState:: { sequence } + AccumulatorPartition:: { + sequence, + _phantom: PhantomData, + } } - pub fn get_at(&self, #[comptime] q: u32) -> &RunningState> { - self.sequence.index(q) + pub fn get_at( + &self, + #[comptime] i: u32, + #[comptime] j: u32, + #[comptime] config: S, + ) -> &AccumulatorTile { + let p = config.tiling_scheme().partition_size; + self.sequence.index(comptime!(i * p.val_dim + j)) } - pub fn get_at_mut(&mut self, #[comptime] q: u32) -> &mut RunningState> { - self.sequence.index_mut(q) + pub fn get_at_mut( + &mut self, + #[comptime] i: u32, + #[comptime] j: u32, + #[comptime] config: S, + ) -> &mut AccumulatorTile { + let p = config.tiling_scheme().partition_size; + self.sequence.index_mut(comptime!(i * p.val_dim + j)) } } diff --git a/crates/cubecl-attention/src/components/stage/unit/attention.rs b/crates/cubecl-attention/src/components/stage/unit/attention.rs new file mode 100644 index 000000000..b51087bab --- /dev/null +++ b/crates/cubecl-attention/src/components/stage/unit/attention.rs @@ -0,0 +1,35 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_std::tensor::layout::Coords1d; + +use crate::components::{ + fragment::FragmentAttention, + global::simple::UnitAttentionWriter, + stage::{ + kv_reuse_attention::KVReuseStageAttention, partitioner::AttentionPartitioner, + unit::UnitKVReuseStageConfig, + }, + tile::UnitReducer, +}; + +pub type UnitKVReuseStageAttention = KVReuseStageAttention< + AP, + SK, + SV, + SO, + FA, + UnitPartitioner, + UnitKVReuseStageConfig<>::Config>, +>; + +pub struct UnitPartitioner {} + +#[cube] +impl AttentionPartitioner for UnitPartitioner { + type Reducer = UnitReducer; + type Writer = UnitAttentionWriter; + + fn seq_q_index() -> Coords1d { + UNIT_POS + } +} diff --git a/crates/cubecl-attention/src/components/stage/dummy/config.rs b/crates/cubecl-attention/src/components/stage/unit/config.rs similarity index 55% rename from crates/cubecl-attention/src/components/stage/dummy/config.rs rename to crates/cubecl-attention/src/components/stage/unit/config.rs index c9c142a9c..a29b7930c 100644 --- a/crates/cubecl-attention/src/components/stage/dummy/config.rs +++ b/crates/cubecl-attention/src/components/stage/unit/config.rs @@ -1,13 +1,12 @@ -use cubecl_matmul::components::{MatrixLayout, StageIdent, TilingScheme, stage::StageMemoryConfig}; - use crate::components::{ - AttentionIdent, AttentionSetupError, AttentionTilingScheme, stage::StageAttentionConfig, - tile::dummy::AttentionMatmulConfig, + AttentionSetupError, AttentionTilingScheme, + fragment::FragmentAttentionConfig, + stage::{AttentionStageMemoryConfig, StageAttentionConfig}, }; #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] -pub struct DummyStageConfig { - tile_config: FC, +pub struct UnitKVReuseStageConfig { + fragment_config: FC, score_stage_memory_config: AttentionStageMemoryConfig, value_stage_memory_config: AttentionStageMemoryConfig, tiling_scheme: AttentionTilingScheme, @@ -15,19 +14,19 @@ pub struct DummyStageConfig { num_planes: u32, } -impl StageAttentionConfig for DummyStageConfig { - type AttentionMatmulConfig = FC; +impl StageAttentionConfig for UnitKVReuseStageConfig { + type FragmentAttentionConfig = FC; fn plane_dim(&self) -> u32 { - 32 + self.fragment_config.plane_dim() } fn num_planes(&self) -> u32 { self.num_planes } - fn tile_config(&self) -> Self::AttentionMatmulConfig { - self.tile_config + fn tile_config(&self) -> Self::FragmentAttentionConfig { + self.fragment_config } fn score_stage_memory_config(&self) -> AttentionStageMemoryConfig { @@ -46,14 +45,14 @@ impl StageAttentionConfig for DummyStageConfig { self.reuse_key_value } - fn num_rows_per_unit(&self, ident: AttentionIdent) -> u32 { - self.tile_config.num_rows_per_unit(ident) + fn num_rows_per_unit(&self) -> u32 { + self.fragment_config.num_rows_per_unit() } } -impl DummyStageConfig { +impl UnitKVReuseStageConfig { pub fn new( - tile_config: FC, + fragment_config: FC, score_stage_memory_config: AttentionStageMemoryConfig, value_stage_memory_config: AttentionStageMemoryConfig, tiling_scheme: AttentionTilingScheme, @@ -61,7 +60,7 @@ impl DummyStageConfig { num_planes: u32, ) -> Result { Self { - tile_config, + fragment_config, score_stage_memory_config, value_stage_memory_config, tiling_scheme, @@ -86,23 +85,3 @@ impl DummyStageConfig { Ok(self) } } - -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] -pub struct AttentionStageMemoryConfig { - pub matmul_tiling_scheme: TilingScheme, -} - -impl AttentionStageMemoryConfig { - pub fn into_matmul_config(&self, ident: StageIdent) -> StageMemoryConfig { - StageMemoryConfig { - num_main_flow_planes: 1, - elements_in_tile_row: self.matmul_tiling_scheme.elements_in_tile_row(ident), - elements_in_tile_col: self.matmul_tiling_scheme.elements_in_tile_col(ident), - tiles_in_stage_row: self.matmul_tiling_scheme.tiles_in_stage_row(ident), - tiles_in_stage_col: self.matmul_tiling_scheme.tiles_in_stage_col(ident), - stage_line_size: 1, - matrix_layout: MatrixLayout::RowMajor, - num_stages: 1, - } - } -} diff --git a/crates/cubecl-attention/src/components/stage/unit/mod.rs b/crates/cubecl-attention/src/components/stage/unit/mod.rs new file mode 100644 index 000000000..cce5602e3 --- /dev/null +++ b/crates/cubecl-attention/src/components/stage/unit/mod.rs @@ -0,0 +1,7 @@ +mod attention; +mod config; +mod setup; + +pub use attention::*; +pub use config::*; +pub use setup::*; diff --git a/crates/cubecl-attention/src/components/stage/unit/setup.rs b/crates/cubecl-attention/src/components/stage/unit/setup.rs new file mode 100644 index 000000000..c8961a18b --- /dev/null +++ b/crates/cubecl-attention/src/components/stage/unit/setup.rs @@ -0,0 +1,115 @@ +use std::marker::PhantomData; + +use crate::components::{ + attention_types::*, + fragment::FragmentAttentionFamily, + stage::{ + AttentionStageMemoryConfig, + unit::{UnitKVReuseStageAttention, config::UnitKVReuseStageConfig}, + }, +}; +use cubecl_core::{client::ComputeClient, prelude::ReadWrite}; +use cubecl_matmul::components::{ + ComputeResources, GlobalPartitionSize, TilingScheme, stage::StageFamily, tile::io::Strided, +}; + +use crate::components::{ + AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, + AttentionSetupError, stage::StageAttentionFamily, tile::AttentionTilingLayout, +}; + +pub struct UnitKVReuseStageAttentionFamily< + FA: FragmentAttentionFamily, + SK: StageFamily, + SV: StageFamily, + SO: StageFamily, +> { + _phantom: PhantomData<(FA, SK, SV, SO)>, +} + +impl< + FA: FragmentAttentionFamily, + SK: StageFamily, + SV: StageFamily, + SO: StageFamily, +> StageAttentionFamily for UnitKVReuseStageAttentionFamily +{ + type Attention = UnitKVReuseStageAttention< + AP, + SK::Stage, AttentionTilingLayout>, + SV::Stage, AttentionTilingLayout>, + SO::Stage, AttentionTilingLayout>, + FA::FragmentAttention, + >; + + type KeyStage = SK; + type ValueStage = SV; + type OutStage = SO; + + type Config = UnitKVReuseStageConfig; + + fn setup( + client: &ComputeClient, + problem: &AttentionProblem, + selection: &AttentionSelection, + line_sizes: &AttentionLineSizes, + ) -> Result { + let compute_resources = if let ComputeResources::Units(units) = FA::computation_resources()? + { + ComputeResources::Units(units * selection.tiling_scheme.stage_size.seq_q) + } else { + return Err(AttentionSetupError::InvalidConfig(Box::new( + "Error: Tried to use a unit stage attention with a plane fragment attention." + .to_string(), + ))); + }; + + let num_planes = compute_resources.num_planes(selection.plane_dim)?; + let tile_config = FA::setup::(client, problem, selection, line_sizes, num_planes)?; + + UnitKVReuseStageConfig::new( + tile_config, + score_attention_stage_memory_config(selection), + value_attention_stage_memory_config(selection), + selection.tiling_scheme, + selection.reuse_key_value, + num_planes, + ) + } +} + +fn score_attention_stage_memory_config( + selection: &AttentionSelection, +) -> AttentionStageMemoryConfig { + let att_tile_size = selection.tiling_scheme.tile_size; + let att_partition_size = selection.tiling_scheme.partition_size; + let att_stage_size = selection.tiling_scheme.stage_size; + + let matmul_tiling_scheme = TilingScheme { + tile_size: att_tile_size.to_score_matmul_tile_size(), + partition_size: att_partition_size.to_score_matmul_partition_size(), + stage_size: (att_stage_size.seq_q, 1, 1).into(), + global_partition_size: GlobalPartitionSize::new(1, 1, 1), + }; + AttentionStageMemoryConfig { + matmul_tiling_scheme, + } +} + +fn value_attention_stage_memory_config( + selection: &AttentionSelection, +) -> AttentionStageMemoryConfig { + let att_tile_size = selection.tiling_scheme.tile_size; + let att_partition_size = selection.tiling_scheme.partition_size; + let att_stage_size = selection.tiling_scheme.stage_size; + + let matmul_tiling_scheme = TilingScheme { + tile_size: att_tile_size.to_value_matmul_tile_size(), + partition_size: att_partition_size.to_value_matmul_partition_size(), + stage_size: (att_stage_size.seq_q, 1, 1).into(), + global_partition_size: GlobalPartitionSize::new(1, 1, 1), + }; + AttentionStageMemoryConfig { + matmul_tiling_scheme, + } +} diff --git a/crates/cubecl-attention/src/components/tile/base.rs b/crates/cubecl-attention/src/components/tile/base.rs index 987a1efc5..84f6a863b 100644 --- a/crates/cubecl-attention/src/components/tile/base.rs +++ b/crates/cubecl-attention/src/components/tile/base.rs @@ -1,115 +1,159 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; use cubecl_matmul::components::{ - ComputeResources, stage::{ContiguousTilingLayout, RowMajorTilingOrder}, tile::StridedTile, }; +use crate::components::tile::{ + AccumulatorTile, KeyValueTile, MaskTile, QueryTile, Reducer, SoftmaxTile, +}; use crate::components::{ - AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, - AttentionSetupError, AvailableLineSizes, + AttentionPrecision, attention_types::*, - tile::{RowWise, RunningState, dummy::AttentionMatmulConfig}, + fragment::FragmentAttentionConfig, + tile::{RowWise, RunningState}, }; -use crate::components::{InvalidConfigError, tile::AccumulatorTile}; -use crate::components::{TileMask, tile::SoftmaxTile}; +use std::marker::PhantomData; + +use crate::components::fragment::FragmentAttention; +use cubecl_std::CubeOption; +use cubecl_std::tensor::layout::Coords2d; pub type AttentionTilingLayout = ContiguousTilingLayout; -/// A family of [TileAttention] implementations that operate with any [precision](AttentionPrecision). -pub trait TileAttentionFamily: Send + Sync + 'static { - /// The specific [TileAttention] implementation associated with this family. - type Attention: TileAttention; +#[derive(CubeType)] +pub struct TileAttention> { + #[cube(comptime)] + _phantom: PhantomData<(AP, FA)>, +} - /// The configuration type associated with this Attention family. - type Config: AttentionMatmulConfig; +#[cube] +impl> TileAttention { + pub fn rescale(acc: &mut AccumulatorTile, prev_state: &RunningState>) { + acc.scale_div(prev_state.l()); + } - /// Constructs the configuration based on the Attention problem, selection, and line sizes. - /// - /// This function may return an error if the configuration cannot be supported on the current runtime. - fn setup( - client: &ComputeClient, - problem: &AttentionProblem, - selection: &AttentionSelection, - line_sizes: &AttentionLineSizes, - ) -> Result; + pub fn write_results( + tile: &mut StridedTile, ReadWrite>, + acc: &AccumulatorTile, + #[comptime] config: FA::Config, + ) { + FA::write_results(&acc.fragment, &mut tile.slice, config) + } - /// Filters out line sizes that are incompatible with this Attention family. - /// - /// By default, returns the input unchanged. - fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes { - available_line_sizes + pub fn init_accumulator(#[comptime] config: FA::Config) -> AccumulatorTile { + AccumulatorTile::new(config) } - fn computation_resources() -> Result; -} + pub fn init_query(#[comptime] config: FA::Config) -> QueryTile { + QueryTile::new(config) + } -#[cube] -pub trait TileAttention: 'static + Send + Sync { - /// The configuration type associated with this Attention. - type Config: AttentionMatmulConfig; - - type QueryTile: CubeType; - type KeyValueTile: CubeType; - type SoftmaxTile: SoftmaxTile; - type AccumulatorTile: AccumulatorTile>; - - fn rescale( - acc: &mut Self::AccumulatorTile, - prev_state: &RunningState>, - #[comptime] config: Self::Config, - ); - - fn write_results( - tile: &mut StridedTile, ReadWrite>, - acc: &Self::AccumulatorTile, - #[comptime] tile_config: Self::Config, - ); + pub fn init_key_value(#[comptime] config: FA::Config) -> KeyValueTile { + KeyValueTile::new_key_value(config) + } + + pub fn init_key(#[comptime] config: FA::Config) -> KeyValueTile { + KeyValueTile::new_key(config) + } - fn init_accumulator(#[comptime] config: Self::Config) -> Self::AccumulatorTile; + pub fn init_value(#[comptime] config: FA::Config) -> KeyValueTile { + KeyValueTile::new_value(config) + } - fn init_query(tile: &StridedTile>, #[comptime] config: Self::Config) -> Self::QueryTile; + pub fn init_mask( + out_of_bounds: CubeOption, + #[comptime] partition_pos: Coords2d, + #[comptime] config: FA::Config, + ) -> MaskTile { + MaskTile::new(out_of_bounds, partition_pos, config) + } - fn init_key_value(#[comptime] config: Self::Config) -> Self::KeyValueTile; - fn init_key(#[comptime] config: Self::Config) -> Self::KeyValueTile; - fn init_value(#[comptime] config: Self::Config) -> Self::KeyValueTile; + pub fn init_softmax(#[comptime] config: FA::Config) -> SoftmaxTile { + SoftmaxTile::new(config) + } - fn init_softmax(#[comptime] config: Self::Config) -> Self::SoftmaxTile; + pub fn init_state(#[comptime] config: FA::Config) -> RunningState> { + RunningState::>::init(config.num_rows_per_unit()) + } - fn fill_key( + pub fn fill_key( tile: &StridedTile, - rhs: &mut Self::KeyValueTile, - #[comptime] config: Self::Config, - ); + registers: &mut KeyValueTile, + #[comptime] config: FA::Config, + ) { + FA::fill_key_value(tile, registers.key_mut(), config); + } - fn fill_value( + pub fn fill_value( tile: &StridedTile, - rhs: &mut Self::KeyValueTile, - #[comptime] config: Self::Config, - ); - - fn zero_softmax(score: &mut Self::SoftmaxTile, #[comptime] config: Self::Config); - - fn accumulate_score( - query: &Self::QueryTile, - key_value: &Self::KeyValueTile, - softmax: &mut Self::SoftmaxTile, - #[comptime] config: Self::Config, - ); - - fn softmax( - softmax: &mut Self::SoftmaxTile, - mask: TileMask, + registers: &mut KeyValueTile, + #[comptime] config: FA::Config, + ) { + FA::fill_key_value(tile, registers.value_mut(), config); + } + + pub fn zero_softmax(score: &mut SoftmaxTile, #[comptime] config: FA::Config) { + FA::zero_softmax(&mut score.fragment, config); + } + + pub fn accumulate_score( + query: &QueryTile, + key_value: &KeyValueTile, + softmax: &mut SoftmaxTile, + #[comptime] config: FA::Config, + ) { + FA::score_matmul( + &query.fragment, + key_value.key(), + &mut softmax.fragment, + config, + ); + } + + pub fn softmax( + softmax: &mut SoftmaxTile, + mask: &MaskTile, state: &mut RunningState>, + max_placeholder: &mut RowWise>, + sum_placeholder: &mut RowWise>, #[comptime] dk: u32, - ) -> RowWise>; - - fn accumulate_value( - softmax: &Self::SoftmaxTile, - key_value: &Self::KeyValueTile, - accumulator: &mut Self::AccumulatorTile, - scale: &RowWise>, - #[comptime] config: Self::Config, - ); + #[comptime] config: FA::Config, + ) -> RowWise> { + SoftmaxTile::scale_and_mask( + softmax, + SM::::new(comptime!(1.0 / (dk as f32).sqrt())), + mask, + ); + + softmax.row_max::(max_placeholder, state.m(), config); + + softmax.to_prob::(state, max_placeholder, sum_placeholder, config) + } + + pub fn accumulate_value( + softmax: &SoftmaxTile, + key_value: &KeyValueTile, + accumulator: &mut AccumulatorTile, + scale: &RowWise>, + #[comptime] config: FA::Config, + ) { + accumulator.scale_mul(scale); + + FA::value_matmul( + &softmax.fragment, + key_value.value(), + &mut accumulator.fragment, + config, + ); + } + + pub fn init_max_placeholder(#[comptime] num_rows: u32) -> RowWise> { + RowWise::new_min_value(num_rows) + } + + pub fn init_sum_placeholder(#[comptime] num_rows: u32) -> RowWise> { + RowWise::new_zero(num_rows) + } } diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention.rs b/crates/cubecl-attention/src/components/tile/dummy/attention.rs deleted file mode 100644 index e0bed2cca..000000000 --- a/crates/cubecl-attention/src/components/tile/dummy/attention.rs +++ /dev/null @@ -1,140 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; -use cubecl_matmul::components::tile::StridedTile; -use std::marker::PhantomData; - -use crate::components::TileMask; -use crate::components::attention_types::*; -use crate::components::tile::AccumulatorTile as _; -use crate::components::tile::AccumulatorTileExpand; -use crate::components::tile::ScaleMode; -use crate::components::tile::SoftmaxTileExpand; -use crate::components::tile::dummy::DummyAccumulator; -use crate::components::tile::dummy::{AttentionMatmul, DummySoftmax}; -use crate::components::tile::{RowWise, RunningState, SoftmaxTile, TileAttention}; -use crate::components::{ - AttentionPrecision, - tile::dummy::{KeyValueFragment, QueryFragment}, -}; - -pub struct DummyTileAttention> { - _phantom: PhantomData<(AP, AM)>, -} - -#[cube] -impl> TileAttention - for DummyTileAttention -{ - type Config = AM::Config; - - type QueryTile = QueryFragment; - type KeyValueTile = KeyValueFragment; - type SoftmaxTile = DummySoftmax; - type AccumulatorTile = DummyAccumulator; - - fn rescale( - acc: &mut Self::AccumulatorTile, - prev_state: &RunningState>, - #[comptime] _config: Self::Config, - ) { - acc.scale(&prev_state.l.cast::>(), ScaleMode::Divide); - } - - fn write_results( - tile: &mut StridedTile, ReadWrite>, - acc: &Self::AccumulatorTile, - #[comptime] tile_config: Self::Config, - ) { - AM::write_results(&acc.fragment, &mut tile.slice, tile_config) - } - - fn init_accumulator(#[comptime] config: Self::Config) -> Self::AccumulatorTile { - Self::AccumulatorTile::new(config) - } - - fn init_query(tile: &StridedTile>, #[comptime] config: Self::Config) -> Self::QueryTile { - Self::QueryTile::new(tile, config) - } - - fn init_key_value(#[comptime] config: Self::Config) -> Self::KeyValueTile { - Self::KeyValueTile::new_key_value(config) - } - - fn init_key(#[comptime] config: Self::Config) -> Self::KeyValueTile { - Self::KeyValueTile::new_key(config) - } - - fn init_value(#[comptime] config: Self::Config) -> Self::KeyValueTile { - Self::KeyValueTile::new_value(config) - } - - fn init_softmax(#[comptime] config: Self::Config) -> Self::SoftmaxTile { - Self::SoftmaxTile::new(config) - } - - fn fill_key( - tile: &StridedTile, - rhs: &mut Self::KeyValueTile, - #[comptime] config: Self::Config, - ) { - AM::fill_key_value(tile, rhs.key_mut(), config); - } - - fn fill_value( - tile: &StridedTile, - rhs: &mut Self::KeyValueTile, - #[comptime] config: Self::Config, - ) { - AM::fill_key_value(tile, rhs.value_mut(), config); - } - - fn zero_softmax(score: &mut Self::SoftmaxTile, #[comptime] config: Self::Config) { - AM::zero_softmax(&mut score.fragment, config); - } - - fn accumulate_score( - query: &Self::QueryTile, - key_value: &Self::KeyValueTile, - softmax: &mut Self::SoftmaxTile, - #[comptime] config: Self::Config, - ) { - AM::score_matmul( - &query.fragment, - key_value.key(), - &mut softmax.fragment, - config, - ); - } - - fn softmax( - softmax: &mut Self::SoftmaxTile, - mask: TileMask, - state: &mut RunningState>, - #[comptime] dk: u32, - ) -> RowWise> { - let inv_sqrt_dk = SM::::new(comptime!(1.0 / (dk as f32).sqrt())); - - softmax.scale_and_mask(inv_sqrt_dk, mask); - - let score_max = softmax.row_max(state.m.copy()); - - softmax.to_prob(state, &score_max) - } - - fn accumulate_value( - softmax: &Self::SoftmaxTile, - key_value: &Self::KeyValueTile, - accumulator: &mut Self::AccumulatorTile, - scale: &RowWise>, - #[comptime] config: Self::Config, - ) { - accumulator.scale(scale, ScaleMode::Multiply); - - AM::value_matmul( - &softmax.fragment, - key_value.value(), - &mut accumulator.fragment, - config, - ); - } -} diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/config.rs b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/config.rs deleted file mode 100644 index 95036a53b..000000000 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/accelerated/config.rs +++ /dev/null @@ -1,182 +0,0 @@ -use cubecl_matmul::components::{MatrixLayout, StageIdent, TileSize, tile::TileConfig}; -use std::fmt::Debug; -use std::hash::Hash; - -use crate::components::{ - AttentionIdent, AttentionPrecision, AttentionSetupError, AttentionTileSize, attention_types::*, - tile::dummy::AttentionMatmulConfig, -}; -use cubecl_core::frontend::CubePrimitive; - -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] -pub struct AcceleratedAttentionMatmulConfig { - plane_dim: u32, - score_config: ScoreConfig, - value_config: ValueConfig, - attention_tile_size: AttentionTileSize, - num_planes: u32, - query_stage_line_size: u32, - key_value_stage_line_size: u32, - cast_query: bool, - check_bounds: bool, -} - -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] -pub struct ScoreConfig { - plane_dim: u32, - tile_size: TileSize, - query_stage_line_size: u32, - key_value_stage_line_size: u32, -} - -impl TileConfig for ScoreConfig { - fn plane_dim(&self) -> u32 { - self.plane_dim - } - - fn matrix_layout(&self, _ident: StageIdent) -> MatrixLayout { - MatrixLayout::RowMajor - } - - fn stage_line_size(&self, ident: StageIdent) -> u32 { - match ident { - StageIdent::Lhs => self.query_stage_line_size, - StageIdent::Rhs => self.key_value_stage_line_size, - StageIdent::Acc => todo!(), - StageIdent::Out => todo!(), - } - } - - fn global_line_size(&self, _ident: StageIdent) -> u32 { - panic!() - } - - fn tile_size(&self) -> &TileSize { - &self.tile_size - } -} - -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] -pub struct ValueConfig { - plane_dim: u32, - tile_size: TileSize, - key_value_stage_line_size: u32, -} - -impl TileConfig for ValueConfig { - fn plane_dim(&self) -> u32 { - self.plane_dim - } - - fn matrix_layout(&self, _ident: StageIdent) -> MatrixLayout { - MatrixLayout::RowMajor - } - - fn stage_line_size(&self, ident: StageIdent) -> u32 { - match ident { - StageIdent::Lhs => todo!(), - StageIdent::Rhs => self.key_value_stage_line_size, - StageIdent::Acc => todo!(), - StageIdent::Out => todo!(), - } - } - - fn global_line_size(&self, _ident: StageIdent) -> u32 { - panic!() - } - - fn tile_size(&self) -> &TileSize { - &self.tile_size - } -} - -impl AttentionMatmulConfig for AcceleratedAttentionMatmulConfig { - fn plane_dim(&self) -> u32 { - self.plane_dim - } - - fn num_planes(&self) -> u32 { - self.num_planes - } - - fn stage_line_size(&self, ident: AttentionIdent) -> u32 { - match ident { - AttentionIdent::Query => self.query_stage_line_size, - AttentionIdent::Key => self.key_value_stage_line_size, - AttentionIdent::Softmax => unreachable!("Not a materialized stage"), - AttentionIdent::Value => self.key_value_stage_line_size, - AttentionIdent::Mask => todo!(), - AttentionIdent::Out => 1, - } - } - - fn attention_tile_size(&self) -> AttentionTileSize { - self.attention_tile_size - } - - fn cast_query(&self) -> bool { - self.cast_query - } - - fn num_units_per_row(&self, ident: AttentionIdent) -> u32 { - // TODO depends on layout, this assumes they are all in the same row - self.plane_dim / self.attention_tile_size.num_rows(ident) - } - - fn num_cols_per_unit(&self, ident: AttentionIdent) -> u32 { - // TODO depends on layout, this assumes they are all in the same row - self.attention_tile_size - .num_cols(ident) - .div_ceil(self.num_units_per_row(ident)) - } - - fn check_bounds(&self) -> bool { - self.check_bounds - } - - fn num_rows_per_unit(&self, ident: AttentionIdent) -> u32 { - // TODO depends on layout, this assumes they are all in the same row - self.attention_tile_size.num_rows(ident) / self.plane_dim - } -} - -impl AcceleratedAttentionMatmulConfig { - pub fn new( - plane_dim: u32, - attention_tile_size: AttentionTileSize, - num_planes: u32, - query_stage_line_size: u32, - key_value_stage_line_size: u32, - check_bounds: bool, - ) -> Result { - let score_config = ScoreConfig { - plane_dim, - tile_size: attention_tile_size.to_score_matmul_tile_size(), - query_stage_line_size, - key_value_stage_line_size, - }; - let value_config = ValueConfig { - plane_dim, - tile_size: attention_tile_size.to_value_matmul_tile_size(), - key_value_stage_line_size, - }; - - Self { - plane_dim, - score_config, - value_config, - attention_tile_size, - num_planes, - query_stage_line_size, - key_value_stage_line_size, - cast_query: QG::::as_type_native_unchecked() - == QT::::as_type_native_unchecked(), - check_bounds, - } - .validate() - } - - pub fn validate(self) -> Result { - Ok(self) - } -} diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/config.rs b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/config.rs deleted file mode 100644 index 49b8b8a1e..000000000 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/config.rs +++ /dev/null @@ -1,171 +0,0 @@ -use cubecl_matmul::components::{MatrixLayout, StageIdent, TileSize, tile::TileConfig}; -use std::fmt::Debug; -use std::hash::Hash; - -use crate::components::attention_types::*; -use crate::components::{ - AttentionIdent, AttentionPrecision, AttentionSetupError, AttentionTileSize, - tile::dummy::AttentionMatmulConfig, -}; -use cubecl_core::frontend::CubePrimitive; - -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] -pub struct DummyRegisterAttentionMatmulConfig { - plane_dim: u32, - attention_tile_size: AttentionTileSize, - num_planes: u32, - query_stage_line_size: u32, - key_value_stage_line_size: u32, - cast_query: bool, - check_bounds: bool, -} - -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] -pub struct ScoreConfig { - plane_dim: u32, - tile_size: TileSize, - query_stage_line_size: u32, - key_value_stage_line_size: u32, -} - -impl TileConfig for ScoreConfig { - fn plane_dim(&self) -> u32 { - self.plane_dim - } - - fn matrix_layout(&self, _ident: StageIdent) -> MatrixLayout { - MatrixLayout::RowMajor - } - - fn stage_line_size(&self, ident: StageIdent) -> u32 { - match ident { - StageIdent::Lhs => self.query_stage_line_size, - StageIdent::Rhs => self.key_value_stage_line_size, - StageIdent::Acc => todo!(), - StageIdent::Out => todo!(), - } - } - - fn global_line_size(&self, _ident: StageIdent) -> u32 { - panic!() - } - - fn tile_size(&self) -> &TileSize { - &self.tile_size - } -} - -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] -pub struct ValueConfig { - plane_dim: u32, - tile_size: TileSize, - key_value_stage_line_size: u32, -} - -impl TileConfig for ValueConfig { - fn plane_dim(&self) -> u32 { - self.plane_dim - } - - fn matrix_layout(&self, _ident: StageIdent) -> MatrixLayout { - MatrixLayout::RowMajor - } - - fn stage_line_size(&self, ident: StageIdent) -> u32 { - match ident { - StageIdent::Lhs => todo!(), - StageIdent::Rhs => self.key_value_stage_line_size, - StageIdent::Acc => todo!(), - StageIdent::Out => todo!(), - } - } - - fn global_line_size(&self, _ident: StageIdent) -> u32 { - panic!() - } - - fn tile_size(&self) -> &TileSize { - &self.tile_size - } -} - -impl AttentionMatmulConfig for DummyRegisterAttentionMatmulConfig { - fn plane_dim(&self) -> u32 { - self.plane_dim - } - - fn num_planes(&self) -> u32 { - self.num_planes - } - - fn stage_line_size(&self, ident: AttentionIdent) -> u32 { - match ident { - AttentionIdent::Query => self.query_stage_line_size, - AttentionIdent::Key => self.key_value_stage_line_size, - AttentionIdent::Softmax => unreachable!("Not a materialized stage"), - AttentionIdent::Value => self.key_value_stage_line_size, - AttentionIdent::Mask => todo!(), - AttentionIdent::Out => 1, - } - } - - fn attention_tile_size(&self) -> AttentionTileSize { - self.attention_tile_size - } - - fn cast_query(&self) -> bool { - self.cast_query - } - - fn num_units_per_row(&self, ident: AttentionIdent) -> u32 { - self.plane_dim / self.attention_tile_size.num_rows(ident) - } - - fn num_cols_per_unit(&self, ident: AttentionIdent) -> u32 { - self.attention_tile_size - .num_cols(ident) - .div_ceil(self.num_units_per_row(ident)) - } - - fn num_rows_per_unit(&self, ident: AttentionIdent) -> u32 { - self.attention_tile_size - .num_rows(ident) - .div_ceil(self.plane_dim) - } - - fn check_bounds(&self) -> bool { - self.check_bounds - } -} - -impl DummyRegisterAttentionMatmulConfig { - pub fn new( - plane_dim: u32, - attention_tile_size: AttentionTileSize, - num_planes: u32, - query_stage_line_size: u32, - key_value_stage_line_size: u32, - check_bounds: bool, - ) -> Result { - Self { - plane_dim, - attention_tile_size, - num_planes, - query_stage_line_size, - key_value_stage_line_size, - cast_query: QG::::as_type_native_unchecked() - == QT::::as_type_native_unchecked(), - check_bounds, - } - .validate() - } - - pub fn validate(self) -> Result { - if self.attention_tile_size.head_dim < self.attention_tile_size.val_dim { - return Err(AttentionSetupError::InvalidConfig(Box::new( - "Can't have tile head_dim < tile val dim (not sure why)", - ))); - } - Ok(self) - } -} diff --git a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/matmul.rs b/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/matmul.rs deleted file mode 100644 index 97d543d37..000000000 --- a/crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/matmul.rs +++ /dev/null @@ -1,215 +0,0 @@ -use std::cmp::max; - -use cubecl_core as cubecl; -use cubecl_core::prelude::*; -use cubecl_matmul::components::tile::StridedTile; - -use crate::components::AttentionPrecision; -use crate::components::attention_types::*; -use crate::components::tile::dummy::dummy_register::DummyRegisterAttentionMatmulConfig; -use crate::components::tile::dummy::{AttentionMatmul, AttentionMatmulConfig as _}; - -/// Dummy AttentionMatmul implementation using simple arrays -/// Only lane 0 performs computations, other lanes idle -pub struct DummyRegisterAttentionMatmul; - -#[cube] -impl AttentionMatmul for DummyRegisterAttentionMatmul { - type Config = DummyRegisterAttentionMatmulConfig; - - type Query = Array>; - type KeyValue = Array>; - type Softmax = Array>; - type Accumulator = Array>; - - fn score_matmul( - lhs: &Self::Query, - rhs: &Self::KeyValue, - out: &mut Self::Softmax, - #[comptime] config: Self::Config, - ) { - if UNIT_POS_X == 0 { - let (m, n, k) = comptime! {let (m, n, k): (u32, u32, u32) = config.attention_tile_size().to_score_matmul_tile_size().into(); (m, n, k)}; - - for i in 0..m { - for j in 0..n { - let mut sum = SM::::from_int(0); - for ki in 0..k { - let lhs_val = lhs[i * k + ki]; - let rhs_val = rhs[ki * n + j]; - sum += SM::::cast_from(lhs_val) * SM::::cast_from(rhs_val); - } - out[i * n + j] += sum; - } - } - } - - sync_cube(); - } - - fn value_matmul( - lhs: &Self::Softmax, - rhs: &Self::KeyValue, - out: &mut Self::Accumulator, - #[comptime] config: Self::Config, - ) { - if UNIT_POS_X == 0 { - let (m, n, k) = comptime! {let (m, n, k): (u32, u32, u32) = config.attention_tile_size().to_value_matmul_tile_size().into(); (m, n, k)}; - - for i in 0..m { - for j in 0..n { - let mut sum = ACC::::from_int(0); - for ki in 0..k { - let lhs_val = lhs[i * k + ki]; - let rhs_val = rhs[ki * n + j]; - sum += ACC::::cast_from(lhs_val) * ACC::::cast_from(rhs_val); - } - out[i * n + j] += sum; - } - } - } - - sync_cube(); - } - - fn allocate_fill_query( - tile: &StridedTile, - #[comptime] config: Self::Config, - ) -> Self::Query { - let seq_q = config.attention_tile_size().seq_q; - let head_dim = config.attention_tile_size().head_dim; - - let mut query = Array::>::new(seq_q * head_dim); - - if UNIT_POS_X == 0 { - // Only lane 0 fills the data - for q in 0..seq_q { - for hd in 0..head_dim { - query[q * head_dim + hd] = QT::::cast_from(tile.get_line(q, hd)); - } - } - } - - sync_cube(); - query - } - - fn allocate_key_value(#[comptime] config: Self::Config) -> Self::KeyValue { - Array::>::new(comptime!(max( - config.attention_tile_size().key_size(), - config.attention_tile_size().value_size(), - ))) - } - - fn allocate_key(#[comptime] config: Self::Config) -> Self::KeyValue { - Array::>::new(config.attention_tile_size().key_size()) - } - - fn allocate_value(#[comptime] config: Self::Config) -> Self::KeyValue { - Array::>::new(config.attention_tile_size().value_size()) - } - - fn fill_key_value( - tile: &StridedTile, - rhs: &mut Self::KeyValue, - #[comptime] config: Self::Config, - ) { - if UNIT_POS_X == 0 { - let size = config.attention_tile_size().key_size(); - for i in 0..size { - rhs[i] = KVT::::cast_from(tile.as_unlined().0[i]); - } - } - - sync_cube(); - } - - fn allocate_softmax(#[comptime] config: Self::Config) -> Self::Softmax { - Array::>::new(config.attention_tile_size().softmax_size()) - } - - fn zero_softmax(softmax: &mut Self::Softmax, #[comptime] config: Self::Config) { - if UNIT_POS_X == 0 { - let len = config.attention_tile_size().softmax_size(); - for i in 0..len { - softmax[i] = SM::::from_int(0); - } - } - sync_cube(); - } - - fn allocate_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator { - Array::>::new(config.attention_tile_size().accumulator_size()) - } - - fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] config: Self::Config) { - if UNIT_POS_X == 0 { - let len = config.attention_tile_size().accumulator_size(); - for i in 0..len { - acc[i] = ACC::::from_int(0); - } - } - - sync_cube(); - } - - fn write_results( - out: &Self::Accumulator, - slice: &mut SliceMut>, - #[comptime] config: Self::Config, - ) { - if UNIT_POS_X == 0 { - let size = config.attention_tile_size().accumulator_size(); - for i in 0..size { - slice[i] = Line::cast_from(out[i]); - } - } - - sync_cube(); - } - - fn tmp_fill_accumulator( - tile: &StridedTile>, - acc: &mut Self::Accumulator, - #[comptime] config: Self::Config, - ) { - if UNIT_POS_X == 0 { - let size = config.attention_tile_size().accumulator_size(); - for i in 0..size { - acc[i] = tile.as_unlined().0[i]; - } - } - - sync_cube(); - } - - fn tmp_fill_prob( - tile: &StridedTile>, - prob: &mut Self::Softmax, - #[comptime] config: Self::Config, - ) { - if UNIT_POS_X == 0 { - let len = config.attention_tile_size().softmax_size(); - for i in 0..len { - prob[i] = tile.as_unlined().0[i]; - } - } - - sync_cube(); - } - - fn tmp_write_softmax( - softmax: &Self::Softmax, - slice: &mut SliceMut>>, - #[comptime] config: Self::Config, - ) { - if UNIT_POS_X == 0 { - let size = config.attention_tile_size().softmax_size(); - for i in 0..size { - slice[i] = Line::cast_from(softmax[i]); - } - } - - sync_cube(); - } -} diff --git a/crates/cubecl-attention/src/components/tile/dummy/fragment/accumulator.rs b/crates/cubecl-attention/src/components/tile/dummy/fragment/accumulator.rs deleted file mode 100644 index 57d6edb6f..000000000 --- a/crates/cubecl-attention/src/components/tile/dummy/fragment/accumulator.rs +++ /dev/null @@ -1,112 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; -use cubecl_matmul::components::MatrixLayout; -use cubecl_matmul::components::tile::StridedTile; - -use crate::components::AttentionIdent; -use crate::components::AttentionPrecision; -use crate::components::attention_types::*; -use crate::components::tile::AccumulatorTile; -use crate::components::tile::AccumulatorTileExpand; -use crate::components::tile::RowWise; -use crate::components::tile::ScaleMode; -use crate::components::tile::dummy::AttentionMatmul; -use crate::components::tile::dummy::AttentionMatmulConfig; - -#[derive(CubeType)] -pub struct DummyAccumulator> { - tmp_smem: SharedMemory>, - pub fragment: AM::Accumulator, - - row: u32, - col_start: u32, - - tmp_smem_start: u32, - tmp_smem_end: u32, - - #[cube(comptime)] - num_rows: u32, - #[cube(comptime)] - num_cols: u32, - #[cube(comptime)] - num_cols_per_unit: u32, - #[cube(comptime)] - config: AM::Config, -} - -#[cube] -impl> DummyAccumulator { - pub fn new(#[comptime] config: AM::Config) -> DummyAccumulator { - let mut fragment = AM::allocate_accumulator(config); - AM::zero_accumulator(&mut fragment, config); - - let num_rows = config.attention_tile_size().num_rows(AttentionIdent::Out); - let num_cols = config.attention_tile_size().num_cols(AttentionIdent::Out); - let num_units_per_row = config.num_units_per_row(AttentionIdent::Out); - let num_cols_per_unit = config.num_cols_per_unit(AttentionIdent::Out); - - let row = UNIT_POS_X / num_units_per_row; - let col_start = (UNIT_POS_X % num_units_per_row) * num_cols_per_unit; - - let acc_size = config.attention_tile_size().accumulator_size(); - let tmp_smem_start = UNIT_POS_Y * acc_size; - let tmp_smem_end = tmp_smem_start + acc_size; - - DummyAccumulator:: { - tmp_smem: SharedMemory::new(acc_size * config.num_planes()), - fragment, - row, - col_start, - tmp_smem_start, - tmp_smem_end, - num_rows, - num_cols, - num_cols_per_unit, - config, - } - } -} - -#[cube] -impl> AccumulatorTile> - for DummyAccumulator -{ - fn scale(&mut self, scale: &RowWise>, #[comptime] scale_op: ScaleMode) { - let mut slice = self - .tmp_smem - .slice_mut(self.tmp_smem_start, self.tmp_smem_end) - .try_cast_unchecked(); - - AM::write_results::>(&self.fragment, &mut slice, self.config); - - if self.row < self.num_rows { - #[unroll] - for i in 0..self.num_cols_per_unit { - let col = self.col_start + i; - - if col < self.num_cols { - match scale_op { - ScaleMode::Multiply => { - slice[self.row * self.num_cols + col] = slice - [self.row * self.num_cols + col] - * Line::cast_from(scale.index(0u32)) - } - ScaleMode::Divide => { - slice[self.row * self.num_cols + col] = slice - [self.row * self.num_cols + col] - / Line::cast_from(scale.index(0u32)) - } - } - } - } - } - - let tile = StridedTile::>::new_strided( - slice.to_slice(), - self.num_cols.runtime(), - MatrixLayout::RowMajor, - ); - - AM::tmp_fill_accumulator(&tile, &mut self.fragment, self.config); - } -} diff --git a/crates/cubecl-attention/src/components/tile/dummy/fragment/key_value.rs b/crates/cubecl-attention/src/components/tile/dummy/fragment/key_value.rs deleted file mode 100644 index 318ea8fe6..000000000 --- a/crates/cubecl-attention/src/components/tile/dummy/fragment/key_value.rs +++ /dev/null @@ -1,100 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -use crate::components::AttentionPrecision; -use crate::components::tile::dummy::AttentionMatmul; - -#[derive(CubeType)] -pub enum KeyValueFragment> { - Reuse(ReuseKV), - Key(Key), - Value(Value), -} - -#[cube] -impl> KeyValueFragment { - pub fn new_key_value(#[comptime] config: AM::Config) -> Self { - Self::new_Reuse(ReuseKV::new(config)) - } - - pub fn new_key(#[comptime] config: AM::Config) -> Self { - Self::new_Key(Key::new(config)) - } - - pub fn new_value(#[comptime] config: AM::Config) -> Self { - Self::new_Value(Value::new(config)) - } - - pub fn key(&self) -> &AM::KeyValue { - match self { - KeyValueFragment::Reuse(reuse_kv) => &reuse_kv.fragment, - KeyValueFragment::Key(key) => &key.fragment, - KeyValueFragment::Value(_) => panic!("Tried to access key on value-only fragment"), - } - } - - pub fn key_mut(&mut self) -> &mut AM::KeyValue { - match self { - KeyValueFragment::Reuse(reuse_kv) => &mut reuse_kv.fragment, - KeyValueFragment::Key(key) => &mut key.fragment, - KeyValueFragment::Value(_) => panic!("Tried to access key on value-only fragment"), - } - } - - pub fn value(&self) -> &AM::KeyValue { - match self { - KeyValueFragment::Reuse(reuse_kv) => &reuse_kv.fragment, - KeyValueFragment::Key(_) => panic!("Tried to access value on key-only fragment"), - KeyValueFragment::Value(value) => &value.fragment, - } - } - - pub fn value_mut(&mut self) -> &mut AM::KeyValue { - match self { - KeyValueFragment::Reuse(reuse_kv) => &mut reuse_kv.fragment, - KeyValueFragment::Key(_) => panic!("Tried to access value on key-only fragment"), - KeyValueFragment::Value(value) => &mut value.fragment, - } - } -} - -#[derive(CubeType)] -pub struct ReuseKV> { - pub fragment: AM::KeyValue, -} - -#[cube] -impl> ReuseKV { - pub fn new(#[comptime] config: AM::Config) -> Self { - let fragment = AM::allocate_key_value(config); - ReuseKV:: { fragment } - } -} - -#[derive(CubeType)] -pub struct Key> { - pub fragment: AM::KeyValue, -} - -#[cube] -impl> Key { - pub fn new(#[comptime] config: AM::Config) -> Self { - Key:: { - fragment: AM::allocate_key(config), - } - } -} - -#[derive(CubeType)] -pub struct Value> { - pub fragment: AM::KeyValue, -} - -#[cube] -impl> Value { - pub fn new(#[comptime] config: AM::Config) -> Self { - Value:: { - fragment: AM::allocate_value(config), - } - } -} diff --git a/crates/cubecl-attention/src/components/tile/dummy/fragment/mod.rs b/crates/cubecl-attention/src/components/tile/dummy/fragment/mod.rs deleted file mode 100644 index 3a0d94be5..000000000 --- a/crates/cubecl-attention/src/components/tile/dummy/fragment/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod accumulator; -mod key_value; -mod query; -mod softmax; - -pub use accumulator::*; -pub use key_value::*; -pub use query::*; -pub use softmax::*; diff --git a/crates/cubecl-attention/src/components/tile/dummy/fragment/query.rs b/crates/cubecl-attention/src/components/tile/dummy/fragment/query.rs deleted file mode 100644 index 3c888a242..000000000 --- a/crates/cubecl-attention/src/components/tile/dummy/fragment/query.rs +++ /dev/null @@ -1,23 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; -use cubecl_matmul::components::tile::StridedTile; - -use crate::components::AttentionPrecision; -use crate::components::tile::dummy::AttentionMatmul; - -#[derive(CubeType)] -pub struct QueryFragment> { - pub fragment: AM::Query, -} - -#[cube] -impl> QueryFragment { - pub fn new( - tile: &StridedTile, - #[comptime] config: AM::Config, - ) -> QueryFragment { - QueryFragment:: { - fragment: AM::allocate_fill_query(tile, config), - } - } -} diff --git a/crates/cubecl-attention/src/components/tile/dummy/fragment/softmax.rs b/crates/cubecl-attention/src/components/tile/dummy/fragment/softmax.rs deleted file mode 100644 index 08d813d93..000000000 --- a/crates/cubecl-attention/src/components/tile/dummy/fragment/softmax.rs +++ /dev/null @@ -1,184 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; -use cubecl_matmul::components::MatrixLayout; -use cubecl_matmul::components::tile::StridedTile; - -use crate::components::AttentionPrecision; -use crate::components::attention_types::*; -use crate::components::tile::RowWise; -use crate::components::{ - AttentionIdent, TileMask, - tile::{ - RunningState, SoftmaxTile, SoftmaxTileExpand, - dummy::{AttentionMatmul, AttentionMatmulConfig}, - }, -}; - -#[derive(CubeType)] -pub struct DummySoftmax> { - tmp_smem: SharedMemory>, - pub fragment: AM::Softmax, - - row: u32, - col_start: u32, - - tmp_smem_start: u32, - tmp_smem_end: u32, - - #[cube(comptime)] - num_rows: u32, - #[cube(comptime)] - num_cols: u32, - #[cube(comptime)] - num_cols_per_unit: u32, - #[cube(comptime)] - config: AM::Config, -} - -#[cube] -impl> DummySoftmax { - pub fn new(#[comptime] config: AM::Config) -> Self { - let mut fragment = AM::allocate_softmax(config); - AM::zero_softmax(&mut fragment, config); - - let num_rows = config - .attention_tile_size() - .num_rows(AttentionIdent::Softmax); - let num_cols = config - .attention_tile_size() - .num_cols(AttentionIdent::Softmax); - let num_units_per_row = config.num_units_per_row(AttentionIdent::Softmax); - let num_cols_per_unit = config.num_cols_per_unit(AttentionIdent::Softmax); - - let row = UNIT_POS_X / num_units_per_row; - let col_start = (UNIT_POS_X % num_units_per_row) * num_cols_per_unit; - - let score_size = config.attention_tile_size().softmax_size(); - let tmp_smem_start = UNIT_POS_Y * score_size; - let tmp_smem_end = tmp_smem_start + score_size; - - DummySoftmax:: { - tmp_smem: SharedMemory::>::new(score_size * config.num_planes()), - fragment, - row, - col_start, - tmp_smem_start, - tmp_smem_end, - num_rows, - num_cols, - num_cols_per_unit, - config, - } - } -} - -#[cube] -impl> SoftmaxTile for DummySoftmax { - fn init_state() -> RunningState> { - RunningState::init(1u32) - } - - fn zero(&mut self) { - AM::zero_softmax(&mut self.fragment, self.config); - } - - fn scale_and_mask(&mut self, scale: SM, mask: TileMask) { - let mut slice = self - .tmp_smem - .slice_mut(self.tmp_smem_start, self.tmp_smem_end) - .try_cast_unchecked(); - - AM::tmp_write_softmax(&self.fragment, &mut slice, self.config); - - if self.row < self.num_rows { - #[unroll] - for i in 0..self.num_cols_per_unit { - let col = self.col_start + i; - - if col < self.num_cols { - let index = self.row * self.num_cols + col; - - slice[index] = - slice[index] * Line::cast_from(scale) + mask.apply::>(self.row, col); - } - } - } - - sync_cube(); - - let tile = StridedTile::>::new_strided( - slice.to_slice().try_cast_unchecked(), - self.num_cols.runtime(), - MatrixLayout::RowMajor, - ); - AM::tmp_fill_prob(&tile, &mut self.fragment, self.config); - sync_cube(); - } - - fn row_max(&self, base: RowWise>) -> RowWise> { - let slice = self.tmp_smem.slice(self.tmp_smem_start, self.tmp_smem_end); - - let row_offset = self.row * self.num_cols; - let mut row_max = base.index(0u32); - - for i in 0..self.num_cols { - let ts = slice[row_offset + i]; - if ts > row_max { - row_max = ts; - } - } - - RowWise::>::single(row_max) - } - - fn to_prob( - &mut self, - state: &mut RunningState>, - new_m: &RowWise>, - ) -> RowWise> { - let new_m_val = new_m.index(0u32); - - let mut slice = self - .tmp_smem - .slice_mut(self.tmp_smem_start, self.tmp_smem_end) - .try_cast_unchecked(); - - if self.row < self.num_rows { - #[unroll] - for i in 0..self.num_cols_per_unit { - let col = self.col_start + i; - - if col < self.num_cols { - let index = self.row * self.num_cols + col; - slice[index] = Exp::exp(slice[index] - Line::cast_from(new_m_val)); - } - } - } - - sync_cube(); - - let tile = StridedTile::>::new_strided( - slice.to_slice(), - self.num_cols.runtime(), - MatrixLayout::RowMajor, - ); - AM::tmp_fill_prob(&tile, &mut self.fragment, self.config); - - sync_cube(); - let slice = self.tmp_smem.slice(self.tmp_smem_start, self.tmp_smem_end); - - let row_offset = self.row * self.num_cols; - - let mut row_sum = SM::::from_int(0); - for i in 0..self.num_cols { - row_sum += slice[row_offset + i]; - } - - let exp_m_diff = Exp::exp(state.m.index(0u32) - new_m_val); - let new_l = exp_m_diff * state.l.index(0u32) + row_sum; - - state.update(new_m.copy(), RowWise::single(new_l)); - - RowWise::>::single(ACC::::cast_from(exp_m_diff)) - } -} diff --git a/crates/cubecl-attention/src/components/tile/dummy/mod.rs b/crates/cubecl-attention/src/components/tile/dummy/mod.rs deleted file mode 100644 index 2e8f3f997..000000000 --- a/crates/cubecl-attention/src/components/tile/dummy/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod attention; -mod attention_matmul; -mod fragment; -mod setup; - -pub use attention::*; -pub use attention_matmul::*; -pub use fragment::*; -pub use setup::DummyTileAttentionFamily; diff --git a/crates/cubecl-attention/src/components/tile/dummy/setup.rs b/crates/cubecl-attention/src/components/tile/dummy/setup.rs deleted file mode 100644 index 4e50b1197..000000000 --- a/crates/cubecl-attention/src/components/tile/dummy/setup.rs +++ /dev/null @@ -1,36 +0,0 @@ -use std::marker::PhantomData; - -use cubecl_core::client::ComputeClient; -use cubecl_matmul::components::ComputeResources; - -use crate::components::{ - AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, - AttentionSetupError, InvalidConfigError, - tile::{ - TileAttentionFamily, - dummy::{AttentionMatmulFamily, DummyTileAttention}, - }, -}; - -pub struct DummyTileAttentionFamily { - _phantom: PhantomData, -} - -impl TileAttentionFamily for DummyTileAttentionFamily { - type Attention = DummyTileAttention>; - - type Config = FM::Config; - - fn setup( - client: &ComputeClient, - problem: &AttentionProblem, - selection: &AttentionSelection, - line_sizes: &AttentionLineSizes, - ) -> Result { - FM::setup::(client, problem, selection, line_sizes) - } - - fn computation_resources() -> Result { - Ok(ComputeResources::Planes(1)) - } -} diff --git a/crates/cubecl-attention/src/components/tile/mod.rs b/crates/cubecl-attention/src/components/tile/mod.rs index 21506cd61..4cd9d22d2 100644 --- a/crates/cubecl-attention/src/components/tile/mod.rs +++ b/crates/cubecl-attention/src/components/tile/mod.rs @@ -1,9 +1,7 @@ -pub mod dummy; - mod base; -mod rowwise; +mod row; mod tiles; pub use base::*; -pub use rowwise::*; +pub use row::*; pub use tiles::*; diff --git a/crates/cubecl-attention/src/components/tile/row/mod.rs b/crates/cubecl-attention/src/components/tile/row/mod.rs new file mode 100644 index 000000000..ed4173afc --- /dev/null +++ b/crates/cubecl-attention/src/components/tile/row/mod.rs @@ -0,0 +1,7 @@ +mod reduce; +mod rowwise; +mod state; + +pub use reduce::*; +pub use rowwise::*; +pub use state::*; diff --git a/crates/cubecl-attention/src/components/tile/row/reduce/base.rs b/crates/cubecl-attention/src/components/tile/row/reduce/base.rs new file mode 100644 index 000000000..70124edad --- /dev/null +++ b/crates/cubecl-attention/src/components/tile/row/reduce/base.rs @@ -0,0 +1,61 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::components::fragment::FragmentAttentionConfig; +use crate::components::fragment::FragmentOps; +use crate::components::tile::RowMax; +use crate::components::tile::RowSum; +use crate::components::tile::RowWise; + +#[cube] +/// Computes the sum of rows on a fragment, using the Reducer's strategy +pub fn row_sum, R: Reducer, TC: FragmentAttentionConfig>( + vals: &mut RowWise, + data: &F, + #[comptime] config: TC, +) { + vals.fill(E::from_int(0)); + R::reduce::(vals, data, config) +} + +#[cube] +/// Computes the max of rows on a fragment, using the Reducer's strategy +/// Starts max at base +pub fn row_max, R: Reducer, TC: FragmentAttentionConfig>( + vals: &mut RowWise, + base: &RowWise, + data: &F, + #[comptime] config: TC, +) { + vals.copy_from(base); + R::reduce::(vals, data, config) +} + +#[cube] +/// Strategy for reducing across units participating in the same row +pub trait Reducer: CubeType { + /// Reduction algorithm, applied inplace in vals + fn reduce, RO: ReduceOp, TC: FragmentAttentionConfig>( + vals: &mut RowWise, + data: &F, + #[comptime] config: TC, + ); +} + +#[cube] +/// A reduction operation +pub trait ReduceOp { + /// Applies the reduction on the elements of the same row held by the unit + fn reduce_local>(data: &F) -> RowWise; + + /// Applies the reduction on the elements of the same row held by the unit, + /// and to the accumulator, and store in the accumulator + fn reduce_local_accumulate>(data: &F, acc: &mut RowWise); + + /// The basic operation on two single values + fn reduce_step_scalar(a: E, b: E) -> E; + + /// Accumulates elem into acc. + /// If mask is activated, the element gets masked prior to being accumulated + fn reduce_step_rowwise(acc: &mut RowWise, elem: &RowWise, mask: bool); +} diff --git a/crates/cubecl-attention/src/components/tile/row/reduce/broadcast_reducer.rs b/crates/cubecl-attention/src/components/tile/row/reduce/broadcast_reducer.rs new file mode 100644 index 000000000..8cdf9a942 --- /dev/null +++ b/crates/cubecl-attention/src/components/tile/row/reduce/broadcast_reducer.rs @@ -0,0 +1,91 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::components::fragment::FragmentAttentionConfig; +use crate::components::fragment::FragmentLayout; +use crate::components::fragment::{FragmentLayoutExpand, FragmentOps, FragmentOpsExpand}; +use crate::components::tile::ReduceOp; +use crate::components::tile::Reducer; +use crate::components::tile::{RowVal, RowWise}; + +#[derive(CubeType)] +/// Applies reduction on rows, masking planes that do not participate in the row +/// +/// TODO: uses shared memory to plane_broadcast, should be replaced with +/// a plane primitive +pub struct BroadcastReducer {} + +#[cube] +impl Reducer for BroadcastReducer { + fn reduce, RO: ReduceOp, FC: FragmentAttentionConfig>( + vals: &mut RowWise, + data: &F, + #[comptime] config: FC, + ) { + let num_units_per_row = data.layout().num_units_per_row(); + let num_shares_within_plane = comptime!((num_units_per_row as f32).log2().ceil() as u32); + + let unit_pos = UNIT_POS_X; + let unit_pos_in_row = unit_pos % num_units_per_row; + + let mut fpb = FakePlaneBroadcast::::new(config.plane_dim(), config.num_planes()); + + RO::reduce_local_accumulate::(data, vals); + + for i in 0..num_shares_within_plane { + let offset = num_units_per_row >> (i + 1); + let source_unit = unit_pos + offset; + + let value_from_source = fpb.plane_broadcast(vals, source_unit); + + // Mask if outside the row + let mask = unit_pos_in_row + offset >= num_units_per_row; + RO::reduce_step_rowwise(vals, &value_from_source, mask); + } + + // Broadcast back to subgroup + let result = &fpb.plane_broadcast(vals, unit_pos - unit_pos_in_row); + vals.copy_from(result); + } +} + +#[derive(CubeType)] +struct FakePlaneBroadcast { + slice: SliceMut, +} + +#[cube] +impl FakePlaneBroadcast { + pub fn new(#[comptime] plane_dim: u32, #[comptime] num_planes: u32) -> Self { + let mut smem = SharedMemory::::new(plane_dim * num_planes); + let start = UNIT_POS_Y * plane_dim; + let end = start + plane_dim; + FakePlaneBroadcast:: { + slice: smem.slice_mut(start, end), + } + } + + pub fn plane_broadcast(&mut self, val: &RowWise, source_unit: u32) -> RowWise { + let mut result = Sequence::new(); + + let mut row = comptime![0]; + + #[unroll] + for _ in 0..val.num_rows { + self.slice[UNIT_POS_X] = val.index(row); + sync_cube(); + + result.push(RowVal:: { + val: self.slice[source_unit], + }); + sync_cube(); + + comptime![row += 1]; + } + + RowWise:: { + num_rows: val.num_rows, + vals: result, + } + } +} diff --git a/crates/cubecl-attention/src/components/tile/row/reduce/mod.rs b/crates/cubecl-attention/src/components/tile/row/reduce/mod.rs new file mode 100644 index 000000000..eed5c238d --- /dev/null +++ b/crates/cubecl-attention/src/components/tile/row/reduce/mod.rs @@ -0,0 +1,11 @@ +mod base; +mod broadcast_reducer; +mod naive_reducer; +mod reduce_op; +mod unit_reducer; + +pub use base::*; +pub use broadcast_reducer::*; +pub use naive_reducer::*; +pub use reduce_op::*; +pub use unit_reducer::*; diff --git a/crates/cubecl-attention/src/components/tile/row/reduce/naive_reducer.rs b/crates/cubecl-attention/src/components/tile/row/reduce/naive_reducer.rs new file mode 100644 index 000000000..a398e5627 --- /dev/null +++ b/crates/cubecl-attention/src/components/tile/row/reduce/naive_reducer.rs @@ -0,0 +1,60 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::components::fragment::FragmentAttentionConfig; +use crate::components::fragment::FragmentLayout; +use crate::components::fragment::{FragmentLayoutExpand, FragmentOps, FragmentOpsExpand}; +use crate::components::tile::ReduceOp; +use crate::components::tile::Reducer; +use crate::components::tile::RowWise; + +#[derive(CubeType)] +/// Naive row reducer using shared memory +pub struct NaiveReducer {} + +#[cube] +impl Reducer for NaiveReducer { + fn reduce, RO: ReduceOp, FC: FragmentAttentionConfig>( + vals: &mut RowWise, + data: &F, + #[comptime] config: FC, + ) { + let num_vals_in_plane = config.num_rows_per_unit() * config.plane_dim(); + let mut smem = SharedMemory::::new(num_vals_in_plane * config.num_planes()); + + let local_vals = RO::reduce_local::(data); + + let plane_offset = UNIT_POS_Y * num_vals_in_plane; + let unit_offset = UNIT_POS_X; + + #[unroll] + for r in 0..config.num_rows_per_unit() { + let row_offset = r * config.plane_dim(); + let offset = plane_offset + row_offset + unit_offset; + + smem[offset] = local_vals.index(r); + } + + sync_cube(); + + let num_units_per_row = data.layout().num_units_per_row(); + + #[unroll] + for r in 0..config.num_rows_per_unit() { + let mut val = vals.index(r); + + let row_offset = r * config.plane_dim(); + + for c in 0..num_units_per_row { + let unit_offset = (UNIT_POS_X / num_units_per_row) * num_units_per_row; + let offset = plane_offset + row_offset + unit_offset; + + val = RO::reduce_step_scalar(val, smem[offset + c]); + } + + vals.replace_at(r, val); + } + + sync_cube(); + } +} diff --git a/crates/cubecl-attention/src/components/tile/row/reduce/reduce_op.rs b/crates/cubecl-attention/src/components/tile/row/reduce/reduce_op.rs new file mode 100644 index 000000000..bc9ca8e43 --- /dev/null +++ b/crates/cubecl-attention/src/components/tile/row/reduce/reduce_op.rs @@ -0,0 +1,58 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::components::fragment::{FragmentOps, FragmentOpsExpand}; +use crate::components::tile::ReduceOp; +use crate::components::tile::RowWise; + +#[derive(CubeType)] +/// Max reduction operation +pub struct RowMax {} + +#[derive(CubeType)] +/// Sum reduction operation +pub struct RowSum {} + +#[cube] +impl ReduceOp for RowMax { + fn reduce_local>(data: &F) -> RowWise { + data.rowwise_max() + } + + fn reduce_local_accumulate>(data: &F, acc: &mut RowWise) { + acc.max_inplace(&Self::reduce_local::(data)) + } + + fn reduce_step_rowwise(acc: &mut RowWise, elem: &RowWise, mask: bool) { + let mut masked = RowWise::new_filled(elem.num_rows, E::cast_from(mask) * E::min_value()); + masked.add_inplace(elem); + + acc.max_inplace(&masked) + } + + fn reduce_step_scalar(a: E, b: E) -> E { + Max::max(a, b) + } +} + +#[cube] +impl ReduceOp for RowSum { + fn reduce_local>(data: &F) -> RowWise { + data.rowwise_sum() + } + + fn reduce_local_accumulate>(data: &F, acc: &mut RowWise) { + acc.add_inplace(&Self::reduce_local::(data)) + } + + fn reduce_step_rowwise(acc: &mut RowWise, elem: &RowWise, mask: bool) { + let mut masked = RowWise::new_filled(elem.num_rows, E::cast_from(!mask)); + masked.mul_inplace(elem); + + acc.add_inplace(&masked) + } + + fn reduce_step_scalar(a: E, b: E) -> E { + a + b + } +} diff --git a/crates/cubecl-attention/src/components/tile/row/reduce/unit_reducer.rs b/crates/cubecl-attention/src/components/tile/row/reduce/unit_reducer.rs new file mode 100644 index 000000000..bee8eaf5a --- /dev/null +++ b/crates/cubecl-attention/src/components/tile/row/reduce/unit_reducer.rs @@ -0,0 +1,23 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::components::fragment::FragmentAttentionConfig; +use crate::components::fragment::FragmentOps; +use crate::components::tile::ReduceOp; +use crate::components::tile::Reducer; +use crate::components::tile::RowWise; + +#[derive(CubeType)] +/// Trivial reducer for one unit +pub struct UnitReducer {} + +#[cube] +impl Reducer for UnitReducer { + fn reduce, RO: ReduceOp, FC: FragmentAttentionConfig>( + vals: &mut RowWise, + data: &F, + #[comptime] _config: FC, + ) { + RO::reduce_local_accumulate::(data, vals); + } +} diff --git a/crates/cubecl-attention/src/components/tile/row/rowwise.rs b/crates/cubecl-attention/src/components/tile/row/rowwise.rs new file mode 100644 index 000000000..e3f70139f --- /dev/null +++ b/crates/cubecl-attention/src/components/tile/row/rowwise.rs @@ -0,0 +1,202 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +#[derive(CubeType)] +/// Contains one value per row of a fragment for which the unit contributes +/// +/// Example: For a 8x8 tile shared by a plane of 32 units, +/// every unit holds 8 values in the tile. +/// +/// In the following layout, values are held contiguously, and num_rows=1 because +/// every two occurences of the same plane id are in the same row +/// 0, 0, 1, 1, 2, 2, 3, 3, +/// 4, 4, 5, 5, 6, 6, 7, 7, +/// 8, 8, 9, 9, 10, 10, 11, 11, +/// 12, 12, 13, 13, 14, 14, 15, 15, +/// 16, 16, 17, 17, 18, 18, 19, 19, +/// 20, 20, 21, 21, 22, 22, 23, 23, +/// 24, 24, 25, 25, 26, 26, 27, 27, +/// 28, 28, 29, 29, 30, 30, 31, 31, +/// +/// In the following layout, values are held disjointly, and num_rows=2 because +/// the two occurences of the same plane id are not in the same row +/// 0, 1, 2, 3, 4, 5, 6, 7, +/// 8, 9, 10, 11, 12, 13, 14, 15, +/// 16, 17, 18, 19, 20, 21, 22, 23, +/// 24, 25, 26, 27, 28, 29, 30, 31, +/// 0, 1, 2, 3, 4, 5, 6, 7, +/// 8, 9, 10, 11, 12, 13, 14, 15, +/// 16, 17, 18, 19, 20, 21, 22, 23, +/// 24, 25, 26, 27, 28, 29, 30, 31, +pub struct RowWise { + #[cube(comptime)] + pub num_rows: u32, + pub vals: Sequence>, +} + +#[derive(CubeType)] +/// Wrapper over a value to enable mutating it +pub struct RowVal { + pub val: E, +} + +#[cube] +impl RowWise { + /// Create a RowWise with the provided value at every row + pub fn new_filled(#[comptime] num_rows: u32, val: E) -> RowWise { + let mut vals = Sequence::new(); + #[unroll] + for _ in 0..num_rows { + vals.push(RowVal:: { val }); + } + RowWise:: { num_rows, vals } + } + + /// Fill the existing RowWise with the provided value at every row + pub fn fill(&mut self, val: E) { + #[unroll] + for i in 0..self.num_rows { + let row_val = self.vals.index_mut(i); + row_val.val = val; + } + } + + /// Create a RowWise with -infinity at every row + pub fn new_min_value(#[comptime] num_rows: u32) -> RowWise { + Self::new_filled(num_rows, E::min_value()) + } + + /// Create a RowWise with zero at every row + pub fn new_zero(#[comptime] num_rows: u32) -> RowWise { + Self::new_filled(num_rows, E::from_int(0)) + } + + /// Fill the current RowWise with the value of other at each row + pub fn copy_from(&mut self, other: &RowWise) { + #[unroll] + for i in 0..self.num_rows { + let row_val = self.vals.index_mut(i); + row_val.val = other.index(i); + } + } + + /// Return the value at row i + pub fn index(&self, i: u32) -> E { + self.vals.index(i).val + } + + /// For each row, add the the current and other, and outputs a new RowWise + pub fn add(&self, other: &RowWise) -> RowWise { + let mut vals = Sequence::new(); + + #[unroll] + for i in 0..self.num_rows { + let val = self.index(i) + other.index(i); + vals.push(RowVal:: { val }); + } + + RowWise:: { + num_rows: self.num_rows, + vals, + } + } + + /// For each row, add the other value to the current RowWise + pub fn add_inplace(&mut self, other: &RowWise) { + #[unroll] + for i in 0..self.num_rows { + let row_val = self.vals.index_mut(i); + row_val.val += other.index(i); + } + } + + /// For each row, multiplies the the current and other, and outputs a new RowWise + pub fn mul(&self, other: &RowWise) -> RowWise { + let mut vals = Sequence::new(); + + #[unroll] + for i in 0..self.num_rows { + let val = self.index(i) * other.index(i); + vals.push(RowVal:: { val }); + } + + RowWise:: { + num_rows: self.num_rows, + vals, + } + } + + /// For each row, multiplies the other value to the current RowWise + pub fn mul_inplace(&mut self, other: &RowWise) { + #[unroll] + for i in 0..self.num_rows { + let row_val = self.vals.index_mut(i); + row_val.val *= other.index(i); + } + } + + /// For each row, maxes the other value to the current RowWise + pub fn max_inplace(&mut self, other: &RowWise) { + #[unroll] + for i in 0..self.num_rows { + let row_val = self.vals.index_mut(i); + row_val.val = Max::max(row_val.val, other.index(i)); + } + } + + /// Changes the value at index i + pub fn replace_at(&mut self, #[comptime] i: u32, new_val: E) { + let row_val = self.vals.index_mut(i); + row_val.val = new_val; + } + + /// Return a copy of self, cast into E2 + pub fn cast_from(row_wise: &RowWise) -> RowWise { + let mut vals = Sequence::new(); + + #[unroll] + for i in 0..row_wise.num_rows { + let val = E2::cast_from(row_wise.index(i)); + vals.push(RowVal:: { val }); + } + + RowWise:: { + num_rows: row_wise.num_rows, + vals, + } + } +} + +#[cube] +impl RowWise { + /// Computes e^(self.val - other.val) for every row, and outputs a new RowWise + pub fn exp_diff(&self, other: &RowWise) -> RowWise { + let mut vals = Sequence::new(); + let mut i = comptime![0u32]; + + #[unroll] + for _ in 0..self.num_rows { + let val = Exp::exp(self.index(i) - other.index(i)); + vals.push(RowVal:: { val }); + + comptime![i += 1]; + } + + RowWise:: { + num_rows: self.num_rows, + vals, + } + } + + /// Changes the value v at each row for 1/v + pub fn recip_inplace(&mut self) { + let mut i = comptime![0u32]; + #[unroll] + for _ in 0..self.num_rows { + let row_val = self.vals.index_mut(i); + row_val.val = Recip::recip(row_val.val); + + comptime![i += 1]; + } + } +} diff --git a/crates/cubecl-attention/src/components/tile/row/state.rs b/crates/cubecl-attention/src/components/tile/row/state.rs new file mode 100644 index 000000000..12168321b --- /dev/null +++ b/crates/cubecl-attention/src/components/tile/row/state.rs @@ -0,0 +1,38 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::components::tile::RowWise; + +#[derive(CubeType)] +/// Flash Attention's running state, per row +pub struct RunningState { + m: RowWise, + l: RowWise, +} + +#[cube] +impl RunningState { + /// Init the state with neutral values + pub fn init(#[comptime] num_rows: u32) -> RunningState { + RunningState:: { + m: RowWise::new_min_value(num_rows), + l: RowWise::new_zero(num_rows), + } + } + + /// Update the state for next iteration + pub fn update(&mut self, new_m: &RowWise, new_l: &RowWise) { + RowWise::copy_from(&mut self.m, new_m); + RowWise::copy_from(&mut self.l, new_l); + } + + /// Get the running m + pub fn m(&self) -> &RowWise { + &self.m + } + + /// Get the running l + pub fn l(&self) -> &RowWise { + &self.l + } +} diff --git a/crates/cubecl-attention/src/components/tile/rowwise.rs b/crates/cubecl-attention/src/components/tile/rowwise.rs deleted file mode 100644 index 2c3b582e7..000000000 --- a/crates/cubecl-attention/src/components/tile/rowwise.rs +++ /dev/null @@ -1,116 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[derive(CubeType)] -pub struct RowWise { - #[cube(comptime)] - num_rows: u32, - vals: Sequence>, -} - -#[derive(CubeType, Copy, Clone)] -pub struct RowVal { - val: E, -} - -#[cube] -impl RowVal { - pub fn new(val: E) -> RowVal { - RowVal:: { val } - } - - pub fn cast(&self) -> RowVal { - RowVal:: { - val: E2::cast_from(self.val), - } - } -} - -#[cube] -impl RowWise { - pub fn new(#[comptime] num_rows: u32, vals: Sequence>) -> RowWise { - RowWise:: { num_rows, vals } - } - - pub fn single(val: E) -> RowWise { - let mut vals = Sequence::new(); - vals.push(RowVal:: { val }); - RowWise::::new(1u32, vals) - } - - pub fn index(&self, #[comptime] i: u32) -> E { - self.vals.index(i).val - } - - pub fn copy(&self) -> RowWise { - let mut vals = Sequence::new(); - #[unroll] - for i in 0..self.num_rows { - vals.push(*self.vals.index(i)); - } - RowWise::::new(self.num_rows, vals) - } - - pub fn copy_into(&self, other: &mut RowWise) { - #[unroll] - for i in 0..self.num_rows { - let place = other.vals.index_mut(i); - let value = self.vals.index(i); - place.val = value.val; - } - } - - pub fn cast(&self) -> RowWise { - let mut vals = Sequence::new(); - #[unroll] - for i in 0..self.num_rows { - vals.push(self.vals.index(i).cast::()); - } - RowWise::::new(self.num_rows, vals) - } -} - -#[derive(CubeType)] -pub struct RunningState { - pub m: RowWise, - pub l: RowWise, -} - -#[cube] -impl RunningState { - pub fn init(#[comptime] num_rows: u32) -> RunningState { - let mut m = Sequence::new(); - let mut l = Sequence::new(); - #[unroll] - for _ in 0..num_rows { - m.push(RowVal::new(E::from_int(-99999999999))); - l.push(RowVal::new(E::from_int(0))); - } - - RunningState:: { - m: RowWise::::new(num_rows, m), - l: RowWise::::new(num_rows, l), - } - } - - pub fn update(&mut self, new_m: RowWise, new_l: RowWise) { - new_m.copy_into(&mut self.m); - new_l.copy_into(&mut self.l); - } -} - -#[derive(CubeType)] -pub struct RowStats { - pub score_max: RowWise, - pub prob_sum: RowWise, -} - -#[cube] -impl RowStats { - pub fn new(score_max: RowWise, prob_sum: RowWise) -> RowStats { - RowStats:: { - score_max, - prob_sum, - } - } -} diff --git a/crates/cubecl-attention/src/components/tile/tiles/accumulator.rs b/crates/cubecl-attention/src/components/tile/tiles/accumulator.rs index f886f8bf6..1fe22cef4 100644 --- a/crates/cubecl-attention/src/components/tile/tiles/accumulator.rs +++ b/crates/cubecl-attention/src/components/tile/tiles/accumulator.rs @@ -1,13 +1,40 @@ -use crate::components::tile::RowWise; use cubecl_core as cubecl; use cubecl_core::prelude::*; +use crate::components::AttentionPrecision; +use crate::components::attention_types::*; +use crate::components::fragment::FragmentAttention; +use crate::components::fragment::{FragmentOps, FragmentOpsExpand}; +use crate::components::tile::RowWise; + +#[derive(CubeType)] +/// Accumulator tile for Tile Attention +pub struct AccumulatorTile> { + pub fragment: FA::Accumulator, +} + #[cube] -pub trait AccumulatorTile: CubeType { - fn scale(&mut self, scale: &RowWise, #[comptime] scale_op: ScaleMode); +impl> AccumulatorTile { + pub fn new(#[comptime] config: FA::Config) -> AccumulatorTile { + let mut fragment = FA::allocate_accumulator(config); + FA::zero_accumulator(&mut fragment); + + AccumulatorTile:: { fragment } + } } -pub enum ScaleMode { - Multiply, - Divide, +#[cube] +impl> AccumulatorTile { + /// Multiplies each row by a scale + pub fn scale_mul(&mut self, scale: &RowWise>) { + self.fragment + .rowwise_scale(&RowWise::>::cast_from(scale)); + } + + /// Divides each row by a scale + pub fn scale_div(&mut self, scale: &RowWise>) { + let mut scale = RowWise::>::cast_from(scale); + scale.recip_inplace(); + self.fragment.rowwise_scale(&scale); + } } diff --git a/crates/cubecl-attention/src/components/tile/tiles/key_value.rs b/crates/cubecl-attention/src/components/tile/tiles/key_value.rs new file mode 100644 index 000000000..a35c2ec4e --- /dev/null +++ b/crates/cubecl-attention/src/components/tile/tiles/key_value.rs @@ -0,0 +1,108 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::components::AttentionPrecision; +use crate::components::fragment::FragmentAttention; + +#[derive(CubeType)] +/// Key and Value inputs to the Tile Attention +/// +/// Key and Value share the same trait because they may +/// be the same reused underlying fragment +pub enum KeyValueTile> { + Reuse(ReuseKV), + Key(Key), + Value(Value), +} + +#[cube] +impl> KeyValueTile { + pub fn new_key_value(#[comptime] config: FA::Config) -> Self { + Self::new_Reuse(ReuseKV::new(config)) + } + + pub fn new_key(#[comptime] config: FA::Config) -> Self { + Self::new_Key(Key::new(config)) + } + + pub fn new_value(#[comptime] config: FA::Config) -> Self { + Self::new_Value(Value::new(config)) + } + + /// Get the underlying key as readable + pub fn key(&self) -> &FA::KeyValue { + match self { + KeyValueTile::Reuse(reuse_kv) => &reuse_kv.fragment, + KeyValueTile::Key(key) => &key.fragment, + KeyValueTile::Value(_) => panic!("Tried to access key on value-only fragment"), + } + } + + /// Get the underlying key as writable + pub fn key_mut(&mut self) -> &mut FA::KeyValue { + match self { + KeyValueTile::Reuse(reuse_kv) => &mut reuse_kv.fragment, + KeyValueTile::Key(key) => &mut key.fragment, + KeyValueTile::Value(_) => panic!("Tried to access key on value-only fragment"), + } + } + + /// Get the underlying value as readable + pub fn value(&self) -> &FA::KeyValue { + match self { + KeyValueTile::Reuse(reuse_kv) => &reuse_kv.fragment, + KeyValueTile::Key(_) => panic!("Tried to access value on key-only fragment"), + KeyValueTile::Value(value) => &value.fragment, + } + } + + /// Get the underlying value as writable + pub fn value_mut(&mut self) -> &mut FA::KeyValue { + match self { + KeyValueTile::Reuse(reuse_kv) => &mut reuse_kv.fragment, + KeyValueTile::Key(_) => panic!("Tried to access value on key-only fragment"), + KeyValueTile::Value(value) => &mut value.fragment, + } + } +} + +#[derive(CubeType)] +pub struct ReuseKV> { + pub fragment: FA::KeyValue, +} + +#[cube] +impl> ReuseKV { + pub fn new(#[comptime] config: FA::Config) -> Self { + let fragment = FA::allocate_key_value(config); + ReuseKV:: { fragment } + } +} + +#[derive(CubeType)] +pub struct Key> { + pub fragment: FA::KeyValue, +} + +#[cube] +impl> Key { + pub fn new(#[comptime] config: FA::Config) -> Self { + Key:: { + fragment: FA::allocate_key(config), + } + } +} + +#[derive(CubeType)] +pub struct Value> { + pub fragment: FA::KeyValue, +} + +#[cube] +impl> Value { + pub fn new(#[comptime] config: FA::Config) -> Self { + Value:: { + fragment: FA::allocate_value(config), + } + } +} diff --git a/crates/cubecl-attention/src/components/tile/tiles/mask.rs b/crates/cubecl-attention/src/components/tile/tiles/mask.rs new file mode 100644 index 000000000..216e88b75 --- /dev/null +++ b/crates/cubecl-attention/src/components/tile/tiles/mask.rs @@ -0,0 +1,159 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; +use cubecl_std::tensor::layout::Coords2d; +use cubecl_std::{CubeOption, CubeOptionExpand}; + +use crate::components::AttentionPrecision; +use crate::components::attention_types::MSK; +use crate::components::fragment::{FragmentAttention, FragmentAttentionConfig}; +use crate::components::fragment::{ + FragmentLayout, FragmentLayoutExpand, FragmentMask, FragmentMaskExpand, +}; +use cubecl_matmul::components::tile::StridedTile; + +use cubecl_std::tensor::layout::Coordinates; + +#[derive(CubeType)] +/// Mask tile for Tile Attention +/// It is an additive mask, which means the result of apply should be added, not multiplied +pub enum MaskTile> { + Materialized(MaterializedTileMask), + Logical(LogicalTileMask), +} + +#[cube] +impl> MaskTile { + pub fn new( + out_of_bounds: CubeOption, + #[comptime] partition_pos: Coords2d, + #[comptime] config: FA::Config, + ) -> MaskTile { + let logical_mask = LogicalTileMask:: { + logical_iter_origin: LogicalIterOrigin::init(), + partition_pos, + causal: config.causal_mask(), + out_of_bounds, + fragment_layout: FA::softmax_layout(config), + }; + + if config.materialized_mask() { + MaskTile::new_Materialized(MaterializedTileMask:: { + fragment: FA::allocate_mask(config), + logical_mask, + config, + }) + } else { + MaskTile::new_Logical(logical_mask) + } + } + + /// Loads the mask data into the fragment, if a tile is given, otherwise only + /// updates the logical mask + pub fn update(&mut self, new_origin: Coords2d, tile: CubeOption>>) { + match self { + MaskTile::Materialized(materialized_tile_mask) => { + materialized_tile_mask + .logical_mask + .update_origin(new_origin); + + materialized_tile_mask.update_tile(tile.unwrap()) + } + MaskTile::Logical(logical_tile_mask) => logical_tile_mask.update_origin(new_origin), + } + } +} + +#[derive(CubeType)] +pub struct LogicalIterOrigin { + row: RuntimeCell, + col: RuntimeCell, +} + +#[cube] +impl LogicalIterOrigin { + fn init() -> LogicalIterOrigin { + LogicalIterOrigin { + row: RuntimeCell::new(0), + col: RuntimeCell::new(0), + } + } + + fn read(&self) -> Coords2d { + (self.row.read(), self.col.read()) + } + + fn update(&mut self, new: Coords2d) { + self.row.store(new.0); + self.col.store(new.1); + } +} + +#[derive(CubeType)] +pub struct LogicalTileMask { + logical_iter_origin: LogicalIterOrigin, + #[cube(comptime)] + partition_pos: Coords2d, + #[cube(comptime)] + causal: bool, + out_of_bounds: CubeOption, + fragment_layout: F, +} + +#[cube] +impl LogicalTileMask { + pub fn should_mask(&self, local_pos: Coords2d) -> bool { + let pos_in_tile = self.fragment_layout.absolute_pos(local_pos); + + let pos = Coords2d::add( + self.logical_iter_origin.read(), + Coords2d::add(self.partition_pos.runtime(), pos_in_tile), + ); + + let causal_masked = self.causal && pos.0 < pos.1; + + let oob_masked = match self.out_of_bounds { + CubeOption::Some(bounds) => !Coords2d::is_in_bounds(&pos, &bounds), + CubeOption::None => false, + }; + + causal_masked || oob_masked + } + + pub fn update_origin(&mut self, new_origin: Coords2d) { + self.logical_iter_origin.update(new_origin); + } +} + +#[derive(CubeType)] +pub struct MaterializedTileMask> { + fragment: FA::Mask, + logical_mask: LogicalTileMask, + #[cube(comptime)] + config: FA::Config, +} + +#[cube] +impl> MaterializedTileMask { + pub fn should_mask(&self, local_pos: Coords2d) -> bool { + let logical_masked = self.logical_mask.should_mask(local_pos); + let materialized_masked = self.fragment.should_mask(local_pos); + + logical_masked || materialized_masked + } + + pub fn update_tile(&mut self, tile: StridedTile>) { + FA::fill_mask(&tile, &mut self.fragment, self.config); + } +} + +#[cube] +impl> FragmentMask for MaskTile { + fn should_mask(&self, local_pos: (u32, u32)) -> bool { + match self { + MaskTile::Materialized(materialized_tile_mask) => { + materialized_tile_mask.should_mask(local_pos) + } + MaskTile::Logical(logical_tile_mask) => logical_tile_mask.should_mask(local_pos), + } + } +} diff --git a/crates/cubecl-attention/src/components/tile/tiles/mod.rs b/crates/cubecl-attention/src/components/tile/tiles/mod.rs index c9f3886a8..d4d205494 100644 --- a/crates/cubecl-attention/src/components/tile/tiles/mod.rs +++ b/crates/cubecl-attention/src/components/tile/tiles/mod.rs @@ -1,5 +1,11 @@ mod accumulator; +mod key_value; +mod mask; +mod query; mod softmax; pub use accumulator::*; +pub use key_value::*; +pub use mask::*; +pub use query::*; pub use softmax::*; diff --git a/crates/cubecl-attention/src/components/tile/tiles/query.rs b/crates/cubecl-attention/src/components/tile/tiles/query.rs new file mode 100644 index 000000000..8123ec5bc --- /dev/null +++ b/crates/cubecl-attention/src/components/tile/tiles/query.rs @@ -0,0 +1,27 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::components::AttentionPrecision; +use crate::components::attention_types::*; +use crate::components::fragment::FragmentAttention; +use cubecl_matmul::components::tile::StridedTile; + +#[derive(CubeType)] +/// Query input to the Tile Attention +pub struct QueryTile> { + pub fragment: FA::Query, +} + +#[cube] +impl> QueryTile { + pub fn new(#[comptime] config: FA::Config) -> QueryTile { + QueryTile:: { + fragment: FA::allocate_query(config), + } + } + + /// Loads the query data into the fragment + pub fn update(&mut self, tile: &StridedTile>) { + FA::fill_query(tile, &mut self.fragment) + } +} diff --git a/crates/cubecl-attention/src/components/tile/tiles/softmax.rs b/crates/cubecl-attention/src/components/tile/tiles/softmax.rs index f08624215..3c7d163b8 100644 --- a/crates/cubecl-attention/src/components/tile/tiles/softmax.rs +++ b/crates/cubecl-attention/src/components/tile/tiles/softmax.rs @@ -2,25 +2,74 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; use crate::components::AttentionPrecision; -use crate::components::TileMask; use crate::components::attention_types::*; -use crate::components::tile::{RowWise, RunningState}; +use crate::components::fragment::FragmentAttention; +use crate::components::fragment::FragmentAttentionConfig; +use crate::components::fragment::{FragmentOps, FragmentOpsExpand}; +use crate::components::tile::MaskTile; +use crate::components::tile::Reducer; +use crate::components::tile::RowWise; +use crate::components::tile::RunningState; +use crate::components::tile::{row_max, row_sum}; + +#[derive(CubeType)] +/// Softmax tile for the Tile Attention +/// +/// This tile is neither an input nor an output, +/// but the intermediate step where the softmax part of attention happens +pub struct SoftmaxTile> { + pub fragment: FA::Softmax, +} #[cube] -pub trait SoftmaxTile: CubeType { - fn init_state() -> RunningState>; +impl> SoftmaxTile { + pub fn new(#[comptime] config: FA::Config) -> Self { + let mut fragment = FA::allocate_softmax(config); + FA::zero_softmax(&mut fragment, config); - fn zero(&mut self); + SoftmaxTile:: { fragment } + } - fn scale_and_mask(&mut self, scale: SM, mask: TileMask); + /// Init the running state used in softmax + pub fn init_state(#[comptime] num_rows: u32) -> RunningState> { + RunningState::>::init(num_rows) + } - fn row_max(&self, base: RowWise>) -> RowWise>; + /// Scale the tile by a constant factor and apply the mask + pub fn scale_and_mask(&mut self, scale: SM, mask: &MaskTile) { + FA::Softmax::scale_and_mask::>(&mut self.fragment, scale, mask); + } - /// Converts scores → probabilities, updates running state, + /// Compute the max of each row, starting with base + /// as first element of the reduction, and storing result in placeholder + pub fn row_max( + &self, + placeholder: &mut RowWise>, + base: &RowWise>, + #[comptime] config: TC, + ) { + row_max::, FA::Softmax, R, TC>(placeholder, base, &self.fragment, config) + } + + /// Converts scores into (unnormalized) probabilities, updates running state, /// and returns the factor needed to scale the accumulator - fn to_prob( + pub fn to_prob( &mut self, state: &mut RunningState>, - max: &RowWise>, - ) -> RowWise>; + new_m: &RowWise>, + rowsum_placeholder: &mut RowWise>, + #[comptime] config: TC, + ) -> RowWise> { + self.fragment.exp_diff(new_m); + + row_sum::, FA::Softmax, R, TC>(rowsum_placeholder, &self.fragment, config); + + let exp_m_diff = state.m().exp_diff(new_m); + + let new_l = exp_m_diff.mul(state.l()).add(rowsum_placeholder); + + state.update(new_m, &new_l); + + exp_m_diff + } } diff --git a/crates/cubecl-attention/src/kernels/algorithm.rs b/crates/cubecl-attention/src/kernels/algorithm.rs index b40b0c3c6..47fc41e28 100644 --- a/crates/cubecl-attention/src/kernels/algorithm.rs +++ b/crates/cubecl-attention/src/kernels/algorithm.rs @@ -1,13 +1,14 @@ use cubecl_core::{Runtime, client::ComputeClient}; +use crate::components::fragment::FragmentAttentionFamily; use crate::components::{ AttentionLineSizes, AttentionPrecision, AttentionProblem, AttentionSelection, AttentionSetupError, AvailableLineSizes, batch::BatchAttentionFamily, - global::GlobalAttentionFamily, stage::StageAttentionFamily, tile::TileAttentionFamily, + global::GlobalAttentionFamily, stage::StageAttentionFamily, }; pub trait Algorithm { - type TileAttention: TileAttentionFamily; + type FragmentAttention: FragmentAttentionFamily; type StageAttention: StageAttentionFamily; type GlobalAttention: GlobalAttentionFamily; type BatchAttention: BatchAttentionFamily; @@ -17,7 +18,7 @@ pub trait Algorithm { } fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &AttentionProblem, selection: &AttentionSelection, line_sizes: &AttentionLineSizes, diff --git a/crates/cubecl-attention/src/kernels/dummy.rs b/crates/cubecl-attention/src/kernels/dummy.rs index 09de33209..dfcc11715 100644 --- a/crates/cubecl-attention/src/kernels/dummy.rs +++ b/crates/cubecl-attention/src/kernels/dummy.rs @@ -1,29 +1,51 @@ use cubecl_matmul::components::{global::PartitionedStageFamily, stage::StridedStageFamily}; +use crate::components::fragment::accelerated::AcceleratedFragmentAttention; +use crate::components::fragment::dummy_register::DummyRegisterFragmentAttention; +use crate::components::stage::plane::PlaneKVReuseStageAttentionFamily; use crate::{ components::{ - AvailableLineSizes, - batch::dummy::DummyBatchAttentionFamily, - global::dummy::DummyGlobalAttentionFamily, - stage::dummy::DummyStageAttentionFamily, - tile::dummy::{DummyTileAttentionFamily, dummy_register::DummyRegisterAttentionMatmul}, + AvailableLineSizes, batch::simple::SimpleBatchAttentionFamily, + global::simple::SimpleGlobalAttentionFamily, }, kernels::Algorithm, }; -pub struct DummyAlgorithm {} +pub struct DummyRegisterAlgorithm {} +pub struct DummyAcceleratedAlgorithm {} -impl Algorithm for DummyAlgorithm { - // type TileAttention = DummyTileAttentionFamily; - type TileAttention = DummyTileAttentionFamily; - type StageAttention = DummyStageAttentionFamily< - Self::TileAttention, +impl Algorithm for DummyRegisterAlgorithm { + type FragmentAttention = DummyRegisterFragmentAttention; + type StageAttention = PlaneKVReuseStageAttentionFamily< + Self::FragmentAttention, StridedStageFamily, StridedStageFamily, PartitionedStageFamily, >; - type GlobalAttention = DummyGlobalAttentionFamily; - type BatchAttention = DummyBatchAttentionFamily; + type GlobalAttention = SimpleGlobalAttentionFamily; + type BatchAttention = SimpleBatchAttentionFamily; + + fn filter_line_sizes(_available_line_sizes: AvailableLineSizes) -> AvailableLineSizes { + AvailableLineSizes { + query: vec![1], + key: vec![1], + value: vec![1], + mask: vec![1], + out: vec![1], + } + } +} + +impl Algorithm for DummyAcceleratedAlgorithm { + type FragmentAttention = AcceleratedFragmentAttention; + type StageAttention = PlaneKVReuseStageAttentionFamily< + Self::FragmentAttention, + StridedStageFamily, + StridedStageFamily, + PartitionedStageFamily, + >; + type GlobalAttention = SimpleGlobalAttentionFamily; + type BatchAttention = SimpleBatchAttentionFamily; fn filter_line_sizes(_available_line_sizes: AvailableLineSizes) -> AvailableLineSizes { AvailableLineSizes { diff --git a/crates/cubecl-attention/src/kernels/mod.rs b/crates/cubecl-attention/src/kernels/mod.rs index fe754bd2d..c54a2f9ce 100644 --- a/crates/cubecl-attention/src/kernels/mod.rs +++ b/crates/cubecl-attention/src/kernels/mod.rs @@ -1,5 +1,6 @@ /// Very slow attention implementation. Temporary pub mod dummy; +pub mod unit; mod algorithm; diff --git a/crates/cubecl-attention/src/kernels/unit.rs b/crates/cubecl-attention/src/kernels/unit.rs new file mode 100644 index 000000000..4753d94f8 --- /dev/null +++ b/crates/cubecl-attention/src/kernels/unit.rs @@ -0,0 +1,35 @@ +use cubecl_matmul::components::{global::PartitionedStageFamily, stage::StridedStageFamily}; + +use crate::components::fragment::unit_register::UnitRegisterFragmentAttention; +use crate::components::stage::unit::UnitKVReuseStageAttentionFamily; +use crate::{ + components::{ + AvailableLineSizes, batch::simple::SimpleBatchAttentionFamily, + global::simple::SimpleGlobalAttentionFamily, + }, + kernels::Algorithm, +}; + +pub struct UnitAlgorithm {} + +impl Algorithm for UnitAlgorithm { + type FragmentAttention = UnitRegisterFragmentAttention; + type StageAttention = UnitKVReuseStageAttentionFamily< + Self::FragmentAttention, + StridedStageFamily, + StridedStageFamily, + PartitionedStageFamily, + >; + type GlobalAttention = SimpleGlobalAttentionFamily; + type BatchAttention = SimpleBatchAttentionFamily; + + fn filter_line_sizes(_available_line_sizes: AvailableLineSizes) -> AvailableLineSizes { + AvailableLineSizes { + query: vec![1], + key: vec![1], + value: vec![1], + mask: vec![1], + out: vec![1], + } + } +} diff --git a/crates/cubecl-attention/src/lib.rs b/crates/cubecl-attention/src/lib.rs index 9730e998d..96f0a476e 100644 --- a/crates/cubecl-attention/src/lib.rs +++ b/crates/cubecl-attention/src/lib.rs @@ -1,9 +1,12 @@ +#![allow(clippy::explicit_counter_loop)] +#![allow(clippy::manual_is_multiple_of)] + mod base; /// Components for matrix multiplication pub mod components; -/// Contains matmul kernels +/// Contains attention kernels pub mod kernels; -/// Tests for matmul kernels +/// Tests for attention kernels #[cfg(feature = "export_tests")] pub mod tests; diff --git a/crates/cubecl-attention/src/tests/attention_test_launcher.rs b/crates/cubecl-attention/src/tests/attention_test_launcher.rs index 6e4682d9a..5e1201fb1 100644 --- a/crates/cubecl-attention/src/tests/attention_test_launcher.rs +++ b/crates/cubecl-attention/src/tests/attention_test_launcher.rs @@ -1,6 +1,7 @@ use cubecl_core::prelude::*; use cubecl_core::server::Allocation; use cubecl_core::{CubeElement, server}; +use cubecl_std::CubeOptionArgs; use crate::components::args::TensorInputsLaunch; use crate::components::batch::BatchAttentionConfig; @@ -22,7 +23,7 @@ pub struct TensorRawParts { /// Test the correctness of the specified Attention on the given device, /// against a naive CPU implementation over the given problem pub fn test_attention_algorithm( - client: ComputeClient, + client: ComputeClient, problem: AttentionProblem, selection: AttentionSelection, ) where @@ -45,20 +46,27 @@ pub fn test_attention_algorithm( let query = tensor_raw_parts_input::(&client, &problem, AttentionIdent::Query, 12); let key = tensor_raw_parts_input::(&client, &problem, AttentionIdent::Key, 34); let value = tensor_raw_parts_input::(&client, &problem, AttentionIdent::Value, 56); - // let mask = tensor_raw_parts_input::(&client, &problem, Ident::Mask, 78); + let mask = match problem.masked { + true => Some(tensor_raw_parts_input::( + &client, + &problem, + AttentionIdent::Mask, + 78, + )), + false => None, + }; let out = tensor_raw_parts_output::(&client, &problem); let line_sizes = AvailableLineSizes::from_elem_types::( - &P::EG::as_type_native_unchecked(), - &P::EM::as_type_native_unchecked(), - &P::EG::as_type_native_unchecked(), + size_of::(), + size_of::(), + size_of::(), ); let line_sizes = A::filter_line_sizes(line_sizes); let line_sizes = line_sizes .filter_with_tensor(AttentionIdent::Query, &query.strides, &query.shape) .filter_with_tensor(AttentionIdent::Key, &key.strides, &key.shape) .filter_with_tensor(AttentionIdent::Value, &value.strides, &value.shape) - // .filter_with_tensor(Ident::Mask, &mask.strides, &mask.shape) .filter_with_tensor(AttentionIdent::Out, &out.strides, &out.shape) .pick_max() .unwrap(); @@ -104,12 +112,15 @@ pub fn test_attention_algorithm( &value.shape, line_sizes.value, ), - // TensorArg::::from_raw_parts::( - // &mask.handle, - // &mask.strides, - // &mask.shape, - // line_sizes.mask, - // ), + match mask.as_ref() { + Some(m) => CubeOptionArgs::Some(TensorArg::::from_raw_parts::( + &m.handle, + &m.strides, + &m.shape, + line_sizes.mask, + )), + None => CubeOptionArgs::None, + }, ), TensorArg::::from_raw_parts::( &out.handle, @@ -126,7 +137,8 @@ pub fn test_attention_algorithm( &query.original_data.unwrap(), &key.original_data.unwrap(), &value.original_data.unwrap(), - None, + mask.as_ref() + .map(|m| m.original_data.as_ref().unwrap().as_slice()), &problem, &client, out.handle, @@ -136,7 +148,7 @@ pub fn test_attention_algorithm( } fn tensor_raw_parts_input( - client: &ComputeClient, + client: &ComputeClient, problem: &AttentionProblem, ident: AttentionIdent, sample_seed: u64, @@ -163,7 +175,7 @@ where } fn tensor_raw_parts_output( - client: &ComputeClient, + client: &ComputeClient, problem: &AttentionProblem, ) -> TensorRawParts { let zero = P::EG::from_int(0); diff --git a/crates/cubecl-attention/src/tests/macros/mod.rs b/crates/cubecl-attention/src/tests/macros/mod.rs index a593b1e5d..5cf294c28 100644 --- a/crates/cubecl-attention/src/tests/macros/mod.rs +++ b/crates/cubecl-attention/src/tests/macros/mod.rs @@ -1,801 +1,96 @@ use cubecl_core::{Runtime, client::ComputeClient}; +mod suite; + use crate::{ components::{ AttentionProblem, AttentionSelection, AttentionTilingScheme, batch::HypercubeSelection, }, - kernels::dummy::DummyAlgorithm, + kernels::Algorithm, tests::attention_test_launcher::test_attention_algorithm, }; -pub fn attention_test_launch( - client: ComputeClient, +#[derive(Default)] +pub struct TestOptions { + pub reuse_key_value: bool, + pub two_rows_in_array_tile: bool, +} + +pub fn attention_test_launch( + client: ComputeClient, tiling_scheme: AttentionTilingScheme, problem: AttentionProblem, - reuse_key_value: bool, + test_options: TestOptions, ) { let selection = AttentionSelection { hypercube_selection: HypercubeSelection {}, plane_dim: 32, tiling_scheme, - reuse_key_value, + reuse_key_value: test_options.reuse_key_value, + two_rows_in_array_tile: test_options.two_rows_in_array_tile, }; - test_attention_algorithm::(client, problem, selection); + test_attention_algorithm::(client, problem, selection); } #[macro_export] macro_rules! testgen_attention { () => { - #[cfg(feature = "attention_tests")] - mod attention { - use super::*; - use cubecl_attention::components::{ - AttentionPartitionSize, AttentionProblem, AttentionStageSize, AttentionTileSize, - AttentionTilingScheme, - }; + use super::*; - #[test] - fn attention_8_8_8_8() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { + #[cfg(feature = "attention_tests")] + mod attention_dummy_register { + type Algorithm = cubecl_attention::kernels::dummy::DummyRegisterAlgorithm; + const TILE_SIZE: cubecl_attention::components::AttentionTileSize = + cubecl_attention::components::AttentionTileSize { seq_q: 8, seq_kv: 8, head_dim: 8, val_dim: 8, }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 1, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ) - } + const STAGE_Q_BASE: u32 = 1; - #[test] - fn attention_9_9_9_9() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 9, - seq_kv: 9, - head_dim: 9, - val_dim: 9, - }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 1, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ) - } - - #[test] - fn attention_7_3_10_10() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 7, - seq_kv: 3, - head_dim: 10, - val_dim: 10, - }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 1, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ) - } - - #[test] - fn attention_8_q16() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 1, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: 16, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ) - } + $crate::testgen_attention_suite!(); + } - #[test] - fn attention_8_q4() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 1, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, + #[cfg(feature = "attention_tests")] + mod attention_unit { + type Algorithm = cubecl_attention::kernels::unit::UnitAlgorithm; + const TILE_SIZE: cubecl_attention::components::AttentionTileSize = + cubecl_attention::components::AttentionTileSize { seq_q: 4, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ) - } - - #[test] - fn attention_partition_q2() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 2, - seq_kv: 1, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ) - } - - #[test] - fn attention_partition_hd2() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 1, - head_dim: 2, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ) - } - - #[test] - fn attention_partition_kv2() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 3, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ); - } - - #[test] - fn attention_partition_vd2() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 5, - head_dim: 7, + seq_kv: 4, + head_dim: 4, val_dim: 4, }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 1, - head_dim: 7, - val_dim: 8, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ); - } - - #[test] - fn attention_partition_all2() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 2, - seq_kv: 2, - head_dim: 2, - val_dim: 2, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ); - } + const STAGE_Q_BASE: u32 = 32; - #[test] - fn attention_global_2() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 1, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let num_iterations = 2; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize * num_iterations, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ); - } - - #[test] - fn attention_partition_kv2_global_2() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 2, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let num_iterations = 2; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize * num_iterations, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ); - } - - #[test] - fn attention_partition_kv1_global2_with_oob() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 1, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 2 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize * 2 + 1, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ); - } - - #[test] - fn attention_partition_oob_in_q() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 2, - seq_kv: 1, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: 1, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ); - } - - #[test] - fn attention_partition_kv2_with_oob() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 2, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize + 1, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ); - } - - #[test] - fn attention_stage2() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 1, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 2 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ); - } - - #[test] - fn attention_stage4() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 1, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 4 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ); - } - - #[test] - fn attention_stage2_problem4() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 1, - head_dim: 1, - val_dim: 1, - }; - let stage_size = AttentionStageSize { seq_q: 2 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize * 2, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize * 2, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ); - } + $crate::testgen_attention_suite!(); + } - #[test] - fn attention_stage2_partition_all2() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { + #[cfg(feature = "attention_tests")] + mod attention_dummy_accelerated { + type Algorithm = cubecl_attention::kernels::dummy::DummyAcceleratedAlgorithm; + #[cfg(target_os = "macos")] + const TILE_SIZE: cubecl_attention::components::AttentionTileSize = + cubecl_attention::components::AttentionTileSize { seq_q: 8, seq_kv: 8, head_dim: 8, val_dim: 8, }; - let partition_size = AttentionPartitionSize { - seq_q: 2, - seq_kv: 2, - head_dim: 2, - val_dim: 2, - }; - let stage_size = AttentionStageSize { seq_q: 2 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, + #[cfg(not(target_os = "macos"))] + const TILE_SIZE: cubecl_attention::components::AttentionTileSize = + cubecl_attention::components::AttentionTileSize { + seq_q: 16, + seq_kv: 16, + head_dim: 16, + val_dim: 16, }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - false, - ); - } + const STAGE_Q_BASE: u32 = 1; - #[test] - fn attention_reuse_key_value() { - let client = TestRuntime::client(&Default::default()); - let tile_size = AttentionTileSize { - seq_q: 8, - seq_kv: 8, - head_dim: 8, - val_dim: 8, - }; - let partition_size = AttentionPartitionSize { - seq_q: 1, - seq_kv: 1, - head_dim: 2, - val_dim: 2, - }; - let stage_size = AttentionStageSize { seq_q: 1 }; - let tiling_scheme = AttentionTilingScheme { - tile_size, - partition_size, - stage_size, - }; - let problem = AttentionProblem { - batch: 1, - num_heads: 1, - seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, - seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, - head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, - val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, - masked: false, - }; - $crate::tests::macros::attention_test_launch::( - client, - tiling_scheme, - problem, - true, - ); - } + // Deactivated + // $crate::testgen_attention_suite!(); } }; } diff --git a/crates/cubecl-attention/src/tests/macros/suite.rs b/crates/cubecl-attention/src/tests/macros/suite.rs new file mode 100644 index 000000000..1683021fc --- /dev/null +++ b/crates/cubecl-attention/src/tests/macros/suite.rs @@ -0,0 +1,958 @@ +#[macro_export] +macro_rules! testgen_attention_suite { + () => { + use super::*; + use cubecl_attention::components::{ + AttentionPartitionSize, AttentionProblem, AttentionStageSize, AttentionTileSize, + AttentionTilingScheme, + }; + use $crate::tests::macros::{TestOptions, attention_test_launch}; + + #[test] + fn attention_one_tile() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } + + #[test] + fn attention_two_rows_in_array_tile() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + TestOptions { + two_rows_in_array_tile: true, + ..Default::default() + }, + ) + } + + #[test] + fn attention_one_tile_seqq16() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: 16, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } + + #[test] + fn attention_one_tile_seqq4() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: 4, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } + + #[test] + fn attention_partition_seqq2() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 2, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } + + #[test] + fn attention_partition_hd2() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 2, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } + + #[test] + fn attention_partition_kv2() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 3, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ); + } + + #[test] + fn attention_partition_vd2() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 2, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ); + } + + #[test] + fn attention_partition_all2() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 2, + seq_kv: 2, + head_dim: 2, + val_dim: 2, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ); + } + + #[test] + fn attention_global_2() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let num_iterations = 2; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize * num_iterations, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ); + } + + #[test] + fn attention_partition_kv2_global_2() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 2, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let num_iterations = 2; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize * num_iterations, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ); + } + + #[test] + fn attention_partition_kv1_global1_with_oob() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize - 1, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ); + } + + #[test] + fn attention_partition_kv1_global2_with_oob() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: 2 * STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize * 2 + 1, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ); + } + + #[test] + fn attention_partition_oob_in_q() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 2, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: 1, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ); + } + + #[test] + fn attention_partition_kv2_with_oob() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 2, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize + 9, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ); + } + + #[test] + fn attention_stage2() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: 2 * STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ); + } + + #[test] + fn attention_stage4() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: 4 * STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ); + } + + #[test] + fn attention_stage2_problem4() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: 2 * STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize * 2, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize * 2, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ); + } + + #[test] + fn attention_stage2_partition_all2() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 2, + seq_kv: 2, + head_dim: 2, + val_dim: 2, + }; + let stage_size = AttentionStageSize { + seq_q: 2 * STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ); + } + + #[test] + fn attention_reuse_key_value() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 2, + val_dim: 2, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + TestOptions { + reuse_key_value: true, + ..Default::default() + }, + ); + } + + #[test] + fn attention_double_row_wise() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + TestOptions { + two_rows_in_array_tile: true, + ..Default::default() + }, + ); + } + + #[test] + fn attention_one_tile_masked() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: true, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } + + #[test] + fn attention_one_tile_causal() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: false, + causal: true, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } + + #[test] + fn attention_one_tile_masked_causal() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: true, + causal: true, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } + + #[test] + fn attention_masked_oob() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize - 1, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: true, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } + + #[test] + fn attention_masked_larger() { + let client = TestRuntime::client(&Default::default()); + + let partition_size = AttentionPartitionSize { + seq_q: 1, + seq_kv: 1, + head_dim: 1, + val_dim: 1, + }; + let stage_size = AttentionStageSize { + seq_q: STAGE_Q_BASE, + }; + let tiling_scheme = AttentionTilingScheme { + tile_size: TILE_SIZE, + partition_size, + stage_size, + }; + let problem = AttentionProblem { + batch: 1, + num_heads: 1, + seq_q: tiling_scheme.elements_in_stage_seq_q() as usize, + seq_kv: tiling_scheme.elements_in_partition_seq_kv() as usize * 2, + head_dim: tiling_scheme.elements_in_partition_head_dim() as usize, + val_dim: tiling_scheme.elements_in_partition_val_dim() as usize, + masked: true, + causal: false, + }; + attention_test_launch::( + client, + tiling_scheme, + problem, + Default::default(), + ) + } + }; +} diff --git a/crates/cubecl-attention/src/tests/test_utils.rs b/crates/cubecl-attention/src/tests/test_utils.rs index d90467101..180e7944d 100644 --- a/crates/cubecl-attention/src/tests/test_utils.rs +++ b/crates/cubecl-attention/src/tests/test_utils.rs @@ -33,7 +33,7 @@ pub trait TestPrecision { value: &[Self::EG], mask: Option<&[Self::EM]>, problem: &AttentionProblem, - client: &ComputeClient, + client: &ComputeClient, out: server::Handle, shape: &[usize], strides: &[usize], @@ -56,9 +56,9 @@ where query: &[EG], key: &[EG], value: &[EG], - mask: Option<&[u8]>, + mask: Option<&[Self::EM]>, problem: &AttentionProblem, - client: &ComputeClient, + client: &ComputeClient, out: server::Handle, shape: &[usize], strides: &[usize], @@ -82,8 +82,8 @@ where // Need to compensate for the temporary conversion to f16/tf32 let epsilon = match maybe_f16 || maybe_tf32 { - true => 10e-5 / EG::EPSILON.to_f32().unwrap() * half::f16::EPSILON.to_f32(), - false => 10e-5, + true => 10e-3 / EG::EPSILON.to_f32().unwrap() * half::f16::EPSILON.to_f32(), + false => 10e-3, }; let expected = flash_attention_v2_cpu::(query, key, value, mask, problem) @@ -100,7 +100,7 @@ where /// Compares the content of a handle to a given slice of f32. pub(crate) fn assert_equals_approx( - client: &ComputeClient, + client: &ComputeClient, output: server::Handle, shape: &[usize], strides: &[usize], @@ -238,7 +238,7 @@ impl CastInto for i32 { pub trait Sampleable: Sized + CubePrimitive { fn sample( - client: &ComputeClient, + client: &ComputeClient, shape: &[usize], seed: u64, ) -> TensorHandle; @@ -249,11 +249,11 @@ macro_rules! sample_float { $( impl Sampleable for $t { - fn sample(client: &ComputeClient, shape: &[usize], seed: u64) -> TensorHandle:: { + fn sample(client: &ComputeClient, shape: &[usize], seed: u64) -> TensorHandle:: { cubecl_random::seed(seed); let output = TensorHandle::::empty(client, shape.to_vec()); - cubecl_random::random_uniform::(&client, Self::from_int(-1), Self::from_int(1), output.as_ref()); + cubecl_random::random_uniform::(&client, Self::from_int(-50), Self::from_int(50), output.as_ref()); output } @@ -266,11 +266,10 @@ sample_float!(half::f16); sample_float!(half::bf16); sample_float!(f32); sample_float!(f64); -sample_float!(u8); impl Sampleable for flex32 { fn sample( - client: &ComputeClient, + client: &ComputeClient, shape: &[usize], seed: u64, ) -> TensorHandle { @@ -290,7 +289,7 @@ impl Sampleable for flex32 { impl Sampleable for tf32 { fn sample( - client: &ComputeClient, + client: &ComputeClient, shape: &[usize], seed: u64, ) -> TensorHandle { @@ -310,7 +309,7 @@ impl Sampleable for tf32 { impl Sampleable for bool { fn sample( - client: &ComputeClient, + client: &ComputeClient, shape: &[usize], seed: u64, ) -> TensorHandle { @@ -323,6 +322,21 @@ impl Sampleable for bool { } } +impl Sampleable for u8 { + fn sample( + client: &ComputeClient, + shape: &[usize], + seed: u64, + ) -> TensorHandle { + cubecl_random::seed(seed); + let output = TensorHandle::::empty(client, shape.to_vec()); + + cubecl_random::random_bernoulli::(client, 0.5, output.as_ref()); + + output + } +} + pub(crate) fn flash_attention_v2_cpu( query: &[P::EG], key: &[P::EG], @@ -334,11 +348,13 @@ where { let batch = problem.batch; let seq_q = problem.seq_q; - let seq_k = problem.seq_kv; + let seq_kv = problem.seq_kv; let num_heads = problem.num_heads; let head_dim = problem.head_dim; let val_dim = problem.val_dim; + let masked = mask.is_some(); + assert!(problem.masked == masked); // Precompute strides for indexing let query_strides = strides(problem, AttentionIdent::Query); @@ -364,8 +380,8 @@ where // For each K/V block let mut k_block_start = 0usize; - while k_block_start < seq_k { - let k_block_end = std::cmp::min(seq_k, k_block_start + seq_k); + while k_block_start < seq_kv { + let k_block_end = std::cmp::min(seq_kv, k_block_start + seq_kv); let cur_block_len = k_block_end - k_block_start; // Step A: compute S_block[j'] = Q_i · K_{j'} for j' in block @@ -390,13 +406,15 @@ where // apply scale (1/sqrt(dk)) dot *= scale; - // apply mask (for masked positions set -inf) - let s_val = if masked { + let s_val = if problem.causal && j > i { + P::EA::new(f32::NEG_INFINITY) + } else if masked { let m_idx = b * mask_strides[0] + i * mask_strides[1] + h * mask_strides[2] + j * mask_strides[3]; let m_val = mask.unwrap()[m_idx].cast_into(); + if m_val != P::EM::from_int(0) { P::EA::new(f32::NEG_INFINITY) } else { diff --git a/crates/cubecl-common/Cargo.toml b/crates/cubecl-common/Cargo.toml index 011f377e9..1806c519b 100644 --- a/crates/cubecl-common/Cargo.toml +++ b/crates/cubecl-common/Cargo.toml @@ -19,7 +19,13 @@ default = ["std"] fp4 = ["float4"] fp8 = ["float8"] serde = ["serde_bytes"] -std = ["rand/std", "futures-lite", "rand/thread_rng", "serde_json?/std"] +std = [ + "rand/std", + "futures-lite", + "rand/thread_rng", + "serde_json?/std", + "parking_lot", +] [dependencies] @@ -58,6 +64,8 @@ futures-lite = { workspace = true, features = [ "std", ], default-features = false, optional = true } +parking_lot = { version = "0.12.5", default-features = false, optional = true } + [target.'cfg(target_has_atomic = "ptr")'.dependencies] spin = { workspace = true, features = ["mutex", "spin_mutex"] } diff --git a/crates/cubecl-common/src/device.rs b/crates/cubecl-common/src/device.rs index 0eb3f07de..94fe95ff8 100644 --- a/crates/cubecl-common/src/device.rs +++ b/crates/cubecl-common/src/device.rs @@ -10,7 +10,7 @@ pub struct DeviceId { } /// Device trait for all cubecl devices. -pub trait Device: Default + Clone + core::fmt::Debug + Send + Sync { +pub trait Device: Default + Clone + core::fmt::Debug + Send + Sync + 'static { /// Create a device from its [id](DeviceId). fn from_id(device_id: DeviceId) -> Self; /// Retrieve the [device id](DeviceId) from the device. @@ -43,3 +43,567 @@ impl PartialOrd for DeviceId { Some(self.cmp(other)) } } + +pub use context::*; + +#[cfg(feature = "std")] +mod reentrant { + pub use parking_lot::{ReentrantMutex, ReentrantMutexGuard}; +} + +// MutCell and MutGuard differs in implementation whether `std` is activated. + +#[cfg(feature = "std")] +mod cell { + use core::cell::{RefCell, RefMut}; + use core::ops::DerefMut; + + pub type MutCell = RefCell; + pub type MutGuard<'a, T> = RefMut<'a, T>; + + pub unsafe fn borrow_mut_split<'a, T>(cell: &MutCell) -> (&'a mut T, MutGuard<'_, T>) { + let mut guard = cell.borrow_mut(); + let item = guard.deref_mut(); + let item: &'a mut T = unsafe { core::mem::transmute(item) }; + + (item, guard) + } +} + +#[cfg(not(feature = "std"))] +mod cell { + use core::ops::{Deref, DerefMut}; + + pub struct MutGuard<'a, T> { + guard: spin::MutexGuard<'a, T>, + } + + pub struct MutCell { + lock: spin::Mutex, + } + + impl MutCell { + pub fn new(item: T) -> Self { + Self { + lock: spin::Mutex::new(item), + } + } + } + + impl<'a, T> Deref for MutGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.guard.deref() + } + } + + impl<'a, T> DerefMut for MutGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.guard.deref_mut() + } + } + + impl MutCell { + pub fn try_borrow_mut(&self) -> Result, ()> { + match self.lock.try_lock() { + Some(guard) => Ok(MutGuard { guard }), + None => Err(()), + } + } + } + + pub unsafe fn borrow_mut_split<'a, T>( + cell: &MutCell, + ) -> (&'a mut T, spin::MutexGuard<'_, T>) { + let mut guard = cell.lock.lock(); + let item = guard.deref_mut(); + let item: &'a mut T = unsafe { core::mem::transmute(item) }; + + (item, guard) + } +} + +#[cfg(not(feature = "std"))] +mod reentrant { + use core::ops::Deref; + + pub struct ReentrantMutex { + inner: spin::RwLock, + } + + impl ReentrantMutex { + pub fn new(item: T) -> Self { + Self { + inner: spin::RwLock::new(item), + } + } + } + + pub struct ReentrantMutexGuard<'a, T> { + guard: spin::RwLockReadGuard<'a, T>, + } + + impl<'a, T> Deref for ReentrantMutexGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.guard.deref() + } + } + + impl ReentrantMutex { + pub fn lock(&self) -> ReentrantMutexGuard<'_, T> { + let guard = self.inner.read(); + ReentrantMutexGuard { guard } + } + } +} + +mod context { + use super::cell::{MutCell, MutGuard}; + use alloc::boxed::Box; + use core::{ + any::{Any, TypeId}, + marker::PhantomData, + }; + use hashbrown::HashMap; + + use super::reentrant::{ReentrantMutex, ReentrantMutexGuard}; + + use crate::{device::cell::borrow_mut_split, stub::Arc}; + + use super::{Device, DeviceId}; + + /// A state that can be saved inside the [DeviceContext]. + pub trait DeviceState: Send + 'static { + /// Initialize a new state on the given device. + fn init(device_id: DeviceId) -> Self; + } + + /// Handle for accessing a [DeviceState] associated with a specific device. + pub struct DeviceContext { + lock: DeviceStateLock, + device_id: DeviceId, + _phantom: PhantomData, + } + + /// There is nothing to read without a lock, and it's fine to allow locking a context reference. + unsafe impl Sync for DeviceContext {} + + impl Clone for DeviceContext { + fn clone(&self) -> Self { + Self { + lock: self.lock.clone(), + _phantom: self._phantom, + device_id: self.device_id, + } + } + } + + /// Guard providing mutable access to [DeviceState]. + /// + /// Automatically releases the lock when dropped. + pub struct DeviceStateGuard<'a, S: DeviceState> { + guard_ref: Option>>, + guard_mutex: Option>, + _phantom: PhantomData, + } + + /// Guard making sure only the locked device can be used. + /// + /// Automatically releases the lock when dropped. + pub struct DeviceGuard<'a> { + guard_mutex: Option>, + } + + impl<'a, S: DeviceState> Drop for DeviceStateGuard<'a, S> { + fn drop(&mut self) { + // Important to drop the ref before. + self.guard_ref = None; + self.guard_mutex = None; + } + } + + impl<'a> Drop for DeviceGuard<'a> { + fn drop(&mut self) { + self.guard_mutex = None; + } + } + + impl<'a, S: DeviceState> core::ops::Deref for DeviceStateGuard<'a, S> { + type Target = S; + + fn deref(&self) -> &Self::Target { + self.guard_ref + .as_ref() + .expect("The guard to not be dropped") + .downcast_ref() + .expect("The type to be correct") + } + } + + impl<'a, S: DeviceState> core::ops::DerefMut for DeviceStateGuard<'a, S> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.guard_ref + .as_mut() + .expect("The guard to not be dropped") + .downcast_mut() + .expect("The type to be correct") + } + } + + impl DeviceContext { + /// Creates a [DeviceState] handle for the given device. + /// + /// Registers the device-type combination globally if needed. + pub fn locate(device: &D) -> Self { + DeviceStateLock::locate(device) + } + + /// Inserts a new state associated with the device. + /// + /// # Returns + /// + /// An error if the device already has a registered state. + pub fn insert( + device: &D, + state_new: S, + ) -> Result { + let lock = Self::locate(device); + let id = TypeId::of::(); + + let state = lock.lock.lock.lock(); + + // It is safe for the same reasons enumerated in the lock function. + let (map, map_guard) = unsafe { borrow_mut_split(&state.map) }; + + if map.contains_key(&id) { + return Err(alloc::format!( + "A server is still registered for device {device:?}" + )); + } + + let any: Box = Box::new(state_new); + let cell = MutCell::new(any); + + map.insert(id, cell); + + core::mem::drop(map_guard); + core::mem::drop(state); + + Ok(lock) + } + + /// Locks the current device making sure this device can be used. + pub fn lock_device(&self) -> DeviceGuard<'_> { + let state = self.lock.lock.lock(); + + DeviceGuard { + guard_mutex: Some(state), + } + } + + /// Acquires exclusive mutable access to the [DeviceState]. + /// + /// The same device can lock multiple types at the same time. + /// + /// # Panics + /// + /// If the same state type is locked multiple times on the same thread. + /// This can only happen with recursive locking of the same state, which isn't allowed + /// since having multiple mutable references to the same state isn't valid. + pub fn lock(&self) -> DeviceStateGuard<'_, S> { + let key = TypeId::of::(); + let state = self.lock.lock.lock(); + + // It is safe for multiple reasons. + // + // 1. The mutability of the map is handled by each map entry with a RefCell. + // Therefore, multiple mutable references to a map entry are checked. + // 2. Map items are never cleaned up, therefore it's impossible to remove the validity of + // an entry. + // 3. Because of the lock, no race condition is possible. + // + // The reason why unsafe is necessary is that the [DeviceStateGuard] doesn't keep track + // of the borrowed map entry lifetime. But since it keeps track of both the [RefCell] + // and the [ReentrantMutex] guards, it is fine to erase the lifetime here. + let (map, map_guard) = unsafe { borrow_mut_split(&state.map) }; + + if !map.contains_key(&key) { + let state_default = S::init(self.device_id); + let any: Box = Box::new(state_default); + let cell = MutCell::new(any); + + map.insert(key, cell); + } + + let value = map + .get(&key) + .expect("Just validated the map contains the key."); + let ref_guard = match value.try_borrow_mut() { + Ok(guard) => guard, + #[cfg(feature = "std")] + Err(_) => panic!( + "State {} is already borrowed by the current thread {:?}", + core::any::type_name::(), + std::thread::current().id() + ), + #[cfg(not(feature = "std"))] + Err(_) => panic!("State {} is already borrowed", core::any::type_name::(),), + }; + + core::mem::drop(map_guard); + + DeviceStateGuard { + guard_ref: Some(ref_guard), + guard_mutex: Some(state), + _phantom: PhantomData, + } + } + } + + type Key = (DeviceId, TypeId); + + static GLOBAL: spin::Mutex = spin::Mutex::new(DeviceLocator { state: None }); + + struct DeviceLocator { + state: Option>, + } + + #[derive(Clone)] + struct DeviceStateLock { + lock: Arc>, + } + + struct DeviceStateMap { + map: MutCell>>>, + } + + impl DeviceStateLock { + fn locate(device: &D) -> DeviceContext { + let id = device.to_id(); + let key = (id, TypeId::of::()); + let mut global = GLOBAL.lock(); + + let map = match &mut global.state { + Some(state) => state, + None => { + global.state = Some(HashMap::default()); + global.state.as_mut().expect("Just created Option::Some") + } + }; + + let lock = match map.get(&key) { + Some(value) => value.clone(), + None => { + let state = DeviceStateMap::new(); + + let value = DeviceStateLock { + lock: Arc::new(ReentrantMutex::new(state)), + }; + + map.insert(key, value); + map.get(&key).expect("Just inserted the key/value").clone() + } + }; + + DeviceContext { + lock, + device_id: id, + _phantom: PhantomData, + } + } + } + + impl DeviceStateMap { + fn new() -> Self { + Self { + map: MutCell::new(HashMap::new()), + } + } + } + + #[cfg(test)] + mod tests { + use core::{ + ops::{Deref, DerefMut}, + time::Duration, + }; + + use super::*; + + #[test] + fn can_have_multiple_mutate_state() { + let device1 = TestDevice::<0>::new(0); + let device2 = TestDevice::<1>::new(0); + + let state1_usize = DeviceContext::::locate(&device1); + let state1_u32 = DeviceContext::::locate(&device1); + let state2_usize = DeviceContext::::locate(&device2); + + let mut guard_usize = state1_usize.lock(); + let mut guard_u32 = state1_u32.lock(); + + let val_usize = guard_usize.deref_mut(); + let val_u32 = guard_u32.deref_mut(); + + *val_usize += 1; + *val_u32 += 2; + + assert_eq!(*val_usize, 1); + assert_eq!(*val_u32, 2); + + core::mem::drop(guard_usize); + core::mem::drop(guard_u32); + + let mut guard_usize = state2_usize.lock(); + + let val_usize = guard_usize.deref_mut(); + *val_usize += 1; + + assert_eq!(*val_usize, 1); + + core::mem::drop(guard_usize); + + let guard_usize = state1_usize.lock(); + let guard_u32 = state1_u32.lock(); + + let val_usize = guard_usize.deref(); + let val_u32 = guard_u32.deref(); + + assert_eq!(*val_usize, 1); + assert_eq!(*val_u32, 2); + } + + #[test] + #[should_panic] + fn can_not_have_multiple_mut_ref_to_same_state() { + let device1 = TestDevice::<0>::new(0); + + struct DummyState; + + impl DeviceState for DummyState { + fn init(_device_id: DeviceId) -> Self { + DummyState + } + } + + fn recursive(total: usize, state: &DeviceContext) { + let _guard = state.lock(); + + if total > 0 { + recursive(total - 1, state); + } + } + + recursive(5, &DeviceContext::locate(&device1)); + } + + #[test] + fn work_with_many_threads() { + let num_threads = 32; + let handles: Vec<_> = (0..num_threads) + .map(|i| std::thread::spawn(move || thread_main((num_threads * 4) - i))) + .collect(); + + handles.into_iter().for_each(|h| h.join().unwrap()); + + let device1 = TestDevice::<0>::new(0); + let device2 = TestDevice::<1>::new(0); + + let state1_i64 = DeviceContext::::locate(&device1); + let state1_i32 = DeviceContext::::locate(&device1); + let state2_i32 = DeviceContext::::locate(&device2); + + let guard_i64 = state1_i64.lock(); + let guard_i32 = state1_i32.lock(); + + assert_eq!(*guard_i64, num_threads as i64); + assert_eq!(*guard_i32, num_threads as i32 * 2); + + core::mem::drop(guard_i64); + core::mem::drop(guard_i32); + + let guard_i32 = state2_i32.lock(); + assert_eq!(*guard_i32, num_threads as i32); + } + + fn thread_main(sleep: u64) { + let device1 = TestDevice::<0>::new(0); + let device2 = TestDevice::<1>::new(0); + + let state1_i64 = DeviceContext::::locate(&device1); + let state1_i32 = DeviceContext::::locate(&device1); + let state2_i32 = DeviceContext::::locate(&device2); + + let mut guard_i64 = state1_i64.lock(); + let mut guard_i32 = state1_i32.lock(); + + let val_i64 = guard_i64.deref_mut(); + let val_i32 = guard_i32.deref_mut(); + + *val_i64 += 1; + *val_i32 += 2; + + core::mem::drop(guard_i64); + core::mem::drop(guard_i32); + + std::thread::sleep(Duration::from_millis(sleep)); + + let mut guard_i32 = state2_i32.lock(); + + let val_i32 = guard_i32.deref_mut(); + *val_i32 += 1; + + core::mem::drop(guard_i32); + } + + #[derive(Debug, Clone, Default, new)] + /// Type is only to create different type ids. + pub struct TestDevice { + index: u32, + } + + impl Device for TestDevice { + fn from_id(device_id: DeviceId) -> Self { + Self { + index: device_id.index_id, + } + } + + fn to_id(&self) -> DeviceId { + DeviceId { + type_id: 0, + index_id: self.index, + } + } + + fn device_count(_type_id: u16) -> usize { + TYPE as usize + 1 + } + } + + impl DeviceState for usize { + fn init(_device_id: DeviceId) -> Self { + 0 + } + } + + impl DeviceState for u32 { + fn init(_device_id: DeviceId) -> Self { + 0 + } + } + impl DeviceState for i32 { + fn init(_device_id: DeviceId) -> Self { + 0 + } + } + impl DeviceState for i64 { + fn init(_device_id: DeviceId) -> Self { + 0 + } + } + } +} diff --git a/crates/cubecl-common/src/lib.rs b/crates/cubecl-common/src/lib.rs index 7412dafff..4f87dd9c0 100644 --- a/crates/cubecl-common/src/lib.rs +++ b/crates/cubecl-common/src/lib.rs @@ -48,6 +48,9 @@ pub mod reader; /// Future utils with a compatible API for native, non-std and wasm environments. pub mod future; +/// Quantization primitives required outside of `cubecl-quant` +pub mod quant; + /// Various utilities to create ID's. extern crate alloc; diff --git a/crates/cubecl-common/src/quant/mod.rs b/crates/cubecl-common/src/quant/mod.rs new file mode 100644 index 000000000..f4e8250ce --- /dev/null +++ b/crates/cubecl-common/src/quant/mod.rs @@ -0,0 +1,2 @@ +/// Types representing the quantization scheme +pub mod scheme; diff --git a/crates/cubecl-quant/src/scheme.rs b/crates/cubecl-common/src/quant/scheme.rs similarity index 89% rename from crates/cubecl-quant/src/scheme.rs rename to crates/cubecl-common/src/quant/scheme.rs index 9f4af3513..847421a11 100644 --- a/crates/cubecl-quant/src/scheme.rs +++ b/crates/cubecl-common/src/quant/scheme.rs @@ -1,7 +1,6 @@ use alloc::vec; use alloc::vec::Vec; use core::{default::Default, ops::Deref}; -use cubecl_common::{e4m3, e5m2}; use serde::{Deserialize, Serialize}; /// Describes a quantization scheme/configuration. @@ -79,6 +78,12 @@ impl QuantScheme { pub fn num_quants(&self) -> usize { self.size_bits_stored() / self.value.size_bits() } + + /// Returns the native packing factor for the values. When native packing > 1, the packed + /// representation stores `num_quants` elements grouped into packs of `native_packing` size. + pub fn native_packing(&self) -> usize { + self.value.native_packing() + } } /// Level or granularity of quantization. @@ -91,6 +96,7 @@ pub enum QuantLevel { } impl QuantLevel { + /// Converting constructor for [`QuantLevel::Block`] pub fn block(values: impl AsRef<[u8]>) -> Self { QuantLevel::Block(BlockSize::new(values)) } @@ -129,6 +135,15 @@ impl QuantValue { } } + /// Packing factor for the native representation used for intermediate values. If > 1, values + /// should always be processed in `native_packing` sized chunks. + pub fn native_packing(&self) -> usize { + match self { + QuantValue::E2M1 => 2, + _ => 1, + } + } + /// The possible range of values allowed by the quant value. pub fn range(&self) -> (f32, f32) { match self { @@ -138,8 +153,8 @@ impl QuantValue { QuantValue::Q8S => (-i8::MAX as f32, i8::MAX as f32), QuantValue::Q4S => (-7.0, 7.0), QuantValue::Q2S => (-1.0, 1.0), - QuantValue::E4M3 => (e4m3::MIN as f32, e4m3::MAX as f32), - QuantValue::E5M2 => (e5m2::MIN as f32, e5m2::MAX as f32), + QuantValue::E4M3 => (-448.0, 448.0), + QuantValue::E5M2 => (-57344.0, 57344.0), QuantValue::E2M1 => (-6.0, 6.0), // Hardcoded because of no-std } } @@ -209,6 +224,9 @@ pub struct BlockSize { } impl BlockSize { + /// Max number of dimensions for block size + pub const MAX_DIMS: usize = MAX_DIMS; + /// Create a new blocksize from a set of values. The number of values must be `<= MAX_DIMS`. pub fn new(values: impl AsRef<[u8]>) -> Self { let values = values.as_ref(); @@ -253,10 +271,12 @@ impl BlockSize { out } + /// Create an iterator over all stored dimensions pub fn iter(&self) -> impl Iterator { self.as_slice().iter() } + /// Returns the total number of elements in each block pub fn num_elements(&self) -> usize { self.iter().map(|it| *it as usize).product() } diff --git a/crates/cubecl-convolution/Cargo.toml b/crates/cubecl-convolution/Cargo.toml index c25b58162..432ea3a59 100644 --- a/crates/cubecl-convolution/Cargo.toml +++ b/crates/cubecl-convolution/Cargo.toml @@ -19,13 +19,13 @@ conv_tests = [] [dependencies] bytemuck = { workspace = true } -cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false } -cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false } -cubecl-matmul = { path = "../cubecl-matmul", version = "0.7.0", default-features = false } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false } -cubecl-std = { path = "../cubecl-std", version = "0.7.0", default-features = false } -cubecl-reduce = { path = "../cubecl-reduce", version = "0.7.0", default-features = false } -cubecl-random = { path = "../cubecl-random", version = "0.7.0", default-features = false } +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false } +cubecl-core = { path = "../cubecl-core", version = "0.9.0", default-features = false } +cubecl-matmul = { path = "../cubecl-matmul", version = "0.9.0", default-features = false } +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false } +cubecl-std = { path = "../cubecl-std", version = "0.9.0", default-features = false } +cubecl-reduce = { path = "../cubecl-reduce", version = "0.9.0", default-features = false } +cubecl-random = { path = "../cubecl-random", version = "0.9.0", default-features = false } half = { workspace = true, features = ["bytemuck"] } pretty_assertions = { workspace = true, optional = true } serde = { workspace = true } diff --git a/crates/cubecl-convolution/src/components/config.rs b/crates/cubecl-convolution/src/components/config.rs index a1924c55b..683092c5c 100644 --- a/crates/cubecl-convolution/src/components/config.rs +++ b/crates/cubecl-convolution/src/components/config.rs @@ -15,16 +15,7 @@ use super::*; /// Convolution specific config, extends regular matmul [`Config`](global::Config) pub trait ConvGemmConfig: GlobalConfig { /// The size of the convolution kernel at `dim` - fn kernel_size(&self, dim: u32) -> u32; - /// The dilation of the kernel at `dim` - fn dilation(&self, dim: u32) -> u32; - /// The stride of the kernel at `dim` - fn stride(&self, dim: u32) -> u32; - /// The padding of the kernel at `dim` - fn padding(&self, dim: u32) -> i32; - /// The dimensionality of the kernel - fn dimensionality(&self) -> Dimensionality; - + fn convolution_params(&self) -> ConvolutionParams; fn line_sizes(&self) -> MatmulLineSizes; fn check_spatial_bounds(&self) -> bool; } @@ -32,12 +23,17 @@ pub trait ConvGemmConfig: GlobalConfig { #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] pub struct ConvolutionConfig { matmul: M, + params: ConvolutionParams, + num_stages: u32, +} + +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +pub struct ConvolutionParams { pub kernel_size: [u32; 3], pub stride: [u32; 3], pub dilation: [u32; 3], pub padding: [i32; 3], - dimensionality: Dimensionality, - num_stages: u32, + pub dimensionality: Dimensionality, } impl Deref for ConvolutionConfig { @@ -121,24 +117,8 @@ impl GlobalConfig for ConvolutionConfig { } impl ConvGemmConfig for ConvolutionConfig { - fn kernel_size(&self, dim: u32) -> u32 { - self.kernel_size[dim as usize] - } - - fn dilation(&self, dim: u32) -> u32 { - self.dilation[dim as usize] - } - - fn stride(&self, dim: u32) -> u32 { - self.stride[dim as usize] - } - - fn padding(&self, dim: u32) -> i32 { - self.padding[dim as usize] - } - - fn dimensionality(&self) -> Dimensionality { - self.dimensionality + fn convolution_params(&self) -> ConvolutionParams { + self.params } fn line_sizes(&self) -> cubecl_matmul::components::MatmulLineSizes { @@ -150,10 +130,10 @@ impl ConvGemmConfig for ConvolutionConfig { } fn check_spatial_bounds(&self) -> bool { - let spatial_dims = self.dimensionality.num_dims(); + let spatial_dims = self.params.dimensionality.num_dims(); let mut has_padding = false; for i in 0..spatial_dims { - has_padding |= self.padding[i as usize] != 0; + has_padding |= self.params.padding[i as usize] != 0; } has_padding } @@ -172,20 +152,22 @@ impl ConvolutionConfig { ) -> Result { let dims = kernel_size.len(); - let mut this = Self { - matmul, + let mut params = ConvolutionParams { kernel_size: [0; 3], stride: [0; 3], dilation: [0; 3], padding: [0; 3], dimensionality: dim, - num_stages, }; - this.kernel_size[0..dims].copy_from_slice(kernel_size); - this.stride[0..dims].copy_from_slice(stride); - this.dilation[0..dims].copy_from_slice(dilation); - this.padding[0..dims].copy_from_slice(padding); - Ok(this) + params.kernel_size[0..dims].copy_from_slice(kernel_size); + params.stride[0..dims].copy_from_slice(stride); + params.dilation[0..dims].copy_from_slice(dilation); + params.padding[0..dims].copy_from_slice(padding); + Ok(Self { + matmul, + params, + num_stages, + }) } pub fn to_matmul_config(self) -> M { diff --git a/crates/cubecl-convolution/src/components/global/args.rs b/crates/cubecl-convolution/src/components/global/args.rs index 5fce0f774..d0d6ca477 100644 --- a/crates/cubecl-convolution/src/components/global/args.rs +++ b/crates/cubecl-convolution/src/components/global/args.rs @@ -2,61 +2,167 @@ use std::any::TypeId; use cubecl::prelude::*; use cubecl_core as cubecl; +use cubecl_std::{ + CubeOptionArgs, FastDivmodArgs, + tensor::{ + launch::ViewArg, + layout::{ + VirtualLayoutLaunch, + chain::{Chain, ChainLaunch}, + }, + }, +}; use crate::{ - components::ConvolutionProblem, + components::{ + ConvGemmConfig, ConvolutionProblem, + global::{ + layout::{ + BiasLayout, BiasLayoutLaunch, Im2colLayout, Im2colLayoutLaunch, NhwcLayout, + NhwcLayoutLaunch, OutLayout, OutLayoutLaunch, WeightLayout, WeightLayoutLaunch, + }, + read::layout::{ + TmaDummyLayout, TmaDummyLayoutLaunch, TmaWeightLayout, TmaWeightLayoutLaunch, + }, + }, + }, kernels::layered::algorithm::simple_tma::{calculate_lower_corner, calculate_upper_corner}, }; use cubecl_matmul::{ MatmulInputHandleRef, components::{ - MatmulLineSizes, MatmulSelection, - global::args::{TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch}, + MatmulIdent, MatmulLineSizes, MatmulSelection, + global::{ + args::{ + TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch, + TensorOutput, TensorOutputLaunch, + }, + memory::{NoopLayout, NoopLayoutLaunch}, + }, }, }; /// Create the input runtime arguments for a matmul kernel that works on concrete inputs and /// output (not fused). pub trait ConcreteInputsFactory: LaunchArg { + #[allow(clippy::too_many_arguments)] fn create<'a, R: Runtime>( + client: &ComputeClient, lhs: &'a MatmulInputHandleRef<'a, R>, rhs: &'a MatmulInputHandleRef<'a, R>, bias: Option<&'a TensorHandleRef<'a, R>>, selection: &MatmulSelection, problem: &ConvolutionProblem, line_sizes: &MatmulLineSizes, + config: impl ConvGemmConfig, + ) -> Self::RuntimeArg<'a, R>; +} + +/// Create the output runtime arguments for a matmul kernel that works on concrete inputs and +/// output (not fused). +pub trait ConcreteOutputFactory: LaunchArg { + fn create<'a, R: Runtime>( + client: &ComputeClient, + out: &'a TensorHandleRef<'a, R>, + selection: &MatmulSelection, + problem: &ConvolutionProblem, + line_sizes: &MatmulLineSizes, + config: impl ConvGemmConfig, ) -> Self::RuntimeArg<'a, R>; } impl ConcreteInputsFactory for TensorInputs { fn create<'a, R: Runtime>( + client: &ComputeClient, lhs: &'a MatmulInputHandleRef<'a, R>, rhs: &'a MatmulInputHandleRef<'a, R>, bias: Option<&'a TensorHandleRef<'a, R>>, _selection: &MatmulSelection, - _problem: &ConvolutionProblem, + problem: &ConvolutionProblem, line_sizes: &MatmulLineSizes, + config: impl ConvGemmConfig, ) -> Self::RuntimeArg<'a, R> { + type LhsLayout = Chain; + type RhsLayout = Chain; + + let layout_nhwc = |handle, line_size, check| { + NhwcLayoutLaunch::from_handle(handle, line_size as u32, check) + }; + let layout_lhs = Im2colLayoutLaunch::from_args( + client, + problem, + config.convolution_params(), + config.global_memory_config(MatmulIdent::Lhs), + ); + let layout_rhs = WeightLayoutLaunch::from_args( + client, + problem, + config.convolution_params(), + config.global_memory_config(MatmulIdent::Rhs), + ); + let layout_bias = + BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32); + + let layout_lhs = { + let global = layout_nhwc(lhs.data(), line_sizes.lhs, config.check_spatial_bounds()); + ChainLaunch::new(global, layout_lhs) + }; + let layout_rhs = { + let global = layout_nhwc(rhs.data(), line_sizes.rhs, false); + ChainLaunch::new(global, layout_rhs) + }; + TensorInputsLaunch::new( - lhs.data().as_tensor_arg(line_sizes.lhs), - lhs.scale().map(|it| it.as_tensor_arg(1)).into(), - rhs.data().as_tensor_arg(line_sizes.rhs), - rhs.scale().map(|it| it.as_tensor_arg(1)).into(), - bias.map(|it| it.as_tensor_arg(line_sizes.out)).into(), + ViewArg::new::(lhs.data().as_array_arg(line_sizes.lhs), layout_lhs), + VirtualLayoutLaunch::new::(NoopLayoutLaunch::new()), + ViewArg::new::(rhs.data().as_array_arg(line_sizes.rhs), layout_rhs), + VirtualLayoutLaunch::new::(NoopLayoutLaunch::new()), + bias.map(|bias| { + ViewArg::new::(bias.as_array_arg(line_sizes.out), layout_bias) + }) + .into(), + bias.map(|_| VirtualLayoutLaunch::new::(NoopLayoutLaunch::new())) + .into(), ) } } +impl ConcreteOutputFactory for TensorOutput { + fn create<'a, R: Runtime>( + client: &ComputeClient, + out: &'a TensorHandleRef<'a, R>, + _selection: &MatmulSelection, + problem: &ConvolutionProblem, + line_sizes: &MatmulLineSizes, + config: impl ConvGemmConfig, + ) -> Self::RuntimeArg<'a, R> { + type Layout = Chain; + + let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out as u32, false); + let layout = OutLayoutLaunch::from_args( + client, + problem, + config.global_memory_config(MatmulIdent::Out), + ); + let layout = ChainLaunch::new(global, layout); + let view = ViewArg::new::(out.as_array_arg(line_sizes.out), layout); + let batch = VirtualLayoutLaunch::new::(NoopLayoutLaunch::new()); + TensorOutputLaunch::new(view, batch) + } +} + impl ConcreteInputsFactory for TensorMapInputs { fn create<'a, R: Runtime>( + client: &ComputeClient, lhs: &'a MatmulInputHandleRef<'a, R>, rhs: &'a MatmulInputHandleRef<'a, R>, bias: Option<&'a TensorHandleRef<'a, R>>, selection: &MatmulSelection, problem: &ConvolutionProblem, line_sizes: &MatmulLineSizes, + config: impl ConvGemmConfig, ) -> Self::RuntimeArg<'a, R> { let tiling_scheme = selection.tiling_scheme; let stage_m = tiling_scheme.elements_in_stage_m(); @@ -119,9 +225,26 @@ impl ConcreteInputsFactory ) .with_prefetch(prefetch_rhs); - let bias = bias.map(|it| it.as_tensor_arg(line_sizes.out)); + let padded_channels = + (problem.channels as u32).next_multiple_of(config.tiling_scheme().elements_in_tile_k()); + + // Dummy layout since we don't support im2col loading rn + let lhs_layout = TmaDummyLayoutLaunch::new(); + let rhs_layout = TmaWeightLayoutLaunch::new(FastDivmodArgs::new(client, padded_channels)); - // TODO: Think about how to handle scales with TMA - TensorMapInputsLaunch::new(lhs, rhs, bias.into()) + let bias = bias.map(|bias| { + let layout = + BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32); + ViewArg::new::(bias.as_array_arg(line_sizes.out), layout) + }); + + TensorMapInputsLaunch::new( + ViewArg::new_tensor_map::(lhs, lhs_layout), + ViewArg::new_tensor_map::(rhs, rhs_layout), + bias.into(), + CubeOptionArgs::Some(VirtualLayoutLaunch::new::( + NoopLayoutLaunch::new(), + )), + ) } } diff --git a/crates/cubecl-convolution/src/components/global/base.rs b/crates/cubecl-convolution/src/components/global/base.rs index 1cf4cda75..f1bb8bf3d 100644 --- a/crates/cubecl-convolution/src/components/global/base.rs +++ b/crates/cubecl-convolution/src/components/global/base.rs @@ -8,7 +8,7 @@ use cubecl_matmul::components::{ }; use cubecl_std::{ CubeOption, - tensor::{layout::Coords2d, r#virtual::VirtualTensor}, + tensor::{View, layout::Coords2d}, }; use crate::{ @@ -28,7 +28,7 @@ pub trait GlobalConvolutionFamily: ConvolutionLaunch + 'static { fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes; fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &ConvolutionProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, @@ -69,36 +69,28 @@ pub trait GlobalConvolution: 'static + Send + Sync { /// Initializes the global reader for the input feature map with an appropriate layout fn init_lhs_global_reader( - lhs: VirtualTensor>, + lhs: View>, Coords2d>, offset: Coords2d, - view_shape: Coords2d, + slice_size: Coords2d, runtime_args: &RuntimeArgs, #[comptime] config: Self::Config, ) -> Self::LhsGlobalReader; /// Initializes the global reader for the weights with an appropriate layout fn init_rhs_global_reader( - rhs: VirtualTensor>, - offset: Coords2d, - view_shape: Coords2d, - runtime_args: &RuntimeArgs, + rhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::RhsGlobalReader; /// Initializes the global reader for the bias with an appropriate layout fn init_bias_global_reader( - bias: CubeOption>>, - n_offset: u32, - slice_size: u32, + bias: CubeOption>, Coords2d>>, #[comptime] config: Self::Config, ) -> Self::AccGlobalReader; /// Initializes the output feature map global writer with an appropriate layout fn init_global_writer( - out: VirtualTensor, ReadWrite>, - offset: Coords2d, - view_shape: Coords2d, - runtime_args: &RuntimeArgs, + out: View>, Coords2d, ReadWrite>, #[comptime] config: Self::Config, ) -> Self::GlobalWriter; diff --git a/crates/cubecl-convolution/src/components/global/entry_point.rs b/crates/cubecl-convolution/src/components/global/entry_point.rs index 8ddfc3a89..ce7155a1c 100644 --- a/crates/cubecl-convolution/src/components/global/entry_point.rs +++ b/crates/cubecl-convolution/src/components/global/entry_point.rs @@ -3,14 +3,10 @@ use cubecl_core as cubecl; use cubecl_core::{Runtime, client::ComputeClient}; use cubecl_matmul::components::{ InputRuntimeArg, MatmulSpec, OutputRuntimeArg, - global::{ - GlobalConfig as _, - args::{MatmulArgs, TensorAcc, TensorLhs, TensorOutput, TensorRhs}, - }, -}; -use cubecl_std::{ - CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs, tensor::r#virtual::VirtualTensor, + batch::SliceIndex, + global::{GlobalConfig as _, args::MatmulArgs}, }; +use cubecl_std::{CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs}; use crate::{ components::{ @@ -32,7 +28,7 @@ pub trait ConvolutionLaunch { /// Out-of-bounds can happen #[allow(clippy::too_many_arguments)] unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>( - client: &ComputeClient<::Server, ::Channel>, + client: &ComputeClient<::Server>, cube_dim: CubeDim, cube_count: CubeCount, input: InputRuntimeArg<'a, MS, R>, @@ -58,26 +54,12 @@ pub(crate) fn implicit_conv< runtime_args: RuntimeArgs, #[comptime] config: GMM::Config, ) { - let mut state = Args::init_state(inputs, output); - - let lhs = TensorLhs::::new(&state); - let rhs = TensorRhs::::new(&state); - let mut out = TensorOutput::::new(&mut state); - - let has_acc = Args::has_acc(&state); - let bias: CubeOption> = match has_acc { - CubeOption::Some(_) => { - let bias = TensorAcc::::new(&state); - let bias = VirtualTensor::::new::>(&bias); - CubeOption::new_Some(bias) - } - CubeOption::None => CubeOption::new_None(), - }; + let mut state = Args::init_state::(inputs, output, config); - let lhs = VirtualTensor::::new::>(&lhs); - let rhs = VirtualTensor::::new::>(&rhs); - let out = - VirtualTensor::::new::>(&mut out); + let lhs = Args::view_lhs(&state); + let rhs = Args::view_rhs(&state); + let bias = Args::view_acc(&state); + let out = Args::view_out(&mut state); let stage_m = config.tiling_scheme().elements_in_stage_m().runtime(); let stage_n = config.tiling_scheme().elements_in_stage_n().runtime(); @@ -88,6 +70,17 @@ pub(crate) fn implicit_conv< let k_range = (0, runtime_args.shape_k); let k_size = runtime_args.shape_k; + let lhs = lhs.view(SliceIndex::new(0, lhs.shape())); + let rhs = rhs.view(SliceIndex::new(0, rhs.shape())); + let bias = match bias { + CubeOption::Some(bias) => { + let view = bias.view(SliceIndex::new(0, bias.shape())); + CubeOption::new_Some(view.slice_unchecked((0, n_offset), (1, stage_n))) + } + CubeOption::None => CubeOption::new_None(), + }; + let out = out.view_mut(SliceIndex::new(0, out.shape())); + GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::execute( GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_lhs_global_reader( lhs, @@ -97,20 +90,14 @@ pub(crate) fn implicit_conv< config, ), GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_rhs_global_reader( - rhs, - (k_range.0, n_offset), - (k_size, stage_n), - &runtime_args, + rhs.slice_unchecked((k_range.0, n_offset), (k_size, stage_n)), config, ), GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_bias_global_reader( - bias, n_offset, stage_n, config, + bias, config, ), GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_global_writer( - out, - (m_offset, n_offset), - (stage_m, stage_n), - &runtime_args, + out.slice_mut_unchecked((m_offset, n_offset), (stage_m, stage_n)), config, ), &mut GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_accumulator(config), @@ -120,12 +107,11 @@ pub(crate) fn implicit_conv< } pub(crate) fn shape_divmod<'a, R: Runtime>( - client: &ComputeClient, + client: &ComputeClient, shape: &[usize], ) -> SequenceArg<'a, R, FastDivmod> { - let shape = shape + shape .iter() .map(|s| FastDivmodArgs::new(client, *s as u32)) - .collect(); - SequenceArg { values: shape } + .collect() } diff --git a/crates/cubecl-convolution/src/components/global/layout/bias.rs b/crates/cubecl-convolution/src/components/global/layout/bias.rs new file mode 100644 index 000000000..985d0d1a5 --- /dev/null +++ b/crates/cubecl-convolution/src/components/global/layout/bias.rs @@ -0,0 +1,34 @@ +use cubecl::prelude::*; +use cubecl_core as cubecl; +use cubecl_std::tensor::layout::*; + +#[derive(CubeType, CubeLaunch)] +pub struct BiasLayout { + shape: u32, + #[cube(comptime)] + line_size: u32, +} + +#[cube] +impl Layout for BiasLayout { + type Coordinates = Coords3d; + type SourceCoordinates = Coords1d; + + fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates { + let (_, _, n) = pos; + n / self.line_size + } + + fn is_in_bounds(&self, pos: Self::Coordinates) -> bool { + let (_, _, n) = pos; + n < self.shape + } + + fn shape(&self) -> Self::Coordinates { + (1, 1, self.shape) + } + + fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) { + (self.to_source_pos(pos), self.is_in_bounds(pos)) + } +} diff --git a/crates/cubecl-convolution/src/components/global/layout/im2col.rs b/crates/cubecl-convolution/src/components/global/layout/im2col.rs index 37c432192..b2ca85216 100644 --- a/crates/cubecl-convolution/src/components/global/layout/im2col.rs +++ b/crates/cubecl-convolution/src/components/global/layout/im2col.rs @@ -5,17 +5,14 @@ use cubecl_matmul::components::{ global::{GlobalConfig, memory::GlobalMemoryConfig}, }; use cubecl_std::{ - FastDivmod, - tensor::layout::{Coords2d, Layout, LayoutExpand}, + FastDivmod, FastDivmodArgs, + tensor::layout::{Coords3d, Layout, LayoutExpand}, }; use crate::{ components::{ - ConvolutionConfig, - global::{ - layout::{NhwcCoords, unwrap}, - read::im2col_tma::div_mod_seq, - }, + ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem, + global::{layout::NhwcCoords, read::im2col_tma::div_mod_seq}, }, kernels::layered::selector::RuntimeArgs, }; @@ -23,7 +20,7 @@ use crate::{ /// Maps a 4D NHWC tensor to a 2D column matrix using the im2col transformation /// It first decomposes the `(m, k)` matrix into `((n, out_h, out_w), (k_h, k_w, c))`, then applies /// the convolution parameters to calculate the position in the input tensor for that kernel element. -#[derive(CubeType, Clone)] +#[derive(CubeType, CubeLaunch, Clone)] pub struct Im2colLayout { /// Shape of output DHW pub shape_out: Sequence, @@ -35,19 +32,9 @@ pub struct Im2colLayout { /// Shape of the combined `k` dimension, including padding pub shape_k: u32, - /// Size of the convolution kernel in DHW + /// Comptime parameters for the convolution #[cube(comptime)] - pub kernel_size: [u32; 3], - /// Stride of the convolution in DHW - #[cube(comptime)] - pub stride: [u32; 3], - /// Dilation applied to the kernel positions in DHW - #[cube(comptime)] - pub dilation: [u32; 3], - /// Padding applied to the convolution in DHW - /// The input position will be offset from the output by `-padding` - #[cube(comptime)] - pub padding: [i32; 3], + pub params: ConvolutionParams, /// Global memory config for the backing tensor #[cube(comptime)] pub config: GlobalMemoryConfig, @@ -66,10 +53,7 @@ impl Im2colLayout { shape_channel: args.shape_channel, shape_m: args.shape_m, shape_k: args.shape_k, - kernel_size: config.kernel_size, - stride: config.stride, - dilation: config.dilation, - padding: config.padding, + params: config.convolution_params(), config: config.global_memory_config(MatmulIdent::Lhs), } } @@ -77,11 +61,12 @@ impl Im2colLayout { #[cube] impl Layout for Im2colLayout { - type Coordinates = Coords2d; + type Coordinates = Coords3d; type SourceCoordinates = NhwcCoords; fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords { - let (view_m, view_k) = pos; + let params = comptime![self.params]; + let (_, view_m, view_k) = pos; let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out); @@ -92,16 +77,15 @@ impl Layout for Im2colLayout { #[unroll] for i in 0..spatial_dims { - let i = unwrap(i); let dim = comptime![spatial_dims - i - 1]; - let ksize = comptime![self.kernel_size[dim as usize]]; + let ksize = comptime![params.kernel_size[dim as usize]]; let k_pos = rem % ksize; rem /= ksize; let out_pos = *out_offs.index(dim); - let stride = comptime![self.stride[dim as usize]]; - let dilate = comptime![self.dilation[dim as usize]]; - let pad = comptime![self.padding[dim as usize]]; + let stride = comptime![params.stride[dim as usize]]; + let dilate = comptime![params.dilation[dim as usize]]; + let pad = comptime![params.padding[dim as usize]]; let pos = (out_pos * stride + k_pos * dilate) as i32 - pad; in_pos.push(pos); @@ -117,7 +101,7 @@ impl Layout for Im2colLayout { } fn shape(&self) -> Self::Coordinates { - (self.shape_m, self.shape_k) + (1, self.shape_m, self.shape_k) } fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) { @@ -125,10 +109,31 @@ impl Layout for Im2colLayout { } fn is_in_bounds(&self, pos: Self::Coordinates) -> bool { - let (view_m, view_k) = pos; + let (_, view_m, view_k) = pos; // Shouldn't be relied on because it doesn't check spatial let m_in_bounds = comptime!(!self.config.check_row_bounds) || view_m < self.shape_m; let k_in_bounds = comptime!(!self.config.check_col_bounds) || view_k < self.shape_k; m_in_bounds && k_in_bounds } } + +impl<'a, R: Runtime> Im2colLayoutLaunch<'a, R> { + pub fn from_args( + client: &ComputeClient, + problem: &ConvolutionProblem, + params: ConvolutionParams, + config: GlobalMemoryConfig, + ) -> Self { + let shape_out = problem + .out_shape + .iter() + .map(|s| FastDivmodArgs::new(client, *s as u32)) + .collect(); + let shape_channel = FastDivmodArgs::new(client, problem.channels as u32); + + let shape_m = ScalarArg::new(problem.m as u32); + let shape_k = ScalarArg::new(problem.k as u32); + + Im2colLayoutLaunch::new(shape_out, shape_channel, shape_m, shape_k, params, config) + } +} diff --git a/crates/cubecl-convolution/src/components/global/layout/mod.rs b/crates/cubecl-convolution/src/components/global/layout/mod.rs index e4436b1b5..e34f1f098 100644 --- a/crates/cubecl-convolution/src/components/global/layout/mod.rs +++ b/crates/cubecl-convolution/src/components/global/layout/mod.rs @@ -1,8 +1,10 @@ +mod bias; mod im2col; mod spatial; mod weight; mod write; +pub use bias::*; pub use im2col::*; pub use spatial::*; pub use weight::*; diff --git a/crates/cubecl-convolution/src/components/global/layout/spatial.rs b/crates/cubecl-convolution/src/components/global/layout/spatial.rs index ad22c17b3..e33639235 100644 --- a/crates/cubecl-convolution/src/components/global/layout/spatial.rs +++ b/crates/cubecl-convolution/src/components/global/layout/spatial.rs @@ -1,5 +1,5 @@ use cubecl::prelude::*; -use cubecl_core::{self as cubecl, intrinsic}; +use cubecl_core::{self as cubecl}; use cubecl_std::tensor::{ layout::{Coordinates, Coords1d, Layout, LayoutExpand}, r#virtual::VirtualTensor, @@ -79,7 +79,7 @@ impl Coordinates for NhwcCoords { /// Layout for a spatial (i.e. NHWC) tensor. Bounds check only applies to spatial dimensions, not /// channel or batch (because these are implicitly checked in the layouts used with spatial tensors). -#[derive(CubeType, Clone)] +#[derive(CubeType, CubeLaunch, Clone)] pub struct NhwcLayout { /// Stride for N pub stride_batch: u32, @@ -154,7 +154,6 @@ impl Layout for NhwcLayout { #[unroll] for i in 0..spatial_dims { - let i = unwrap(i); read_pos += *spatial.index(i) as u32 * *self.strides_spatial.index(i); } @@ -172,7 +171,6 @@ impl Layout for NhwcLayout { #[unroll] for i in 0..spatial_dims { - let i = unwrap(i); let pos = *pos.spatial.index(i); spatial_in_bounds &= pos >= 0 && (pos as u32) < *self.shapes_spatial.index(i); } @@ -192,12 +190,6 @@ impl Layout for NhwcLayout { } } -#[allow(unused_variables)] -#[cube] -pub(crate) fn unwrap(v: u32) -> comptime_type!(u32) { - intrinsic!(|_| v.constant().expect("Must be constant").as_u32()) -} - #[cube] pub(crate) fn cast_seq( seq: Sequence, @@ -206,9 +198,44 @@ pub(crate) fn cast_seq( let mut out_seq = Sequence::new(); #[unroll] for i in 0..num_elems { - let i = unwrap(i); let elem = To::cast_from(*seq.index(i)); out_seq.push(elem); } out_seq } + +impl<'a, R: Runtime> NhwcLayoutLaunch<'a, R> { + pub fn from_handle( + handle: &TensorHandleRef<'a, R>, + line_size: u32, + check_spatial: bool, + ) -> Self { + let rank = handle.shape.len(); + let dim_c = rank - 1; + + let stride_batch = ScalarArg::new(handle.strides[0] as u32); + let strides_spatial = handle.strides[1..dim_c] + .iter() + .map(|s| ScalarArg::new(*s as u32)) + .collect(); + let stride_channel = ScalarArg::new(handle.strides[dim_c] as u32); + + let shape_batch = ScalarArg::new(handle.shape[0] as u32); + let shapes_spatial = handle.shape[1..dim_c] + .iter() + .map(|s| ScalarArg::new(*s as u32)) + .collect(); + let shape_channel = ScalarArg::new(handle.shape[dim_c] as u32); + + Self::new( + stride_batch, + strides_spatial, + stride_channel, + shape_batch, + shapes_spatial, + shape_channel, + line_size, + check_spatial, + ) + } +} diff --git a/crates/cubecl-convolution/src/components/global/layout/weight.rs b/crates/cubecl-convolution/src/components/global/layout/weight.rs index f1390a8b7..a95093738 100644 --- a/crates/cubecl-convolution/src/components/global/layout/weight.rs +++ b/crates/cubecl-convolution/src/components/global/layout/weight.rs @@ -5,32 +5,22 @@ use cubecl_matmul::components::{ global::{GlobalConfig, memory::GlobalMemoryConfig}, }; use cubecl_std::{ - FastDivmod, - tensor::{ - layout::{Coords2d, Layout, LayoutExpand}, - r#virtual::VirtualTensor, - }, + FastDivmod, FastDivmodArgs, + tensor::layout::{Coords3d, Layout, LayoutExpand}, }; use crate::{ components::{ - ConvGemmConfig, ConvolutionConfig, - global::layout::{NhwcCoords, unwrap}, + ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem, + global::layout::NhwcCoords, }, kernels::layered::selector::RuntimeArgs, }; /// Maps a 4D weight tensor of shape `(out_c, (k_h, k_w, in_c))` to a col-major 2D matmul tile with /// shape `(n, k)` -#[derive(CubeType, Clone)] +#[derive(CubeType, CubeLaunch, Clone)] pub struct WeightLayout { - /// Stride of `out_c` - pub stride_out_c: u32, - /// Stride of `k_h`, `k_w` - pub strides_spatial: Sequence, - /// Stride of `in_c` - pub stride_in_c: u32, - /// Number of channels, including padding, used for decomposing `k` pub channels: FastDivmod, @@ -41,7 +31,7 @@ pub struct WeightLayout { /// Size of the convolution kernel #[cube(comptime)] - pub kernel_size: [u32; 3], + pub params: ConvolutionParams, /// Global memory config for the backing tensor #[cube(comptime)] pub config: GlobalMemoryConfig, @@ -50,29 +40,14 @@ pub struct WeightLayout { #[cube] impl WeightLayout { pub fn new( - tensor: &VirtualTensor, args: &RuntimeArgs, #[comptime] config: ConvolutionConfig, ) -> WeightLayout { - let spatial_dims = comptime![config.dimensionality().num_dims()]; - let mut strides_spatial = Sequence::new(); - - #[unroll] - for i in 0..spatial_dims { - strides_spatial.push(tensor.stride(i + 1)); - } - - let stride_out_c = tensor.stride(0); - let stride_in_c = tensor.stride(spatial_dims + 1); - WeightLayout { - stride_out_c, - strides_spatial, - stride_in_c, shape_k: args.shape_k, shape_n: args.shape_n, channels: args.padded_channels, - kernel_size: config.kernel_size, + params: config.convolution_params(), config: config.global_memory_config(MatmulIdent::Rhs), } } @@ -80,22 +55,22 @@ impl WeightLayout { #[cube] impl Layout for WeightLayout { - type Coordinates = Coords2d; + type Coordinates = Coords3d; type SourceCoordinates = NhwcCoords; fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords { - let (k, n) = coords; + let params = comptime![self.params]; + let (_, k, n) = coords; let (mut rem, in_c) = self.channels.div_mod(k); - let spatial_dims = comptime![self.strides_spatial.len()]; + let spatial_dims = comptime![params.dimensionality.num_dims()]; let mut kernel_pos = Sequence::::new(); #[unroll] for i in 0..spatial_dims { - let i = unwrap(i); let dim = comptime![spatial_dims - i - 1]; - let ksize = comptime![self.kernel_size[dim as usize]]; + let ksize = comptime![params.kernel_size[dim as usize]]; let k_pos = rem % ksize; rem /= ksize; @@ -116,13 +91,28 @@ impl Layout for WeightLayout { } fn shape(&self) -> Self::Coordinates { - (self.shape_k, self.shape_n) + (1, self.shape_k, self.shape_n) } fn is_in_bounds(&self, pos: Self::Coordinates) -> bool { - let (k, n) = pos; + let (_, k, n) = pos; let check_k = comptime![self.config.check_row_bounds]; let check_n = comptime![self.config.check_col_bounds]; (!check_k || k < self.shape_k) && (!check_n || n < self.shape_n) } } + +impl<'a, R: Runtime> WeightLayoutLaunch<'a, R> { + pub fn from_args( + client: &ComputeClient, + problem: &ConvolutionProblem, + params: ConvolutionParams, + config: GlobalMemoryConfig, + ) -> Self { + let channels = FastDivmodArgs::new(client, problem.channels as u32); + let shape_k = ScalarArg::new(problem.k as u32); + let shape_n = ScalarArg::new(problem.n as u32); + + WeightLayoutLaunch::new(channels, shape_k, shape_n, params, config) + } +} diff --git a/crates/cubecl-convolution/src/components/global/layout/write.rs b/crates/cubecl-convolution/src/components/global/layout/write.rs index ea98a1998..c89709517 100644 --- a/crates/cubecl-convolution/src/components/global/layout/write.rs +++ b/crates/cubecl-convolution/src/components/global/layout/write.rs @@ -2,21 +2,24 @@ use cubecl::prelude::*; use cubecl_core::{self as cubecl}; use cubecl_matmul::components::global::memory::GlobalMemoryConfig; use cubecl_std::{ - FastDivmod, - tensor::layout::{Coords2d, Layout, LayoutExpand}, + FastDivmod, FastDivmodArgs, + tensor::layout::{Coords3d, Layout, LayoutExpand}, }; use crate::{ - components::global::{ - layout::{NhwcCoords, cast_seq}, - read::im2col_tma::div_mod_seq, + components::{ + ConvolutionProblem, + global::{ + layout::{NhwcCoords, cast_seq}, + read::im2col_tma::div_mod_seq, + }, }, kernels::layered::selector::RuntimeArgs, }; /// Maps a 4D NHWC out tensor of shape `((n, h, w), c)` to a col-major 2D matmul tile with /// shape `(m, n)` -#[derive(CubeType, Clone)] +#[derive(CubeType, CubeLaunch, Clone)] pub struct OutLayout { /// Shape of DHW pub shape_out: Sequence, @@ -45,11 +48,11 @@ impl OutLayout { #[cube] impl Layout for OutLayout { - type Coordinates = Coords2d; + type Coordinates = Coords3d; type SourceCoordinates = NhwcCoords; fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords { - let (view_m, view_n) = coords; + let (_, view_m, view_n) = coords; let (batch, spatial) = div_mod_seq(view_m, &self.shape_out); NhwcCoords { @@ -64,13 +67,31 @@ impl Layout for OutLayout { } fn shape(&self) -> Self::Coordinates { - (self.shape_m, self.shape_n) + (1, self.shape_m, self.shape_n) } fn is_in_bounds(&self, pos: Self::Coordinates) -> bool { - let (m, n) = pos; + let (_, m, n) = pos; let check_m = comptime![self.config.check_row_bounds]; let check_n = comptime![self.config.check_col_bounds]; (!check_m || m < self.shape_m) && (!check_n || n < self.shape_n) } } + +impl<'a, R: Runtime> OutLayoutLaunch<'a, R> { + pub fn from_args( + client: &ComputeClient, + problem: &ConvolutionProblem, + config: GlobalMemoryConfig, + ) -> Self { + let shape_out = problem + .out_shape + .iter() + .map(|s| FastDivmodArgs::new(client, *s as u32)) + .collect(); + let shape_m = ScalarArg::new(problem.m as u32); + let shape_n = ScalarArg::new(problem.n as u32); + + Self::new(shape_out, shape_m, shape_n, config) + } +} diff --git a/crates/cubecl-convolution/src/components/global/memory/tma.rs b/crates/cubecl-convolution/src/components/global/memory/tma.rs index 0a15fb0ef..87e6b12f2 100644 --- a/crates/cubecl-convolution/src/components/global/memory/tma.rs +++ b/crates/cubecl-convolution/src/components/global/memory/tma.rs @@ -1,13 +1,12 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use cubecl_std::tensor::r#virtual::VirtualTensor; #[derive(CubeType)] /// A view of a feature map tensor that starts reading data from a specified offset. /// Ensures safe access by preventing out-of-bounds errors. /// Includes pre-fetched shapes and strides for optimized performance. pub struct Im2colTmaReader { - pub tensor: TensorMap, + pub tensor: TensorMap>, pub n_offset: u32, pub spatial_offsets: Sequence, pub k_offset: u32, @@ -17,15 +16,13 @@ pub struct Im2colTmaReader { impl Im2colTmaReader { #[allow(clippy::too_many_arguments)] pub fn new( - tensor: VirtualTensor, + tensor: TensorMap>, n_offset: u32, spatial_offsets: Sequence, k_offset: u32, ) -> Im2colTmaReader { - let map = tensor.as_tensor_map().unwrap(); - Im2colTmaReader:: { - tensor: map, + tensor, n_offset, spatial_offsets, k_offset, diff --git a/crates/cubecl-convolution/src/components/global/multi_stage/tma/config.rs b/crates/cubecl-convolution/src/components/global/multi_stage/tma/config.rs index 63c1496ef..5bfb6b826 100644 --- a/crates/cubecl-convolution/src/components/global/multi_stage/tma/config.rs +++ b/crates/cubecl-convolution/src/components/global/multi_stage/tma/config.rs @@ -10,7 +10,7 @@ const NUM_STAGES_MAX: u32 = 8; const MIN_STAGES_PER_PIPELINE: u32 = 32; pub(crate) fn num_stages( - client: &ComputeClient, + client: &ComputeClient, problem: &ConvolutionProblem, num_planes: u32, tiling_scheme: &TilingScheme, diff --git a/crates/cubecl-convolution/src/components/global/multi_stage/tma/convolution.rs b/crates/cubecl-convolution/src/components/global/multi_stage/tma/convolution.rs index e34827bbd..67b806c0e 100644 --- a/crates/cubecl-convolution/src/components/global/multi_stage/tma/convolution.rs +++ b/crates/cubecl-convolution/src/components/global/multi_stage/tma/convolution.rs @@ -15,7 +15,7 @@ use cubecl_matmul::components::{ }; use cubecl_std::{ CubeOption, - tensor::{AsTensorView, AsTensorViewExpand, layout::Coords2d, r#virtual::VirtualTensor}, + tensor::{View, layout::Coords2d}, }; use crate::{ @@ -23,11 +23,9 @@ use crate::{ ConvGemmConfig, ConvolutionConfig, global::{ GlobalConvolution, - layout::{NhwcLayout, OutLayout}, read::{ bias::{BiasGlobalReader, BiasStage}, im2col_tma::{TmaIm2colGlobalReader, TmaIm2colTiling}, - layout::TmaWeightLayout, weight_tma::{TmaWeightGlobalReader, TmaWeightTiling}, }, }, @@ -69,7 +67,7 @@ where { type Config = ConvolutionConfig>; - type LhsGlobalReader = TmaIm2colGlobalReader; + type LhsGlobalReader = TmaIm2colGlobalReader; type RhsGlobalReader = TmaWeightGlobalReader; type AccGlobalReader = BiasGlobalReader; type GlobalWriter = PlaneWriter; @@ -113,12 +111,9 @@ where let (mut tile_lhs, mut tile_rhs) = SMM::init_tile_inputs(stage_config); let partition_scheduler = SMM::init_scheduler(config.stage_config()); - let mut stage = comptime![0u32]; - // Create barriers and prefetch each stage #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..num_stages { + for stage in 0..num_stages { let barrier = Barrier::new_with_tma_proxy(BarrierLevel::cube_coop(0u32)); lhs_reader.fill_stage(&barrier, stage); @@ -130,19 +125,14 @@ where rhs_reader.advance_view(); barriers.push(barrier); - - comptime![stage += 1]; } for k in 0..num_loops { let k = k * num_stages; - let mut stage = comptime![0u32]; - // Loop through all stages #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..num_stages { + for stage in 0..num_stages { let k = k + stage; let next_k = k + num_stages; @@ -177,8 +167,6 @@ where rhs_reader.advance_view(); } } - - comptime![stage += 1]; } } @@ -197,7 +185,7 @@ where } fn init_lhs_global_reader( - lhs: VirtualTensor>, + lhs: View>, Coords2d>, offset: Coords2d, _slice_size: Coords2d, runtime_args: &RuntimeArgs, @@ -205,26 +193,22 @@ where ) -> Self::LhsGlobalReader { let (x_offset, y_offset) = offset; Self::LhsGlobalReader::new( - lhs, + lhs.as_tensor_map().unwrap(), x_offset, y_offset, runtime_args, config.num_stages(MatmulIdent::Lhs), - config, + config.convolution_params(), + config.stage_memory_config(MatmulIdent::Lhs), ) } fn init_rhs_global_reader( - rhs: VirtualTensor>, - offset: Coords2d, - slice_size: Coords2d, - runtime_args: &RuntimeArgs, + rhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::RhsGlobalReader { - let layout = TmaWeightLayout::new(runtime_args.padded_channels); - let rhs = rhs.as_tensor_map().unwrap().view_3d(layout); Self::RhsGlobalReader::new( - rhs.slice(offset, slice_size), + rhs, config.k_step, config.num_stages(MatmulIdent::Rhs), config.stage_memory_config(MatmulIdent::Rhs), @@ -232,35 +216,18 @@ where } fn init_bias_global_reader( - bias: CubeOption>>, - n_offset: u32, - slice_size: u32, + bias: CubeOption>, Coords2d>>, #[comptime] config: Self::Config, ) -> Self::AccGlobalReader { - Self::AccGlobalReader::new( - bias, - n_offset, - slice_size, - config.stage_memory_config(MatmulIdent::Out), - ) + Self::AccGlobalReader::new(bias, config.stage_memory_config(MatmulIdent::Out)) } fn init_global_writer( - out: VirtualTensor, ReadWrite>, - offset: Coords2d, - slice_size: Coords2d, - runtime_args: &RuntimeArgs, + out: View>, Coords2d, ReadWrite>, #[comptime] config: Self::Config, ) -> Self::GlobalWriter { let global_conf = config.global_memory_config(MatmulIdent::Out); - let layout_global = NhwcLayout::new(out, comptime![config.dimensionality()], false); - let layout_out = OutLayout::new(runtime_args, global_conf); - let out = out.view_mut(layout_global).view_mut(layout_out); - Self::GlobalWriter::new::( - out.slice_mut_unchecked(offset, slice_size), - global_conf, - config.stage_config(), - ) + Self::GlobalWriter::new::(out, global_conf, config.stage_config()) } fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulators { diff --git a/crates/cubecl-convolution/src/components/global/multi_stage/tma/launch.rs b/crates/cubecl-convolution/src/components/global/multi_stage/tma/launch.rs index c89266735..55ff606f2 100644 --- a/crates/cubecl-convolution/src/components/global/multi_stage/tma/launch.rs +++ b/crates/cubecl-convolution/src/components/global/multi_stage/tma/launch.rs @@ -28,7 +28,7 @@ impl< > ConvolutionLaunch> for MultiStageTmaConvolutionFamily { unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>( - client: &ComputeClient<::Server, ::Channel>, + client: &ComputeClient<::Server>, cube_dim: CubeDim, cube_count: CubeCount, input: InputRuntimeArg<'a, MS, R>, diff --git a/crates/cubecl-convolution/src/components/global/multi_stage/tma/setup.rs b/crates/cubecl-convolution/src/components/global/multi_stage/tma/setup.rs index ec49a4cc5..4de7da306 100644 --- a/crates/cubecl-convolution/src/components/global/multi_stage/tma/setup.rs +++ b/crates/cubecl-convolution/src/components/global/multi_stage/tma/setup.rs @@ -49,7 +49,7 @@ where } fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &ConvolutionProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-convolution/src/components/global/read/reader/bias.rs b/crates/cubecl-convolution/src/components/global/read/reader/bias.rs index 2efaa03ed..6ce061584 100644 --- a/crates/cubecl-convolution/src/components/global/read/reader/bias.rs +++ b/crates/cubecl-convolution/src/components/global/read/reader/bias.rs @@ -2,11 +2,11 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; use cubecl_std::{ CubeOption, CubeOptionExpand, - tensor::{View, layout::Coords1d, r#virtual::VirtualTensor}, + tensor::{View, layout::Coords2d}, }; use cubecl_matmul::components::{ - MatmulIdent, MatrixPrecision, StageIdent, + MatrixPrecision, StageIdent, global::GlobalConfig, stage::{StageMemoryConfig, StridedStage}, }; @@ -17,7 +17,7 @@ use crate::components::stage::reader::BiasTilingLayout; #[derive(CubeType)] pub enum BiasGlobalReader { Some { - view: View, Coords1d>, + view: View, Coords2d>, stage: StridedStage, }, None, @@ -33,7 +33,7 @@ impl BiasGlobalReader { pub fn load_stage(&mut self, #[comptime] config: G) { match self { BiasGlobalReader::Some { view, stage } => { - let line_size = config.global_line_size(MatmulIdent::Out); + let line_size = view.line_size(); let num_stage_elements = config.tiling_scheme().elements_in_stage_n(); let unit_id = UNIT_POS_Y * config.plane_dim() + UNIT_POS_X; @@ -42,7 +42,7 @@ impl BiasGlobalReader { let mut slice = stage.as_slice_mut(line_size); if unit_pos < num_stage_elements { - let read_line = view.read_checked(unit_pos); + let read_line = view.read_checked((0, unit_pos)); slice[unit_id] = Line::cast_from(read_line); } } @@ -64,15 +64,12 @@ impl BiasGlobalReader { impl BiasGlobalReader { /// Create a new bias reader from the bias tensor and a global offset `n_offset`. pub fn new( - tensor: CubeOption>, - n_offset: u32, - slice_size: u32, + view: CubeOption, Coords2d>>, #[comptime] config: StageMemoryConfig, ) -> Self { - match tensor { - CubeOption::Some(tensor) => { + match view { + CubeOption::Some(view) => { let stage = init_stage::(config); - let view = tensor.as_view().slice_unchecked(n_offset, slice_size); BiasGlobalReader::::new_Some(view, stage) } diff --git a/crates/cubecl-convolution/src/components/global/read/reader/im2col_tma.rs b/crates/cubecl-convolution/src/components/global/read/reader/im2col_tma.rs index 2ce695404..997feb450 100644 --- a/crates/cubecl-convolution/src/components/global/read/reader/im2col_tma.rs +++ b/crates/cubecl-convolution/src/components/global/read/reader/im2col_tma.rs @@ -1,11 +1,11 @@ +use cubecl_core::prelude::*; use cubecl_core::{self as cubecl, prelude::barrier::Barrier}; -use cubecl_core::{intrinsic, prelude::*}; -use cubecl_matmul::components::{MatmulIdent, MatrixPrecision, StageIdent}; -use cubecl_std::{FastDivmod, tensor::r#virtual::VirtualTensor}; +use cubecl_matmul::components::{MatrixPrecision, StageIdent, stage::StageMemoryConfig}; +use cubecl_std::FastDivmod; use crate::{ - components::{ConvGemmConfig, Dimensionality, global::memory::Im2colTmaReader}, + components::{ConvolutionParams, Dimensionality, global::memory::Im2colTmaReader}, kernels::layered::selector::RuntimeArgs, }; use cubecl_matmul::components::stage::{ColMajorTilingOrder, ContiguousTilingLayout, StridedStage}; @@ -15,54 +15,55 @@ pub type TmaIm2colStage = StridedStage<::Stage, TmaIm /// Reader that translates matrix coordinates to input coordinates using the `im2col` algorithm #[derive(CubeType)] -pub struct TmaIm2colGlobalReader { +pub struct TmaIm2colGlobalReader { pub map: Im2colTmaReader, pub stages: Sequence>, padded_channels: FastDivmod, #[cube(comptime)] - config: G, + params: ConvolutionParams, + #[cube(comptime)] + config: StageMemoryConfig, } #[cube] -impl TmaIm2colGlobalReader { +impl TmaIm2colGlobalReader { pub fn new( - tensor: VirtualTensor, + tensor: TensorMap>, x_offset: u32, y_offset: u32, runtime_args: &RuntimeArgs, #[comptime] num_stages: u32, - #[comptime] config: G, + #[comptime] params: ConvolutionParams, + #[comptime] config: StageMemoryConfig, ) -> Self { let mut stages = Sequence::new(); #[unroll] for _ in 0..num_stages { - stages.push(StridedStage::new_aligned( - StageIdent::Lhs, - 128u32, - comptime!(config.stage_memory_config(MatmulIdent::Lhs)), - )) + stages.push(StridedStage::new_aligned(StageIdent::Lhs, 128u32, config)) } let (n_offs, spatial_offsets) = div_mod_seq(x_offset, &runtime_args.shape_out); let map = Im2colTmaReader::::new(tensor, n_offs, spatial_offsets, y_offset); - TmaIm2colGlobalReader:: { + TmaIm2colGlobalReader:: { map, stages, padded_channels: runtime_args.padded_channels, + params, config, } } pub fn fill_stage(&mut self, bar: &Barrier, #[comptime] stage_idx: u32) { let stage = self.stages.index_mut(stage_idx); + let params = comptime![self.params]; let config = comptime![self.config]; if UNIT_POS == 0 { - let m_size = config.tiling_scheme().elements_in_stage_m(); - let k_size = config.tiling_scheme().elements_in_tile_k(); + let m_size = config.elements_in_stage_row(); + let k_size = config.elements_in_tile_col; let slice_size = m_size * k_size; let mut full_stage = stage.as_slice_mut(1u32); let tensor = self.map.tensor.try_cast_unchecked(); @@ -72,22 +73,22 @@ impl TmaIm2colGlobalReader { #[unroll] for dim in 0..spatial_dims { - let dim = unwrap(dim); - let offs = self.map.spatial_offsets.index(dim) * comptime![config.stride(dim)]; - let offs = offs as i32 - comptime![config.padding(dim)]; + let offs = + self.map.spatial_offsets.index(dim) * comptime![params.stride[dim as usize]]; + let offs = offs as i32 - comptime![params.padding[dim as usize]]; in_offs.push(offs); } #[unroll] - for tile_k in 0..config.tiling_scheme().tiles_in_stage_k() { + for tile_k in 0..config.tiles_in_stage_col { let k = self.map.k_offset + tile_k * k_size; let (k_idx, channel_start) = self.padded_channels.div_mod(k); let slice_start = tile_k * slice_size; let mut stage = full_stage.slice_mut(slice_start, slice_start + slice_size); - match config.dimensionality() { + match params.dimensionality { Dimensionality::Dim1 => { - let offset = k_idx * config.dilation(0); + let offset = k_idx * comptime![params.dilation[0]]; bar.tma_load_im2col_3d( &tensor, @@ -99,11 +100,13 @@ impl TmaIm2colGlobalReader { ); } Dimensionality::Dim2 => { - let (k_x, k_y) = - (k_idx % config.kernel_size(1), k_idx / config.kernel_size(1)); + let (k_x, k_y) = ( + k_idx % comptime![params.kernel_size[1]], + k_idx / comptime![params.kernel_size[1]], + ); - let offset_y = k_y * config.dilation(0); - let offset_x = k_x * config.dilation(1); + let offset_y = k_y * comptime![params.dilation[0]]; + let offset_x = k_x * comptime![params.dilation[1]]; bar.tma_load_im2col_4d( &tensor, @@ -117,13 +120,18 @@ impl TmaIm2colGlobalReader { ); } Dimensionality::Dim3 => { - let (k_x, rem) = - (k_idx % config.kernel_size(2), k_idx / config.kernel_size(2)); - let (k_y, k_z) = (rem % config.kernel_size(1), rem / config.kernel_size(1)); + let (k_x, rem) = ( + k_idx % comptime![params.kernel_size[2]], + k_idx / comptime![params.kernel_size[2]], + ); + let (k_y, k_z) = ( + rem % comptime![params.kernel_size[1]], + rem / comptime![params.kernel_size[1]], + ); - let offset_z = k_z * config.dilation(0); - let offset_y = k_y * config.dilation(1); - let offset_x = k_x * config.dilation(2); + let offset_z = k_z * comptime![params.dilation[0]]; + let offset_y = k_y * comptime![params.dilation[1]]; + let offset_x = k_x * comptime![params.dilation[2]]; bar.tma_load_im2col_5d( &tensor, @@ -162,7 +170,6 @@ pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence) -> (u32, Seque #[unroll] for i in 0..rank { - let i = unwrap(i); let dim = comptime![rank - i - 1]; let (rem, offs_local) = shape.index(dim).div_mod(offs); out.push(offs_local); @@ -171,9 +178,3 @@ pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence) -> (u32, Seque (offs, out.rev()) } - -#[allow(unused_variables)] -#[cube] -fn unwrap(v: u32) -> comptime_type!(u32) { - intrinsic!(|_| v.constant().expect("Must be constant").as_u32()) -} diff --git a/crates/cubecl-convolution/src/components/global/read/reader/layout.rs b/crates/cubecl-convolution/src/components/global/read/reader/layout.rs index b3e9a244d..0bcb19004 100644 --- a/crates/cubecl-convolution/src/components/global/read/reader/layout.rs +++ b/crates/cubecl-convolution/src/components/global/read/reader/layout.rs @@ -2,10 +2,10 @@ use cubecl::prelude::*; use cubecl_core as cubecl; use cubecl_std::{ FastDivmod, - tensor::layout::{Coords2d, Coords3d, Layout, LayoutExpand}, + tensor::layout::{Coords3d, Layout, LayoutExpand}, }; -#[derive(CubeType)] +#[derive(CubeType, CubeLaunch)] pub struct TmaWeightLayout { padded_channels: FastDivmod, } @@ -19,11 +19,11 @@ impl TmaWeightLayout { #[cube] impl Layout for TmaWeightLayout { - type Coordinates = Coords2d; + type Coordinates = Coords3d; type SourceCoordinates = Coords3d; fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates { - let (k, n) = pos; + let (_, k, n) = pos; let (k_idx, in_c) = self.padded_channels.div_mod(k); (n, k_idx, in_c) } @@ -33,7 +33,33 @@ impl Layout for TmaWeightLayout { } fn shape(&self) -> Self::Coordinates { - (u32::MAX, u32::MAX).runtime() + (u32::MAX, u32::MAX, u32::MAX).runtime() + } + + fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) { + (self.to_source_pos(pos), self.is_in_bounds(pos)) + } +} + +/// Dummy layout for launching, to be exited out later with `as_tensor_map`. +#[derive(CubeType, CubeLaunch)] +pub struct TmaDummyLayout {} + +#[cube] +impl Layout for TmaDummyLayout { + type Coordinates = Coords3d; + type SourceCoordinates = Coords3d; + + fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates { + pos + } + + fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool { + true.runtime() + } + + fn shape(&self) -> Self::Coordinates { + (u32::MAX, u32::MAX, u32::MAX).runtime() } fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) { diff --git a/crates/cubecl-convolution/src/components/global/read/reader/weight_tma.rs b/crates/cubecl-convolution/src/components/global/read/reader/weight_tma.rs index d473356af..a60b43106 100644 --- a/crates/cubecl-convolution/src/components/global/read/reader/weight_tma.rs +++ b/crates/cubecl-convolution/src/components/global/read/reader/weight_tma.rs @@ -15,7 +15,7 @@ pub type TmaWeightStage = StridedStage<::Stage, TmaWe #[derive(CubeType)] pub struct TmaWeightGlobalReader { - pub global_iter: GlobalIterator, + pub global_iter: GlobalIterator>, pub stages: Sequence>, #[cube(comptime)] config: StageMemoryConfig, @@ -24,7 +24,7 @@ pub struct TmaWeightGlobalReader { #[cube] impl TmaWeightGlobalReader { pub fn new( - global_view: View, + global_view: View, Coords2d>, k_step: u32, #[comptime] num_stages: u32, #[comptime] config: StageMemoryConfig, diff --git a/crates/cubecl-convolution/src/components/global/single_stage/simple/convolution.rs b/crates/cubecl-convolution/src/components/global/single_stage/simple/convolution.rs index eeafe5724..793872ae9 100644 --- a/crates/cubecl-convolution/src/components/global/single_stage/simple/convolution.rs +++ b/crates/cubecl-convolution/src/components/global/single_stage/simple/convolution.rs @@ -13,15 +13,14 @@ use cubecl_matmul::components::{ }; use cubecl_std::{ CubeOption, - tensor::{layout::Coords2d, r#virtual::VirtualTensor}, + tensor::{View, layout::Coords2d}, }; use crate::{ components::{ - ConvGemmConfig, ConvolutionConfig, + ConvolutionConfig, global::{ ConvTilingLayout, GlobalConvolution, - layout::{Im2colLayout, NhwcLayout, OutLayout, WeightLayout}, read::bias::{BiasGlobalReader, BiasStage}, }, }, @@ -121,16 +120,12 @@ where } fn init_lhs_global_reader( - lhs: VirtualTensor>, + lhs: View>, Coords2d>, offset: Coords2d, slice_size: Coords2d, - runtime_args: &RuntimeArgs, + _runtime_args: &RuntimeArgs, #[comptime] config: Self::Config, ) -> Self::LhsGlobalReader { - let check_spatial = comptime![config.check_spatial_bounds()]; - let layout_global = NhwcLayout::new(lhs, comptime![config.dimensionality()], check_spatial); - let layout_im2col = Im2colLayout::new(runtime_args, config); - let lhs = lhs.view(layout_global).view(layout_im2col); Self::LhsGlobalReader::new( lhs.slice_unchecked(offset, slice_size), config.k_step, @@ -140,53 +135,25 @@ where } fn init_rhs_global_reader( - rhs: VirtualTensor>, - offset: Coords2d, - slice_size: Coords2d, - runtime_args: &RuntimeArgs, + rhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::RhsGlobalReader { - let layout_global = NhwcLayout::new(rhs, comptime![config.dimensionality()], false); - let layout_weight = WeightLayout::new(&rhs, runtime_args, config); - let rhs = rhs.view(layout_global).view(layout_weight); - Self::RhsGlobalReader::new( - rhs.slice_unchecked(offset, slice_size), - config.k_step, - MatmulIdent::Rhs, - config, - ) + Self::RhsGlobalReader::new(rhs, config.k_step, MatmulIdent::Rhs, config) } fn init_bias_global_reader( - bias: CubeOption>>, - n_offset: u32, - slice_size: u32, + bias: CubeOption>, Coords2d>>, #[comptime] config: Self::Config, ) -> Self::AccGlobalReader { - Self::AccGlobalReader::new( - bias, - n_offset, - slice_size, - config.stage_memory_config(MatmulIdent::Out), - ) + Self::AccGlobalReader::new(bias, config.stage_memory_config(MatmulIdent::Out)) } fn init_global_writer( - out: VirtualTensor, ReadWrite>, - offset: Coords2d, - slice_size: Coords2d, - runtime_args: &RuntimeArgs, + out: View>, Coords2d, ReadWrite>, #[comptime] config: Self::Config, ) -> Self::GlobalWriter { let global_conf = config.global_memory_config(MatmulIdent::Out); - let layout_global = NhwcLayout::new(out, comptime![config.dimensionality()], false); - let layout_out = OutLayout::new(runtime_args, global_conf); - let out = out.view_mut(layout_global).view_mut(layout_out); - Self::GlobalWriter::new::( - out.slice_mut_unchecked(offset, slice_size), - global_conf, - config.stage_config(), - ) + Self::GlobalWriter::new::(out, global_conf, config.stage_config()) } fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulators { diff --git a/crates/cubecl-convolution/src/components/global/single_stage/simple/launch.rs b/crates/cubecl-convolution/src/components/global/single_stage/simple/launch.rs index 9a947b169..45716e054 100644 --- a/crates/cubecl-convolution/src/components/global/single_stage/simple/launch.rs +++ b/crates/cubecl-convolution/src/components/global/single_stage/simple/launch.rs @@ -28,7 +28,7 @@ impl< > ConvolutionLaunch> for SimpleConvolutionFamily { unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>( - client: &ComputeClient<::Server, ::Channel>, + client: &ComputeClient<::Server>, cube_dim: CubeDim, cube_count: CubeCount, input: InputRuntimeArg<'a, MS, R>, diff --git a/crates/cubecl-convolution/src/components/global/single_stage/simple/setup.rs b/crates/cubecl-convolution/src/components/global/single_stage/simple/setup.rs index 6aaab8827..4a993a993 100644 --- a/crates/cubecl-convolution/src/components/global/single_stage/simple/setup.rs +++ b/crates/cubecl-convolution/src/components/global/single_stage/simple/setup.rs @@ -45,7 +45,7 @@ where } fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &ConvolutionProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-convolution/src/components/global/single_stage/tma/convolution.rs b/crates/cubecl-convolution/src/components/global/single_stage/tma/convolution.rs index 092053192..1196dfc8e 100644 --- a/crates/cubecl-convolution/src/components/global/single_stage/tma/convolution.rs +++ b/crates/cubecl-convolution/src/components/global/single_stage/tma/convolution.rs @@ -15,7 +15,7 @@ use cubecl_matmul::components::{ }; use cubecl_std::{ CubeOption, - tensor::{AsTensorView, AsTensorViewExpand, layout::Coords2d, r#virtual::VirtualTensor}, + tensor::{View, layout::Coords2d}, }; use crate::{ @@ -23,11 +23,9 @@ use crate::{ ConvGemmConfig, ConvolutionConfig, global::{ GlobalConvolution, - layout::{NhwcLayout, OutLayout}, read::{ bias::{BiasGlobalReader, BiasStage}, im2col_tma::{TmaIm2colGlobalReader, TmaIm2colTiling}, - layout::TmaWeightLayout, weight_tma::{TmaWeightGlobalReader, TmaWeightTiling}, }, }, @@ -56,7 +54,7 @@ where { type Config = ConvolutionConfig>; - type LhsGlobalReader = TmaIm2colGlobalReader; + type LhsGlobalReader = TmaIm2colGlobalReader; type RhsGlobalReader = TmaWeightGlobalReader; type AccGlobalReader = BiasGlobalReader; type GlobalWriter = PlaneWriter; @@ -133,27 +131,30 @@ where } fn init_lhs_global_reader( - lhs: VirtualTensor>, + lhs: View>, Coords2d>, offset: Coords2d, _slice_size: Coords2d, runtime_args: &RuntimeArgs, #[comptime] config: Self::Config, ) -> Self::LhsGlobalReader { let (x_offset, y_offset) = offset; - Self::LhsGlobalReader::new(lhs, x_offset, y_offset, runtime_args, 1u32, config) + Self::LhsGlobalReader::new( + lhs.as_tensor_map().unwrap(), + x_offset, + y_offset, + runtime_args, + 1u32, + config.convolution_params(), + config.stage_memory_config(MatmulIdent::Lhs), + ) } fn init_rhs_global_reader( - rhs: VirtualTensor>, - offset: Coords2d, - slice_size: Coords2d, - runtime_args: &RuntimeArgs, + rhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::RhsGlobalReader { - let layout = TmaWeightLayout::new(runtime_args.padded_channels); - let rhs = rhs.as_tensor_map().unwrap().view_3d(layout); Self::RhsGlobalReader::new( - rhs.slice(offset, slice_size), + rhs, config.k_step, 1u32, config.stage_memory_config(MatmulIdent::Rhs), @@ -161,35 +162,18 @@ where } fn init_bias_global_reader( - bias: CubeOption>>, - n_offset: u32, - slice_size: u32, + bias: CubeOption>, Coords2d>>, #[comptime] config: Self::Config, ) -> Self::AccGlobalReader { - Self::AccGlobalReader::new( - bias, - n_offset, - slice_size, - config.stage_memory_config(MatmulIdent::Out), - ) + Self::AccGlobalReader::new(bias, config.stage_memory_config(MatmulIdent::Out)) } fn init_global_writer( - out: VirtualTensor, ReadWrite>, - offset: Coords2d, - slice_size: Coords2d, - runtime_args: &RuntimeArgs, + out: View>, Coords2d, ReadWrite>, #[comptime] config: Self::Config, ) -> Self::GlobalWriter { let global_conf = config.global_memory_config(MatmulIdent::Out); - let layout_global = NhwcLayout::new(out, comptime![config.dimensionality()], false); - let layout_out = OutLayout::new(runtime_args, global_conf); - let out = out.view_mut(layout_global).view_mut(layout_out); - Self::GlobalWriter::new::( - out.slice_mut_unchecked(offset, slice_size), - global_conf, - config.stage_config(), - ) + Self::GlobalWriter::new::(out, global_conf, config.stage_config()) } fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulators { diff --git a/crates/cubecl-convolution/src/components/global/single_stage/tma/launch.rs b/crates/cubecl-convolution/src/components/global/single_stage/tma/launch.rs index 817edefc9..0a946cd73 100644 --- a/crates/cubecl-convolution/src/components/global/single_stage/tma/launch.rs +++ b/crates/cubecl-convolution/src/components/global/single_stage/tma/launch.rs @@ -28,7 +28,7 @@ impl< > ConvolutionLaunch> for SimpleTmaConvolutionFamily { unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>( - client: &ComputeClient<::Server, ::Channel>, + client: &ComputeClient<::Server>, cube_dim: CubeDim, cube_count: CubeCount, input: InputRuntimeArg<'a, MS, R>, diff --git a/crates/cubecl-convolution/src/components/global/single_stage/tma/setup.rs b/crates/cubecl-convolution/src/components/global/single_stage/tma/setup.rs index 40bc4cc79..90dc289ae 100644 --- a/crates/cubecl-convolution/src/components/global/single_stage/tma/setup.rs +++ b/crates/cubecl-convolution/src/components/global/single_stage/tma/setup.rs @@ -49,7 +49,7 @@ where } fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &ConvolutionProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-convolution/src/components/mod.rs b/crates/cubecl-convolution/src/components/mod.rs index c6b3966d6..f47d43677 100644 --- a/crates/cubecl-convolution/src/components/mod.rs +++ b/crates/cubecl-convolution/src/components/mod.rs @@ -7,11 +7,11 @@ mod problem; mod selection; pub use config::*; -use cubecl_matmul::components::tile::{accelerated::AcceleratedMatmul, io::Strided}; +use cubecl_matmul::components::tile::{cmma::CmmaMatmul, io::Strided}; use cubecl_std::CubeOption; pub use error::*; pub use problem::*; pub use selection::*; /// Convolution using `AcceleratedMatmul` -pub type AcceleratedConv = AcceleratedMatmul>; +pub type AcceleratedConv = CmmaMatmul>; diff --git a/crates/cubecl-convolution/src/components/problem.rs b/crates/cubecl-convolution/src/components/problem.rs index f893abfce..d90ef222e 100644 --- a/crates/cubecl-convolution/src/components/problem.rs +++ b/crates/cubecl-convolution/src/components/problem.rs @@ -30,6 +30,7 @@ impl ConvolutionProblem { k: self.k, lhs_batches: vec![], rhs_batches: vec![], + out_batches: vec![], lhs_layout: self.lhs_layout, rhs_layout: self.rhs_layout, } diff --git a/crates/cubecl-convolution/src/components/selection.rs b/crates/cubecl-convolution/src/components/selection.rs index c7dae451e..7fbbfe9bb 100644 --- a/crates/cubecl-convolution/src/components/selection.rs +++ b/crates/cubecl-convolution/src/components/selection.rs @@ -75,7 +75,7 @@ pub(crate) fn find_stage_size_m_n( } pub fn convolution_matmul_selection( - client: &ComputeClient, + client: &ComputeClient, problem: &ConvolutionProblem, plane_dim: u32, matmul_elems: MatmulElems, @@ -84,22 +84,7 @@ pub fn convolution_matmul_selection( // to be the rough cutoff for the k=4 size. let stage_k = if problem.k >= 4096 { 4 } else { 2 }; - let tile_size = find_instruction_size( - if TMM::requires_accelerator() { - Some(( - client.properties(), - ( - matmul_elems.lhs_register, - matmul_elems.rhs_register, - matmul_elems.acc_register, - ), - )) - } else { - None - }, - problem.m, - problem.n, - ); + let tile_size = find_instruction_size::(client, &matmul_elems, problem.m, problem.n); let hardware = &client.properties().hardware; let num_sm = hardware diff --git a/crates/cubecl-convolution/src/components/stage/reader.rs b/crates/cubecl-convolution/src/components/stage/reader.rs index c754a030e..9f0ae0160 100644 --- a/crates/cubecl-convolution/src/components/stage/reader.rs +++ b/crates/cubecl-convolution/src/components/stage/reader.rs @@ -1,8 +1,9 @@ use cubecl::prelude::*; use cubecl_core as cubecl; use cubecl_matmul::components::{ - MatrixLayout, StageIdent, - stage::{StageMemoryConfig, StridedStage, TilingLayout}, + InvalidConfigError, MatrixLayout, StageIdent, + global::memory::GlobalMemoryConfig, + stage::{StageMemoryConfig, StridedStage, TilingLayout, TilingValidation}, tile::StridedTile, }; use cubecl_std::tensor::layout::Coords2d; @@ -39,3 +40,13 @@ impl TilingLayout for BiasTilingLayout { ) } } + +impl TilingValidation for BiasTilingLayout { + fn check(config: GlobalMemoryConfig) -> Result<(), InvalidConfigError> { + let stage_width = config.elements_in_stage_col; + if config.global_line_size > stage_width { + return Err(Box::new("Invalid line size")); + } + Ok(()) + } +} diff --git a/crates/cubecl-convolution/src/kernels/layered/algorithm/mod.rs b/crates/cubecl-convolution/src/kernels/layered/algorithm/mod.rs index fae300e9d..2cc9f944d 100644 --- a/crates/cubecl-convolution/src/kernels/layered/algorithm/mod.rs +++ b/crates/cubecl-convolution/src/kernels/layered/algorithm/mod.rs @@ -60,7 +60,7 @@ pub trait Algorithm { /// Make a convolution config from a convolution problem, and launch options fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &ConvolutionProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, @@ -75,13 +75,13 @@ pub trait Algorithm { } fn into_tensor_handle( - client: &ComputeClient, + client: &ComputeClient, handle: &TensorHandleRef<'_, R>, ident: MatmulIdent, ) -> TensorHandle; fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &ConvolutionProblem, plane_dim: u32, matmul_elems: MatmulElems, diff --git a/crates/cubecl-convolution/src/kernels/layered/algorithm/multi_stage_tma.rs b/crates/cubecl-convolution/src/kernels/layered/algorithm/multi_stage_tma.rs index ab5817635..f65c3d31d 100644 --- a/crates/cubecl-convolution/src/kernels/layered/algorithm/multi_stage_tma.rs +++ b/crates/cubecl-convolution/src/kernels/layered/algorithm/multi_stage_tma.rs @@ -50,7 +50,7 @@ impl< type Args = TensorMapArgs; fn into_tensor_handle( - client: &ComputeClient, + client: &ComputeClient, handle: &TensorHandleRef<'_, R>, ident: MatmulIdent, ) -> TensorHandle { @@ -63,7 +63,7 @@ impl< } fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &ConvolutionProblem, plane_dim: u32, matmul_elems: MatmulElems, diff --git a/crates/cubecl-convolution/src/kernels/layered/algorithm/simple.rs b/crates/cubecl-convolution/src/kernels/layered/algorithm/simple.rs index 91a81da7c..df850b6b8 100644 --- a/crates/cubecl-convolution/src/kernels/layered/algorithm/simple.rs +++ b/crates/cubecl-convolution/src/kernels/layered/algorithm/simple.rs @@ -52,7 +52,7 @@ impl< type Args = TensorArgs; fn into_tensor_handle( - client: &ComputeClient, + client: &ComputeClient, handle: &TensorHandleRef<'_, R>, ident: MatmulIdent, ) -> TensorHandle { @@ -69,7 +69,7 @@ impl< } fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &ConvolutionProblem, plane_dim: u32, matmul_elems: MatmulElems, diff --git a/crates/cubecl-convolution/src/kernels/layered/algorithm/simple_tma.rs b/crates/cubecl-convolution/src/kernels/layered/algorithm/simple_tma.rs index 84b6b17d0..e8be3a41d 100644 --- a/crates/cubecl-convolution/src/kernels/layered/algorithm/simple_tma.rs +++ b/crates/cubecl-convolution/src/kernels/layered/algorithm/simple_tma.rs @@ -64,7 +64,7 @@ impl< } fn into_tensor_handle( - client: &ComputeClient, + client: &ComputeClient, handle: &TensorHandleRef<'_, R>, ident: MatmulIdent, ) -> TensorHandle { @@ -77,7 +77,7 @@ impl< } fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &ConvolutionProblem, plane_dim: u32, matmul_elems: MatmulElems, @@ -92,7 +92,7 @@ impl< } pub(crate) fn into_tensor_handle_tma( - client: &ComputeClient, + client: &ComputeClient, handle: &TensorHandleRef<'_, R>, ident: MatmulIdent, ) -> TensorHandle { diff --git a/crates/cubecl-convolution/src/kernels/layered/selector/select_kernel.rs b/crates/cubecl-convolution/src/kernels/layered/selector/select_kernel.rs index 413826540..321dd5dac 100644 --- a/crates/cubecl-convolution/src/kernels/layered/selector/select_kernel.rs +++ b/crates/cubecl-convolution/src/kernels/layered/selector/select_kernel.rs @@ -4,15 +4,17 @@ use cubecl_matmul::{ MatmulInputHandleRef, components::{ InputArg, InputRuntimeArg, MatmulLineSizes, MatmulSelection, MatmulSpec, OutputArg, - OutputRuntimeArg, - global::{GlobalConfig as _, args::ConcreteOutputFactory}, + OutputRuntimeArg, global::GlobalConfig as _, }, }; use crate::{ components::{ ConvSetupError, ConvolutionProblem, - global::{args::ConcreteInputsFactory, entry_point::ConvolutionLaunch}, + global::{ + args::{ConcreteInputsFactory, ConcreteOutputFactory}, + entry_point::ConvolutionLaunch, + }, }, kernels::layered::algorithm::Algorithm, }; @@ -22,7 +24,7 @@ use crate::{ /// Only works for concrete tensor inputs and output. #[allow(clippy::result_large_err, clippy::too_many_arguments)] pub fn launch_kernel_concrete( - client: &ComputeClient, + client: &ComputeClient, input: &MatmulInputHandleRef<'_, R>, weight: &MatmulInputHandleRef<'_, R>, bias: &Option>, @@ -38,18 +40,22 @@ where let config = A::setup::(client, &problem, &selection, &line_sizes)?; let input = as ConcreteInputsFactory>::create( + client, input, weight, bias.as_ref(), &selection, &problem, &line_sizes, + config, ); let output = as ConcreteOutputFactory>::create( + client, out, &selection, - &problem.as_matmul_problem(), + &problem, &line_sizes, + config, ); unsafe { @@ -69,7 +75,7 @@ where /// Select which kernel to launch for the given Algorithm. pub fn launch_kernel_virtual<'a, MS: MatmulSpec, R: Runtime, A: Algorithm>( - client: &ComputeClient, + client: &ComputeClient, input: InputRuntimeArg<'a, MS, R>, output: OutputRuntimeArg<'a, MS, R>, problem: ConvolutionProblem, diff --git a/crates/cubecl-convolution/src/launch.rs b/crates/cubecl-convolution/src/launch.rs index a766ac67a..ebf0190bf 100644 --- a/crates/cubecl-convolution/src/launch.rs +++ b/crates/cubecl-convolution/src/launch.rs @@ -10,10 +10,13 @@ use crate::{ kernels::layered::selector::launch_kernel_concrete, }; use crate::{ - components::{ConvolutionProblem, Dimensionality, global::args::ConcreteInputsFactory}, + components::{ + ConvolutionProblem, Dimensionality, + global::args::{ConcreteInputsFactory, ConcreteOutputFactory}, + }, kernels::layered::algorithm::Algorithm, }; -use cubecl_matmul::components::global::args::{ConcreteOutputFactory, MatmulArgs}; +use cubecl_matmul::components::global::args::MatmulArgs; use cubecl_matmul::components::{ self, AvailableLineSizes, LhsG, MatmulElems, MatmulIdent, MatmulPrecision, MatmulSelection, MatrixPrecision, RhsG, @@ -45,7 +48,7 @@ pub struct ConvolutionArgs { /// * `options` - The options to use for the convolution #[allow(clippy::result_large_err)] pub fn launch_conv( - client: &ComputeClient, + client: &ComputeClient, input: &MatmulInputHandleRef<'_, R>, weight: &MatmulInputHandleRef<'_, R>, bias: &Option>, @@ -81,7 +84,7 @@ where } fn launch( - client: &ComputeClient, + client: &ComputeClient, input: &MatmulInputHandleRef<'_, R>, weight: &MatmulInputHandleRef<'_, R>, bias: &Option>, @@ -166,7 +169,7 @@ where #[allow(clippy::result_large_err, clippy::too_many_arguments)] pub fn launch_kernel( - client: &ComputeClient, + client: &ComputeClient, input: &MatmulInputHandleRef<'_, R>, weight: &MatmulInputHandleRef<'_, R>, bias: &Option>, @@ -178,10 +181,10 @@ where Input: ConcreteInputsFactory, Output: ConcreteOutputFactory, { - let line_sizes = AvailableLineSizes::from_types::( - &LhsG::::as_type_native_unchecked(), - &RhsG::::as_type_native_unchecked(), - &AccG::::as_type_native_unchecked(), + let line_sizes = AvailableLineSizes::from_type_sizes::( + input.data().elem_size, + weight.data().elem_size, + out.elem_size, ) .filter_lhs_with_tensor(input.data().strides, input.data().shape, problem.lhs_layout) .filter_rhs_with_tensor( diff --git a/crates/cubecl-convolution/src/tests/convolution_test_launcher.rs b/crates/cubecl-convolution/src/tests/convolution_test_launcher.rs index 6fa7000a0..e5ca52255 100644 --- a/crates/cubecl-convolution/src/tests/convolution_test_launcher.rs +++ b/crates/cubecl-convolution/src/tests/convolution_test_launcher.rs @@ -5,14 +5,17 @@ use cubecl_matmul::components::MatmulSelection; use cubecl_matmul::components::global::GlobalConfig; use cubecl_matmul::{MatmulInputHandleRef, components::AvailableLineSizes}; -use cubecl_matmul::components::global::args::{ConcreteOutputFactory, MatmulArgs}; +use cubecl_matmul::components::global::args::MatmulArgs; use cubecl_matmul::tests::layered::matmul_test_launcher::TensorRawParts; use cubecl_matmul::tests::test_utils::Sample; use crate::{ components::{ ConvGemmConfig as _, ConvolutionProblem, - global::{args::ConcreteInputsFactory, entry_point::ConvolutionLaunch}, + global::{ + args::{ConcreteInputsFactory, ConcreteOutputFactory}, + entry_point::ConvolutionLaunch, + }, }, kernels::layered::algorithm::Algorithm, }; @@ -25,7 +28,7 @@ type Output = ::Output; /// Test the correctness of the specified Matmul on the given device, /// against a naive CPU implementation over the given problem pub fn test_convolution_algorithm( - client: ComputeClient, + client: ComputeClient, problem: ConvolutionProblem, selection: MatmulSelection, ) where @@ -53,7 +56,7 @@ pub fn test_convolution_algorithm( let line_sizes = AvailableLineSizes { lhs: vec![1], rhs: vec![1], - out: R::io_optimized_line_sizes_unchecked(&P::EG::as_type_native_unchecked()).collect(), + out: R::io_optimized_line_sizes_unchecked(size_of::()).collect(), } .filter_lhs_with_tensor(&lhs.strides, &lhs.shape, problem.lhs_layout) .filter_rhs_with_tensor(&rhs.strides, &rhs.shape, problem.rhs_layout) @@ -105,18 +108,22 @@ pub fn test_convolution_algorithm( let rhs_handle = MatmulInputHandleRef::new(rhs_handle.as_ref()); let inputs = as ConcreteInputsFactory>::create( + &client, &lhs_handle, &rhs_handle, None, &selection, &problem, &config.line_sizes(), + config, ); let output = as ConcreteOutputFactory>::create( + &client, &out_handle, &selection, - &problem.as_matmul_problem(), + &problem, &config.line_sizes(), + config, ); unsafe { @@ -146,7 +153,7 @@ pub fn test_convolution_algorithm( } fn tensor_raw_parts( - client: &ComputeClient, + client: &ComputeClient, problem: &ConvolutionProblem, ident: MatmulIdent, ) -> TensorRawParts { diff --git a/crates/cubecl-convolution/src/tests/test_macros/mod.rs b/crates/cubecl-convolution/src/tests/test_macros/mod.rs index 4ae873121..ad57d9b8c 100644 --- a/crates/cubecl-convolution/src/tests/test_macros/mod.rs +++ b/crates/cubecl-convolution/src/tests/test_macros/mod.rs @@ -10,7 +10,7 @@ macro_rules! testgen_conv2d_accelerated { use super::*; use cubecl_std::CubeOption; use cubecl_matmul::components::tile::io::Strided; - type TMM = cubecl_matmul::components::tile::accelerated::AcceleratedMatmul>; + type TMM = cubecl_matmul::components::tile::cmma::CmmaMatmul>; ::paste::paste! { $(mod [<$float _ty>] { diff --git a/crates/cubecl-convolution/src/tests/test_macros/suite.rs b/crates/cubecl-convolution/src/tests/test_macros/suite.rs index e7a8e7335..a86abba78 100644 --- a/crates/cubecl-convolution/src/tests/test_macros/suite.rs +++ b/crates/cubecl-convolution/src/tests/test_macros/suite.rs @@ -1,5 +1,8 @@ use crate::{ - components::{ConvolutionProblem, Dimensionality, global::args::ConcreteInputsFactory}, + components::{ + ConvolutionProblem, Dimensionality, + global::args::{ConcreteInputsFactory, ConcreteOutputFactory}, + }, tests::test_utils::TestPrecision, }; use crate::{ @@ -7,7 +10,6 @@ use crate::{ tests::convolution_test_launcher::test_convolution_algorithm, }; use cubecl_core::Runtime; -use cubecl_matmul::components::global::args::ConcreteOutputFactory; use cubecl_matmul::components::global::args::MatmulArgs; use cubecl_matmul::components::stage::PartitionBuffering; use cubecl_matmul::components::{ diff --git a/crates/cubecl-convolution/src/tests/test_utils.rs b/crates/cubecl-convolution/src/tests/test_utils.rs index a30230d96..6c2ab9d1c 100644 --- a/crates/cubecl-convolution/src/tests/test_utils.rs +++ b/crates/cubecl-convolution/src/tests/test_utils.rs @@ -20,7 +20,7 @@ pub trait TestPrecision { lhs: &[Self::EG], rhs: &[Self::EG], problem: &ConvolutionProblem, - client: &ComputeClient, + client: &ComputeClient, out: server::Handle, shape: &[usize], strides: &[usize], @@ -41,7 +41,7 @@ where lhs: &[EG], rhs: &[EG], problem: &ConvolutionProblem, - client: &ComputeClient, + client: &ComputeClient, out: server::Handle, shape: &[usize], strides: &[usize], @@ -84,7 +84,7 @@ where /// Compares the content of a handle to a given slice of f32. pub(crate) fn assert_equals_approx( - client: &ComputeClient, + client: &ComputeClient, output: server::Handle, shape: &[usize], strides: &[usize], diff --git a/crates/cubecl-core/Cargo.toml b/crates/cubecl-core/Cargo.toml index 5f3497039..0f07f72df 100644 --- a/crates/cubecl-core/Cargo.toml +++ b/crates/cubecl-core/Cargo.toml @@ -20,15 +20,15 @@ std = ["cubecl-runtime/std"] template = [] [dependencies] -cubecl-ir = { path = "../cubecl-ir", version = "0.7.0", default-features = false, features = [ +cubecl-ir = { path = "../cubecl-ir", version = "0.9.0", default-features = false, features = [ "serde", ] } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false } +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false } bitflags = { workspace = true } bytemuck = { workspace = true } -cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false } -cubecl-macros = { path = "../cubecl-macros", version = "0.7.0", default-features = false } +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false } +cubecl-macros = { path = "../cubecl-macros", version = "0.9.0", default-features = false } derive-new = { workspace = true } derive_more = { workspace = true, features = [ "not", diff --git a/crates/cubecl-core/src/codegen/integrator.rs b/crates/cubecl-core/src/codegen/integrator.rs index 72026ffb8..0fc9ba89b 100644 --- a/crates/cubecl-core/src/codegen/integrator.rs +++ b/crates/cubecl-core/src/codegen/integrator.rs @@ -1,11 +1,7 @@ use cubecl_common::CubeDim; use cubecl_ir::{Id, Scope, StorageType, Type}; -use enumset::EnumSet; -use crate::{ - compute::{Binding, KernelDefinition, Location, ScalarBinding, Visibility}, - prelude::FastMath, -}; +use crate::compute::{Binding, KernelDefinition, Location, ScalarBinding, Visibility}; /// The kernel integrator allows you to create a [kernel definition](KernelDefinition) based on /// [kernel expansion](KernelExpansion) and [kernel settings](KernelSettings). @@ -36,7 +32,6 @@ pub struct KernelSettings { pub struct KernelOptions { pub kernel_name: String, pub debug_symbols: bool, - pub fp_math_mode: EnumSet, pub cluster_dim: Option, } @@ -61,12 +56,6 @@ impl KernelSettings { self } - /// Set FP math mode - pub fn fp_math_mode(mut self, mode: EnumSet) -> Self { - self.options.fp_math_mode = mode; - self - } - /// Set cluster dim pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self { self.options.cluster_dim = Some(cluster_dim); diff --git a/crates/cubecl-core/src/compute/launcher.rs b/crates/cubecl-core/src/compute/launcher.rs index 4cd05b43b..1449c1553 100644 --- a/crates/cubecl-core/src/compute/launcher.rs +++ b/crates/cubecl-core/src/compute/launcher.rs @@ -112,7 +112,7 @@ impl KernelLauncher { self, cube_count: CubeCount, kernel: K, - client: &ComputeClient, + client: &ComputeClient, ) { let bindings = self.into_bindings(); let kernel = Box::new(KernelTask::::new(kernel)); @@ -133,7 +133,7 @@ impl KernelLauncher { self, cube_count: CubeCount, kernel: K, - client: &ComputeClient, + client: &ComputeClient, ) { unsafe { let bindings = self.into_bindings(); diff --git a/crates/cubecl-core/src/frontend/branch.rs b/crates/cubecl-core/src/frontend/branch.rs index 1388c1555..cae0cf146 100644 --- a/crates/cubecl-core/src/frontend/branch.rs +++ b/crates/cubecl-core/src/frontend/branch.rs @@ -27,6 +27,10 @@ pub trait Iterable: Sized { scope: &mut Scope, body: impl FnMut(&mut Scope, ::ExpandType), ); + /// Return the comptime length of this iterable, if possible + fn const_len(&self) -> Option { + None + } } pub struct RangeExpand { @@ -109,6 +113,12 @@ impl Iterable for RangeExpand { inclusive: self.inclusive, }))); } + + fn const_len(&self) -> Option { + let start = self.start.expand.as_const()?.as_i64(); + let end = self.end.expand.as_const()?.as_i64(); + Some(start.abs_diff(end) as usize) + } } pub struct SteppedRangeExpand { @@ -176,6 +186,13 @@ impl> Iterable for SteppedRangeExpand { } } } + + fn const_len(&self) -> Option { + let start = self.start.constant()?.as_i64(); + let end = self.end.constant()?.as_i64(); + let step = self.step.constant()?.as_u64(); + Some((start.abs_diff(end) / step) as usize) + } } /// integer range. Equivalent to: @@ -254,7 +271,7 @@ pub fn for_expand( unroll: bool, body: impl FnMut(&mut Scope, ExpandElementTyped), ) { - if unroll { + if unroll || range.const_len() == Some(1) { range.expand_unroll(scope, body); } else { range.expand(scope, body); diff --git a/crates/cubecl-core/src/frontend/comment.rs b/crates/cubecl-core/src/frontend/comment.rs deleted file mode 100644 index f6681a343..000000000 --- a/crates/cubecl-core/src/frontend/comment.rs +++ /dev/null @@ -1,10 +0,0 @@ -pub mod cube_comment { - use crate::ir::NonSemantic; - use cubecl_ir::Scope; - - pub fn expand(scope: &mut Scope, content: &str) { - scope.register(NonSemantic::Comment { - content: content.to_string(), - }); - } -} diff --git a/crates/cubecl-core/src/frontend/container/line/base.rs b/crates/cubecl-core/src/frontend/container/line/base.rs index 31c2b1419..0a2dc912d 100644 --- a/crates/cubecl-core/src/frontend/container/line/base.rs +++ b/crates/cubecl-core/src/frontend/container/line/base.rs @@ -8,7 +8,7 @@ use crate::{ prelude::{Dot, Numeric, binary_expand_fixed_output}, unexpanded, }; -use cubecl_ir::{Comparison, ExpandElement, StorageType}; +use cubecl_ir::{Comparison, ConstantScalarValue, ExpandElement, StorageType}; use cubecl_macros::{cube, intrinsic}; use derive_more::derive::Neg; /// A contiguous list of elements that supports auto-vectorized operations. @@ -265,6 +265,10 @@ impl CubePrimitive for Line

{ fn size() -> Option { P::size() } + + fn from_const_value(value: ConstantScalarValue) -> Self { + Self::new(P::from_const_value(value)) + } } impl Dot for Line { diff --git a/crates/cubecl-core/src/frontend/container/line/ops.rs b/crates/cubecl-core/src/frontend/container/line/ops.rs index cac015871..6b7857296 100644 --- a/crates/cubecl-core/src/frontend/container/line/ops.rs +++ b/crates/cubecl-core/src/frontend/container/line/ops.rs @@ -4,7 +4,7 @@ use num_traits::{NumCast, ToPrimitive}; use crate::{ self as cubecl, - prelude::{IsInf, IsNan, Powi, SaturatingAdd, SaturatingSub}, + prelude::{InverseSqrt, IsInf, IsNan, Powi, SaturatingAdd, SaturatingSub, Trunc}, }; use crate::{ frontend::{ @@ -247,6 +247,7 @@ impl Exp for Line

{} impl Powf for Line

{} impl, I: CubePrimitive> Powi> for Line

{} impl Sqrt for Line

{} +impl InverseSqrt for Line

{} impl Cos for Line

{} impl Sin for Line

{} impl Tanh for Line

{} @@ -255,6 +256,7 @@ impl Remainder for Line

{} impl Round for Line

{} impl Floor for Line

{} impl Ceil for Line

{} +impl Trunc for Line

{} impl ReverseBits for Line

{} impl BitwiseNot for Line

{} impl SaturatingAdd for Line

{} diff --git a/crates/cubecl-core/src/frontend/container/sequence/base.rs b/crates/cubecl-core/src/frontend/container/sequence/base.rs index 17b513eb3..d5dffa7a7 100644 --- a/crates/cubecl-core/src/frontend/container/sequence/base.rs +++ b/crates/cubecl-core/src/frontend/container/sequence/base.rs @@ -139,6 +139,10 @@ impl Iterable for SequenceExpand { func(scope, elem); } } + + fn const_len(&self) -> Option { + Some(self.values.borrow().len()) + } } impl IntoMut for SequenceExpand { @@ -258,8 +262,11 @@ impl SequenceExpand { } pub fn __expand_rev_method(self, _scope: &mut Scope) -> Self { - self.values.borrow_mut().reverse(); - self + let mut values = self.values.borrow().clone(); + values.reverse(); + Self { + values: Rc::new(RefCell::new(values)), + } } pub fn __expand_clone_method(&self, _scope: &mut Scope) -> Self { diff --git a/crates/cubecl-core/src/frontend/container/sequence/launch.rs b/crates/cubecl-core/src/frontend/container/sequence/launch.rs index 967ab835c..48d6cd1ea 100644 --- a/crates/cubecl-core/src/frontend/container/sequence/launch.rs +++ b/crates/cubecl-core/src/frontend/container/sequence/launch.rs @@ -104,3 +104,11 @@ impl ArgSettings for SequenceArg<'_, R, T> { self.values.iter().for_each(|arg| arg.register(launcher)); } } + +impl<'a, R: Runtime, E: LaunchArg> FromIterator> for SequenceArg<'a, R, E> { + fn from_iter>>(iter: T) -> Self { + SequenceArg { + values: iter.into_iter().collect(), + } + } +} diff --git a/crates/cubecl-core/src/frontend/container/shared_memory.rs b/crates/cubecl-core/src/frontend/container/shared_memory.rs index 9d2818268..c44d4c8b6 100644 --- a/crates/cubecl-core/src/frontend/container/shared_memory.rs +++ b/crates/cubecl-core/src/frontend/container/shared_memory.rs @@ -5,7 +5,7 @@ use crate::{ prelude::{Lined, LinedExpand}, unexpanded, }; -use cubecl_ir::{Instruction, Operation, VariableKind}; +use cubecl_ir::{Marker, VariableKind}; use cubecl_macros::{cube, intrinsic}; use crate::{ @@ -142,7 +142,7 @@ impl SharedMemory { /// *Must* be used in uniform control flow /// *Must not* have any dangling references to this shared memory pub unsafe fn free(self) { - intrinsic!(|scope| { scope.register(Instruction::no_out(Operation::Free(*self.expand))) }) + intrinsic!(|scope| { scope.register(Marker::Free(*self.expand)) }) } } diff --git a/crates/cubecl-core/src/frontend/container/tensor/base.rs b/crates/cubecl-core/src/frontend/container/tensor/base.rs index e1c31c9af..cd7e9778e 100644 --- a/crates/cubecl-core/src/frontend/container/tensor/base.rs +++ b/crates/cubecl-core/src/frontend/container/tensor/base.rs @@ -15,7 +15,7 @@ use crate as cubecl; /// The tensor type is similar to the [array type](crate::prelude::Array), however it comes with more /// metadata such as [stride](Tensor::stride) and [shape](Tensor::shape). -#[derive(new)] +#[derive(new, Clone, Copy)] pub struct Tensor { _val: PhantomData, } diff --git a/crates/cubecl-core/src/frontend/container/tensor/launch.rs b/crates/cubecl-core/src/frontend/container/tensor/launch.rs index b9289eb37..b92565067 100644 --- a/crates/cubecl-core/src/frontend/container/tensor/launch.rs +++ b/crates/cubecl-core/src/frontend/container/tensor/launch.rs @@ -6,7 +6,9 @@ use crate::{ Runtime, compute::{KernelBuilder, KernelLauncher}, ir::{Id, LineSize, Type}, - prelude::{ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg}, + prelude::{ + ArgSettings, ArrayArg, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, + }, }; use super::Tensor; @@ -184,6 +186,11 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> { ) } } + /// Convert the handle into an [array argument](ArrayArg). + pub fn as_array_arg(&'a self, line_size: u8) -> ArrayArg<'a, R> { + let length = self.shape.iter().product(); + unsafe { ArrayArg::from_raw_parts_and_size(self.handle, length, line_size, self.elem_size) } + } /// Create a handle from raw parts. /// /// # Safety diff --git a/crates/cubecl-core/src/frontend/debug.rs b/crates/cubecl-core/src/frontend/debug.rs index f2036881d..5c3b5494b 100644 --- a/crates/cubecl-core/src/frontend/debug.rs +++ b/crates/cubecl-core/src/frontend/debug.rs @@ -98,3 +98,14 @@ macro_rules! debug_print_expand { $crate::debug_print_expand!($format, $($args),*) }; } + +pub mod cube_comment { + use crate::ir::NonSemantic; + use cubecl_ir::Scope; + + pub fn expand(scope: &mut Scope, content: &str) { + scope.register(NonSemantic::Comment { + content: content.to_string(), + }); + } +} diff --git a/crates/cubecl-core/src/frontend/element/atomic.rs b/crates/cubecl-core/src/frontend/element/atomic.rs index ee43d9448..b0d97b796 100644 --- a/crates/cubecl-core/src/frontend/element/atomic.rs +++ b/crates/cubecl-core/src/frontend/element/atomic.rs @@ -1,4 +1,4 @@ -use cubecl_ir::{AtomicOp, ExpandElement, StorageType}; +use cubecl_ir::{AtomicOp, ConstantScalarValue, ExpandElement, StorageType}; use super::{ExpandElementIntoMut, ExpandElementTyped, Int, Numeric, into_mut_expand_element}; use crate::{ @@ -306,6 +306,10 @@ impl CubePrimitive for Atomic { fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType { ExpandElementTyped::new(elem) } + + fn from_const_value(_value: ConstantScalarValue) -> Self { + panic!("Can't have constant atomic"); + } } impl ExpandElementIntoMut for Atomic { diff --git a/crates/cubecl-core/src/frontend/element/base.rs b/crates/cubecl-core/src/frontend/element/base.rs index 881614f49..804b99b39 100644 --- a/crates/cubecl-core/src/frontend/element/base.rs +++ b/crates/cubecl-core/src/frontend/element/base.rs @@ -342,6 +342,11 @@ impl ExpandElementTyped { _ => None, } } + + pub fn __expand_into_lit_unchecked_method(self, _scope: &mut Scope) -> T { + let value = self.constant().unwrap(); + T::from_const_value(value) + } } pub(crate) fn into_runtime_expand_element>( diff --git a/crates/cubecl-core/src/frontend/element/bool.rs b/crates/cubecl-core/src/frontend/element/bool.rs index 640a93f08..8ed4ee9c0 100644 --- a/crates/cubecl-core/src/frontend/element/bool.rs +++ b/crates/cubecl-core/src/frontend/element/bool.rs @@ -1,4 +1,4 @@ -use cubecl_ir::{ExpandElement, Scope, StorageType}; +use cubecl_ir::{ConstantScalarValue, ExpandElement, Scope, StorageType}; use crate::frontend::{CubePrimitive, CubeType}; use crate::ir::ElemType; @@ -31,6 +31,13 @@ impl CubePrimitive for bool { fn as_type_native() -> Option { Some(StorageType::Scalar(ElemType::Bool)) } + + fn from_const_value(value: ConstantScalarValue) -> Self { + let ConstantScalarValue::Bool(value) = value else { + unreachable!() + }; + value + } } impl IntoRuntime for bool { diff --git a/crates/cubecl-core/src/frontend/element/cube_elem.rs b/crates/cubecl-core/src/frontend/element/cube_elem.rs index 4211499ed..f85538b9c 100644 --- a/crates/cubecl-core/src/frontend/element/cube_elem.rs +++ b/crates/cubecl-core/src/frontend/element/cube_elem.rs @@ -1,7 +1,5 @@ -use cubecl_ir::{ExpandElement, StorageType}; -use cubecl_runtime::{ - TypeUsage, channel::ComputeChannel, client::ComputeClient, server::ComputeServer, -}; +use cubecl_ir::{ConstantScalarValue, ExpandElement, StorageType}; +use cubecl_runtime::{TypeUsage, client::ComputeClient, server::ComputeServer}; use enumset::EnumSet; use crate::frontend::CubeType; @@ -47,12 +45,23 @@ pub trait CubePrimitive: Self::as_type_native().map(|t| t.size_bits()) } + /// Only native element types have a size. + fn size_bits_unchecked() -> usize { + Self::as_type_native_unchecked().size_bits() + } + fn from_expand_elem(elem: ExpandElement) -> Self::ExpandType { ExpandElementTyped::new(elem) } - fn supported_uses>( - client: &ComputeClient, + fn from_const_value(value: ConstantScalarValue) -> Self; + + fn into_lit_unchecked(self) -> Self { + self + } + + fn supported_uses( + client: &ComputeClient, ) -> EnumSet { let elem = Self::as_type_native_unchecked(); client.properties().features.type_usage(elem) diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs index 69b558619..9b90cc471 100644 --- a/crates/cubecl-core/src/frontend/element/float.rs +++ b/crates/cubecl-core/src/frontend/element/float.rs @@ -1,4 +1,4 @@ -use cubecl_ir::{Scope, StorageType}; +use cubecl_ir::{ConstantScalarValue, Scope, StorageType}; use half::{bf16, f16}; use crate::{ @@ -41,10 +41,11 @@ pub trait Float: + Powf + Powi + Sqrt - + Rsqrt + + InverseSqrt + Round + Floor + Ceil + + Trunc + Erf + Recip + Magnitude @@ -86,7 +87,7 @@ pub trait Float: macro_rules! impl_float { (half $primitive:ident, $kind:ident) => { - impl_float!($primitive, $kind, |val| $primitive::from_f32(val)); + impl_float!($primitive, $kind, |val| $primitive::from_f64(val)); }; ($primitive:ident, $kind:ident) => { impl_float!($primitive, $kind, |val| val as $primitive); @@ -101,6 +102,13 @@ macro_rules! impl_float { fn as_type_native() -> Option { Some(StorageType::Scalar(ElemType::Float(FloatKind::$kind))) } + + fn from_const_value(value: ConstantScalarValue) -> Self { + let ConstantScalarValue::Float(value, _) = value else { + unreachable!() + }; + $new(value) + } } impl IntoRuntime for $primitive { @@ -146,7 +154,7 @@ macro_rules! impl_float { const RADIX: u32 = $primitive::RADIX; fn new(val: f32) -> Self { - $new(val) + $new(val as f64) } } }; diff --git a/crates/cubecl-core/src/frontend/element/float/fp4.rs b/crates/cubecl-core/src/frontend/element/float/fp4.rs index 7fdc40e06..3c2b9dfb9 100644 --- a/crates/cubecl-core/src/frontend/element/float/fp4.rs +++ b/crates/cubecl-core/src/frontend/element/float/fp4.rs @@ -1,5 +1,5 @@ use cubecl_common::{e2m1, e2m1x2}; -use cubecl_ir::{ElemType, ExpandElement, FloatKind, Scope, StorageType}; +use cubecl_ir::{ConstantScalarValue, ElemType, ExpandElement, FloatKind, Scope, StorageType}; use crate::{ Runtime, @@ -19,6 +19,13 @@ impl CubePrimitive for e2m1 { fn as_type_native() -> Option { Some(StorageType::Scalar(ElemType::Float(FloatKind::E2M1))) } + + fn from_const_value(value: ConstantScalarValue) -> Self { + let ConstantScalarValue::Float(value, _) = value else { + unreachable!() + }; + e2m1::from_f64(value) + } } impl IntoRuntime for e2m1 { @@ -43,6 +50,15 @@ impl CubePrimitive for e2m1x2 { fn as_type_native() -> Option { Some(StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2)) } + + fn from_const_value(value: ConstantScalarValue) -> Self { + let ConstantScalarValue::Float(value, _) = value else { + unreachable!() + }; + let val = e2m1::from_f64(value).to_bits(); + // Fill both values, not sure this is ever useful but it works + e2m1x2::from_bits(val | (val << 4)) + } } impl IntoRuntime for e2m1x2 { diff --git a/crates/cubecl-core/src/frontend/element/float/fp6.rs b/crates/cubecl-core/src/frontend/element/float/fp6.rs index 359aaddd2..0ba628b6f 100644 --- a/crates/cubecl-core/src/frontend/element/float/fp6.rs +++ b/crates/cubecl-core/src/frontend/element/float/fp6.rs @@ -1,5 +1,5 @@ use cubecl_common::{e2m3, e3m2}; -use cubecl_ir::{ElemType, ExpandElement, FloatKind, Scope, StorageType}; +use cubecl_ir::{ConstantScalarValue, ElemType, ExpandElement, FloatKind, Scope, StorageType}; use crate::prelude::{ CubePrimitive, CubeType, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime, @@ -15,6 +15,10 @@ impl CubePrimitive for e2m3 { fn as_type_native() -> Option { Some(ElemType::Float(FloatKind::E2M3).into()) } + + fn from_const_value(_value: ConstantScalarValue) -> Self { + unimplemented!("e2m3 doesn't yet support conversion"); + } } impl IntoRuntime for e2m3 { @@ -39,6 +43,10 @@ impl CubePrimitive for e3m2 { fn as_type_native() -> Option { Some(ElemType::Float(FloatKind::E3M2).into()) } + + fn from_const_value(_value: ConstantScalarValue) -> Self { + unimplemented!("e3m2 doesn't yet support conversion"); + } } impl IntoRuntime for e3m2 { diff --git a/crates/cubecl-core/src/frontend/element/float/fp8.rs b/crates/cubecl-core/src/frontend/element/float/fp8.rs index e8de01a63..8c329c4ef 100644 --- a/crates/cubecl-core/src/frontend/element/float/fp8.rs +++ b/crates/cubecl-core/src/frontend/element/float/fp8.rs @@ -1,5 +1,5 @@ use cubecl_common::{e4m3, e5m2, ue8m0}; -use cubecl_ir::{ElemType, ExpandElement, FloatKind, Scope, StorageType}; +use cubecl_ir::{ConstantScalarValue, ElemType, ExpandElement, FloatKind, Scope, StorageType}; use crate::{ Runtime, @@ -19,6 +19,13 @@ impl CubePrimitive for e4m3 { fn as_type_native() -> Option { Some(ElemType::Float(FloatKind::E4M3).into()) } + + fn from_const_value(value: ConstantScalarValue) -> Self { + let ConstantScalarValue::Float(value, _) = value else { + unreachable!() + }; + e4m3::from_f64(value) + } } impl IntoRuntime for e4m3 { @@ -58,6 +65,13 @@ impl CubePrimitive for e5m2 { fn as_type_native() -> Option { Some(ElemType::Float(FloatKind::E5M2).into()) } + + fn from_const_value(value: ConstantScalarValue) -> Self { + let ConstantScalarValue::Float(value, _) = value else { + unreachable!() + }; + e5m2::from_f64(value) + } } impl IntoRuntime for e5m2 { @@ -97,6 +111,13 @@ impl CubePrimitive for ue8m0 { fn as_type_native() -> Option { Some(ElemType::Float(FloatKind::UE8M0).into()) } + + fn from_const_value(value: ConstantScalarValue) -> Self { + let ConstantScalarValue::Float(value, _) = value else { + unreachable!() + }; + ue8m0::from_f64(value) + } } impl IntoRuntime for ue8m0 { diff --git a/crates/cubecl-core/src/frontend/element/float/relaxed.rs b/crates/cubecl-core/src/frontend/element/float/relaxed.rs index ce854977a..c2bb5923e 100644 --- a/crates/cubecl-core/src/frontend/element/float/relaxed.rs +++ b/crates/cubecl-core/src/frontend/element/float/relaxed.rs @@ -1,5 +1,5 @@ use cubecl_common::flex32; -use cubecl_ir::{ElemType, ExpandElement, FloatKind, Scope, StorageType}; +use cubecl_ir::{ConstantScalarValue, ElemType, ExpandElement, FloatKind, Scope, StorageType}; use crate::prelude::{Numeric, into_runtime_expand_element}; @@ -17,6 +17,13 @@ impl CubePrimitive for flex32 { fn as_type_native() -> Option { Some(ElemType::Float(FloatKind::Flex32).into()) } + + fn from_const_value(value: ConstantScalarValue) -> Self { + let ConstantScalarValue::Float(value, _) = value else { + unreachable!() + }; + flex32::from_f64(value) + } } impl IntoRuntime for flex32 { diff --git a/crates/cubecl-core/src/frontend/element/float/tensor_float.rs b/crates/cubecl-core/src/frontend/element/float/tensor_float.rs index 3e94ba873..ac15924cd 100644 --- a/crates/cubecl-core/src/frontend/element/float/tensor_float.rs +++ b/crates/cubecl-core/src/frontend/element/float/tensor_float.rs @@ -1,5 +1,5 @@ use cubecl_common::tf32; -use cubecl_ir::{ElemType, ExpandElement, FloatKind, Scope, StorageType}; +use cubecl_ir::{ConstantScalarValue, ElemType, ExpandElement, FloatKind, Scope, StorageType}; use half::f16; use crate::prelude::{Numeric, into_runtime_expand_element}; @@ -18,6 +18,13 @@ impl CubePrimitive for tf32 { fn as_type_native() -> Option { Some(ElemType::Float(FloatKind::TF32).into()) } + + fn from_const_value(value: ConstantScalarValue) -> Self { + let ConstantScalarValue::Float(value, _) = value else { + unreachable!() + }; + tf32::from_f64(value) + } } impl IntoRuntime for tf32 { diff --git a/crates/cubecl-core/src/frontend/element/float/typemap.rs b/crates/cubecl-core/src/frontend/element/float/typemap.rs index 52a7b622b..78622a836 100644 --- a/crates/cubecl-core/src/frontend/element/float/typemap.rs +++ b/crates/cubecl-core/src/frontend/element/float/typemap.rs @@ -184,6 +184,10 @@ impl CubePrimitive for ElemExpand { fn as_type(scope: &Scope) -> StorageType { scope.resolve_type::().expect("Type to be registered") } + + fn from_const_value(_value: ConstantScalarValue) -> Self { + unimplemented!("Can't turn `ElemExpand` into a constant value") + } } impl From> for Variable { @@ -259,10 +263,11 @@ impl ArcTan2 for ElemExpand {} impl Powf for ElemExpand {} impl Powi for ElemExpand {} impl Sqrt for ElemExpand {} -impl Rsqrt for ElemExpand {} +impl InverseSqrt for ElemExpand {} impl Round for ElemExpand {} impl Floor for ElemExpand {} impl Ceil for ElemExpand {} +impl Trunc for ElemExpand {} impl IsNan for ElemExpand {} impl IsInf for ElemExpand {} diff --git a/crates/cubecl-core/src/frontend/element/int.rs b/crates/cubecl-core/src/frontend/element/int.rs index a28a440d6..3d00fccaa 100644 --- a/crates/cubecl-core/src/frontend/element/int.rs +++ b/crates/cubecl-core/src/frontend/element/int.rs @@ -1,4 +1,4 @@ -use cubecl_ir::{ExpandElement, StorageType}; +use cubecl_ir::{ConstantScalarValue, ExpandElement, StorageType}; use crate::Runtime; use crate::frontend::{CubeType, Numeric}; @@ -70,6 +70,13 @@ macro_rules! impl_int { fn as_type_native() -> Option { Some(ElemType::Int(IntKind::$kind).into()) } + + fn from_const_value(value: ConstantScalarValue) -> Self { + let ConstantScalarValue::Int(value, _) = value else { + unreachable!() + }; + value as $type + } } impl IntoRuntime for $type { diff --git a/crates/cubecl-core/src/frontend/element/int/typemap.rs b/crates/cubecl-core/src/frontend/element/int/typemap.rs index cd3330db0..835741b49 100644 --- a/crates/cubecl-core/src/frontend/element/int/typemap.rs +++ b/crates/cubecl-core/src/frontend/element/int/typemap.rs @@ -1,6 +1,8 @@ use bytemuck::{Pod, Zeroable}; use core::ops::*; -use cubecl_ir::{ElemType, ExpandElement, IntKind, Scope, StorageType, Variable}; +use cubecl_ir::{ + ConstantScalarValue, ElemType, ExpandElement, IntKind, Scope, StorageType, Variable, +}; use derive_more::derive::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Display, Div, DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, @@ -141,6 +143,10 @@ impl CubePrimitive for IntExpand { fn as_type(scope: &Scope) -> StorageType { scope.resolve_type::().expect("Type to be registered") } + + fn from_const_value(_value: ConstantScalarValue) -> Self { + unimplemented!("Can't turn `IntExpand` into a constant value") + } } impl From> for Variable { diff --git a/crates/cubecl-core/src/frontend/element/uint.rs b/crates/cubecl-core/src/frontend/element/uint.rs index 837fdae84..3c409181a 100644 --- a/crates/cubecl-core/src/frontend/element/uint.rs +++ b/crates/cubecl-core/src/frontend/element/uint.rs @@ -1,4 +1,4 @@ -use cubecl_ir::{ExpandElement, Scope, StorageType, UIntKind}; +use cubecl_ir::{ConstantScalarValue, ExpandElement, Scope, StorageType, UIntKind}; use crate::Runtime; use crate::frontend::{CubePrimitive, CubeType, Numeric}; @@ -20,6 +20,13 @@ macro_rules! declare_uint { fn as_type_native() -> Option { Some(ElemType::UInt(UIntKind::$kind).into()) } + + fn from_const_value(value: ConstantScalarValue) -> Self { + let ConstantScalarValue::UInt(value, _) = value else { + unreachable!() + }; + value as $primitive + } } impl IntoRuntime for $primitive { diff --git a/crates/cubecl-core/src/frontend/mod.rs b/crates/cubecl-core/src/frontend/mod.rs index 7ec6f71dd..233107162 100644 --- a/crates/cubecl-core/src/frontend/mod.rs +++ b/crates/cubecl-core/src/frontend/mod.rs @@ -4,7 +4,6 @@ pub mod cmma; pub mod synchronization; mod base; -mod comment; pub mod comptime_error; mod const_expand; mod container; @@ -19,7 +18,6 @@ mod polyfills; mod topology; pub use branch::{RangeExpand, SteppedRangeExpand, range, range_stepped}; -pub use comment::*; pub use const_expand::*; pub use container::*; pub use debug::*; diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index 82e482249..cdb3b776a 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -332,10 +332,10 @@ impl_unary_func!( f64 ); impl_unary_func!( - Rsqrt, - rsqrt, - __expand_rsqrt, - Arithmetic::Rsqrt, + InverseSqrt, + inverse_sqrt, + __expand_inverse_sqrt, + Arithmetic::InverseSqrt, f16, bf16, flex32, @@ -379,6 +379,18 @@ impl_unary_func!( f32, f64 ); +impl_unary_func!( + Trunc, + trunc, + __expand_trunc, + Arithmetic::Trunc, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); impl_unary_func!( Erf, erf, diff --git a/crates/cubecl-core/src/frontend/options.rs b/crates/cubecl-core/src/frontend/options.rs index 832e3defe..d20a93979 100644 --- a/crates/cubecl-core/src/frontend/options.rs +++ b/crates/cubecl-core/src/frontend/options.rs @@ -1,40 +1,15 @@ -use enumset::{EnumSet, EnumSetType}; -use serde::{Deserialize, Serialize}; +use cubecl_ir::{FastMath, Scope}; +use enumset::EnumSet; -/// Unchecked optimizations for float operations. May cause precision differences, or undefined -/// behaviour if the relevant conditions are not followed. -#[derive(Default, Debug, Hash, Serialize, Deserialize, EnumSetType)] -pub enum FastMath { - /// Disable unsafe optimizations - #[default] - None, - /// Assume values are never `NaN`. If they are, the result is considered undefined behaviour. - NotNaN, - /// Assume values are never `Inf`/`-Inf`. If they are, the result is considered undefined - /// behaviour. - NotInf, - /// Ignore sign on zero values. - UnsignedZero, - /// Allow swapping float division with a reciprocal, even if that swap would change precision. - AllowReciprocal, - /// Allow contracting float operations into fewer operations, even if the precision could - /// change. - AllowContraction, - /// Allow reassociation for float operations, even if the precision could change. - AllowReassociation, - /// Allow all mathematical transformations for float operations, including contraction and - /// reassociation, even if the precision could change. - AllowTransform, - /// Allow using lower precision intrinsics (CUDA `--use_fast_math`) - /// Also impacts `NaN`, `Inf` and signed zero handling, as well as subnormals and rounding. - /// - /// Notable edge case: - /// powf - Returns `NaN` for negative bases - ReducedPrecision, -} +pub fn fast_math_expand( + scope: &mut Scope, + value: EnumSet, + body: impl FnOnce(&mut Scope) -> R, +) -> R { + let prev = scope.modes.borrow().fp_math_mode; + scope.modes.borrow_mut().fp_math_mode = value; + let res = body(scope); + scope.modes.borrow_mut().fp_math_mode = prev; -impl FastMath { - pub fn all() -> EnumSet { - EnumSet::all() - } + res } diff --git a/crates/cubecl-core/src/frontend/plane.rs b/crates/cubecl-core/src/frontend/plane.rs index a9780a203..6acab78f8 100644 --- a/crates/cubecl-core/src/frontend/plane.rs +++ b/crates/cubecl-core/src/frontend/plane.rs @@ -60,6 +60,153 @@ pub mod plane_broadcast { } } +/// Perform an arbitrary lane shuffle operation across the plane. +/// Each unit reads the value from the specified source lane. +/// +/// # Example +/// `plane_shuffle(value, 0)` - all lanes read from lane 0 (same as broadcast) +/// `plane_shuffle(value, lane_id ^ 1)` - butterfly pattern (same as shuffle_xor) +#[allow(unused_variables)] +pub fn plane_shuffle(value: E, src_lane: u32) -> E { + unexpanded!() +} + +/// Module containing the expand function for [plane_shuffle()]. +pub mod plane_shuffle { + + use super::*; + + /// Expand method of [plane_shuffle()]. + pub fn expand( + scope: &mut Scope, + value: ExpandElementTyped, + src_lane: ExpandElementTyped, + ) -> ExpandElementTyped { + let output = scope.create_local(value.expand.ty); + let out = *output; + let lhs = *value.expand; + let rhs = *src_lane.expand; + + scope.register(Instruction::new( + Plane::Shuffle(crate::ir::BinaryOperator { lhs, rhs }), + out, + )); + + output.into() + } +} + +/// Perform a shuffle XOR operation across the plane. +/// Each unit exchanges its value with another unit at an index determined by XOR with the mask. +/// This is useful for butterfly reduction patterns. +/// +/// # Example +/// For a 32-lane warp with mask=1: +/// - Lane 0 gets value from lane 1, lane 1 gets value from lane 0 +/// - Lane 2 gets value from lane 3, lane 3 gets value from lane 2 +/// - etc. +#[allow(unused_variables)] +pub fn plane_shuffle_xor(value: E, mask: u32) -> E { + unexpanded!() +} + +/// Module containing the expand function for [plane_shuffle_xor()]. +pub mod plane_shuffle_xor { + + use super::*; + + /// Expand method of [plane_shuffle_xor()]. + pub fn expand( + scope: &mut Scope, + value: ExpandElementTyped, + mask: ExpandElementTyped, + ) -> ExpandElementTyped { + let output = scope.create_local(value.expand.ty); + let out = *output; + let lhs = *value.expand; + let rhs = *mask.expand; + + scope.register(Instruction::new( + Plane::ShuffleXor(crate::ir::BinaryOperator { lhs, rhs }), + out, + )); + + output.into() + } +} + +/// Perform a shuffle up operation across the plane. +/// Each unit reads the value from a unit with a lower lane ID (current_id - delta). +/// Units with lane_id < delta will read from themselves (no change). +/// +/// # Example +/// For delta=1: `[a, b, c, d] -> [a, a, b, c]` +#[allow(unused_variables)] +pub fn plane_shuffle_up(value: E, delta: u32) -> E { + unexpanded!() +} + +/// Module containing the expand function for [plane_shuffle_up()]. +pub mod plane_shuffle_up { + + use super::*; + + /// Expand method of [plane_shuffle_up()]. + pub fn expand( + scope: &mut Scope, + value: ExpandElementTyped, + delta: ExpandElementTyped, + ) -> ExpandElementTyped { + let output = scope.create_local(value.expand.ty); + let out = *output; + let lhs = *value.expand; + let rhs = *delta.expand; + + scope.register(Instruction::new( + Plane::ShuffleUp(crate::ir::BinaryOperator { lhs, rhs }), + out, + )); + + output.into() + } +} + +/// Perform a shuffle down operation across the plane. +/// Each unit reads the value from a unit with a higher lane ID (current_id + delta). +/// Units at the end will read from themselves if (lane_id + delta >= plane_dim). +/// +/// # Example +/// For delta=1: `[a, b, c, d] -> [b, c, d, d]` +#[allow(unused_variables)] +pub fn plane_shuffle_down(value: E, delta: u32) -> E { + unexpanded!() +} + +/// Module containing the expand function for [plane_shuffle_down()]. +pub mod plane_shuffle_down { + + use super::*; + + /// Expand method of [plane_shuffle_down()]. + pub fn expand( + scope: &mut Scope, + value: ExpandElementTyped, + delta: ExpandElementTyped, + ) -> ExpandElementTyped { + let output = scope.create_local(value.expand.ty); + let out = *output; + let lhs = *value.expand; + let rhs = *delta.expand; + + scope.register(Instruction::new( + Plane::ShuffleDown(crate::ir::BinaryOperator { lhs, rhs }), + out, + )); + + output.into() + } +} + /// Perform a reduce sum operation across all units in a plane. #[allow(unused_variables)] pub fn plane_sum(value: E) -> E { diff --git a/crates/cubecl-core/src/id.rs b/crates/cubecl-core/src/id.rs index 61ca5bca9..c4d950d21 100644 --- a/crates/cubecl-core/src/id.rs +++ b/crates/cubecl-core/src/id.rs @@ -10,10 +10,7 @@ pub struct CubeTuneId { impl CubeTuneId { /// Create a new ID. - pub fn new( - client: &ComputeClient, - device: &R::Device, - ) -> Self { + pub fn new(client: &ComputeClient, device: &R::Device) -> Self { Self { device: device.to_id(), name: R::name(client), diff --git a/crates/cubecl-core/src/post_processing/unroll.rs b/crates/cubecl-core/src/post_processing/unroll.rs index 66cf0c49c..a5b15e50b 100644 --- a/crates/cubecl-core/src/post_processing/unroll.rs +++ b/crates/cubecl-core/src/post_processing/unroll.rs @@ -44,6 +44,10 @@ impl UnrollProcessor { inst: &Instruction, mappings: &mut Mappings, ) -> TransformAction { + if matches!(inst.operation, Operation::Marker(_)) { + return TransformAction::Ignore; + } + if inst.operation.args().is_none() { // Detect unhandled ops that can't be reflected match &inst.operation { @@ -81,7 +85,7 @@ impl UnrollProcessor { } _ => return TransformAction::Ignore, }, - Operation::Branch(_) | Operation::NonSemantic(_) => { + Operation::Branch(_) | Operation::NonSemantic(_) | Operation::Marker(_) => { return TransformAction::Ignore; } _ => { @@ -505,6 +509,7 @@ impl UnrollProcessor { Instruction { out, source_loc: inst.source_loc.clone(), + modes: inst.modes, operation, } }) @@ -623,7 +628,7 @@ fn create_unrolled( allocator.create_local_mut(item) } VariableKind::LocalConst { .. } => allocator.create_local(item), - _ => panic!("Out must be local"), + other => panic!("Out must be local, found {other:?}"), }) .collect() } diff --git a/crates/cubecl-core/src/prelude.rs b/crates/cubecl-core/src/prelude.rs index c7b55e862..4be59013c 100644 --- a/crates/cubecl-core/src/prelude.rs +++ b/crates/cubecl-core/src/prelude.rs @@ -24,4 +24,4 @@ pub use cubecl_runtime::server::CubeCount; pub use crate::frontend::*; pub use crate::{comment, comptime, comptime_type, derive_cube_comptime, terminate}; pub use cubecl_common::{CubeDim, ExecutionMode, flex32, tf32}; -pub use cubecl_ir::Scope; +pub use cubecl_ir::{FastMath, Scope}; diff --git a/crates/cubecl-core/src/runtime.rs b/crates/cubecl-core/src/runtime.rs index 025af4ed3..a9462705b 100644 --- a/crates/cubecl-core/src/runtime.rs +++ b/crates/cubecl-core/src/runtime.rs @@ -2,9 +2,8 @@ use crate::codegen::Compiler; use crate::compute::CubeTask; use cubecl_common::device::Device; use cubecl_ir::{StorageType, TargetProperties}; -use cubecl_runtime::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer}; +use cubecl_runtime::{client::ComputeClient, server::ComputeServer}; -pub use cubecl_runtime::channel; pub use cubecl_runtime::client; pub use cubecl_runtime::server; pub use cubecl_runtime::tune; @@ -19,16 +18,14 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug { type Compiler: Compiler; /// The compute server used to run kernels and perform autotuning. type Server: ComputeServer>>; - /// The channel used to communicate with the compute server. - type Channel: ComputeChannel; /// The device used to retrieve the compute client. type Device: Device; /// Retrieve the compute client from the runtime device. - fn client(device: &Self::Device) -> ComputeClient; + fn client(device: &Self::Device) -> ComputeClient; /// The runtime name on the given device. - fn name(client: &ComputeClient) -> &'static str; + fn name(client: &ComputeClient) -> &'static str; /// Return true if global input array lengths should be added to kernel info. fn require_array_lengths() -> bool { @@ -38,6 +35,11 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug { /// Returns the supported line sizes for the current runtime's compiler. fn supported_line_sizes() -> &'static [u8]; + /// The maximum line size that can be used for global buffer bindings. + fn max_global_line_size() -> u8 { + u8::MAX + } + /// Returns all line sizes that are useful to perform optimal IO operation on the given element. fn io_optimized_line_sizes(elem: &StorageType) -> impl Iterator + Clone { let max = (LOAD_WIDTH / elem.size_bits()) as u8; @@ -48,9 +50,15 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug { /// Returns all line sizes that are useful to perform optimal IO operation on the given element. /// Ignores native support, and allows all line sizes. This means the returned size may be /// unrolled, and may not support dynamic indexing. - fn io_optimized_line_sizes_unchecked(elem: &StorageType) -> impl Iterator + Clone { - let max = LOAD_WIDTH / elem.size_bits(); - (1..max as u8).rev().filter(|v| v.is_power_of_two()) + fn io_optimized_line_sizes_unchecked(size: usize) -> impl Iterator + Clone { + let size_bits = size * 8; + let max = LOAD_WIDTH / size_bits; + let max = usize::min(Self::max_global_line_size() as usize, max); + + // If the max is 8, we want to test 1, 2, 4, 8 which is log2(8) + 1. + let num_candidates = f32::log2(max as f32) as u32 + 1; + + (0..num_candidates).map(|i| 2u8.pow(i)).rev() } /// Returns the maximum cube count on each dimension that can be launched. diff --git a/crates/cubecl-core/src/runtime_tests/assign.rs b/crates/cubecl-core/src/runtime_tests/assign.rs index 44ef0bb1c..a1de2146f 100644 --- a/crates/cubecl-core/src/runtime_tests/assign.rs +++ b/crates/cubecl-core/src/runtime_tests/assign.rs @@ -32,7 +32,7 @@ pub fn kernel_add_assign_line(output: &mut Array>) { } pub fn test_kernel_assign_scalar( - client: ComputeClient, + client: ComputeClient, ) { let handle = client.create(F::as_bytes(&[F::new(0.0), F::new(1.0)])); @@ -52,7 +52,7 @@ pub fn test_kernel_assign_scalar( } pub fn test_kernel_add_assign_array( - client: ComputeClient, + client: ComputeClient, ) { let handle = client.create(F::as_bytes(&[F::new(0.0), F::new(1.0)])); @@ -72,7 +72,7 @@ pub fn test_kernel_add_assign_array( } pub fn test_kernel_add_assign_line( - client: ComputeClient, + client: ComputeClient, ) { let handle = client.create(F::as_bytes(&[F::new(0.0), F::new(1.0)])); diff --git a/crates/cubecl-core/src/runtime_tests/atomic.rs b/crates/cubecl-core/src/runtime_tests/atomic.rs index 99d8425d0..b5e7ebd8b 100644 --- a/crates/cubecl-core/src/runtime_tests/atomic.rs +++ b/crates/cubecl-core/src/runtime_tests/atomic.rs @@ -12,7 +12,7 @@ pub fn kernel_atomic_add(output: &mut Array>) { } fn supports_feature( - client: &ComputeClient, + client: &ComputeClient, feat: TypeUsage, ) -> bool { let ty = StorageType::Atomic(F::as_type_native_unchecked().elem_type()); @@ -20,7 +20,7 @@ fn supports_feature( } pub fn test_kernel_atomic_add( - client: ComputeClient, + client: ComputeClient, ) { if !supports_feature::(&client, TypeUsage::AtomicAdd) { println!( @@ -52,7 +52,7 @@ pub fn kernel_atomic_min(output: &mut Array>) { } pub fn test_kernel_atomic_min( - client: ComputeClient, + client: ComputeClient, ) { if !supports_feature::(&client, TypeUsage::AtomicMinMax) { println!( @@ -84,7 +84,7 @@ pub fn kernel_atomic_max(output: &mut Array>) { } pub fn test_kernel_atomic_max( - client: ComputeClient, + client: ComputeClient, ) { if !supports_feature::(&client, TypeUsage::AtomicMinMax) { println!( diff --git a/crates/cubecl-core/src/runtime_tests/barrier.rs b/crates/cubecl-core/src/runtime_tests/barrier.rs index b09595194..2106863b4 100644 --- a/crates/cubecl-core/src/runtime_tests/barrier.rs +++ b/crates/cubecl-core/src/runtime_tests/barrier.rs @@ -17,9 +17,7 @@ pub fn async_copy_test(input: &Array>, output: &mut Array( - client: ComputeClient, -) { +pub fn test_async_copy(client: ComputeClient) { if !client.properties().supports_type(SemanticType::Barrier) { // We can't execute the test, skip. return; @@ -130,9 +128,7 @@ fn two_independent_loads( output[UNIT_POS_X] = dot; } -pub fn test_memcpy_one_load( - client: ComputeClient, -) { +pub fn test_memcpy_one_load(client: ComputeClient) { if !client.properties().supports_type(SemanticType::Barrier) { // We can't execute the test, skip. return; @@ -160,7 +156,7 @@ pub fn test_memcpy_one_load( pub fn test_memcpy_two_loads( independent: bool, - client: ComputeClient, + client: ComputeClient, ) { if !client.properties().supports_type(SemanticType::Barrier) { // We can't execute the test, skip. diff --git a/crates/cubecl-core/src/runtime_tests/binary.rs b/crates/cubecl-core/src/runtime_tests/binary.rs index df88cb787..3c7a9d8e8 100644 --- a/crates/cubecl-core/src/runtime_tests/binary.rs +++ b/crates/cubecl-core/src/runtime_tests/binary.rs @@ -1,16 +1,17 @@ -use std::fmt::Display; +use std::{fmt::Display, sync::LazyLock}; use crate::{self as cubecl, as_type}; use cubecl::prelude::*; use cubecl_runtime::server::Handle; +use enumset::EnumSet; #[track_caller] pub(crate) fn assert_equals_approx< R: Runtime, F: Float + num_traits::Float + CubeElement + Display, >( - client: &ComputeClient, + client: &ComputeClient, output: Handle, expected: &[F], epsilon: f32, @@ -45,6 +46,10 @@ expected: {:?}", } } +// Needs lazy because const trait fns aren't stable +static FAST_MATH: LazyLock> = + LazyLock::new(|| FastMath::all().difference(FastMath::NotNaN.into())); + macro_rules! test_binary_impl { ( $test_name:ident, @@ -57,8 +62,8 @@ macro_rules! test_binary_impl { rhs: $rhs:expr, expected: $expected:expr }),*]) => { - pub fn $test_name(client: ComputeClient) { - #[cube(launch_unchecked, fast_math = FastMath::AllowTransform | FastMath::UnsignedZero)] + pub fn $test_name(client: ComputeClient) { + #[cube(launch_unchecked, fast_math = *FAST_MATH)] fn test_function<$float_type: Float>(lhs: &Array<$float_type>, rhs: &Array<$float_type>, output: &mut Array<$float_type>) { if ABSOLUTE_POS < rhs.len() { output[ABSOLUTE_POS] = $binary_func(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]); @@ -201,7 +206,7 @@ macro_rules! test_powi_impl { rhs: $rhs:expr, expected: $expected:expr }),*]) => { - pub fn $test_name(client: ComputeClient) { + pub fn $test_name(client: ComputeClient) { $( { let lhs = $lhs; @@ -270,7 +275,7 @@ macro_rules! test_mulhi_impl { rhs: $rhs:expr, expected: $expected:expr }),*]) => { - pub fn $test_name(client: ComputeClient) { + pub fn $test_name(client: ComputeClient) { $( { let lhs = $lhs; diff --git a/crates/cubecl-core/src/runtime_tests/branch.rs b/crates/cubecl-core/src/runtime_tests/branch.rs index d602676a0..466de7f2b 100644 --- a/crates/cubecl-core/src/runtime_tests/branch.rs +++ b/crates/cubecl-core/src/runtime_tests/branch.rs @@ -48,9 +48,7 @@ pub fn kernel_select(output: &mut Array, cond: u32) { } } -pub fn test_switch_statement( - client: ComputeClient, -) { +pub fn test_switch_statement(client: ComputeClient) { let handle = client.create(as_bytes![F: 0.0, 1.0]); let vectorization = 1; @@ -72,7 +70,7 @@ pub fn test_switch_statement( } pub fn test_switch_used_as_value( - client: ComputeClient, + client: ComputeClient, ) { let handle = client.create(as_bytes![F: 0.0, 1.0]); @@ -92,9 +90,7 @@ pub fn test_switch_used_as_value( assert_eq!(actual[0], F::new(3.0)); } -pub fn test_switch_default( - client: ComputeClient, -) { +pub fn test_switch_default(client: ComputeClient) { let handle = client.create(as_bytes![F: 0.0, 1.0]); let vectorization = 2; @@ -113,9 +109,7 @@ pub fn test_switch_default( assert_eq!(actual[0], F::new(5.0)); } -pub fn test_switch_or_branch( - client: ComputeClient, -) { +pub fn test_switch_or_branch(client: ComputeClient) { let handle = client.create(as_bytes![F: 0.0, 1.0]); let vectorization = 2; @@ -135,7 +129,7 @@ pub fn test_switch_or_branch( } pub fn test_select( - client: ComputeClient, + client: ComputeClient, cond: bool, ) { let handle = client.create(as_bytes![F: 0.0]); diff --git a/crates/cubecl-core/src/runtime_tests/cluster.rs b/crates/cubecl-core/src/runtime_tests/cluster.rs index bb3c995ed..3194e9fe9 100644 --- a/crates/cubecl-core/src/runtime_tests/cluster.rs +++ b/crates/cubecl-core/src/runtime_tests/cluster.rs @@ -20,7 +20,7 @@ fn cluster_meta_kernel(out: &mut Array) { } } -pub fn test_cluster_meta(client: ComputeClient) { +pub fn test_cluster_meta(client: ComputeClient) { if !client.properties().features.cube_cluster { return; } diff --git a/crates/cubecl-core/src/runtime_tests/cmma.rs b/crates/cubecl-core/src/runtime_tests/cmma.rs index c37f466ca..eb84e6ed3 100644 --- a/crates/cubecl-core/src/runtime_tests/cmma.rs +++ b/crates/cubecl-core/src/runtime_tests/cmma.rs @@ -301,10 +301,7 @@ pub fn cast_matrix_bf16(input: &Array, out: &mut Array) { ); } -pub fn test_simple_1_lined( - client: ComputeClient, - cube_dimensions: CubeDim, -) { +pub fn test_simple_1_lined(client: ComputeClient, cube_dimensions: CubeDim) { if !client.properties().features.cmma.contains(&MmaConfig { a_type: ElemType::Float(FloatKind::F16).into(), b_type: ElemType::Float(FloatKind::F16).into(), @@ -342,7 +339,7 @@ pub fn test_simple_1_lined( } pub fn test_simple_1_lined_offset( - client: ComputeClient, + client: ComputeClient, cube_dimensions: CubeDim, ) { if !client.properties().features.cmma.contains(&MmaConfig { @@ -399,10 +396,7 @@ pub fn test_simple_1_lined_offset( ); } -pub fn test_simple_1( - client: ComputeClient, - cube_dimensions: CubeDim, -) { +pub fn test_simple_1(client: ComputeClient, cube_dimensions: CubeDim) { if !client.properties().features.cmma.contains(&MmaConfig { a_type: ElemType::Float(FloatKind::F16).into(), b_type: ElemType::Float(FloatKind::F16).into(), @@ -466,7 +460,7 @@ pub fn test_simple_1_expected() -> Vec { } // pub fn test_simple_2( -// client: ComputeClient, +// client: ComputeClient, // cube_dimensions: CubeDim, // ) { // if !client.properties().features.cmma.contains(&MmaConfig { @@ -507,10 +501,7 @@ pub fn test_simple_1_expected() -> Vec { // assert_eq!(expected, actual); // } -pub fn test_cmma_cast_f16( - client: ComputeClient, - cube_dimensions: CubeDim, -) { +pub fn test_cmma_cast_f16(client: ComputeClient, cube_dimensions: CubeDim) { if !client.properties().features.cmma.contains(&MmaConfig { a_type: ElemType::Float(FloatKind::F16).into(), b_type: ElemType::Float(FloatKind::F16).into(), @@ -544,10 +535,7 @@ pub fn test_cmma_cast_f16( assert_eq!(actual, expected); } -pub fn test_cmma_cast_bf16( - client: ComputeClient, - cube_dimensions: CubeDim, -) { +pub fn test_cmma_cast_bf16(client: ComputeClient, cube_dimensions: CubeDim) { if !client.properties().features.cmma.contains(&MmaConfig { a_type: ElemType::Float(FloatKind::BF16).into(), b_type: ElemType::Float(FloatKind::BF16).into(), @@ -581,10 +569,7 @@ pub fn test_cmma_cast_bf16( assert_eq!(actual, expected); } -pub fn test_simple_tf32( - client: ComputeClient, - cube_dimensions: CubeDim, -) { +pub fn test_simple_tf32(client: ComputeClient, cube_dimensions: CubeDim) { if !client.properties().features.cmma.contains(&MmaConfig { a_type: ElemType::Float(FloatKind::TF32).into(), b_type: ElemType::Float(FloatKind::TF32).into(), @@ -687,10 +672,7 @@ pub fn kernel_strided( ); } -pub fn test_cmma_strided( - client: ComputeClient, - cube_dimensions: CubeDim, -) { +pub fn test_cmma_strided(client: ComputeClient, cube_dimensions: CubeDim) { // Lhs (row major) will have strided tiles let (m, n, k) = (16, 16, 32); let (t_m, t_n, t_k) = (16, 16, 16); @@ -860,7 +842,7 @@ pub fn test_cmma_manual< B: CubeElement + Numeric, CD: CubeElement + Numeric, >( - client: ComputeClient, + client: ComputeClient, cube_dimensions: CubeDim, (m, n, k): (usize, usize, usize), ) { @@ -1057,7 +1039,7 @@ pub fn kernel_scaled( - client: ComputeClient, + client: ComputeClient, cube_dimensions: CubeDim, (m, n, k): (usize, usize, usize), scales_factor: usize, @@ -1169,7 +1151,7 @@ pub fn test_cmma_scaled( - client: ComputeClient, + client: ComputeClient, cube_dimensions: CubeDim, (m, n, k): (usize, usize, usize), scales_factor: usize, @@ -1442,7 +1424,7 @@ macro_rules! testgen_cmma { test(16, 8, 64, 2); } - fn cube_dim(client: &ComputeClient) -> CubeDim { + fn cube_dim(client: &ComputeClient) -> CubeDim { let plane_dim = client.properties().hardware.plane_size_max; CubeDim::new(plane_dim, 1, 1) } diff --git a/crates/cubecl-core/src/runtime_tests/comparison.rs b/crates/cubecl-core/src/runtime_tests/comparison.rs index 324a77b61..f8b8c38c8 100644 --- a/crates/cubecl-core/src/runtime_tests/comparison.rs +++ b/crates/cubecl-core/src/runtime_tests/comparison.rs @@ -11,7 +11,7 @@ macro_rules! test_binary_impl { lhs: $lhs:expr, rhs: $rhs:expr, }),*]) => { - pub fn $test_name(client: ComputeClient) { + pub fn $test_name(client: ComputeClient) { #[cube(launch_unchecked, fast_math = FastMath::all())] fn test_function(lhs: &Array<$primitive_type>, rhs: &Array<$primitive_type>, output: &mut Array) { if ABSOLUTE_POS < rhs.len() { diff --git a/crates/cubecl-core/src/runtime_tests/const_match.rs b/crates/cubecl-core/src/runtime_tests/const_match.rs index c86f6a0c9..79c6b1309 100644 --- a/crates/cubecl-core/src/runtime_tests/const_match.rs +++ b/crates/cubecl-core/src/runtime_tests/const_match.rs @@ -27,7 +27,7 @@ pub fn test_kernel_const_match< F: Float + CubeElement, U: Int + hash::Hash + Eq + Debug, >( - client: ComputeClient, + client: ComputeClient, ) { let handle = client.create(as_bytes![F: 0.0, 1.0]); diff --git a/crates/cubecl-core/src/runtime_tests/constants.rs b/crates/cubecl-core/src/runtime_tests/constants.rs index 1fb551516..7aeeb2c36 100644 --- a/crates/cubecl-core/src/runtime_tests/constants.rs +++ b/crates/cubecl-core/src/runtime_tests/constants.rs @@ -10,7 +10,7 @@ fn constant_array_kernel(out: &mut Array, #[comptime] data: Vec(client: ComputeClient) { +pub fn test_constant_array(client: ComputeClient) { let handle = client.create(f32::as_bytes(&[0.0, 1.0])); let vectorization = 1; diff --git a/crates/cubecl-core/src/runtime_tests/debug.rs b/crates/cubecl-core/src/runtime_tests/debug.rs index 6320167e4..d8230db2d 100644 --- a/crates/cubecl-core/src/runtime_tests/debug.rs +++ b/crates/cubecl-core/src/runtime_tests/debug.rs @@ -13,7 +13,7 @@ fn simple_call_kernel(out: &mut Array) { } } -pub fn test_simple_call(client: ComputeClient) { +pub fn test_simple_call(client: ComputeClient) { let handle = client.create(f32::as_bytes(&[10.0, 1.0])); let vectorization = 1; @@ -43,7 +43,7 @@ fn nested_call_kernel(out: &mut Array) { } } -pub fn test_nested_call(client: ComputeClient) { +pub fn test_nested_call(client: ComputeClient) { let handle = client.create(f32::as_bytes(&[10.0, 1.0])); let vectorization = 1; @@ -71,7 +71,7 @@ fn debug_print_kernel(out: &mut Array) { } #[cfg(not(all(target_os = "macos")))] -pub fn test_debug_print(client: ComputeClient) { +pub fn test_debug_print(client: ComputeClient) { //let logger = MemoryLogger::setup(log::Level::Info); let handle = client.create(f32::as_bytes(&[10.0, 1.0])); diff --git a/crates/cubecl-core/src/runtime_tests/different_rank.rs b/crates/cubecl-core/src/runtime_tests/different_rank.rs index 5e90e8802..e38afa786 100644 --- a/crates/cubecl-core/src/runtime_tests/different_rank.rs +++ b/crates/cubecl-core/src/runtime_tests/different_rank.rs @@ -8,7 +8,7 @@ pub fn kernel_different_rank(lhs: &Tensor, rhs: &Tensor, output: } pub fn test_kernel_different_rank_first_biggest( - client: ComputeClient, + client: ComputeClient, ) { let shape_lhs = vec![2, 2, 2]; let shape_rhs = vec![8]; @@ -26,7 +26,7 @@ pub fn test_kernel_different_rank_first_biggest( - client: ComputeClient, + client: ComputeClient, ) { let shape_lhs = vec![2, 4]; let shape_rhs = vec![8]; @@ -44,7 +44,7 @@ pub fn test_kernel_different_rank_last_biggest( - client: ComputeClient, + client: ComputeClient, (shape_lhs, shape_rhs, shape_out): (Vec, Vec, Vec), (strides_lhs, strides_rhs, strides_out): (Vec, Vec, Vec), ) { diff --git a/crates/cubecl-core/src/runtime_tests/enums.rs b/crates/cubecl-core/src/runtime_tests/enums.rs index 21b42ad65..9a6c41f6a 100644 --- a/crates/cubecl-core/src/runtime_tests/enums.rs +++ b/crates/cubecl-core/src/runtime_tests/enums.rs @@ -69,7 +69,7 @@ pub fn kernel_scalar_enum(test: TestEnum, output: &mut Array) { }; } -pub fn test_scalar_enum(client: ComputeClient) { +pub fn test_scalar_enum(client: ComputeClient) { let array = client.empty(std::mem::size_of::()); kernel_scalar_enum::launch::( @@ -102,7 +102,7 @@ fn kernel_array_float_int(array: &mut ArrayFloatInt) { } pub fn test_array_float_int( - client: &ComputeClient, + client: &ComputeClient, expected: T, ) { let array = client.empty(std::mem::size_of::()); @@ -140,7 +140,7 @@ fn kernel_tuple_enum(first: &mut SimpleEnum>, second: SimpleEnum(client: &ComputeClient) { +pub fn test_tuple_enum(client: &ComputeClient) { let first = client.create(as_bytes![u32: 20]); let second = client.create(as_bytes![u32: 5]); diff --git a/crates/cubecl-core/src/runtime_tests/index.rs b/crates/cubecl-core/src/runtime_tests/index.rs index 3c064c20c..836895f9e 100644 --- a/crates/cubecl-core/src/runtime_tests/index.rs +++ b/crates/cubecl-core/src/runtime_tests/index.rs @@ -18,7 +18,7 @@ pub fn kernel_assign(output: &mut Array) { } pub fn test_kernel_index_scalar( - client: ComputeClient, + client: ComputeClient, ) { let handle = client.create(F::as_bytes(as_type![F: 0.0, 1.0, 123.0, 6.0])); let handle_slice = handle diff --git a/crates/cubecl-core/src/runtime_tests/launch.rs b/crates/cubecl-core/src/runtime_tests/launch.rs index efbe146b4..ef082d11c 100644 --- a/crates/cubecl-core/src/runtime_tests/launch.rs +++ b/crates/cubecl-core/src/runtime_tests/launch.rs @@ -53,7 +53,7 @@ pub fn kernel_with_max_shared( } } -pub fn test_kernel_with_comptime_tag(client: ComputeClient) { +pub fn test_kernel_with_comptime_tag(client: ComputeClient) { let handle = client.create(f32::as_bytes(&[5.0])); let array_arg = unsafe { ArrayArg::from_raw_parts::(&handle, 1, 1) }; @@ -86,7 +86,7 @@ pub fn test_kernel_with_comptime_tag(client: ComputeClient( - client: ComputeClient, + client: ComputeClient, ) { let handle = client.create(as_bytes![F: 0.0, 1.0]); @@ -103,7 +103,7 @@ pub fn test_kernel_with_generics( assert_eq!(actual[0], F::new(5.0)); } -pub fn test_kernel_without_generics(client: ComputeClient) { +pub fn test_kernel_without_generics(client: ComputeClient) { let handle = client.create(f32::as_bytes(&[0.0, 1.0])); kernel_without_generics::launch::( @@ -119,7 +119,7 @@ pub fn test_kernel_without_generics(client: ComputeClient(client: ComputeClient) { +pub fn test_kernel_max_shared(client: ComputeClient) { let total_shared_size = client.properties().hardware.max_shared_memory_size; let handle = client.create(u32::as_bytes(&[0, 1, 2, 3, 4, 5, 6, 7])); diff --git a/crates/cubecl-core/src/runtime_tests/line.rs b/crates/cubecl-core/src/runtime_tests/line.rs index 7f3a8efe5..003f06dbd 100644 --- a/crates/cubecl-core/src/runtime_tests/line.rs +++ b/crates/cubecl-core/src/runtime_tests/line.rs @@ -12,9 +12,7 @@ pub fn kernel_line_index(output: &mut Array, #[comptime] line_size: } #[allow(clippy::needless_range_loop)] -pub fn test_line_index( - client: ComputeClient, -) { +pub fn test_line_index(client: ComputeClient) { for line_size in R::io_optimized_line_sizes(&F::as_type_native().unwrap()) { if line_size < 4 { continue; @@ -51,7 +49,7 @@ pub fn kernel_line_index_assign(output: &mut Array>) { } pub fn test_line_index_assign( - client: ComputeClient, + client: ComputeClient, ) { for line_size in R::io_optimized_line_sizes(&F::as_type_native().unwrap()) { let handle = client.create(F::as_bytes(&vec![F::new(0.0); line_size as usize])); @@ -86,9 +84,7 @@ pub fn kernel_line_loop_unroll(output: &mut Array>, #[comptime } } -pub fn test_line_loop_unroll( - client: ComputeClient, -) { +pub fn test_line_loop_unroll(client: ComputeClient) { for line_size in R::io_optimized_line_sizes(&F::as_type_native_unchecked()) { let handle = client.create(F::as_bytes(&vec![F::new(0.0); line_size as usize])); unsafe { @@ -119,9 +115,7 @@ pub fn kernel_shared_memory(output: &mut Array>) { output[0] = smem1[0]; } -pub fn test_shared_memory( - client: ComputeClient, -) { +pub fn test_shared_memory(client: ComputeClient) { for line_size in R::io_optimized_line_sizes(&F::as_type_native().unwrap()) { let output = client.create(F::as_bytes(&vec![F::new(0.0); line_size as usize])); unsafe { @@ -155,7 +149,7 @@ macro_rules! impl_line_comparison { } pub fn [< test_line_ $cmp >] ( - client: ComputeClient, + client: ComputeClient, ) { let lhs = client.create(as_bytes![F: 0.0, 1.0, 2.0, 3.0]); let rhs = client.create(as_bytes![F: 0.0, 2.0, 1.0, 3.0]); diff --git a/crates/cubecl-core/src/runtime_tests/metadata.rs b/crates/cubecl-core/src/runtime_tests/metadata.rs index 6fbcfd7f0..a429479cb 100644 --- a/crates/cubecl-core/src/runtime_tests/metadata.rs +++ b/crates/cubecl-core/src/runtime_tests/metadata.rs @@ -79,7 +79,7 @@ pub fn kernel_buffer_len(out: &mut Tensor) { out[0] = out.buffer_len(); } -pub fn test_shape_dim_4(client: ComputeClient) { +pub fn test_shape_dim_4(client: ComputeClient) { let handle1 = client.empty(12 * core::mem::size_of::()); let handle2 = client.empty(12 * core::mem::size_of::()); let handle3 = client.empty(12 * core::mem::size_of::()); @@ -102,7 +102,7 @@ pub fn test_shape_dim_4(client: ComputeClient assert_eq!(actual, &expect); } -pub fn test_shape_different_ranks(client: ComputeClient) { +pub fn test_shape_different_ranks(client: ComputeClient) { let handle1 = client.empty(12 * core::mem::size_of::()); let handle2 = client.empty(12 * core::mem::size_of::()); let handle3 = client.empty(12 * core::mem::size_of::()); @@ -125,7 +125,7 @@ pub fn test_shape_different_ranks(client: ComputeClient(client: ComputeClient) { +pub fn test_stride_different_ranks(client: ComputeClient) { let handle1 = client.empty(9 * core::mem::size_of::()); let handle2 = client.empty(9 * core::mem::size_of::()); let handle3 = client.empty(9 * core::mem::size_of::()); @@ -148,7 +148,7 @@ pub fn test_stride_different_ranks(client: ComputeClient(client: ComputeClient) { +pub fn test_len_different_ranks(client: ComputeClient) { let handle1 = client.empty(3 * core::mem::size_of::()); let handle2 = client.empty(3 * core::mem::size_of::()); let handle3 = client.empty(3 * core::mem::size_of::()); @@ -171,7 +171,7 @@ pub fn test_len_different_ranks(client: ComputeClient(client: ComputeClient) { +pub fn test_buffer_len_discontiguous(client: ComputeClient) { let handle1 = client.empty(64 * core::mem::size_of::()); unsafe { @@ -189,7 +189,7 @@ pub fn test_buffer_len_discontiguous(client: ComputeClient(client: ComputeClient) { +pub fn test_buffer_len_vectorized(client: ComputeClient) { let handle1 = client.empty(32 * core::mem::size_of::()); unsafe { @@ -207,7 +207,7 @@ pub fn test_buffer_len_vectorized(client: ComputeClient(client: ComputeClient) { +pub fn test_buffer_len_offset(client: ComputeClient) { let handle1 = client.empty(256 * core::mem::size_of::()); // We use an offset of 256 bytes here because this is the default in WebGPU and // as of wgpu 22+, 256 is the value of 'min_storage_buffer_offset_alignment' for metal GPUs. diff --git a/crates/cubecl-core/src/runtime_tests/minifloat.rs b/crates/cubecl-core/src/runtime_tests/minifloat.rs index 14729b1b0..0efff263b 100644 --- a/crates/cubecl-core/src/runtime_tests/minifloat.rs +++ b/crates/cubecl-core/src/runtime_tests/minifloat.rs @@ -48,7 +48,7 @@ pub fn kernel_scale(input: &mut Array>, out: &mut Array>) #[allow(clippy::unusual_byte_groupings, reason = "Split by float components")] pub fn test_fp8( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { if !e4m3::supported_uses(&client).contains(TypeUsage::Conversion) { @@ -93,7 +93,7 @@ pub fn test_fp8( #[allow(clippy::unusual_byte_groupings, reason = "Split by float components")] pub fn test_fp6( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { if !e2m3::supported_uses(&client).contains(TypeUsage::Conversion) { @@ -138,7 +138,7 @@ pub fn test_fp6( #[allow(clippy::unusual_byte_groupings, reason = "Split by float components")] pub fn test_fp4( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { if !e2m1x2::supported_uses(&client).contains(TypeUsage::Conversion) { @@ -178,7 +178,7 @@ pub fn test_fp4( assert_eq!(&actual_2[..num_out], &expected_data[..num_out]); } -pub fn test_scale(client: ComputeClient, vectorization: u8) { +pub fn test_scale(client: ComputeClient, vectorization: u8) { if !ue8m0::supported_uses(&client).contains(TypeUsage::Conversion) { println!("Unsupported, skipping"); return; diff --git a/crates/cubecl-core/src/runtime_tests/mod.rs b/crates/cubecl-core/src/runtime_tests/mod.rs index 431caaec5..f36af5620 100644 --- a/crates/cubecl-core/src/runtime_tests/mod.rs +++ b/crates/cubecl-core/src/runtime_tests/mod.rs @@ -16,6 +16,7 @@ pub mod launch; pub mod line; pub mod metadata; pub mod minifloat; +pub mod numeric; pub mod plane; pub mod saturating; pub mod sequence; @@ -129,6 +130,7 @@ macro_rules! testgen_uint { macro_rules! testgen_untyped { () => { cubecl_core::testgen_cmma!(); + cubecl_core::testgen_numeric!(); cubecl_core::testgen_metadata!(); cubecl_core::testgen_topology!(); diff --git a/crates/cubecl-core/src/runtime_tests/numeric.rs b/crates/cubecl-core/src/runtime_tests/numeric.rs new file mode 100644 index 000000000..f2e669b39 --- /dev/null +++ b/crates/cubecl-core/src/runtime_tests/numeric.rs @@ -0,0 +1,44 @@ +use crate::{self as cubecl}; +use cubecl::prelude::*; +use cubecl_ir::{ElemType, FloatKind}; + +#[cube(launch)] +pub fn kernel_define(array: &mut Array, #[define(N)] _elem: ElemType) { + array[UNIT_POS] += N::cast_from(5.0f32); +} + +pub fn test_kernel_define(client: ComputeClient) { + let handle = client.create(f32::as_bytes(&[f32::new(0.0), f32::new(1.0)])); + + let elem = ElemType::Float(FloatKind::F32); + + kernel_define::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new_1d(2), + unsafe { ArrayArg::from_raw_parts_and_size(&handle, 2, 1, elem.size()) }, + elem, + ); + + let actual = client.read_one(handle); + let actual = f32::from_bytes(&actual); + + assert_eq!(actual[0], f32::new(5.0)); + assert_eq!(actual[1], f32::new(6.0)); +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_numeric { + () => { + use super::*; + use cubecl_core::prelude::*; + + #[test] + fn test_kernel_define() { + let client = TestRuntime::client(&Default::default()); + let cube_dimensions = cube_dim::(&client); + cubecl_core::runtime_tests::numeric::test_kernel_define::(client); + } + }; +} diff --git a/crates/cubecl-core/src/runtime_tests/plane.rs b/crates/cubecl-core/src/runtime_tests/plane.rs index 7b892d174..c14bf5759 100644 --- a/crates/cubecl-core/src/runtime_tests/plane.rs +++ b/crates/cubecl-core/src/runtime_tests/plane.rs @@ -118,11 +118,45 @@ pub fn kernel_ballot(output: &mut Tensor>) { } } +#[cube(launch)] +pub fn kernel_shuffle(output: &mut Tensor) { + let val = output[UNIT_POS]; + let val2 = plane_shuffle(val, 0); // All lanes read from lane 0 + + if UNIT_POS == 0 { + output[0] = val2; + } +} + +#[cube(launch)] +pub fn kernel_shuffle_xor(output: &mut Tensor) { + let val = output[UNIT_POS]; + let val2 = plane_shuffle_xor(val, 1); + + output[UNIT_POS] = val2; +} + +#[cube(launch)] +pub fn kernel_shuffle_up(output: &mut Tensor) { + let val = output[UNIT_POS]; + let val2 = plane_shuffle_up(val, 1); + + output[UNIT_POS] = val2; +} + +#[cube(launch)] +pub fn kernel_shuffle_down(output: &mut Tensor) { + let val = output[UNIT_POS]; + let val2 = plane_shuffle_down(val, 1); + + output[UNIT_POS] = val2; +} + pub fn test_plane_sum< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, >( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { let plane_size = 32; @@ -159,7 +193,7 @@ pub fn test_plane_inclusive_sum< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, >( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { let plane_size = 32; @@ -201,7 +235,7 @@ pub fn test_plane_exclusive_sum< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, >( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { let plane_size = 32; @@ -243,7 +277,7 @@ pub fn test_plane_prod< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, >( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { let plane_size = 32; @@ -285,7 +319,7 @@ pub fn test_plane_inclusive_prod< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, >( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { let plane_size = 32; @@ -331,7 +365,7 @@ pub fn test_plane_exclusive_prod< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, >( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { let plane_size = 32; @@ -377,7 +411,7 @@ pub fn test_plane_max< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, >( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { let plane_size = 32; @@ -416,7 +450,7 @@ pub fn test_plane_min< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, >( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { let plane_size = 32; @@ -455,7 +489,7 @@ pub fn test_plane_all< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, >( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { let plane_size = 32; @@ -496,7 +530,7 @@ pub fn test_plane_any< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, >( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { let plane_size = 32; @@ -533,9 +567,7 @@ pub fn test_plane_any< ); } -pub fn test_plane_ballot( - client: ComputeClient, -) { +pub fn test_plane_ballot(client: ComputeClient) { if !client.properties().features.plane.contains(Plane::Ops) { // Can't execute the test. return; @@ -563,7 +595,7 @@ pub fn test_plane_elect< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, >( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { let plane_size = 32; @@ -595,7 +627,7 @@ pub fn test_plane_broadcast< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, >( - client: ComputeClient, + client: ComputeClient, vectorization: u8, ) { let plane_size = 32; @@ -626,6 +658,165 @@ pub fn test_plane_broadcast< ); } +pub fn test_plane_shuffle< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( + client: ComputeClient, + vectorization: u8, +) { + let plane_size = 32; + let input: Vec = (0..plane_size * vectorization as u32) + .map(|x| x as f32) + .collect(); + let mut expected = input.clone(); + + // All lanes read from lane 0 (same as broadcast(value, 0)) + expected[..vectorization as usize].copy_from_slice(&input[..vectorization as usize]); + + let input: Vec = input.into_iter().map(|x| F::new(x)).collect(); + let expected: Vec = expected.into_iter().map(|x| F::new(x)).collect(); + + test_plane_operation::( + &input, + &expected, + vectorization, + client.clone(), + |cube_count, handle| { + kernel_shuffle::launch::( + &client, + cube_count, + CubeDim::new(plane_size, 1, 1), + handle, + ) + }, + ); +} + +pub fn test_plane_shuffle_xor< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( + client: ComputeClient, + vectorization: u8, +) { + let plane_size = 32; + let input: Vec = (0..plane_size * vectorization as u32) + .map(|x| x as f32) + .collect(); + let mut expected = input.clone(); + + // XOR with mask=1: lane i gets value from lane (i XOR 1) + // So lane 0 <-> 1, lane 2 <-> 3, lane 4 <-> 5, etc. + for lane in 0..plane_size as usize { + let partner = lane ^ 1; + for v in 0..vectorization as usize { + expected[lane * vectorization as usize + v] = + input[partner * vectorization as usize + v]; + } + } + + let input: Vec = input.into_iter().map(|x| F::new(x)).collect(); + let expected: Vec = expected.into_iter().map(|x| F::new(x)).collect(); + + test_plane_operation::( + &input, + &expected, + vectorization, + client.clone(), + |cube_count, handle| { + kernel_shuffle_xor::launch::( + &client, + cube_count, + CubeDim::new(plane_size, 1, 1), + handle, + ) + }, + ); +} + +pub fn test_plane_shuffle_up< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( + client: ComputeClient, + vectorization: u8, +) { + let plane_size = 32; + let input: Vec = (0..plane_size * vectorization as u32) + .map(|x| x as f32) + .collect(); + let mut expected = input.clone(); + + // Shuffle up with delta=1: lane i gets value from lane (i - 1) + // Lane 0 stays the same, lanes 1..31 shift down + for lane in 1..plane_size as usize { + for v in 0..vectorization as usize { + expected[lane * vectorization as usize + v] = + input[(lane - 1) * vectorization as usize + v]; + } + } + + let input: Vec = input.into_iter().map(|x| F::new(x)).collect(); + let expected: Vec = expected.into_iter().map(|x| F::new(x)).collect(); + + test_plane_operation::( + &input, + &expected, + vectorization, + client.clone(), + |cube_count, handle| { + kernel_shuffle_up::launch::( + &client, + cube_count, + CubeDim::new(plane_size, 1, 1), + handle, + ) + }, + ); +} + +pub fn test_plane_shuffle_down< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( + client: ComputeClient, + vectorization: u8, +) { + let plane_size = 32; + let input: Vec = (0..plane_size * vectorization as u32) + .map(|x| x as f32) + .collect(); + let mut expected = input.clone(); + + // Shuffle down with delta=1: lane i gets value from lane (i + 1) + // Lanes 0..30 shift up, lane 31 stays the same + for lane in 0..(plane_size - 1) as usize { + for v in 0..vectorization as usize { + expected[lane * vectorization as usize + v] = + input[(lane + 1) * vectorization as usize + v]; + } + } + + let input: Vec = input.into_iter().map(|x| F::new(x)).collect(); + let expected: Vec = expected.into_iter().map(|x| F::new(x)).collect(); + + test_plane_operation::( + &input, + &expected, + vectorization, + client.clone(), + |cube_count, handle| { + kernel_shuffle_down::launch::( + &client, + cube_count, + CubeDim::new(plane_size, 1, 1), + handle, + ) + }, + ); +} + fn test_plane_operation< TestRuntime: Runtime, F: Float + num_traits::Float + CubeElement + Display, @@ -634,7 +825,7 @@ fn test_plane_operation< input: &[F], expected: &[F], vectorization: u8, - client: ComputeClient, + client: ComputeClient, launch: Launch, ) where Launch: Fn(CubeCount, TensorArg<'_, TestRuntime>), @@ -911,5 +1102,85 @@ macro_rules! testgen_plane { let client = TestRuntime::client(&Default::default()); cubecl_core::runtime_tests::plane::test_plane_ballot::(client.clone()); } + + fn impl_test_plane_shuffle(vectorization: u8) { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::plane::test_plane_shuffle::( + client.clone(), + vectorization, + ); + } + #[test] + fn test_plane_shuffle_vec1() { + impl_test_plane_shuffle(1); + } + #[test] + fn test_plane_shuffle_vec2() { + impl_test_plane_shuffle(2); + } + #[test] + fn test_plane_shuffle_vec4() { + impl_test_plane_shuffle(4); + } + + fn impl_test_plane_shuffle_xor(vectorization: u8) { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::plane::test_plane_shuffle_xor::( + client.clone(), + vectorization, + ); + } + #[test] + fn test_plane_shuffle_xor_vec1() { + impl_test_plane_shuffle_xor(1); + } + #[test] + fn test_plane_shuffle_xor_vec2() { + impl_test_plane_shuffle_xor(2); + } + #[test] + fn test_plane_shuffle_xor_vec4() { + impl_test_plane_shuffle_xor(4); + } + + fn impl_test_plane_shuffle_up(vectorization: u8) { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::plane::test_plane_shuffle_up::( + client.clone(), + vectorization, + ); + } + #[test] + fn test_plane_shuffle_up_vec1() { + impl_test_plane_shuffle_up(1); + } + #[test] + fn test_plane_shuffle_up_vec2() { + impl_test_plane_shuffle_up(2); + } + #[test] + fn test_plane_shuffle_up_vec4() { + impl_test_plane_shuffle_up(4); + } + + fn impl_test_plane_shuffle_down(vectorization: u8) { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::plane::test_plane_shuffle_down::( + client.clone(), + vectorization, + ); + } + #[test] + fn test_plane_shuffle_down_vec1() { + impl_test_plane_shuffle_down(1); + } + #[test] + fn test_plane_shuffle_down_vec2() { + impl_test_plane_shuffle_down(2); + } + #[test] + fn test_plane_shuffle_down_vec4() { + impl_test_plane_shuffle_down(4); + } }; } diff --git a/crates/cubecl-core/src/runtime_tests/saturating.rs b/crates/cubecl-core/src/runtime_tests/saturating.rs index 2561872ec..4efa5952a 100644 --- a/crates/cubecl-core/src/runtime_tests/saturating.rs +++ b/crates/cubecl-core/src/runtime_tests/saturating.rs @@ -25,7 +25,7 @@ pub fn kernel_saturating_sub( #[allow(clippy::needless_range_loop)] pub fn test_saturating_add_unsigned( - client: ComputeClient, + client: ComputeClient, line_size: u32, ) { let lhs = vec![ @@ -64,7 +64,7 @@ pub fn test_saturating_add_unsigned( #[allow(clippy::needless_range_loop)] pub fn test_saturating_sub_unsigned( - client: ComputeClient, + client: ComputeClient, line_size: u32, ) { let lhs = vec![ @@ -99,7 +99,7 @@ pub fn test_saturating_sub_unsigned( // Signed has a lot more possible cases due to overflow/underflow #[allow(clippy::needless_range_loop)] pub fn test_saturating_add_signed( - client: ComputeClient, + client: ComputeClient, line_size: u32, ) { let lhs = vec![ @@ -180,7 +180,7 @@ pub fn test_saturating_add_signed( // Signed has a lot more possible cases due to overflow/underflow #[allow(clippy::needless_range_loop)] pub fn test_saturating_sub_signed( - client: ComputeClient, + client: ComputeClient, line_size: u32, ) { let lhs = vec![ diff --git a/crates/cubecl-core/src/runtime_tests/sequence.rs b/crates/cubecl-core/src/runtime_tests/sequence.rs index 6c0e5c292..c303db8e7 100644 --- a/crates/cubecl-core/src/runtime_tests/sequence.rs +++ b/crates/cubecl-core/src/runtime_tests/sequence.rs @@ -31,7 +31,7 @@ pub fn sequence_index(output: &mut Array) { } pub fn test_sequence_for_loop( - client: ComputeClient, + client: ComputeClient, ) { let handle = client.create(as_bytes![F: 0.0]); @@ -48,9 +48,7 @@ pub fn test_sequence_for_loop( assert_eq!(actual[0], F::new(5.0)); } -pub fn test_sequence_index( - client: ComputeClient, -) { +pub fn test_sequence_index(client: ComputeClient) { let handle = client.create(as_bytes![F: 0.0]); sequence_index::launch::( diff --git a/crates/cubecl-core/src/runtime_tests/slice.rs b/crates/cubecl-core/src/runtime_tests/slice.rs index 0a7eaaf69..ca0a21c3e 100644 --- a/crates/cubecl-core/src/runtime_tests/slice.rs +++ b/crates/cubecl-core/src/runtime_tests/slice.rs @@ -46,9 +46,7 @@ pub fn slice_mut_len(output: &mut Array) { } } -pub fn test_slice_select( - client: ComputeClient, -) { +pub fn test_slice_select(client: ComputeClient) { let input = client.create(as_bytes![F: 0.0, 1.0, 2.0, 3.0, 4.0]); let output = client.empty(core::mem::size_of::()); @@ -68,9 +66,7 @@ pub fn test_slice_select( assert_eq!(actual[0], F::new(2.0)); } -pub fn test_slice_len( - client: ComputeClient, -) { +pub fn test_slice_len(client: ComputeClient) { let input = client.create(as_bytes![F: 0.0, 1.0, 2.0, 3.0, 4.0]); let output = client.empty(core::mem::size_of::()); @@ -90,9 +86,7 @@ pub fn test_slice_len( assert_eq!(actual, &[2]); } -pub fn test_slice_for( - client: ComputeClient, -) { +pub fn test_slice_for(client: ComputeClient) { let input = client.create(as_bytes![F: 0.0, 1.0, 2.0, 3.0, 4.0]); let output = client.create(as_bytes![F: 0.0]); @@ -112,9 +106,7 @@ pub fn test_slice_for( assert_eq!(actual[0], F::new(5.0)); } -pub fn test_slice_mut_assign( - client: ComputeClient, -) { +pub fn test_slice_mut_assign(client: ComputeClient) { let input = client.create(as_bytes![F: 15.0]); let output = client.create(as_bytes![F: 0.0, 1.0, 2.0, 3.0, 4.0]); @@ -134,7 +126,7 @@ pub fn test_slice_mut_assign( assert_eq!(&actual[0..5], as_type![F: 0.0, 1.0, 15.0, 3.0, 4.0]); } -pub fn test_slice_mut_len(client: ComputeClient) { +pub fn test_slice_mut_len(client: ComputeClient) { let output = client.empty(core::mem::size_of::() * 4); unsafe { diff --git a/crates/cubecl-core/src/runtime_tests/stream.rs b/crates/cubecl-core/src/runtime_tests/stream.rs index 3d8905373..c111164e9 100644 --- a/crates/cubecl-core/src/runtime_tests/stream.rs +++ b/crates/cubecl-core/src/runtime_tests/stream.rs @@ -14,9 +14,7 @@ pub fn big_task(input: &Array, output: &mut Array, num_loop: u } } -pub fn test_stream( - client: ComputeClient, -) { +pub fn test_stream(client: ComputeClient) { let client_1 = unsafe { let mut c = client.clone(); c.set_stream(StreamId { value: 10000 }); diff --git a/crates/cubecl-core/src/runtime_tests/synchronization.rs b/crates/cubecl-core/src/runtime_tests/synchronization.rs index a0bad5f76..536979d68 100644 --- a/crates/cubecl-core/src/runtime_tests/synchronization.rs +++ b/crates/cubecl-core/src/runtime_tests/synchronization.rs @@ -13,7 +13,7 @@ fn kernel_test_sync_cube(buffer: &mut Array, out: &mut Array) { } } -pub fn test_sync_cube(client: ComputeClient) { +pub fn test_sync_cube(client: ComputeClient) { let handle = client.empty(32 * core::mem::size_of::()); let test = client.empty(32 * core::mem::size_of::()); @@ -52,7 +52,7 @@ fn kernel_test_finished_sync_cube(buffer: &mut Array, out: &mut Array) sync_cube(); } -pub fn test_finished_sync_cube(client: ComputeClient) { +pub fn test_finished_sync_cube(client: ComputeClient) { let handle = client.empty(32 * core::mem::size_of::()); let test = client.empty(32 * core::mem::size_of::()); @@ -88,7 +88,7 @@ fn kernel_test_sync_plane(out: &mut Array) { out[UNIT_POS] = shared_memory[0]; } -pub fn test_sync_plane(client: ComputeClient) { +pub fn test_sync_plane(client: ComputeClient) { if !client.properties().features.plane.contains(Plane::Sync) { // We can't execute the test, skip. return; diff --git a/crates/cubecl-core/src/runtime_tests/tensor.rs b/crates/cubecl-core/src/runtime_tests/tensor.rs index 67be69d7b..ce832a70a 100644 --- a/crates/cubecl-core/src/runtime_tests/tensor.rs +++ b/crates/cubecl-core/src/runtime_tests/tensor.rs @@ -8,7 +8,7 @@ pub fn tensor_coordinate(input: &Tensor, output: &mut Array) { output[UNIT_POS] = input.coordinate(index, dim); } -pub fn test_tensor_coordinate(client: ComputeClient) { +pub fn test_tensor_coordinate(client: ComputeClient) { let stride = [2, 1, 4]; let shape = [2, 2, 3]; diff --git a/crates/cubecl-core/src/runtime_tests/tensormap.rs b/crates/cubecl-core/src/runtime_tests/tensormap.rs index 3873544dc..fe140c0a1 100644 --- a/crates/cubecl-core/src/runtime_tests/tensormap.rs +++ b/crates/cubecl-core/src/runtime_tests/tensormap.rs @@ -106,9 +106,8 @@ fn tensormap_metadata( output_2[3] = output_2.shape(0); } -pub fn test_tensormap_load( - client: ComputeClient, -) where +pub fn test_tensormap_load(client: ComputeClient) +where <::Storage as ComputeStorage>::Resource: Debug, { if !client.properties().features.tma.contains(Tma::Base) { @@ -147,9 +146,8 @@ pub fn test_tensormap_load( assert_eq!(actual, &expected); } -pub fn test_tensormap_store( - client: ComputeClient, -) where +pub fn test_tensormap_store(client: ComputeClient) +where <::Storage as ComputeStorage>::Resource: Debug, { if !client.properties().features.tma.contains(Tma::Base) { @@ -202,7 +200,7 @@ pub fn test_tensormap_store( } pub fn test_tensormap_load_im2col( - client: ComputeClient, + client: ComputeClient, ) where <::Storage as ComputeStorage>::Resource: Debug, { @@ -289,9 +287,8 @@ pub fn test_tensormap_load_im2col( assert_eq!(actual, &expected_actual); } -pub fn test_tensormap_metadata( - client: ComputeClient, -) where +pub fn test_tensormap_metadata(client: ComputeClient) +where <::Storage as ComputeStorage>::Resource: Debug, { if !client.properties().features.tma.contains(Tma::Base) { diff --git a/crates/cubecl-core/src/runtime_tests/topology.rs b/crates/cubecl-core/src/runtime_tests/topology.rs index 51a724400..f341bbdc7 100644 --- a/crates/cubecl-core/src/runtime_tests/topology.rs +++ b/crates/cubecl-core/src/runtime_tests/topology.rs @@ -11,7 +11,7 @@ pub fn kernel_absolute_pos(output1: &mut Array) { output1[ABSOLUTE_POS] = ABSOLUTE_POS; } -pub fn test_kernel_topology_absolute_pos(client: ComputeClient) { +pub fn test_kernel_topology_absolute_pos(client: ComputeClient) { let cube_count = (3, 5, 7); let cube_dim = (16, 16, 1); diff --git a/crates/cubecl-core/src/runtime_tests/unary.rs b/crates/cubecl-core/src/runtime_tests/unary.rs index b22e26f70..f71504920 100644 --- a/crates/cubecl-core/src/runtime_tests/unary.rs +++ b/crates/cubecl-core/src/runtime_tests/unary.rs @@ -10,7 +10,7 @@ pub(crate) fn assert_equals_approx< R: Runtime, F: Float + num_traits::Float + CubeElement + Display, >( - client: &ComputeClient, + client: &ComputeClient, output: Handle, expected: &[F], epsilon: F, @@ -20,7 +20,11 @@ pub(crate) fn assert_equals_approx< for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() { assert!( - (*a - *e).abs() < epsilon || (a.is_nan() && e.is_nan()), + (*a - *e).abs() < epsilon + || (a.is_nan() && e.is_nan()) + || (a.is_infinite() + && e.is_infinite() + && a.is_sign_positive() == e.is_sign_positive()), "Values differ more than epsilon: actual={}, expected={}, difference={}, epsilon={} index: {} actual: {:?} @@ -65,8 +69,8 @@ macro_rules! test_unary_impl { expected: $expected:expr }),*], $epsilon:expr) => { - pub fn $test_name(client: ComputeClient) { - #[cube(launch_unchecked)] + pub fn $test_name(client: ComputeClient) { + #[cube(launch_unchecked, fast_math = FastMath::all())] fn test_function<$float_type: Float>(input: &Array<$float_type>, output: &mut Array<$float_type>) { if ABSOLUTE_POS < input.len() { output[ABSOLUTE_POS] = $unary_func(input[ABSOLUTE_POS]); @@ -108,7 +112,7 @@ macro_rules! test_unary_impl_fixed { input: $input:expr, expected: $expected:expr }),*]) => { - pub fn $test_name(client: ComputeClient) { + pub fn $test_name(client: ComputeClient) { #[cube(launch_unchecked)] fn test_function<$float_type: Float>(input: &Array<$float_type>, output: &mut Array<$out_type>) { if ABSOLUTE_POS < input.len() { @@ -153,7 +157,7 @@ macro_rules! test_unary_impl_int { input: $input:expr, expected: $expected:expr }),*]) => { - pub fn $test_name(client: ComputeClient) { + pub fn $test_name(client: ComputeClient) { #[cube(launch_unchecked)] fn test_function<$int_type: Int>(input: &Array<$int_type>, output: &mut Array<$int_type>) { if ABSOLUTE_POS < input.len() { @@ -199,7 +203,7 @@ macro_rules! test_unary_impl_int_fixed { input: $input:expr, expected: $expected:expr }),*]) => { - pub fn $test_name(client: ComputeClient) { + pub fn $test_name(client: ComputeClient) { #[cube(launch_unchecked)] fn test_function<$int_type: Int>(input: &Array<$int_type>, output: &mut Array<$out_type>) { if ABSOLUTE_POS < input.len() { @@ -506,27 +510,6 @@ test_unary_impl!(test_sqrt, F, F::sqrt, [ } ]); -test_unary_impl!(test_rsqrt, F, F::rsqrt, [ - { - input_vectorization: 1, - out_vectorization: 1, - input: as_type![F: 1., 4., 9., 16., 25.], - expected: as_type![F: 1., 0.5, 0.33333333333, 0.25, 0.2] - }, - { - input_vectorization: 2, - out_vectorization: 2, - input: as_type![F: 1., 4., 9., 16.], - expected: as_type![F: 1., 0.5, 0.33333333333, 0.25] - }, - { - input_vectorization: 4, - out_vectorization: 4, - input: as_type![F: 1., 4., 9., 16.], - expected: as_type![F: 1., 0.5, 0.33333333333, 0.25] - } -]); - test_unary_impl!(test_degrees, F, F::to_degrees, [ { input_vectorization: 1, @@ -622,6 +605,27 @@ test_unary_impl!(test_abs, F, F::abs, [ } ]); +test_unary_impl!(test_inverse_sqrt, F, F::inverse_sqrt, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 1.0, 4.0, 16.0, 0.25], + expected: as_type![F: 1.0, 0.5, 0.25, 2.0] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 9.0, 25.0, 0.0625, 100.0], + expected: as_type![F: 0.333_333_34, 0.2, 4.0, 0.1] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0.0, 0.01, 64.0, 0.111111], + expected: as_type![F: f32::INFINITY, 10.0, 0.125, 3.0] + } +]); + test_unary_impl!( test_normalize, F, @@ -660,6 +664,29 @@ test_unary_impl!( ] ); +test_unary_impl!( + test_trunc, + F, + F::trunc, + [{ + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: -1.2, -1., -0., 0.], + expected: as_type![F: -1., -1., -0., 0.] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: f32::NAN, 1., 1.2, 1.9], + expected: as_type![F: f32::NAN, 1., 1., 1.0] + },{ + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: -0.9, 0.2, f32::NAN, 1.99], + expected: as_type![F: -0., 0., f32::NAN, 1.] + }] +); + test_unary_impl_fixed!( test_is_nan, F, @@ -848,8 +875,9 @@ macro_rules! testgen_unary { add_test!(test_normalize); add_test!(test_magnitude); add_test!(test_sqrt); - add_test!(test_rsqrt); + add_test!(test_inverse_sqrt); add_test!(test_abs); + add_test!(test_trunc); add_test!(test_is_nan); add_test!(test_is_inf); } diff --git a/crates/cubecl-core/src/runtime_tests/unroll.rs b/crates/cubecl-core/src/runtime_tests/unroll.rs index ccef6bc62..45bf3b22f 100644 --- a/crates/cubecl-core/src/runtime_tests/unroll.rs +++ b/crates/cubecl-core/src/runtime_tests/unroll.rs @@ -14,7 +14,7 @@ pub fn unroll_add(output: &mut Array>) { let mut out = Line::empty(4u32); #[unroll] - for i in 0..4 { + for i in 0..4u32 { out[i] = c[i]; } @@ -35,9 +35,7 @@ pub fn unroll_load_store(output: &mut Array>) { output[0] = c; } -pub fn test_unroll_add( - client: ComputeClient, -) { +pub fn test_unroll_add(client: ComputeClient) { let handle = client.empty(4 * size_of::()); unroll_add::launch::( @@ -54,7 +52,7 @@ pub fn test_unroll_add( } pub fn test_unroll_load_store( - client: ComputeClient, + client: ComputeClient, ) { let handle = client.create(as_bytes!(F: 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)); diff --git a/crates/cubecl-cpp/Cargo.toml b/crates/cubecl-cpp/Cargo.toml index 4daad86fd..18ec2b856 100644 --- a/crates/cubecl-cpp/Cargo.toml +++ b/crates/cubecl-cpp/Cargo.toml @@ -23,11 +23,11 @@ metal = [] std = ["cubecl-runtime/std", "cubecl-common/std", "cubecl-core/std"] [dependencies] -cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false } -cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false } +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false } +cubecl-core = { path = "../cubecl-core", version = "0.9.0", default-features = false } # For shared allocation -cubecl-opt = { path = "../cubecl-opt", version = "0.7.0", default-features = false } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false, features = [ +cubecl-opt = { path = "../cubecl-opt", version = "0.9.0", default-features = false } +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false, features = [ "channel-mutex", ] } itertools = { version = "0.14.0", default-features = false } diff --git a/crates/cubecl-cpp/src/metal/dialect.rs b/crates/cubecl-cpp/src/metal/dialect.rs index 9ba1a9316..8a0052596 100644 --- a/crates/cubecl-cpp/src/metal/dialect.rs +++ b/crates/cubecl-cpp/src/metal/dialect.rs @@ -7,8 +7,8 @@ use crate::{ self, AtomicKind, Binding, Component, CubeIndexFlags, DialectBindings, DialectCubeBuiltins, DialectIncludes, DialectInstructions, DialectProcessors, DialectTypes, DialectWarpReduceCompiler, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment, - FragmentIdent, FragmentLayout, Instruction, Item, ManualMma, SupportedMmaCombinations, - Variable, WarpInstruction, WmmaInstruction, wmma_api_base, + FragmentIdent, FragmentLayout, Instruction, Item, ManualMma, SharedMemory, + SupportedMmaCombinations, Variable, WarpInstruction, WmmaInstruction, wmma_api_base, }, }; use cubecl_core::{ @@ -33,94 +33,104 @@ impl Dialect for MslDialect { type Architecture = MetalArchitecture; } +impl MslDialect { + fn warp_op_vectorized( + f: &mut core::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + simd_op_prefix: &str, + simd_op_suffix: &str, + ) -> core::fmt::Result { + let out = out.fmt_left(); + let vectorization = input.item().vectorization; + + f.write_fmt(format_args!("{out} = {} {{", input.item()))?; + + for k in 0..vectorization { + let index = if vectorization > 1 { + format!(".i_{k}") + } else { + String::new() + }; + let comma = if k + 1 < vectorization { "," } else { "" }; + + writeln!(f, "{simd_op_prefix}{input}{index}{simd_op_suffix}{comma}")?; + } + + f.write_fmt(format_args!("}};\n")) + } +} + impl DialectWarpReduceCompiler for MslDialect { fn warp_reduce_sum( f: &mut core::fmt::Formatter<'_>, input: &Variable, out: &Variable, ) -> core::fmt::Result { - let out = out.fmt_left(); - f.write_fmt(format_args!("{out} = simd_sum({input});\n")) + Self::warp_op_vectorized(f, input, out, "simd_sum(", ")") } fn warp_reduce_prod( f: &mut core::fmt::Formatter<'_>, input: &Variable, out: &Variable, ) -> core::fmt::Result { - let out = out.fmt_left(); - f.write_fmt(format_args!("{out} = simd_product({input});\n")) + Self::warp_op_vectorized(f, input, out, "simd_product(", ")") } fn warp_reduce_max( f: &mut core::fmt::Formatter<'_>, input: &Variable, out: &Variable, ) -> core::fmt::Result { - let out = out.fmt_left(); - f.write_fmt(format_args!("{out} = simd_max({input});\n")) + Self::warp_op_vectorized(f, input, out, "simd_max(", ")") } fn warp_reduce_min( f: &mut core::fmt::Formatter<'_>, input: &Variable, out: &Variable, ) -> core::fmt::Result { - let out = out.fmt_left(); - f.write_fmt(format_args!("{out} = simd_min({input});\n")) + Self::warp_op_vectorized(f, input, out, "simd_min(", ")") } fn warp_reduce_all( f: &mut core::fmt::Formatter<'_>, input: &Variable, out: &Variable, ) -> core::fmt::Result { - let out = out.fmt_left(); - f.write_fmt(format_args!("{out} = simd_and({input});\n")) + Self::warp_op_vectorized(f, input, out, "simd_and(", "? 1u : 0u) != 0u") } fn warp_reduce_any( f: &mut core::fmt::Formatter<'_>, input: &Variable, out: &Variable, ) -> core::fmt::Result { - let out = out.fmt_left(); - f.write_fmt(format_args!("{out} = simd_or({input});\n")) + Self::warp_op_vectorized(f, input, out, "simd_or(", "? 1u : 0u) != 0u") } fn warp_reduce_sum_inclusive( f: &mut core::fmt::Formatter<'_>, input: &Variable, out: &Variable, ) -> core::fmt::Result { - let out = out.fmt_left(); - f.write_fmt(format_args!( - "{out} = simd_prefix_inclusive_sum({input});\n" - )) + Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_sum(", ")") } fn warp_reduce_prod_inclusive( f: &mut core::fmt::Formatter<'_>, input: &Variable, out: &Variable, ) -> core::fmt::Result { - let out = out.fmt_left(); - f.write_fmt(format_args!( - "{out} = simd_prefix_inclusive_product({input});\n" - )) + Self::warp_op_vectorized(f, input, out, "simd_prefix_inclusive_product(", ")") } fn warp_reduce_sum_exclusive( f: &mut core::fmt::Formatter<'_>, input: &Variable, out: &Variable, ) -> core::fmt::Result { - let out = out.fmt_left(); - f.write_fmt(format_args!( - "{out} = simd_prefix_exclusive_sum({input});\n" - )) + Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_sum(", ")") } fn warp_reduce_prod_exclusive( f: &mut core::fmt::Formatter<'_>, input: &Variable, out: &Variable, ) -> core::fmt::Result { - let out = out.fmt_left(); - f.write_fmt(format_args!( - "{out} = simd_prefix_exclusive_product({input});\n" - )) + Self::warp_op_vectorized(f, input, out, "simd_prefix_exclusive_product(", ")") } } @@ -316,6 +326,22 @@ struct alignas({alignment}) {item} {{" fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "thread") } + + fn compile_shared_memory_declaration( + f: &mut std::fmt::Formatter<'_>, + shared: &SharedMemory, + ) -> std::fmt::Result { + let item = shared.item; + let index = shared.index; + let offset = shared.offset; + let size = shared.length; + let size_bytes = size * shared.item.size() as u32; + writeln!(f, "// Shared memory size: {size}, {size_bytes} bytes")?; + writeln!( + f, + "threadgroup {item}* shared_memory_{index} = reinterpret_cast(&dynamic_shared_mem[{offset}]);" + ) + } } // Kernel argument bindings @@ -413,16 +439,8 @@ void {kernel_name}(" .map(|it| it.offset + it.size()) .max() .unwrap(); - let max_align = body - .shared_memories - .iter() - .map(|smem| smem.align) - .max() - .unwrap(); - writeln!( - f, - "threadgroup alignas({max_align}) uchar dynamic_shared_mem[{size}];", - )?; + + writeln!(f, "threadgroup uchar dynamic_shared_mem[{size}];",)?; } Ok(()) } diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index b46c68e36..5f2d368fc 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -1,19 +1,19 @@ use std::{collections::HashSet, fmt::Debug}; use cubecl_common::ExecutionMode; -use cubecl_core::CubeDim; -use cubecl_core::ir::{FloatKind, Processor, UIntKind, VariableKind}; +use cubecl_core::ir::{FloatKind, InstructionModes, Processor, UIntKind, VariableKind}; use cubecl_core::post_processing::checked_io::CheckedIoProcessor; use cubecl_core::{ Compiler, ir::{self as gpu}, }; +use cubecl_core::{CubeDim, ir::ElemType}; use cubecl_core::{ ir::{Operation, SourceLoc}, prelude::{FastMath, KernelDefinition}, }; use cubecl_opt::{Optimizer, SharedLiveness}; -use cubecl_runtime::{DeviceProperties, TypeUsage}; +use cubecl_runtime::{DeviceProperties, EnumSet, TypeUsage}; use crate::shared::MmaShape; @@ -33,6 +33,7 @@ pub struct CompilationOptions { pub warp_size: u32, pub grid_constants: bool, pub supports_clusters: bool, + pub supports_fast_math: bool, } impl Default for CompilationOptions { @@ -41,6 +42,7 @@ impl Default for CompilationOptions { warp_size: 32, grid_constants: false, supports_clusters: false, + supports_fast_math: false, } } } @@ -78,7 +80,6 @@ pub struct Flags { pub indexes: CubeIndexFlags, pub op_barrier: bool, pub op_pipeline: bool, - pub inst_fast_math: bool, pub inst_tma: bool, pub inst_tma_im2col: bool, pub inst_wmma: bool, @@ -172,10 +173,6 @@ impl CppCompiler { elem_bf16: self.flags.elem_bf16, elem_f16: self.flags.elem_f16, elem_tf32: self.flags.elem_tf32, - inst_fast_math: value - .options - .fp_math_mode - .contains(FastMath::ReducedPrecision), inst_tma: self.flags.inst_tma, inst_tma_im2col: self.flags.inst_tma_im2col, use_grid_constants: self.compilation_options.grid_constants, @@ -312,7 +309,9 @@ impl CppCompiler { out: self.compile_variable(out.unwrap()), })); } - gpu::Operation::Arithmetic(op) => self.compile_arithmetic(op, out, instructions), + gpu::Operation::Arithmetic(op) => { + self.compile_arithmetic(op, out, instruction.modes, instructions) + } gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions), gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions), gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions), @@ -420,6 +419,34 @@ impl CppCompiler { out, })) } + gpu::Plane::Shuffle(op) => { + instructions.push(Instruction::Warp(WarpInstruction::Shuffle { + input: self.compile_variable(op.lhs), + src_lane: self.compile_variable(op.rhs), + out, + })) + } + gpu::Plane::ShuffleXor(op) => { + instructions.push(Instruction::Warp(WarpInstruction::ShuffleXor { + input: self.compile_variable(op.lhs), + mask: self.compile_variable(op.rhs), + out, + })) + } + gpu::Plane::ShuffleUp(op) => { + instructions.push(Instruction::Warp(WarpInstruction::ShuffleUp { + input: self.compile_variable(op.lhs), + delta: self.compile_variable(op.rhs), + out, + })) + } + gpu::Plane::ShuffleDown(op) => { + instructions.push(Instruction::Warp(WarpInstruction::ShuffleDown { + input: self.compile_variable(op.lhs), + delta: self.compile_variable(op.rhs), + out, + })) + } } } gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)), @@ -600,7 +627,7 @@ impl CppCompiler { } } } - gpu::Operation::Free(_) => {} + gpu::Operation::Marker(_) => {} } } @@ -905,6 +932,7 @@ impl CppCompiler { &mut self, value: gpu::Arithmetic, out: Option, + modes: InstructionModes, instructions: &mut Vec>, ) { let out = out.unwrap(); @@ -919,7 +947,17 @@ impl CppCompiler { instructions.push(Instruction::Mul(self.compile_binary(op, out))) } gpu::Arithmetic::Div(op) => { - instructions.push(Instruction::Div(self.compile_binary(op, out))) + let op = self.compile_binary(op, out); + instructions.push(self.select_fast_float( + out.ty, + modes, + FastMath::AllowReciprocal + | FastMath::ReducedPrecision + | FastMath::UnsignedZero + | FastMath::NotInf, + Instruction::Div(op), + Instruction::FastDiv(op), + )) } gpu::Arithmetic::Sub(op) => { instructions.push(Instruction::Sub(self.compile_binary(op, out))) @@ -939,27 +977,62 @@ impl CppCompiler { instructions.push(Instruction::Abs(self.compile_unary(op, out))) } gpu::Arithmetic::Exp(op) => { - instructions.push(Instruction::Exp(self.compile_unary(op, out))) + let op = self.compile_unary(op, out); + instructions.push(self.select_fast_float( + out.ty, + modes, + FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf, + Instruction::Exp(op), + Instruction::FastExp(op), + )); } gpu::Arithmetic::Log(op) => { - instructions.push(Instruction::Log(self.compile_unary(op, out))) + let op = self.compile_unary(op, out); + instructions.push(self.select_fast_float( + out.ty, + modes, + FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf, + Instruction::Log(op), + Instruction::FastLog(op), + )); } gpu::Arithmetic::Log1p(op) => { instructions.push(Instruction::Log1p(self.compile_unary(op, out))) } gpu::Arithmetic::Cos(op) => { - instructions.push(Instruction::Cos(self.compile_unary(op, out))) + let op = self.compile_unary(op, out); + instructions.push(self.select_fast_float( + out.ty, + modes, + FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf, + Instruction::Cos(op), + Instruction::FastCos(op), + )); } gpu::Arithmetic::Sin(op) => { - instructions.push(Instruction::Sin(self.compile_unary(op, out))) + let op = self.compile_unary(op, out); + instructions.push(self.select_fast_float( + out.ty, + modes, + FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf, + Instruction::Sin(op), + Instruction::FastSin(op), + )); } gpu::Arithmetic::Tan(op) => { instructions.push(Instruction::Tan(self.compile_unary(op, out))) } gpu::Arithmetic::Tanh(op) => { - let instruction = Instruction::Tanh(self.compile_unary(op, out)); + let op = self.compile_unary(op, out); + let instruction = Instruction::Tanh(op); D::register_instruction_extension(&mut self.extensions, &instruction); - instructions.push(instruction) + instructions.push(self.select_fast_float( + out.ty, + modes, + FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf, + instruction, + Instruction::FastTanh(op), + )) } gpu::Arithmetic::Sinh(op) => { let instruction = Instruction::Sinh(self.compile_unary(op, out)); @@ -1017,16 +1090,37 @@ impl CppCompiler { instructions.push(instruction) } gpu::Arithmetic::Powf(op) => { - instructions.push(Instruction::Powf(self.compile_binary(op, out))) + let op = self.compile_binary(op, out); + instructions.push(self.select_fast_float( + out.ty, + modes, + FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf, + Instruction::Powf(op), + Instruction::FastPowf(op), + )) } gpu::Arithmetic::Powi(op) => { instructions.push(Instruction::Powi(self.compile_binary(op, out))) } gpu::Arithmetic::Sqrt(op) => { - instructions.push(Instruction::Sqrt(self.compile_unary(op, out))) - } - gpu::Arithmetic::Rsqrt(op) => { - instructions.push(Instruction::Rsqrt(self.compile_unary(op, out))) + let op = self.compile_unary(op, out); + instructions.push(self.select_fast_float( + out.ty, + modes, + FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf, + Instruction::Sqrt(op), + Instruction::FastSqrt(op), + )) + } + gpu::Arithmetic::InverseSqrt(op) => { + let op = self.compile_unary(op, out); + instructions.push(self.select_fast_float( + out.ty, + modes, + FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf, + Instruction::InverseSqrt(op), + Instruction::FastInverseSqrt(op), + )) } gpu::Arithmetic::Erf(op) => { let instruction = Instruction::Erf(self.compile_unary(op, out)); @@ -1051,18 +1145,31 @@ impl CppCompiler { }), gpu::Arithmetic::Recip(op) => { let elem = op.input.ty.elem_type(); + let input = self.compile_variable(op.input); + let out = self.compile_variable(out); let lhs = match elem { gpu::ElemType::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind), gpu::ElemType::Int(kind) => gpu::ConstantScalarValue::Int(1, kind), gpu::ElemType::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind), gpu::ElemType::Bool => gpu::ConstantScalarValue::Bool(true), }; - - instructions.push(Instruction::Div(BinaryInstruction { + let div = Instruction::Div(BinaryInstruction { lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)), - rhs: self.compile_variable(op.input), - out: self.compile_variable(out), - })) + rhs: input, + out, + }); + let recip = Instruction::FastRecip(UnaryInstruction { input, out }); + + instructions.push(self.select_fast_float( + elem.into(), + modes, + FastMath::AllowReciprocal + | FastMath::ReducedPrecision + | FastMath::UnsignedZero + | FastMath::NotInf, + div, + recip, + )) } gpu::Arithmetic::Round(op) => { instructions.push(Instruction::Round(self.compile_unary(op, out))) @@ -1073,6 +1180,9 @@ impl CppCompiler { gpu::Arithmetic::Ceil(op) => { instructions.push(Instruction::Ceil(self.compile_unary(op, out))) } + gpu::Arithmetic::Trunc(op) => { + instructions.push(Instruction::Trunc(self.compile_unary(op, out))) + } gpu::Arithmetic::Remainder(op) => { instructions.push(Instruction::Remainder(self.compile_binary(op, out))) } @@ -1086,10 +1196,24 @@ impl CppCompiler { instructions.push(Instruction::Neg(self.compile_unary(op, out))) } gpu::Arithmetic::Normalize(op) => { - instructions.push(Instruction::Normalize(self.compile_unary(op, out))) + let op = self.compile_unary(op, out); + instructions.push(self.select_fast_float( + out.ty, + modes, + FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf, + Instruction::Normalize(op), + Instruction::FastNormalize(op), + )) } gpu::Arithmetic::Magnitude(op) => { - instructions.push(Instruction::Magnitude(self.compile_unary(op, out))) + let op = self.compile_unary(op, out); + instructions.push(self.select_fast_float( + out.ty, + modes, + FastMath::ReducedPrecision | FastMath::NotNaN | FastMath::NotInf, + Instruction::Magnitude(op), + Instruction::FastMagnitude(op), + )) } gpu::Arithmetic::Dot(op) => { instructions.push(Instruction::Dot(self.compile_binary(op, out))) @@ -1097,6 +1221,27 @@ impl CppCompiler { }; } + fn select_fast_float( + &self, + ty: gpu::Type, + modes: InstructionModes, + required_flags: EnumSet, + default: Instruction, + fast: Instruction, + ) -> Instruction { + if !self.compilation_options.supports_fast_math + || !matches!(ty.elem_type(), ElemType::Float(FloatKind::F32)) + { + return default; + } + + if modes.fp_math_mode.is_superset(required_flags) { + fast + } else { + default + } + } + fn compile_comparison( &mut self, value: gpu::Comparison, diff --git a/crates/cubecl-cpp/src/shared/binary.rs b/crates/cubecl-cpp/src/shared/binary.rs index 3f7c28f3b..35e918c63 100644 --- a/crates/cubecl-cpp/src/shared/binary.rs +++ b/crates/cubecl-cpp/src/shared/binary.rs @@ -129,6 +129,20 @@ operator!(BitwiseXor, "^"); operator!(Or, "||"); operator!(And, "&&"); +pub struct FastDiv; + +impl Binary for FastDiv { + fn format_scalar( + f: &mut std::fmt::Formatter<'_>, + lhs: Lhs, + rhs: Rhs, + _out_item: Item, + ) -> std::fmt::Result { + // f32 only + write!(f, "__fdividef({lhs}, {rhs})") + } +} + pub struct HiMul; impl Binary for HiMul { @@ -248,6 +262,20 @@ impl Binary for Powf { } } +pub struct FastPowf; + +impl Binary for FastPowf { + // Only executed for f32 + fn format_scalar( + f: &mut std::fmt::Formatter<'_>, + lhs: Lhs, + rhs: Rhs, + _item: Item, + ) -> std::fmt::Result { + write!(f, "__powf({lhs}, {rhs})") + } +} + pub struct Powi; impl Binary for Powi { diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index d7a12d20c..1329cd7e5 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -13,7 +13,7 @@ use std::{ pub(crate) const INFO_NAME: &str = "info"; pub(crate) const STATIC_INFO_NAME: &str = "static_info"; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct BinaryInstruction { pub lhs: Variable, pub rhs: Variable, @@ -36,7 +36,7 @@ pub struct IndexAssignInstruction { pub out: Variable, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct UnaryInstruction { pub input: Variable, pub out: Variable, @@ -78,6 +78,8 @@ pub enum Instruction { out: Variable, }, Div(BinaryInstruction), + FastDiv(BinaryInstruction), + FastRecip(UnaryInstruction), Mul(BinaryInstruction), Sub(BinaryInstruction), SaturatingSub(BinaryInstruction), @@ -160,9 +162,12 @@ pub enum Instruction { FindFirstSet(UnaryInstruction), Abs(UnaryInstruction), Exp(UnaryInstruction), + FastExp(UnaryInstruction), Log(UnaryInstruction), + FastLog(UnaryInstruction), Log1p(UnaryInstruction), Cos(UnaryInstruction), + FastCos(UnaryInstruction), Sin(UnaryInstruction), Tan(UnaryInstruction), Tanh(UnaryInstruction), @@ -177,10 +182,15 @@ pub enum Instruction { Degrees(UnaryInstruction), Radians(UnaryInstruction), ArcTan2(BinaryInstruction), + FastSin(UnaryInstruction), + FastTanh(UnaryInstruction), Powf(BinaryInstruction), + FastPowf(BinaryInstruction), Powi(BinaryInstruction), Sqrt(UnaryInstruction), - Rsqrt(UnaryInstruction), + FastSqrt(UnaryInstruction), + InverseSqrt(UnaryInstruction), + FastInverseSqrt(UnaryInstruction), Min(BinaryInstruction), Max(BinaryInstruction), Not(UnaryInstruction), @@ -213,6 +223,7 @@ pub enum Instruction { }, Round(UnaryInstruction), Ceil(UnaryInstruction), + Trunc(UnaryInstruction), Floor(UnaryInstruction), Warp(WarpInstruction), Wmma(WmmaInstruction), @@ -235,7 +246,9 @@ pub enum Instruction { }, Neg(UnaryInstruction), Magnitude(UnaryInstruction), + FastMagnitude(UnaryInstruction), Normalize(UnaryInstruction), + FastNormalize(UnaryInstruction), Dot(BinaryInstruction), Copy { input: Variable, @@ -324,6 +337,8 @@ impl Display for Instruction { } Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::FastDiv(it) => FastDiv::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::FastRecip(it) => FastRecip::format(f, &it.input, &it.out), Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out), Instruction::SaturatingSub(it) => SaturatingSub::format(f, &it.lhs, &it.rhs, &it.out), Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out), @@ -524,9 +539,12 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ Instruction::Erf(it) => Erf::format(f, &it.input, &it.out), Instruction::Abs(it) => Abs::format(f, &it.input, &it.out), Instruction::Exp(it) => Exp::format(f, &it.input, &it.out), + Instruction::FastExp(it) => FastExp::format(f, &it.input, &it.out), Instruction::Log(it) => Log::format(f, &it.input, &it.out), + Instruction::FastLog(it) => FastLog::format(f, &it.input, &it.out), Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out), Instruction::Cos(it) => Cos::format(f, &it.input, &it.out), + Instruction::FastCos(it) => FastCos::format(f, &it.input, &it.out), Instruction::Sin(it) => Sin::format(f, &it.input, &it.out), Instruction::Tan(it) => Tan::format(f, &it.input, &it.out), Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out), @@ -541,10 +559,15 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ Instruction::Degrees(it) => Degrees::format(f, &it.input, &it.out), Instruction::Radians(it) => Radians::format(f, &it.input, &it.out), Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::FastSin(it) => FastSin::format(f, &it.input, &it.out), + Instruction::FastTanh(it) => FastTanh::format(f, &it.input, &it.out), Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::FastPowf(it) => FastPowf::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out), - Instruction::Rsqrt(it) => Rsqrt::format(f, &it.input, &it.out), + Instruction::FastSqrt(it) => FastSqrt::format(f, &it.input, &it.out), + Instruction::InverseSqrt(it) => InverseSqrt::format(f, &it.input, &it.out), + Instruction::FastInverseSqrt(it) => FastInverseSqrt::format(f, &it.input, &it.out), Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Not(it) => Not::format(f, &it.input, &it.out), @@ -564,6 +587,7 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ Instruction::ThreadFence => f.write_str("__threadfence();\n"), Instruction::Round(it) => Round::format(f, &it.input, &it.out), Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out), + Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out), Instruction::Floor(it) => Floor::format(f, &it.input, &it.out), Instruction::SliceLength { input, out } => { let out = out.fmt_left(); @@ -635,8 +659,16 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ let out = out.fmt_left(); writeln!(f, "{out} = -{input};") } - Instruction::Normalize(inst) => Normalize::format(f, &inst.input, &inst.out), - Instruction::Magnitude(inst) => Magnitude::format(f, &inst.input, &inst.out), + Instruction::Normalize(inst) => { + Normalize::::format(f, &inst.input, &inst.out) + } + Instruction::FastNormalize(inst) => { + Normalize::::format(f, &inst.input, &inst.out) + } + Instruction::Magnitude(inst) => Magnitude::::format(f, &inst.input, &inst.out), + Instruction::FastMagnitude(inst) => { + Magnitude::::format(f, &inst.input, &inst.out) + } Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out), Instruction::VecInit { inputs, out } => { let item = out.item(); @@ -928,11 +960,12 @@ impl Remainder { } } -struct Magnitude { +struct Magnitude> { _dialect: PhantomData, + _sqrt: PhantomData, } -impl Magnitude { +impl> Magnitude { fn format( f: &mut core::fmt::Formatter<'_>, input: &Variable, @@ -952,16 +985,17 @@ impl Magnitude { let out = out.fmt_left(); write!(f, "{out} = ")?; - Sqrt::format_unary(f, &mag, elem)?; + S::format_unary(f, &mag, elem)?; f.write_str(";\n") } } -struct Normalize { +struct Normalize> { _dialect: PhantomData, + _rsqrt: PhantomData, } -impl Normalize { +impl> Normalize { fn format( f: &mut core::fmt::Formatter<'_>, input: &Variable, @@ -981,17 +1015,17 @@ impl Normalize { } write!(f, "{norm} = ")?; - Sqrt::format_unary(f, &norm, elem)?; + InvS::format_unary(f, &norm, elem)?; f.write_str(";\n")?; if num == 1 { - writeln!(f, "{out} = {input} / {norm};") + writeln!(f, "{out} = {input} * {norm};") } else { write!(f, "{out} = {out_item}{{")?; for i in 0..num { let input_i = input.index(i); - writeln!(f, "{input_i} / {norm},")?; + writeln!(f, "{input_i} * {norm},")?; } f.write_str("};\n") diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index 6a55feeb9..1e6457c75 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -149,23 +149,33 @@ macro_rules! function { } function!(Log, "log"); -function!(Cos, "cos"); +function!(FastLog, "__logf", false); function!(Sin, "sin"); +function!(Cos, "cos"); function!(Tan, "tan"); function!(Sinh, "sinh", false); function!(Cosh, "cosh", false); +// Tanh is separete below, idk why function!(ArcCos, "acos", false); function!(ArcSin, "asin", false); function!(ArcTan, "atan", false); function!(ArcSinh, "asinh", false); function!(ArcCosh, "acosh", false); function!(ArcTanh, "atanh", false); +function!(FastSin, "__sinf", false); +function!(FastCos, "__cosf", false); function!(Sqrt, "sqrt"); -function!(Rsqrt, "rsqrt"); +function!(InverseSqrt, "rsqrt"); +function!(FastSqrt, "__fsqrt_rn", false); +function!(FastInverseSqrt, "__frsqrt_rn", false); function!(Exp, "exp"); +function!(FastExp, "__expf", false); function!(Ceil, "ceil"); +function!(Trunc, "trunc"); function!(Floor, "floor"); function!(Round, "rint"); +function!(FastRecip, "__frcp_rn", false); +function!(FastTanh, "__tanhf", false); function!(Erf, "erf", false); function!(Abs, "abs", false); diff --git a/crates/cubecl-cpp/src/shared/warp.rs b/crates/cubecl-cpp/src/shared/warp.rs index 1ec101ff1..efe1b3928 100644 --- a/crates/cubecl-cpp/src/shared/warp.rs +++ b/crates/cubecl-cpp/src/shared/warp.rs @@ -58,6 +58,26 @@ pub enum WarpInstruction { id: Variable, out: Variable, }, + Shuffle { + input: Variable, + src_lane: Variable, + out: Variable, + }, + ShuffleXor { + input: Variable, + mask: Variable, + out: Variable, + }, + ShuffleUp { + input: Variable, + delta: Variable, + out: Variable, + }, + ShuffleDown { + input: Variable, + delta: Variable, + out: Variable, + }, } impl Display for WarpInstruction { @@ -98,6 +118,67 @@ impl Display for WarpInstruction { writeln!(f, ", 0, 0, 0 }};") } WarpInstruction::Broadcast { input, id, out } => reduce_broadcast(f, input, out, id), + WarpInstruction::Shuffle { + input, + src_lane, + out, + } => { + let out_fmt = out.fmt_left(); + write!(f, "{out_fmt} = {{ ")?; + for i in 0..input.item().vectorization { + let comma = if i > 0 { ", " } else { "" }; + write!(f, "{comma}")?; + D::compile_warp_shuffle( + f, + &format!("{}", input.index(i)), + &format!("{src_lane}"), + )?; + } + writeln!(f, " }};") + } + WarpInstruction::ShuffleXor { input, mask, out } => { + let out_fmt = out.fmt_left(); + write!(f, "{out_fmt} = {{ ")?; + for i in 0..input.item().vectorization { + let comma = if i > 0 { ", " } else { "" }; + write!(f, "{comma}")?; + D::compile_warp_shuffle_xor( + f, + &format!("{}", input.index(i)), + input.item().elem(), + &format!("{mask}"), + )?; + } + writeln!(f, " }};") + } + WarpInstruction::ShuffleUp { input, delta, out } => { + let out_fmt = out.fmt_left(); + write!(f, "{out_fmt} = {{ ")?; + for i in 0..input.item().vectorization { + let comma = if i > 0 { ", " } else { "" }; + write!(f, "{comma}")?; + D::compile_warp_shuffle_up( + f, + &format!("{}", input.index(i)), + &format!("{delta}"), + )?; + } + writeln!(f, " }};") + } + WarpInstruction::ShuffleDown { input, delta, out } => { + let out_fmt = out.fmt_left(); + write!(f, "{out_fmt} = {{ ")?; + for i in 0..input.item().vectorization { + let comma = if i > 0 { ", " } else { "" }; + write!(f, "{comma}")?; + D::compile_warp_shuffle_down( + f, + &format!("{}", input.index(i)), + &format!("{delta}"), + )?; + } + writeln!(f, " }};") + } WarpInstruction::Elect { out } => write!( f, " diff --git a/crates/cubecl-cpu/Cargo.toml b/crates/cubecl-cpu/Cargo.toml index bad3798a5..4f595eec4 100644 --- a/crates/cubecl-cpu/Cargo.toml +++ b/crates/cubecl-cpu/Cargo.toml @@ -73,20 +73,20 @@ matmul_tests_all = [ conv_tests = ["cubecl-convolution/conv_tests"] [dependencies] -cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false } -cubecl-std = { path = "../cubecl-std", version = "0.7.0", default-features = false } -cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false, features = [ +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false } +cubecl-std = { path = "../cubecl-std", version = "0.9.0", default-features = false } +cubecl-core = { path = "../cubecl-core", version = "0.9.0", default-features = false } +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false, features = [ "channel-mutex", ] } -cubecl-opt = { path = "../cubecl-opt", version = "0.7.0", default-features = false } -cubecl-matmul = { path = "../cubecl-matmul", version = "0.7.0", features = [ +cubecl-opt = { path = "../cubecl-opt", version = "0.9.0", default-features = false } +cubecl-matmul = { path = "../cubecl-matmul", version = "0.9.0", features = [ "export_tests", ] } -cubecl-convolution = { path = "../cubecl-convolution", version = "0.7.0", features = [ +cubecl-convolution = { path = "../cubecl-convolution", version = "0.9.0", features = [ "export_tests", ] } -cubecl-reduce = { path = "../cubecl-reduce", version = "0.7.0", features = [ +cubecl-reduce = { path = "../cubecl-reduce", version = "0.9.0", features = [ "export_tests", ] } @@ -99,20 +99,20 @@ sysinfo = { workspace = true } tracel-llvm = { workspace = true } [dev-dependencies] -cubecl-core = { path = "../cubecl-core", version = "0.7.0", features = [ +cubecl-core = { path = "../cubecl-core", version = "0.9.0", features = [ "export_tests", ] } -cubecl-reduce = { path = "../cubecl-reduce", version = "0.7.0", features = [ +cubecl-reduce = { path = "../cubecl-reduce", version = "0.9.0", features = [ "export_tests", ] } -cubecl-random = { path = "../cubecl-random", version = "0.7.0", features = [ +cubecl-random = { path = "../cubecl-random", version = "0.9.0", features = [ "export_tests", ] } -cubecl-std = { path = "../cubecl-std", version = "0.7.0", features = [ +cubecl-std = { path = "../cubecl-std", version = "0.9.0", features = [ "export_tests", ] } paste = { workspace = true } pretty_assertions = { workspace = true } [build-dependencies] -tracel-llvm-bundler = { version = "20.1.4-5" } \ No newline at end of file +tracel-llvm-bundler = { version = "20.1.4-5" } diff --git a/crates/cubecl-cpu/src/compiler/mlir_data.rs b/crates/cubecl-cpu/src/compiler/mlir_data.rs index f12b17744..a41773059 100644 --- a/crates/cubecl-cpu/src/compiler/mlir_data.rs +++ b/crates/cubecl-cpu/src/compiler/mlir_data.rs @@ -41,6 +41,7 @@ impl MlirData { bindings: Bindings, shared_memories: &SharedMemories, memory_management: &mut MemoryManagement, + memory_management_shared_memory: &mut MemoryManagement, ) -> Self { let Bindings { buffers, @@ -90,17 +91,22 @@ impl MlirData { } let stream_id = StreamId::current(); + let mut smem_handles = Vec::with_capacity(shared_memories.0.len()); for shared_memory in shared_memories.0.iter() { let length = (shared_memory.ty.size() * shared_memory.length as usize) as u64; - let handle = memory_management.reserve(length).unwrap(); + let handle = memory_management_shared_memory.reserve(length).unwrap(); + smem_handles.push(handle.clone()); + let b = Handle::new(handle, None, None, stream_id, 0, length).binding(); - let mut handle = memory_management + let mut handle = memory_management_shared_memory .get_resource(b.memory, b.offset_start, b.offset_end) .expect("Failed to find resource"); let ptr = handle.write(); let line_memref = LineMemRef::new(ptr); push_undirected(line_memref); } + // It is important to make sure multiple shared memories don't shared the same handle. + core::mem::drop(smem_handles); let ptr = shared_mlir_data.metadata.as_mut(); let line_memref = LineMemRef::new(ptr); diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index 423fac127..ab456cf3a 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -104,6 +104,15 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result); } + Arithmetic::Trunc(trunc) => { + let value = self.get_variable(trunc.input); + let result = self.append_operation_with_result(llvm_ods::intr_trunc( + self.context, + value, + self.location, + )); + self.insert_variable(out, result); + } Arithmetic::Clamp(clamp) => { let value = self.get_variable(clamp.input); let mut min = self.get_variable(clamp.min_value); @@ -479,15 +488,6 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, output); } - Arithmetic::Rsqrt(rsqrt) => { - let input = self.get_variable(rsqrt.input); - let output = self.append_operation_with_result(math_ods::rsqrt( - self.context, - input, - self.location, - )); - self.insert_variable(out, output); - } Arithmetic::Sin(sin) => { let input = self.get_variable(sin.input); let output = self.append_operation_with_result(llvm_ods::intr_sin( @@ -506,6 +506,18 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result); } + Arithmetic::InverseSqrt(sqrt) => { + let input = self.get_variable(sqrt.input); + let value = self.append_operation_with_result(llvm_ods::intr_sqrt( + self.context, + input, + self.location, + )); + let one = self.create_float_constant_from_item(sqrt.input.ty, 1.0); + let recip = + self.append_operation_with_result(arith::divf(one, value, self.location)); + self.insert_variable(out, recip); + } Arithmetic::Sqrt(sqrt) => { let input = self.get_variable(sqrt.input); let output = self.append_operation_with_result(llvm_ods::intr_sqrt( diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/mod.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/mod.rs index 184af70e5..e87925789 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/mod.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/mod.rs @@ -98,7 +98,7 @@ impl<'a> Visitor<'a> { Operation::Branch(_) => { unreachable!("Branch operation are removed in SSA form"); } - Operation::Synchronization(_) | Operation::NonSemantic(_) | Operation::Free(_) => { + Operation::Synchronization(_) | Operation::NonSemantic(_) | Operation::Marker(_) => { unreachable!("{operation} doesn't have an out"); } } diff --git a/crates/cubecl-cpu/src/compute/scheduler.rs b/crates/cubecl-cpu/src/compute/scheduler.rs index 99450c1dc..ac6da2873 100644 --- a/crates/cubecl-cpu/src/compute/scheduler.rs +++ b/crates/cubecl-cpu/src/compute/scheduler.rs @@ -49,6 +49,7 @@ impl Scheduler { bindings: Bindings, kind: ExecutionMode, memory_management: &mut MemoryManagement, + memory_management_shared_memory: &mut MemoryManagement, ) { let kernel = self .compilation_cache @@ -65,8 +66,12 @@ impl Scheduler { let cube_dim_size = cube_dim.num_elems(); let mlir_engine = kernel.repr.clone().unwrap(); - let mut mlir_data = - MlirData::new(bindings, &mlir_engine.0.shared_memories, memory_management); + let mut mlir_data = MlirData::new( + bindings, + &mlir_engine.0.shared_memories, + memory_management, + memory_management_shared_memory, + ); mlir_data.builtin.set_cube_dim(cube_dim); mlir_data.builtin.set_cube_count(cube_count); diff --git a/crates/cubecl-cpu/src/compute/server.rs b/crates/cubecl-cpu/src/compute/server.rs index 9b51c3b33..4ecdfc794 100644 --- a/crates/cubecl-cpu/src/compute/server.rs +++ b/crates/cubecl-cpu/src/compute/server.rs @@ -7,7 +7,7 @@ use cubecl_core::{ future::DynFut, server::{ Allocation, AllocationDescriptor, Binding, Bindings, ComputeServer, CopyDescriptor, Handle, - IoError, ProfileError, ProfilingToken, ServerCommunication, + IoError, ProfileError, ProfilingToken, ServerCommunication, ServerUtilities, }, }; use cubecl_runtime::{ @@ -25,13 +25,13 @@ use super::scheduler::Scheduler; pub struct CpuServer { ctx: CpuContext, scheduler: Scheduler, - logger: Arc, + utilities: Arc>, } impl CpuServer { - pub fn new(ctx: CpuContext) -> Self { + pub fn new(ctx: CpuContext, utilities: ServerUtilities) -> Self { Self { - logger: Arc::new(ServerLogger::default()), + utilities: Arc::new(utilities), scheduler: Scheduler::default(), ctx, } @@ -41,13 +41,18 @@ impl CpuServer { #[derive(Debug)] pub struct CpuContext { memory_management: MemoryManagement, + memory_management_shared_memory: MemoryManagement, timestamps: TimestampProfiler, } impl CpuContext { - pub fn new(memory_management: MemoryManagement) -> Self { + pub fn new( + memory_management: MemoryManagement, + memory_management_shared_memory: MemoryManagement, + ) -> Self { Self { memory_management, + memory_management_shared_memory, timestamps: TimestampProfiler::default(), } } @@ -88,7 +93,11 @@ impl ComputeServer for CpuServer { type Info = (); fn logger(&self) -> Arc { - self.logger.clone() + self.utilities.logger.clone() + } + + fn utilities(&self) -> Arc> { + self.utilities.clone() } fn create( @@ -181,13 +190,14 @@ impl ComputeServer for CpuServer { bindings, kind, &mut self.ctx.memory_management, + &mut self.ctx.memory_management_shared_memory, ); } fn flush(&mut self, _stream_id: StreamId) {} fn sync(&mut self, _stream_id: StreamId) -> DynFut<()> { - self.logger.profile_summary(); + self.utilities.logger.profile_summary(); Box::pin(async move {}) } @@ -201,7 +211,7 @@ impl ComputeServer for CpuServer { stream_id: StreamId, token: ProfilingToken, ) -> Result { - self.logger.profile_summary(); + self.utilities.logger.profile_summary(); cubecl_common::future::block_on(self.sync(stream_id)); self.ctx.timestamps.stop(token) } diff --git a/crates/cubecl-cpu/src/lib.rs b/crates/cubecl-cpu/src/lib.rs index 6a995030a..0a96f2f57 100644 --- a/crates/cubecl-cpu/src/lib.rs +++ b/crates/cubecl-cpu/src/lib.rs @@ -11,6 +11,7 @@ mod tests { cubecl_core::testgen_all!(f32: [f16, f32, f64], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]); cubecl_std::testgen!(); cubecl_std::testgen_tensor_identity!([f16, f32, u32]); + cubecl_std::testgen_quantized_view!(f32); cubecl_random::testgen_random!(); cubecl_matmul::testgen_matmul_simple!([f16, f32]); cubecl_matmul::testgen_matmul_unit!(); diff --git a/crates/cubecl-cpu/src/runtime.rs b/crates/cubecl-cpu/src/runtime.rs index d084693e7..31a68b164 100644 --- a/crates/cubecl-cpu/src/runtime.rs +++ b/crates/cubecl-cpu/src/runtime.rs @@ -1,13 +1,16 @@ -use cubecl_common::profile::TimingMethod; +use cubecl_common::{device::DeviceState, profile::TimingMethod}; use cubecl_core::{ CubeCount, CubeDim, MemoryConfiguration, Runtime, - channel::MpscComputeChannel, client::ComputeClient, ir::{StorageType, TargetProperties}, + server::ServerUtilities, }; use cubecl_runtime::{ - ComputeRuntime, DeviceProperties, - memory_management::{HardwareProperties, MemoryDeviceProperties, MemoryManagement}, + DeviceProperties, + logging::ServerLogger, + memory_management::{ + HardwareProperties, MemoryDeviceProperties, MemoryManagement, MemoryManagementOptions, + }, storage::BytesStorage, }; use cubecl_std::tensor::is_contiguous; @@ -30,69 +33,78 @@ pub struct RuntimeOptions { #[derive(Debug)] pub struct CpuRuntime; -static RUNTIME: ComputeRuntime = ComputeRuntime::new(); - pub type CpuCompiler = MlirCompiler; -type Server = CpuServer; -type Channel = MpscComputeChannel; - -fn create_client(options: RuntimeOptions) -> ComputeClient { - let max_cube_dim = CubeDim::new(u32::MAX, u32::MAX, u32::MAX); - let max_cube_count = CubeCount::Static(64, 64, 64); - let system = System::new_all(); - let max_shared_memory_size = system - .cgroup_limits() - .map(|g| g.total_memory) - .unwrap_or(system.total_memory()) as usize; - - let topology = HardwareProperties { - plane_size_min: 1, - plane_size_max: 1, - max_bindings: u32::MAX, - max_shared_memory_size, - max_cube_count, - max_units_per_cube: u32::MAX, - max_cube_dim, - num_streaming_multiprocessors: None, - num_tensor_cores: None, - min_tensor_cores_dim: None, - }; - let storage = BytesStorage::default(); - - const ALIGNMENT: u64 = 4; - let mem_properties = MemoryDeviceProperties { - max_page_size: max_shared_memory_size as u64, - alignment: ALIGNMENT, - }; - - let memory_management = - MemoryManagement::from_configuration(storage, &mem_properties, options.memory_config); - let mut device_props = DeviceProperties::new( - Default::default(), - mem_properties, - topology, - TimingMethod::Device, - ); - register_supported_types(&mut device_props); - - let ctx = CpuContext::new(memory_management); - let server = CpuServer::new(ctx); - ComputeClient::new(Channel::new(server), device_props, ()) +impl DeviceState for CpuServer { + fn init(_device_id: cubecl_common::device::DeviceId) -> Self { + let options = RuntimeOptions::default(); + let max_cube_dim = CubeDim::new(u32::MAX, u32::MAX, u32::MAX); + let max_cube_count = CubeCount::Static(64, 64, 64); + let system = System::new_all(); + let max_shared_memory_size = system + .cgroup_limits() + .map(|g| g.total_memory) + .unwrap_or(system.total_memory()) as usize; + let logger = cubecl_common::stub::Arc::new(ServerLogger::default()); + + let topology = HardwareProperties { + plane_size_min: 1, + plane_size_max: 1, + max_bindings: u32::MAX, + max_shared_memory_size, + max_cube_count, + max_units_per_cube: u32::MAX, + max_cube_dim, + num_streaming_multiprocessors: None, + num_tensor_cores: None, + min_tensor_cores_dim: None, + }; + + const ALIGNMENT: u64 = 4; + let mem_properties = MemoryDeviceProperties { + max_page_size: max_shared_memory_size as u64, + alignment: ALIGNMENT, + }; + + let memory_management = MemoryManagement::from_configuration( + BytesStorage::default(), + &mem_properties, + options.memory_config, + logger.clone(), + MemoryManagementOptions::new("Main CPU"), + ); + let memory_management_shared_memory = MemoryManagement::from_configuration( + BytesStorage::default(), + &mem_properties, + MemoryConfiguration::ExclusivePages, + logger.clone(), + MemoryManagementOptions::new("Shared Memory"), + ); + + let mut device_props = DeviceProperties::new( + Default::default(), + mem_properties, + topology, + TimingMethod::Device, + ); + register_supported_types(&mut device_props); + + let ctx = CpuContext::new(memory_management, memory_management_shared_memory); + let utilities = ServerUtilities::new(device_props, logger, ()); + CpuServer::new(ctx, utilities) + } } impl Runtime for CpuRuntime { type Compiler = CpuCompiler; type Server = CpuServer; - - type Channel = Channel; type Device = CpuDevice; - fn client(_device: &Self::Device) -> ComputeClient { - RUNTIME.client(_device, move || create_client(RuntimeOptions::default())) + fn client(device: &Self::Device) -> ComputeClient { + ComputeClient::load(device) } - fn name(_client: &ComputeClient) -> &'static str { + fn name(_client: &ComputeClient) -> &'static str { "cpu" } @@ -106,8 +118,9 @@ impl Runtime for CpuRuntime { supported.iter().filter(move |v| **v <= max).cloned() } - fn io_optimized_line_sizes_unchecked(elem: &StorageType) -> impl Iterator + Clone { - let max = LOAD_WIDTH / elem.size_bits(); + fn io_optimized_line_sizes_unchecked(elem_size: usize) -> impl Iterator + Clone { + let elem_size_bits = elem_size * 8; + let max = LOAD_WIDTH / elem_size_bits; (1..max as u8).rev().filter(|v| v.is_power_of_two()) } diff --git a/crates/cubecl-cuda/Cargo.toml b/crates/cubecl-cuda/Cargo.toml index d6c95d283..5fed9208d 100644 --- a/crates/cubecl-cuda/Cargo.toml +++ b/crates/cubecl-cuda/Cargo.toml @@ -17,14 +17,12 @@ default = [ "cubecl-common/default", "cubecl-core/default", "cudarc/dynamic-loading", - "cuda-12050", + "cudarc/cuda-version-from-build-system", + "cudarc/fallback-latest", ] ptx-wmma = [] std = ["cubecl-runtime/std", "cubecl-common/std", "cubecl-core/std"] -cuda-12050 = ["cudarc/cuda-12050"] -cuda-12080 = ["cudarc/cuda-12080"] - attention_tests = ["cubecl-attention/attention_tests"] conv_tests = ["cubecl-convolution/conv_tests"] matmul_tests_all = [ @@ -80,14 +78,14 @@ matmul_tests_unit = ["cubecl-matmul/matmul_tests_unit"] matmul_tests_vecmat = ["cubecl-matmul/matmul_tests_vecmat"] [dependencies] -cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false, features = [ +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false, features = [ "cache", ] } -cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false } -cubecl-cpp = { path = "../cubecl-cpp", version = "0.7.0", default-features = false, features = [ +cubecl-core = { path = "../cubecl-core", version = "0.9.0", default-features = false } +cubecl-cpp = { path = "../cubecl-cpp", version = "0.9.0", default-features = false, features = [ "cuda", ] } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false, features = [ +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false, features = [ "channel-mutex", ] } @@ -100,30 +98,33 @@ log = { workspace = true } serde = { workspace = true } [dev-dependencies] -cubecl-attention = { path = "../cubecl-attention", version = "0.7.0", features = [ +cubecl-attention = { path = "../cubecl-attention", version = "0.9.0", features = [ "export_tests", ] } -cubecl-convolution = { path = "../cubecl-convolution", version = "0.7.0", features = [ +cubecl-convolution = { path = "../cubecl-convolution", version = "0.9.0", features = [ "export_tests", ] } -cubecl-core = { path = "../cubecl-core", version = "0.7.0", features = [ +cubecl-core = { path = "../cubecl-core", version = "0.9.0", features = [ "export_tests", ] } -cubecl-matmul = { path = "../cubecl-matmul", version = "0.7.0", features = [ +cubecl-matmul = { path = "../cubecl-matmul", version = "0.9.0", features = [ "export_tests", ] } -cubecl-quant = { path = "../cubecl-quant", version = "0.7.0", features = [ +cubecl-quant = { path = "../cubecl-quant", version = "0.9.0", features = [ "export_tests", "kernels", ] } -cubecl-random = { path = "../cubecl-random", version = "0.7.0", features = [ +cubecl-random = { path = "../cubecl-random", version = "0.9.0", features = [ "export_tests", ] } -cubecl-reduce = { path = "../cubecl-reduce", version = "0.7.0", features = [ +cubecl-reduce = { path = "../cubecl-reduce", version = "0.9.0", features = [ "export_tests", ] } -cubecl-std = { path = "../cubecl-std", version = "0.7.0", features = [ +cubecl-std = { path = "../cubecl-std", version = "0.9.0", features = [ "export_tests", ] } paste = { workspace = true } pretty_assertions = { workspace = true } + +[build-dependencies] +cudarc = { workspace = true } diff --git a/crates/cubecl-cuda/build.rs b/crates/cubecl-cuda/build.rs new file mode 100644 index 000000000..d802e63c2 --- /dev/null +++ b/crates/cubecl-cuda/build.rs @@ -0,0 +1,13 @@ +use cudarc::driver::sys::CUDA_VERSION; + +fn main() { + println!("cargo::rustc-check-cfg=cfg(cuda_12050)"); + println!("cargo::rustc-check-cfg=cfg(cuda_12080)"); + + if CUDA_VERSION >= 12050 { + println!("cargo:rustc-cfg=cuda_12050"); + } + if CUDA_VERSION >= 12080 { + println!("cargo:rustc-cfg=cuda_12080"); + } +} diff --git a/crates/cubecl-cuda/src/compute/command.rs b/crates/cubecl-cuda/src/compute/command.rs index 920f7ff98..2c8177d19 100644 --- a/crates/cubecl-cuda/src/compute/command.rs +++ b/crates/cubecl-cuda/src/compute/command.rs @@ -423,7 +423,14 @@ pub(crate) unsafe fn write_to_gpu( dstPitch: pitch, WidthInBytes: width_bytes, Height: dim_y, - ..Default::default() + srcXInBytes: Default::default(), + srcY: Default::default(), + srcDevice: Default::default(), + srcArray: Default::default(), + dstXInBytes: Default::default(), + dstY: Default::default(), + dstHost: Default::default(), + dstArray: Default::default(), }; unsafe { @@ -474,7 +481,14 @@ pub(crate) unsafe fn write_to_cpu( dstPitch: width_bytes, WidthInBytes: width_bytes, Height: dim_y, - ..Default::default() + srcXInBytes: Default::default(), + srcY: Default::default(), + srcArray: Default::default(), + dstXInBytes: Default::default(), + dstY: Default::default(), + dstArray: Default::default(), + srcHost: Default::default(), + dstDevice: Default::default(), }; unsafe { diff --git a/crates/cubecl-cuda/src/compute/context.rs b/crates/cubecl-cuda/src/compute/context.rs index 7a525e152..a3a7f3e2a 100644 --- a/crates/cubecl-cuda/src/compute/context.rs +++ b/crates/cubecl-cuda/src/compute/context.rs @@ -117,7 +117,6 @@ impl CudaContext { let compute_kernel = kernel_compiled.repr.as_ref().unwrap(); let cube_dim = kernel_compiled.cube_dim; - let fast_math = compute_kernel.flags.inst_fast_math; let arch = if self.arch.version >= 90 { format!("--gpu-architecture=sm_{}a", self.arch) } else { @@ -129,9 +128,6 @@ impl CudaContext { let cccl_include_path = cccl_include_path(); let cccl_include_option = format!("--include-path={}", cccl_include_path.to_str().unwrap()); let mut options = vec![arch.as_str(), include_option.as_str(), "-lineinfo"]; - if fast_math { - options.push("--use_fast_math"); - } if cccl_include_path.exists() { options.push(&cccl_include_option); } diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs index fb5cd26fe..baa281157 100644 --- a/crates/cubecl-cuda/src/compute/server.rs +++ b/crates/cubecl-cuda/src/compute/server.rs @@ -5,8 +5,7 @@ use crate::compute::context::CudaContext; use crate::compute::stream::CudaStreamBackend; use crate::compute::sync::Fence; use cubecl_common::{bytes::Bytes, profile::ProfileDuration, stream_id::StreamId}; -use cubecl_core::ir::{ElemType, IntKind, UIntKind}; -use cubecl_core::server::{Binding, ServerCommunication}; +use cubecl_core::server::{Binding, ServerCommunication, ServerUtilities}; use cubecl_core::{MemoryConfiguration, prelude::*}; use cubecl_core::{compute::CubeTask, server::IoError}; use cubecl_core::{ @@ -21,6 +20,10 @@ use cubecl_core::{ ir::StorageType, server::{Allocation, AllocationDescriptor, ProfileError, ProfilingToken}, }; +use cubecl_core::{ + ir::{ElemType, IntKind, UIntKind}, + server::TensorMapMeta, +}; use cubecl_runtime::config::GlobalConfig; use cubecl_runtime::logging::ServerLogger; use cubecl_runtime::memory_management::{MemoryAllocationMode, offset_handles}; @@ -45,6 +48,7 @@ pub struct CudaServer { streams: MultiStream, peer_activated: bool, mem_alignment: usize, + utilities: Arc>, } unsafe impl Send for CudaServer {} @@ -58,6 +62,10 @@ impl ComputeServer for CudaServer { self.streams.logger.clone() } + fn utilities(&self) -> Arc> { + self.utilities.clone() + } + fn read( &mut self, descriptors: Vec>, @@ -231,10 +239,7 @@ impl ComputeServer for CudaServer { .resource(binding) .expect("Tensor map resource exists."); let device_ptr = resource.ptr as *mut c_void; - debug_assert!( - (device_ptr as usize).is_multiple_of(16), - "Tensor pointer must be 16 byte aligned" - ); + let mut map_ptr = MaybeUninit::zeroed(); let shape: Vec<_> = map.shape.iter().rev().map(|s| *s as u64).collect(); @@ -247,16 +252,17 @@ impl ComputeServer for CudaServer { .collect(); let elem_stride: Vec<_> = map.elem_stride.iter().rev().map(|s| *s as u32).collect(); - debug_assert!( - strides.iter().all(|it| it % 16 == 0), - "Strides must be 16 byte aligned" - ); + if cfg!(debug_assertions) { + check_tma_generic(&map, device_ptr, &shape, &strides, &elem_stride); + } match &map.format { TensorMapFormat::Tiled { tile_size } => unsafe { - debug_assert_eq!(tile_size.len(), map.rank, "Tile shape should match rank"); let tile_size: Vec<_> = tile_size.iter().rev().copied().collect(); - println!("ptr: {:x}", resource.ptr); + + if cfg!(debug_assertions) { + check_tma_tiled(&map, &tile_size); + } cuTensorMapEncodeTiled( map_ptr.as_mut_ptr(), @@ -281,14 +287,21 @@ impl ComputeServer for CudaServer { channels_per_pixel, pixels_per_column, } => unsafe { - debug_assert_eq!(pixel_box_lower_corner.len(), map.rank - 2); - debug_assert_eq!(pixel_box_upper_corner.len(), map.rank - 2); - let lower_corner: Vec<_> = pixel_box_lower_corner.iter().rev().copied().collect(); let upper_corner: Vec<_> = pixel_box_upper_corner.iter().rev().copied().collect(); + if cfg!(debug_assertions) { + check_tma_im2col( + &map, + &lower_corner, + &upper_corner, + *channels_per_pixel, + *pixels_per_column, + ); + } + cuTensorMapEncodeIm2col( map_ptr.as_mut_ptr(), elem_to_tensor_map_type(map.storage_ty), @@ -309,13 +322,16 @@ impl ComputeServer for CudaServer { .result() .unwrap() }, - #[cfg(feature = "cuda-12080")] + #[cfg(cuda_12080)] TensorMapFormat::Im2colWide { pixel_box_lower_corner_width, pixel_box_upper_corner_width, channels_per_pixel, pixels_per_column, } => unsafe { + use cudarc::driver::sys::{ + CUtensorMapIm2ColWideMode, cuTensorMapEncodeIm2colWide, + }; cuTensorMapEncodeIm2colWide( map_ptr.as_mut_ptr(), elem_to_tensor_map_type(map.storage_ty), @@ -337,7 +353,7 @@ impl ComputeServer for CudaServer { .result() .unwrap() }, - #[cfg(not(feature = "cuda-12080"))] + #[cfg(not(cuda_12080))] TensorMapFormat::Im2colWide { pixel_box_lower_corner_width: _, pixel_box_upper_corner_width: _, @@ -449,6 +465,7 @@ impl CudaServer { mem_config: MemoryConfiguration, mem_alignment: usize, device_id: i32, + utilities: ServerUtilities, ) -> Self { let config = GlobalConfig::get(); let max_streams = config.streaming.max_streams; @@ -469,10 +486,16 @@ impl CudaServer { ctx, peer_activated, streams: MultiStream::new( - Arc::new(ServerLogger::default()), - CudaStreamBackend::new(mem_props, mem_config, mem_alignment), + utilities.logger.clone(), + CudaStreamBackend::new( + mem_props, + mem_config, + mem_alignment, + utilities.logger.clone(), + ), max_streams, ), + utilities: Arc::new(utilities), } } @@ -623,7 +646,7 @@ fn elem_to_tensor_map_type(ty: StorageType) -> CUtensorMapDataType { match ty { // packed fp4 should be treated as single 4-bit values to simplify indexing/shape handling // So a tile of width 16 with fp4 elements is 8 x fp4x2 elements wide. - #[cfg(feature = "cuda-12080")] + #[cfg(cuda_12080)] StorageType::Packed(ty, 2) if ty.size_bits() == 4 => CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B, StorageType::Scalar(ElemType::Float(kind)) => match kind { // There's no special handling for FP8, so load as u8. `0u8 == 0.0` when reinterpreting. @@ -718,6 +741,160 @@ pub fn valid_strides(shape: &[usize], strides: &[usize]) -> bool { true } +fn check_tma_generic( + map: &TensorMapMeta, + device_ptr: *mut c_void, + shape: &[u64], + strides: &[u64], + elem_strides: &[u32], +) { + // globalAddress invariants + assert!( + (device_ptr as usize).is_multiple_of(16), + "Tensor pointer must be 16 byte aligned" + ); + if !matches!(map.interleave, TensorMapInterleave::None) { + assert!( + (device_ptr as usize).is_multiple_of(32), + "Tensor pointer must be 32 byte aligned" + ); + } + + // tensorRank invariants + assert!((1..=5).contains(&map.rank), "Rank must be between 1 and 5"); + assert!( + matches!(map.interleave, TensorMapInterleave::None) || map.rank >= 3, + "When interleave is enabled, rank must be >= 3" + ); + + // globalDim invariants + assert!( + shape.iter().all(|it| *it <= u32::MAX as u64), + "Shape must be <= u32::MAX" + ); + #[cfg(cuda_12080)] + if matches!(map.storage_ty, StorageType::Packed(ty, 2) if ty.size_bits() == 4) { + assert!( + shape[0].is_multiple_of(2), + "Packed tensor map must have multiple of 2 for the innermost dimension" + ); + } + + // globalStrides invariants + assert!( + strides.iter().all(|it| it.is_multiple_of(16)), + "Strides must be 16 byte aligned" + ); + if matches!(map.interleave, TensorMapInterleave::B32) { + assert!( + strides.iter().all(|it| it.is_multiple_of(32)), + "Strides must be 32 byte aligned when interleave is B32" + ); + } + + // elementStrides invariants + assert!( + elem_strides.iter().all(|it| *it > 0 && *it <= 8), + "Element strides must be non-zero and <= 8" + ); + if matches!(map.interleave, TensorMapInterleave::None) { + assert_eq!( + elem_strides[0], 1, + "Innermost element stride is ignored without interleaving" + ); + } + + // oobFill invariants + if matches!(map.oob_fill, OobFill::NaN) { + assert!( + map.storage_ty.is_float(), + "NaN fill is only supported for float types" + ); + } +} + +fn check_tma_tiled(map: &TensorMapMeta, tile_size: &[u32]) { + assert_eq!(tile_size.len(), map.rank, "Tile shape should match rank"); + assert!( + tile_size.iter().all(|it| *it > 0 && *it <= 256), + "Tile shape must be non-zero and <= 256" + ); + let tile_size_0_bytes = tile_size[0] as usize * map.storage_ty.size(); + if matches!(map.interleave, TensorMapInterleave::None) { + let align = match map.swizzle { + TensorMapSwizzle::None => 16, + TensorMapSwizzle::B32 => 32, + TensorMapSwizzle::B64 => 64, + TensorMapSwizzle::B128 => 128, + }; + assert!( + tile_size_0_bytes.is_multiple_of(align), + "Innermost tile dimension must be aligned to swizzle size" + ); + } + if matches!(map.interleave, TensorMapInterleave::B32) { + assert_eq!( + map.swizzle, + TensorMapSwizzle::B32, + "If interleave is B32, swizzle must be B32" + ); + } +} + +fn check_tma_im2col( + map: &TensorMapMeta, + lower_corner: &[i32], + upper_corner: &[i32], + channels_per_pixel: u32, + pixels_per_column: u32, +) { + assert_eq!( + lower_corner.len(), + map.rank - 2, + "Lower corner must be rank - 2 elements" + ); + assert_eq!( + upper_corner.len(), + map.rank - 2, + "Upper corner must be rank - 2 elements" + ); + + assert!( + map.rank >= 3 && map.rank <= 5, + "im2col requires rank to be between 3 and 5" + ); + + let (range_lower, range_upper) = match map.rank { + 3 => (-32768, 32767), + 4 => (-128, 127), + 5 => (-16, 15), + _ => unreachable!(), + }; + assert!( + lower_corner + .iter() + .all(|it| *it >= range_lower && *it <= range_upper), + "Lower corner must be in range [{range_lower}, {range_upper}] for {}D im2col", + map.rank + ); + assert!( + upper_corner + .iter() + .all(|it| *it >= range_lower && *it <= range_upper), + "Upper corner must be in range [{range_lower}, {range_upper}] for {}D im2col", + map.rank + ); + + assert!( + channels_per_pixel <= 256, + "Channels per pixel must be <= 256" + ); + assert!( + pixels_per_column <= 1024, + "Pixels per column must be <= 1024" + ); +} + use cudarc::driver::sys::cudaError_enum::CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED; use cudarc::driver::sys::cudaError_enum::CUDA_SUCCESS; diff --git a/crates/cubecl-cuda/src/compute/stream.rs b/crates/cubecl-cuda/src/compute/stream.rs index d26169904..0a73b9932 100644 --- a/crates/cubecl-cuda/src/compute/stream.rs +++ b/crates/cubecl-cuda/src/compute/stream.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::compute::{ storage::{ cpu::{PINNED_MEMORY_ALIGNMENT, PinnedMemoryStorage}, @@ -7,7 +9,10 @@ use crate::compute::{ }; use cubecl_core::MemoryConfiguration; use cubecl_runtime::{ - memory_management::{MemoryDeviceProperties, MemoryManagement}, + logging::ServerLogger, + memory_management::{ + MemoryAllocationMode, MemoryDeviceProperties, MemoryManagement, MemoryManagementOptions, + }, stream::EventStreamBackend, }; @@ -23,6 +28,7 @@ pub struct CudaStreamBackend { mem_props: MemoryDeviceProperties, mem_config: MemoryConfiguration, mem_alignment: usize, + logger: Arc, } impl EventStreamBackend for CudaStreamBackend { @@ -36,8 +42,13 @@ impl EventStreamBackend for CudaStreamBackend { .expect("Can create a new stream."); let storage = GpuStorage::new(self.mem_alignment, stream); - let memory_management_gpu = - MemoryManagement::from_configuration(storage, &self.mem_props, self.mem_config.clone()); + let memory_management_gpu = MemoryManagement::from_configuration( + storage, + &self.mem_props, + self.mem_config.clone(), + self.logger.clone(), + MemoryManagementOptions::new("Main GPU Memory"), + ); // We use the same page size and memory pools configuration for CPU pinned memory, since we // expect the CPU to have at least the same amount of RAM as GPU memory. let memory_management_cpu = MemoryManagement::from_configuration( @@ -47,6 +58,8 @@ impl EventStreamBackend for CudaStreamBackend { alignment: PINNED_MEMORY_ALIGNMENT as u64, }, self.mem_config.clone(), + self.logger.clone(), + MemoryManagementOptions::new("Pinned CPU Memory").mode(MemoryAllocationMode::Auto), ); Stream { diff --git a/crates/cubecl-cuda/src/lib.rs b/crates/cubecl-cuda/src/lib.rs index ecd351dc1..2e47eab67 100644 --- a/crates/cubecl-cuda/src/lib.rs +++ b/crates/cubecl-cuda/src/lib.rs @@ -87,6 +87,7 @@ mod tests { // TODO: re-instate matmul quantized tests cubecl_matmul::testgen_matmul_simple!([f16, bf16, f32]); cubecl_std::testgen_tensor_identity!([f16, bf16, f32, u32]); + cubecl_std::testgen_quantized_view!(f16); cubecl_convolution::testgen_conv2d_accelerated!([f16: f16, bf16: bf16, f32: tf32]); cubecl_reduce::testgen_reduce!([f16, bf16, f32, f64]); cubecl_random::testgen_random!(); diff --git a/crates/cubecl-cuda/src/runtime.rs b/crates/cubecl-cuda/src/runtime.rs index 3def05c97..1da4b061a 100644 --- a/crates/cubecl-cuda/src/runtime.rs +++ b/crates/cubecl-cuda/src/runtime.rs @@ -3,13 +3,17 @@ use crate::{ compute::{CudaServer, context::CudaContext, valid_strides}, device::CudaDevice, }; -use cubecl_common::profile::TimingMethod; +use cubecl_common::{ + device::{Device, DeviceState}, + profile::TimingMethod, +}; use cubecl_core::{ CubeCount, CubeDim, MemoryConfiguration, Runtime, ir::{ ElemType, FloatKind, MatrixLayout, MmaProperties, SemanticType, StorageType, TargetProperties, }, + server::ServerUtilities, }; use cubecl_cpp::{ DialectWmmaCompiler, @@ -21,13 +25,13 @@ use cubecl_cpp::{ }, }; use cubecl_runtime::{ - ComputeRuntime, DeviceProperties, Plane, Tma, TypeUsage, - channel::MutexComputeChannel, + DeviceProperties, Plane, Tma, TypeUsage, client::ComputeClient, + logging::ServerLogger, memory_management::{HardwareProperties, MemoryDeviceProperties}, }; use cudarc::driver::sys::cuDeviceTotalMem_v2; -use std::mem::MaybeUninit; +use std::{mem::MaybeUninit, sync::Arc}; /// Options configuring the CUDA runtime. #[derive(Default)] @@ -39,207 +43,216 @@ pub struct RuntimeOptions { #[derive(Debug)] pub struct CudaRuntime; -type Server = CudaServer; -type Channel = MutexComputeChannel; - -static RUNTIME: ComputeRuntime = ComputeRuntime::new(); - -pub type CudaCompiler = CppCompiler>; - -fn create_client>>( - device: &CudaDevice, - options: RuntimeOptions, -) -> ComputeClient { - // To get the supported WMMA features, and memory properties, we have to initialize the server immediately. - cudarc::driver::result::init().unwrap(); - let device_id = device.index as i32; - let device_ptr = cudarc::driver::result::device::get(device_id).unwrap(); - let arch_major; - let arch_version = unsafe { - arch_major = cudarc::driver::result::device::get_attribute( +impl DeviceState for CudaServer { + fn init(device_id: cubecl_common::device::DeviceId) -> Self { + let options = RuntimeOptions::default(); + let device = CudaDevice::from_id(device_id); + + // To get the supported WMMA features, and memory properties, we have to initialize the server immediately. + cudarc::driver::result::init().unwrap(); + let device_id = device.index as i32; + let device_ptr = cudarc::driver::result::device::get(device_id).unwrap(); + let arch_major; + let arch_version = unsafe { + arch_major = cudarc::driver::result::device::get_attribute( device_ptr, cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, ) .unwrap(); - let minor = cudarc::driver::result::device::get_attribute( + let minor = cudarc::driver::result::device::get_attribute( device_ptr, cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, ) .unwrap(); - arch_major * 10 + minor - } as u32; - - // cudamalloc and co. align to _256_ bytes. - // - // TODO: Find the correct value from the driver. - let mem_alignment = 256; - - // Ask the wmma compiler for its supported combinations - let arch = CudaArchitecture { - version: arch_version, - }; - let supported_wmma_combinations = M::supported_wmma_combinations(&arch); - let supported_mma_combinations = M::supported_mma_combinations(&arch); - let supported_scaled_mma_combinations = M::supported_scaled_mma_combinations(&arch); - - let ctx = unsafe { - let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap(); - cudarc::driver::result::ctx::set_current(ctx).unwrap(); - ctx - }; - - let max_memory = unsafe { - let mut bytes = MaybeUninit::uninit(); - cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr); - bytes.assume_init() as u64 - }; - let mem_properties = MemoryDeviceProperties { - max_page_size: max_memory / 4, - alignment: mem_alignment as u64, - }; - - let mut comp_opts = CompilationOptions::default(); - - let hardware_props = unsafe { - use cudarc::driver::{result::device::get_attribute, sys::CUdevice_attribute::*}; - let warp_size = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_WARP_SIZE).unwrap() as u32; - let max_shared = get_attribute( - device_ptr, - CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, - ) - .unwrap() as usize; - let max_threads = - get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK).unwrap() as u32; - let block_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X).unwrap(); - let block_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y).unwrap(); - let block_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z).unwrap(); - let max_cube_dim = - CubeDim::new_3d(block_dim_x as u32, block_dim_y as u32, block_dim_z as u32); - - let grid_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X).unwrap(); - let grid_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y).unwrap(); - let grid_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z).unwrap(); - let max_cube_count = - CubeCount::new_3d(grid_dim_x as u32, grid_dim_y as u32, grid_dim_z as u32); - - let num_streaming_multiprocessors = Some( - get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT).unwrap() as u32, + arch_major * 10 + minor + } as u32; + + // cudamalloc and co. align to _256_ bytes. + // + // TODO: Find the correct value from the driver. + let mem_alignment = 256; + + // Ask the wmma compiler for its supported combinations + let arch = CudaArchitecture { + version: arch_version, + }; + let supported_wmma_combinations = WmmaCompiler::supported_wmma_combinations(&arch); + let supported_mma_combinations = WmmaCompiler::supported_mma_combinations(&arch); + let supported_scaled_mma_combinations = + WmmaCompiler::supported_scaled_mma_combinations(&arch); + + let ctx = unsafe { + let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap(); + cudarc::driver::result::ctx::set_current(ctx).unwrap(); + ctx + }; + + let max_memory = unsafe { + let mut bytes = MaybeUninit::uninit(); + cuDeviceTotalMem_v2(bytes.as_mut_ptr(), device_ptr); + bytes.assume_init() as u64 + }; + let mem_properties = MemoryDeviceProperties { + max_page_size: max_memory / 4, + alignment: mem_alignment as u64, + }; + + let mut comp_opts = CompilationOptions { + supports_fast_math: true, + ..Default::default() + }; + + let hardware_props = unsafe { + use cudarc::driver::{result::device::get_attribute, sys::CUdevice_attribute::*}; + let warp_size = + get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_WARP_SIZE).unwrap() as u32; + let max_shared = get_attribute( + device_ptr, + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + ) + .unwrap() as usize; + let max_threads = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK) + .unwrap() as u32; + let block_dim_x = + get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X).unwrap(); + let block_dim_y = + get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y).unwrap(); + let block_dim_z = + get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z).unwrap(); + let max_cube_dim = + CubeDim::new_3d(block_dim_x as u32, block_dim_y as u32, block_dim_z as u32); + + let grid_dim_x = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X).unwrap(); + let grid_dim_y = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y).unwrap(); + let grid_dim_z = get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z).unwrap(); + let max_cube_count = + CubeCount::new_3d(grid_dim_x as u32, grid_dim_y as u32, grid_dim_z as u32); + + let num_streaming_multiprocessors = Some( + get_attribute(device_ptr, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT).unwrap() as u32, + ); + let num_tensor_cores = tensor_cores_per_sm(arch_version); + + comp_opts.warp_size = warp_size; + + HardwareProperties { + plane_size_min: warp_size, + plane_size_max: warp_size, + max_bindings: crate::device::CUDA_MAX_BINDINGS, + max_shared_memory_size: max_shared, + max_cube_count, + max_units_per_cube: max_threads, + max_cube_dim, + num_streaming_multiprocessors, + num_tensor_cores, + min_tensor_cores_dim: if supported_wmma_combinations.is_empty() { + None + } else { + Some(8) + }, + } + }; + + let mut device_props = DeviceProperties::new( + Default::default(), + mem_properties.clone(), + hardware_props, + TimingMethod::System, ); - let num_tensor_cores = tensor_cores_per_sm(arch_version); - - comp_opts.warp_size = warp_size; - - HardwareProperties { - plane_size_min: warp_size, - plane_size_max: warp_size, - max_bindings: crate::device::CUDA_MAX_BINDINGS, - max_shared_memory_size: max_shared, - max_cube_count, - max_units_per_cube: max_threads, - max_cube_dim, - num_streaming_multiprocessors, - num_tensor_cores, - min_tensor_cores_dim: if supported_wmma_combinations.is_empty() { - None - } else { - Some(8) - }, + register_supported_types(&mut device_props); + device_props.register_type_usage(ElemType::Float(FloatKind::TF32), TypeUsage::Conversion); + if arch_version >= 60 { + device_props.register_type_usage( + StorageType::Atomic(ElemType::Float(FloatKind::F64)), + TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore, + ); + } + if arch_version >= 70 { + device_props.register_type_usage( + StorageType::Atomic(ElemType::Float(FloatKind::F16)), + TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore, + ); + device_props.register_semantic_type(SemanticType::Pipeline); + device_props.register_semantic_type(SemanticType::Barrier); + device_props.features.plane.insert(Plane::Sync); + + comp_opts.grid_constants = true; } - }; - - let mut device_props = DeviceProperties::new( - Default::default(), - mem_properties.clone(), - hardware_props, - TimingMethod::System, - ); - register_supported_types(&mut device_props); - device_props.register_type_usage(ElemType::Float(FloatKind::TF32), TypeUsage::Conversion); - if arch_version >= 60 { - device_props.register_type_usage( - StorageType::Atomic(ElemType::Float(FloatKind::F64)), - TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore, - ); - } - if arch_version >= 70 { - device_props.register_type_usage( - StorageType::Atomic(ElemType::Float(FloatKind::F16)), - TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore, - ); - device_props.register_semantic_type(SemanticType::Pipeline); - device_props.register_semantic_type(SemanticType::Barrier); - device_props.features.plane.insert(Plane::Sync); - comp_opts.grid_constants = true; - } + // NOTE: I commented that since I observed synchronisation issues with atomic add for bf16. + // if arch.get_version() >= 80 { + // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16))); + // } + + if arch_version >= 89 { + device_props.register_type_usage( + ElemType::Float(FloatKind::E4M3), + TypeUsage::Conversion | TypeUsage::Buffer, + ); + device_props.register_type_usage( + ElemType::Float(FloatKind::E5M2), + TypeUsage::Conversion | TypeUsage::Buffer, + ); + } + if arch_version >= 90 { + device_props.features.tma.insert(Tma::Base); + device_props.register_semantic_type(SemanticType::TensorMap); + device_props.features.cube_cluster = true; + comp_opts.supports_clusters = true; + } - // NOTE: I commented that since I observed synchronisation issues with atomic add for bf16. - // if arch.get_version() >= 80 { - // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16))); - // } + if arch_version >= 100 { + device_props.features.tma.insert(Tma::Im2colWide); + } - if arch_version >= 89 { - device_props.register_type_usage( - ElemType::Float(FloatKind::E4M3), - TypeUsage::Conversion | TypeUsage::Buffer, - ); - device_props.register_type_usage( - ElemType::Float(FloatKind::E5M2), - TypeUsage::Conversion | TypeUsage::Buffer, - ); - } - if arch_version >= 90 { - device_props.features.tma.insert(Tma::Base); - device_props.register_semantic_type(SemanticType::TensorMap); - device_props.features.cube_cluster = true; - comp_opts.supports_clusters = true; - } + // NOTE: FP6/FP4 is explicitly not marked as forward compatible, but is compatible within a + // major version. Try to keep this up to date with new arch major revisions if they also + // implement it. + if arch_major == 10 || arch_major == 12 { + device_props + .register_type_usage(ElemType::Float(FloatKind::E2M1), TypeUsage::Conversion); + device_props.register_type_usage( + StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2), + TypeUsage::Conversion | TypeUsage::Buffer, + ); + device_props.register_type_usage( + ElemType::Float(FloatKind::E2M3), + TypeUsage::Conversion | TypeUsage::Buffer, + ); + device_props.register_type_usage( + ElemType::Float(FloatKind::E3M2), + TypeUsage::Conversion | TypeUsage::Buffer, + ); + device_props.register_type_usage( + ElemType::Float(FloatKind::UE8M0), + TypeUsage::Conversion | TypeUsage::Buffer, + ); + } - if arch_version >= 100 { - device_props.features.tma.insert(Tma::Im2colWide); - } + device_props.features.dynamic_line_size = true; + device_props.features.plane.insert(Plane::Ops); - // NOTE: FP6/FP4 is explicitly not marked as forward compatible, but is compatible within a - // major version. Try to keep this up to date with new arch major revisions if they also - // implement it. - if arch_major == 10 || arch_major == 12 { - device_props.register_type_usage(ElemType::Float(FloatKind::E2M1), TypeUsage::Conversion); - device_props.register_type_usage( - StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2), - TypeUsage::Conversion | TypeUsage::Buffer, - ); - device_props.register_type_usage( - ElemType::Float(FloatKind::E2M3), - TypeUsage::Conversion | TypeUsage::Buffer, - ); - device_props.register_type_usage( - ElemType::Float(FloatKind::E3M2), - TypeUsage::Conversion | TypeUsage::Buffer, - ); - device_props.register_type_usage( - ElemType::Float(FloatKind::UE8M0), - TypeUsage::Conversion | TypeUsage::Buffer, - ); - } + register_wmma_features(supported_wmma_combinations, &mut device_props); + register_mma_features(supported_mma_combinations, &mut device_props); + register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props); - device_props.features.dynamic_line_size = true; - device_props.features.plane.insert(Plane::Ops); - - register_wmma_features(supported_wmma_combinations, &mut device_props); - register_mma_features(supported_mma_combinations, &mut device_props); - register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props); - - let cuda_ctx = CudaContext::new(comp_opts, ctx, arch); - let server = CudaServer::new( - cuda_ctx, - mem_properties, - options.memory_config, - mem_alignment, - device_id, - ); - ComputeClient::new(MutexComputeChannel::new(server), device_props, ()) + let cuda_ctx = CudaContext::new(comp_opts, ctx, arch); + let logger = Arc::new(ServerLogger::default()); + let utilities = ServerUtilities::new(device_props, logger, ()); + + CudaServer::new( + cuda_ctx, + mem_properties, + options.memory_config, + mem_alignment, + device_id, + utilities, + ) + } } +pub type CudaCompiler = CppCompiler>; + fn tensor_cores_per_sm(version: u32) -> Option { match version { 70 | 75 => Some(8), // Volta, Turing @@ -251,17 +264,13 @@ fn tensor_cores_per_sm(version: u32) -> Option { impl Runtime for CudaRuntime { type Compiler = CudaCompiler; type Server = CudaServer; - - type Channel = MutexComputeChannel; type Device = CudaDevice; - fn client(device: &Self::Device) -> ComputeClient { - RUNTIME.client(device, move || { - create_client::(device, RuntimeOptions::default()) - }) + fn client(device: &Self::Device) -> ComputeClient { + ComputeClient::load(device) } - fn name(_client: &ComputeClient) -> &'static str { + fn name(_client: &ComputeClient) -> &'static str { "cuda" } diff --git a/crates/cubecl-hip/Cargo.toml b/crates/cubecl-hip/Cargo.toml index c14d2cee6..effeee5f1 100644 --- a/crates/cubecl-hip/Cargo.toml +++ b/crates/cubecl-hip/Cargo.toml @@ -73,18 +73,18 @@ matmul_tests_all = [ conv_tests = ["cubecl-convolution/conv_tests"] [dependencies] -cubecl-common = { path = "../cubecl-common", version = "0.7.0", features = [ +cubecl-common = { path = "../cubecl-common", version = "0.9.0", features = [ "cache", ] } -cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false } -cubecl-cpp = { path = "../cubecl-cpp", version = "0.7.0", default-features = false, features = [ +cubecl-core = { path = "../cubecl-core", version = "0.9.0", default-features = false } +cubecl-cpp = { path = "../cubecl-cpp", version = "0.9.0", default-features = false, features = [ "hip", ] } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false, features = [ +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false, features = [ "channel-mutex", ] } -cubecl-hip-sys = { version = "6.4.4348201" } -cubecl-quant = { path = "../cubecl-quant", version = "0.7.0", default-features = false } +cubecl-hip-sys = { version = "7.0.5183101" } +cubecl-quant = { path = "../cubecl-quant", version = "0.9.0", default-features = false } bytemuck = { workspace = true } @@ -96,26 +96,26 @@ paste = { workspace = true } serde = { workspace = true } [dev-dependencies] -cubecl-core = { path = "../cubecl-core", version = "0.7.0", features = [ +cubecl-core = { path = "../cubecl-core", version = "0.9.0", features = [ "export_tests", ] } -cubecl-std = { path = "../cubecl-std", version = "0.7.0", features = [ +cubecl-std = { path = "../cubecl-std", version = "0.9.0", features = [ "export_tests", ] } -cubecl-matmul = { path = "../cubecl-matmul", version = "0.7.0", features = [ +cubecl-matmul = { path = "../cubecl-matmul", version = "0.9.0", features = [ "export_tests", ] } -cubecl-convolution = { path = "../cubecl-convolution", version = "0.7.0", features = [ +cubecl-convolution = { path = "../cubecl-convolution", version = "0.9.0", features = [ "export_tests", ] } -cubecl-reduce = { path = "../cubecl-reduce", version = "0.7.0", features = [ +cubecl-reduce = { path = "../cubecl-reduce", version = "0.9.0", features = [ "export_tests", ] } -cubecl-random = { path = "../cubecl-random", version = "0.7.0", features = [ +cubecl-random = { path = "../cubecl-random", version = "0.9.0", features = [ "export_tests", ] } pretty_assertions = { workspace = true } -cubecl-quant = { path = "../cubecl-quant", version = "0.7.0", features = [ +cubecl-quant = { path = "../cubecl-quant", version = "0.9.0", features = [ "export_tests", "kernels", ] } diff --git a/crates/cubecl-hip/src/compute/command.rs b/crates/cubecl-hip/src/compute/command.rs index 456039503..a52aa0384 100644 --- a/crates/cubecl-hip/src/compute/command.rs +++ b/crates/cubecl-hip/src/compute/command.rs @@ -7,6 +7,7 @@ use cubecl_core::{ }; use cubecl_hip_sys::{ HIP_SUCCESS, hipMemcpyKind_hipMemcpyDeviceToHost, hipMemcpyKind_hipMemcpyHostToDevice, + ihipStream_t, }; use cubecl_runtime::{ id::KernelId, @@ -14,7 +15,7 @@ use cubecl_runtime::{ memory_management::{MemoryAllocationMode, MemoryHandle}, stream::ResolvedStreams, }; -use std::sync::Arc; +use std::{ffi::c_void, sync::Arc}; use crate::{ compute::{ @@ -31,7 +32,7 @@ use crate::{ /// registration, and task execution. pub struct Command<'a> { ctx: &'a mut HipContext, - streams: ResolvedStreams<'a, HipStreamBackend>, + pub(crate) streams: ResolvedStreams<'a, HipStreamBackend>, } impl<'a> Command<'a> { @@ -274,58 +275,13 @@ impl<'a> Command<'a> { return Err(IoError::UnsupportedStrides); } - let rank = shape.len(); let resource = self.resource(binding)?; let stream = match stream_id { Some(id) => self.streams.get(&id), None => self.streams.current(), }; - if rank <= 1 { - unsafe { - let status = cubecl_hip_sys::hipMemcpyDtoHAsync( - bytes.as_mut_ptr() as *mut _, - resource.ptr, - bytes.len(), - stream.sys, - ); - - if status != HIP_SUCCESS { - return Err(IoError::Unknown(format!("HIP memcpy failed: {status}"))); - } - } - return Ok(()); - } - - let dim_x = shape[rank - 1]; - let width_bytes = dim_x * elem_size; - let dim_y: usize = shape.iter().rev().skip(1).product(); - let pitch = strides[rank - 2] * elem_size; - - unsafe { - let status = cubecl_hip_sys::hipMemcpy2DAsync( - bytes.as_mut_ptr() as *mut _, - width_bytes, - resource.ptr, - pitch, - width_bytes, - dim_y, - hipMemcpyKind_hipMemcpyDeviceToHost, - stream.sys, - ); - - // Fallback, sometimes the copy doesn't work. - if status != HIP_SUCCESS { - let status = cubecl_hip_sys::hipMemcpyDtoHAsync( - bytes.as_mut_ptr() as *mut _, - resource.ptr, - bytes.len(), - stream.sys, - ); - assert_eq!(status, HIP_SUCCESS, "Should send data to device"); - } - } - Ok(()) + unsafe { write_to_cpu(shape, strides, elem_size, bytes, resource.ptr, stream.sys) } } /// Writes data from the host to GPU memory as specified by the copy descriptor. @@ -469,3 +425,61 @@ impl<'a> Command<'a> { }; } } + +pub(crate) unsafe fn write_to_cpu( + shape: &[usize], + strides: &[usize], + elem_size: usize, + bytes: &mut Bytes, + resource_ptr: *mut c_void, + stream: *mut ihipStream_t, +) -> Result<(), IoError> { + let rank = shape.len(); + + if rank <= 1 { + let status = unsafe { + cubecl_hip_sys::hipMemcpyDtoHAsync( + bytes.as_mut_ptr() as *mut _, + resource_ptr, + bytes.len(), + stream, + ) + }; + + if status != HIP_SUCCESS { + return Err(IoError::Unknown(format!("HIP memcpy failed: {status}"))); + } + return Ok(()); + } + + let dim_x = shape[rank - 1]; + let width_bytes = dim_x * elem_size; + let dim_y: usize = shape.iter().rev().skip(1).product(); + let pitch = strides[rank - 2] * elem_size; + + unsafe { + let status = cubecl_hip_sys::hipMemcpy2DAsync( + bytes.as_mut_ptr() as *mut _, + width_bytes, + resource_ptr, + pitch, + width_bytes, + dim_y, + hipMemcpyKind_hipMemcpyDeviceToHost, + stream, + ); + + // Fallback, sometimes the copy doesn't work. + if status != HIP_SUCCESS { + let status = cubecl_hip_sys::hipMemcpyDtoHAsync( + bytes.as_mut_ptr() as *mut _, + resource_ptr, + bytes.len(), + stream, + ); + assert_eq!(status, HIP_SUCCESS, "Should send data to device"); + } + } + + Ok(()) +} diff --git a/crates/cubecl-hip/src/compute/server.rs b/crates/cubecl-hip/src/compute/server.rs index f9cf3e752..29097e7d4 100644 --- a/crates/cubecl-hip/src/compute/server.rs +++ b/crates/cubecl-hip/src/compute/server.rs @@ -1,7 +1,9 @@ use super::storage::gpu::GpuResource; use super::storage::gpu::GpuStorage; use crate::compute::command::Command; +use crate::compute::command::write_to_cpu; use crate::compute::context::HipContext; +use crate::compute::fence::Fence; use crate::compute::stream::HipStreamBackend; use crate::runtime::HipCompiler; use cubecl_common::bytes::Bytes; @@ -10,6 +12,7 @@ use cubecl_common::profile::ProfileDuration; use cubecl_common::stream_id::StreamId; use cubecl_core::compute::CubeTask; use cubecl_core::server::ServerCommunication; +use cubecl_core::server::ServerUtilities; use cubecl_core::server::{ Allocation, AllocationKind, CopyDescriptor, IoError, ProfileError, ProfilingToken, }; @@ -29,6 +32,7 @@ pub struct HipServer { ctx: HipContext, streams: MultiStream, mem_alignment: usize, + utilities: Arc>, } unsafe impl Send for HipServer {} @@ -42,6 +46,10 @@ impl ComputeServer for HipServer { self.streams.logger.clone() } + fn utilities(&self) -> Arc> { + self.utilities.clone() + } + fn create( &mut self, descriptors: Vec>, @@ -238,7 +246,17 @@ impl ComputeServer for HipServer { } impl ServerCommunication for HipServer { - const SERVER_COMM_ENABLED: bool = false; + const SERVER_COMM_ENABLED: bool = true; + + fn copy( + server_src: &mut Self, + server_dst: &mut Self, + src: CopyDescriptor<'_>, + stream_id_src: StreamId, + stream_id_dst: StreamId, + ) -> Result { + Self::change_server_serialized(server_src, server_dst, src, stream_id_src, stream_id_dst) + } } impl HipServer { @@ -248,6 +266,7 @@ impl HipServer { mem_props: MemoryDeviceProperties, mem_config: MemoryConfiguration, mem_alignment: usize, + utilities: ServerUtilities, ) -> Self { let config = GlobalConfig::get(); let max_streams = config.streaming.max_streams; @@ -256,10 +275,16 @@ impl HipServer { ctx, mem_alignment, streams: MultiStream::new( - Arc::new(ServerLogger::default()), - HipStreamBackend::new(mem_props, mem_config, mem_alignment), + utilities.logger.clone(), + HipStreamBackend::new( + mem_props, + mem_config, + mem_alignment, + utilities.logger.clone(), + ), max_streams, ), + utilities: Arc::new(utilities), } } @@ -276,6 +301,77 @@ impl HipServer { Command::new(&mut self.ctx, streams) } + + fn change_server_serialized( + server_src: &mut Self, + server_dst: &mut Self, + src: CopyDescriptor<'_>, + stream_id_src: StreamId, + stream_id_dst: StreamId, + ) -> Result { + let shape = src.shape.to_vec(); + let strides = src.strides.to_vec(); + let elem_size = src.elem_size; + let binding = src.binding.clone(); + let num_bytes = shape.iter().product::() * elem_size; + + // We start by creating a command on the destination server. + // + // Here we allocate the necessary bytes using pinned memory managed by the destination + // server along a new GPU handle. This way, the bytes could be reused later by that server, + // and the lifetime of that handle is aligned with the execution order of the destination server, + // removing the need to keep the bytes handle alive using synchronization, which would be the + // case if we allocated the bytes using the source server. + let mut command_dst = server_dst.command_no_inputs(stream_id_dst); + let handle = command_dst.reserve(binding.size())?; + let mut bytes = command_dst.reserve_cpu(num_bytes, true, None); + let copy_desc = handle.copy_descriptor(&shape, &strides, elem_size); + + // We need to free the command before creating another one. + core::mem::drop(command_dst); + + // We create a command on the source server to retrieve the correct resource from the + // source memory pools. We also make sure the current stream is aligned with the stream of + // the binding, where the data was first allocated. + // + // We use the source stream to copy the data from the source server into the allocated + // bytes. This ensures that the source binding follows the correct execution order, meaning + // that we don't have to keep the source handle alive using synchronization, which would be + // the case if we performed the copy on the destination server. + let mut command_src = server_src.command(stream_id_src, [&src.binding].into_iter()); + let resource_src = command_src.resource(binding.clone())?; + let stream_src = command_src.streams.current().sys; + + unsafe { + write_to_cpu( + &shape, + &strides, + elem_size, + &mut bytes, + resource_src.ptr, + stream_src, + )?; + } + let fence_src = Fence::new(stream_src); + + // We need to free the command before creating another one. + core::mem::drop(command_src); + + // Finally, we recreate a new command on the destination server to write the data stored in + // pinned memory into the destination server. Here we need to wait for the initial copy + // made by the source server using an event. The synchronization is done lazily on the + // destination stream, which is very efficient. + let mut command_dst = server_dst.command_no_inputs(stream_id_dst); + let stream_dst = command_dst.streams.current().sys; + + fence_src.wait_async(stream_dst); + command_dst.write_to_gpu(copy_desc, &bytes)?; + + // We drop the last command. + core::mem::drop(command_dst); + + Ok(Allocation { handle, strides }) + } } pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec { diff --git a/crates/cubecl-hip/src/compute/stream.rs b/crates/cubecl-hip/src/compute/stream.rs index d725159ea..26d9d0131 100644 --- a/crates/cubecl-hip/src/compute/stream.rs +++ b/crates/cubecl-hip/src/compute/stream.rs @@ -1,7 +1,12 @@ +use std::sync::Arc; + use cubecl_core::MemoryConfiguration; use cubecl_hip_sys::HIP_SUCCESS; use cubecl_runtime::{ - memory_management::{MemoryDeviceProperties, MemoryManagement}, + logging::ServerLogger, + memory_management::{ + MemoryAllocationMode, MemoryDeviceProperties, MemoryManagement, MemoryManagementOptions, + }, stream::EventStreamBackend, }; @@ -23,6 +28,7 @@ pub struct HipStreamBackend { mem_props: MemoryDeviceProperties, mem_config: MemoryConfiguration, mem_alignment: usize, + logger: Arc, } impl EventStreamBackend for HipStreamBackend { @@ -37,8 +43,13 @@ impl EventStreamBackend for HipStreamBackend { stream }; let storage = GpuStorage::new(self.mem_alignment); - let memory_management_gpu = - MemoryManagement::from_configuration(storage, &self.mem_props, self.mem_config.clone()); + let memory_management_gpu = MemoryManagement::from_configuration( + storage, + &self.mem_props, + self.mem_config.clone(), + self.logger.clone(), + MemoryManagementOptions::new("Main GPU Memory"), + ); // We use the same page size and memory pools configuration for CPU pinned memory, since we // expect the CPU to have at least the same amount of RAM as GPU memory. let memory_management_cpu = MemoryManagement::from_configuration( @@ -48,6 +59,8 @@ impl EventStreamBackend for HipStreamBackend { alignment: PINNED_MEMORY_ALIGNMENT as u64, }, self.mem_config.clone(), + self.logger.clone(), + MemoryManagementOptions::new("Pinned CPU Memory").mode(MemoryAllocationMode::Auto), ); Stream { diff --git a/crates/cubecl-hip/src/runtime.rs b/crates/cubecl-hip/src/runtime.rs index 3b0980bc2..5c3e363b8 100644 --- a/crates/cubecl-hip/src/runtime.rs +++ b/crates/cubecl-hip/src/runtime.rs @@ -3,10 +3,14 @@ use crate::{ compute::{HipServer, context::HipContext, contiguous_strides}, device::AmdDevice, }; -use cubecl_common::profile::TimingMethod; +use cubecl_common::{ + device::{Device, DeviceState}, + profile::TimingMethod, +}; use cubecl_core::{ - CubeCount, CubeDim, MemoryConfiguration, Runtime, channel, + CubeCount, CubeDim, MemoryConfiguration, Runtime, ir::{MatrixLayout, MmaProperties, TargetProperties}, + server::ServerUtilities, }; use cubecl_cpp::{ hip::{HipDialect, arch::AMDArchitecture}, @@ -18,11 +22,12 @@ use cubecl_cpp::{ }; use cubecl_hip_sys::HIP_SUCCESS; use cubecl_runtime::{ - ComputeRuntime, DeviceProperties, Plane, + DeviceProperties, Plane, client::ComputeClient, + logging::ServerLogger, memory_management::{HardwareProperties, MemoryDeviceProperties}, }; -use std::{ffi::CStr, mem::MaybeUninit}; +use std::{ffi::CStr, mem::MaybeUninit, sync::Arc}; /// The values that control how a HIP Runtime will perform its calculations. #[derive(Default)] @@ -34,155 +39,153 @@ pub struct RuntimeOptions { #[derive(Debug)] pub struct HipRuntime; -static RUNTIME: ComputeRuntime = ComputeRuntime::new(); - pub type HipCompiler = CppCompiler>; -type Server = HipServer; -type Channel = channel::MutexComputeChannel; -// type Channel = channel::MpscComputeChannel; - -fn create_client>>( - device: &AmdDevice, - options: RuntimeOptions, -) -> ComputeClient { - #[allow(unused_assignments)] - let mut prop_warp_size = 0; - #[allow(unused_assignments)] - let mut prop_arch_name = ""; - #[allow(unused_assignments)] - let mut prop_max_shared_memory_size = 0; - #[allow(unused_assignments)] - let mut max_cube_count = CubeCount::new_single(); - #[allow(unused_assignments)] - let mut prop_max_threads = 0; - let mut max_cube_dim = CubeDim::new_single(); - let mut mem_alignment = 32; - unsafe { - let mut ll_device_props = MaybeUninit::uninit(); - let status = cubecl_hip_sys::hipGetDevicePropertiesR0600( - ll_device_props.as_mut_ptr(), - device.index as cubecl_hip_sys::hipDevice_t, - ); - assert_eq!(status, HIP_SUCCESS, "Should get device properties"); - let ll_device_props = ll_device_props.assume_init(); - prop_warp_size = ll_device_props.warpSize; - prop_arch_name = CStr::from_ptr(ll_device_props.gcnArchName.as_ptr()) - .to_str() - .unwrap(); - prop_max_shared_memory_size = ll_device_props.sharedMemPerBlock; - max_cube_count = CubeCount::new_3d( - ll_device_props.maxGridSize[0] as u32, - ll_device_props.maxGridSize[1] as u32, - ll_device_props.maxGridSize[2] as u32, - ); - prop_max_threads = ll_device_props.maxThreadsPerBlock as u32; - max_cube_dim.x = ll_device_props.maxThreadsDim[0] as u32; - max_cube_dim.y = ll_device_props.maxThreadsDim[1] as u32; - max_cube_dim.z = ll_device_props.maxThreadsDim[2] as u32; - - // Just to be sure we check both. - mem_alignment = usize::max(mem_alignment, ll_device_props.textureAlignment); - mem_alignment = usize::max(mem_alignment, ll_device_props.surfaceAlignment); - }; - let normalized_arch_name = prop_arch_name.split(':').next().unwrap_or(prop_arch_name); - let arch = AMDArchitecture::parse(normalized_arch_name).unwrap(); - assert_eq!(prop_warp_size as u32, arch.warp_size()); - - unsafe { - let status = cubecl_hip_sys::hipSetDevice(device.index as cubecl_hip_sys::hipDevice_t); - assert_eq!( - status, HIP_SUCCESS, - "Should set the default device for the current thread" - ); - } +impl DeviceState for HipServer { + fn init(device_id: cubecl_common::device::DeviceId) -> Self { + let device = AmdDevice::from_id(device_id); + + #[allow(unused_assignments)] + let mut prop_warp_size = 0; + #[allow(unused_assignments)] + let mut prop_arch_name = ""; + #[allow(unused_assignments)] + let mut prop_max_shared_memory_size = 0; + #[allow(unused_assignments)] + let mut max_cube_count = CubeCount::new_single(); + #[allow(unused_assignments)] + let mut prop_max_threads = 0; + let mut max_cube_dim = CubeDim::new_single(); + let mut mem_alignment = 32; + unsafe { + let mut ll_device_props = MaybeUninit::uninit(); + let status = cubecl_hip_sys::hipGetDevicePropertiesR0600( + ll_device_props.as_mut_ptr(), + device.index as cubecl_hip_sys::hipDevice_t, + ); + assert_eq!(status, HIP_SUCCESS, "Should get device properties"); + let ll_device_props = ll_device_props.assume_init(); + prop_warp_size = ll_device_props.warpSize; + prop_arch_name = CStr::from_ptr(ll_device_props.gcnArchName.as_ptr()) + .to_str() + .unwrap(); + prop_max_shared_memory_size = ll_device_props.sharedMemPerBlock; + max_cube_count = CubeCount::new_3d( + ll_device_props.maxGridSize[0] as u32, + ll_device_props.maxGridSize[1] as u32, + ll_device_props.maxGridSize[2] as u32, + ); + prop_max_threads = ll_device_props.maxThreadsPerBlock as u32; + max_cube_dim.x = ll_device_props.maxThreadsDim[0] as u32; + max_cube_dim.y = ll_device_props.maxThreadsDim[1] as u32; + max_cube_dim.z = ll_device_props.maxThreadsDim[2] as u32; + + // Just to be sure we check both. + mem_alignment = usize::max(mem_alignment, ll_device_props.textureAlignment); + mem_alignment = usize::max(mem_alignment, ll_device_props.surfaceAlignment); + }; + let normalized_arch_name = prop_arch_name.split(':').next().unwrap_or(prop_arch_name); + let arch = AMDArchitecture::parse(normalized_arch_name).unwrap(); + assert_eq!(prop_warp_size as u32, arch.warp_size()); + + unsafe { + let status = cubecl_hip_sys::hipSetDevice(device.index as cubecl_hip_sys::hipDevice_t); + assert_eq!( + status, HIP_SUCCESS, + "Should set the default device for the current thread" + ); + } - let max_memory = unsafe { - let free: usize = 0; - let total: usize = 0; - let status = cubecl_hip_sys::hipMemGetInfo( - &free as *const _ as *mut usize, - &total as *const _ as *mut usize, - ); - assert_eq!( - status, HIP_SUCCESS, - "Should get the available memory of the device" + let max_memory = unsafe { + let free: usize = 0; + let total: usize = 0; + let status = cubecl_hip_sys::hipMemGetInfo( + &free as *const _ as *mut usize, + &total as *const _ as *mut usize, + ); + assert_eq!( + status, HIP_SUCCESS, + "Should get the available memory of the device" + ); + total + }; + let mem_properties = MemoryDeviceProperties { + max_page_size: max_memory as u64 / 4, + alignment: mem_alignment as u64, + }; + + let supported_wmma_combinations = HipWmmaCompiler::supported_wmma_combinations(&arch); + let supported_mma_combinations = HipWmmaCompiler::supported_mma_combinations(&arch); + let supported_scaled_mma_combinations = + HipWmmaCompiler::supported_scaled_mma_combinations(&arch); + + let topology = HardwareProperties { + plane_size_min: prop_warp_size as u32, + plane_size_max: prop_warp_size as u32, + max_bindings: crate::device::AMD_MAX_BINDINGS, + max_shared_memory_size: prop_max_shared_memory_size, + max_cube_count, + max_units_per_cube: prop_max_threads, + max_cube_dim, + num_streaming_multiprocessors: None, + num_tensor_cores: None, + min_tensor_cores_dim: if supported_wmma_combinations.is_empty() { + None + } else { + Some(16) + }, + }; + + let mut device_props = DeviceProperties::new( + Default::default(), + mem_properties.clone(), + topology, + TimingMethod::System, ); - total - }; - let mem_properties = MemoryDeviceProperties { - max_page_size: max_memory as u64 / 4, - alignment: mem_alignment as u64, - }; - - let supported_wmma_combinations = M::supported_wmma_combinations(&arch); - let supported_mma_combinations = M::supported_mma_combinations(&arch); - let supported_scaled_mma_combinations = M::supported_scaled_mma_combinations(&arch); - - let topology = HardwareProperties { - plane_size_min: prop_warp_size as u32, - plane_size_max: prop_warp_size as u32, - max_bindings: crate::device::AMD_MAX_BINDINGS, - max_shared_memory_size: prop_max_shared_memory_size, - max_cube_count, - max_units_per_cube: prop_max_threads, - max_cube_dim, - num_streaming_multiprocessors: None, - num_tensor_cores: None, - min_tensor_cores_dim: if supported_wmma_combinations.is_empty() { - None - } else { - Some(16) - }, - }; - - let mut device_props = DeviceProperties::new( - Default::default(), - mem_properties.clone(), - topology, - TimingMethod::System, - ); - register_supported_types(&mut device_props); - - // TODO look into unsafeAtomicAdd (https://github.com/ROCm/HIP/issues/3573120) - // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F16))); - // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16))); - - device_props.features.dynamic_line_size = true; - device_props.features.plane.insert(Plane::Ops); - - register_wmma_features(supported_wmma_combinations, &mut device_props); - register_mma_features(supported_mma_combinations, &mut device_props); - register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props); - - let comp_opts = CompilationOptions { - warp_size: arch.warp_size(), - grid_constants: false, - supports_clusters: false, - }; - let hip_ctx = HipContext::new(comp_opts); - let server = HipServer::new( - hip_ctx, - mem_properties, - options.memory_config, - mem_alignment, - ); - ComputeClient::new(Channel::new(server), device_props, ()) + register_supported_types(&mut device_props); + + // TODO look into unsafeAtomicAdd (https://github.com/ROCm/HIP/issues/3573120) + // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F16))); + // device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::BF16))); + + device_props.features.dynamic_line_size = true; + device_props.features.plane.insert(Plane::Ops); + + register_wmma_features(supported_wmma_combinations, &mut device_props); + register_mma_features(supported_mma_combinations, &mut device_props); + register_scaled_mma_features(supported_scaled_mma_combinations, &mut device_props); + + let comp_opts = CompilationOptions { + warp_size: arch.warp_size(), + grid_constants: false, + supports_clusters: false, + supports_fast_math: true, + }; + let hip_ctx = HipContext::new(comp_opts); + let logger = Arc::new(ServerLogger::default()); + let utilities = ServerUtilities::new(device_props, logger, ()); + let options = RuntimeOptions::default(); + + HipServer::new( + hip_ctx, + mem_properties, + options.memory_config, + mem_alignment, + utilities, + ) + } } impl Runtime for HipRuntime { type Compiler = HipCompiler; type Server = HipServer; - type Channel = Channel; type Device = AmdDevice; - fn client(device: &Self::Device) -> ComputeClient { - RUNTIME.client(device, move || { - create_client::(device, RuntimeOptions::default()) - }) + fn client(device: &Self::Device) -> ComputeClient { + ComputeClient::load(device) } - fn name(_client: &ComputeClient) -> &'static str { + fn name(_client: &ComputeClient) -> &'static str { "hip" } diff --git a/crates/cubecl-ir/Cargo.toml b/crates/cubecl-ir/Cargo.toml index d13abe65f..1b6b1a8be 100644 --- a/crates/cubecl-ir/Cargo.toml +++ b/crates/cubecl-ir/Cargo.toml @@ -17,18 +17,19 @@ version.workspace = true [features] default = ["serde", "std"] -serde = ["dep:serde", "hashbrown/serde"] +serde = ["dep:serde", "hashbrown/serde", "enumset/serde"] std = [] [dependencies] -cubecl-common = { path = "../cubecl-common", version = "0.7", features = [ +cubecl-common = { path = "../cubecl-common", version = "0.9", features = [ "fp4", "fp8", ] } -cubecl-macros-internal = { path = "../cubecl-macros-internal", version = "0.7" } +cubecl-macros-internal = { path = "../cubecl-macros-internal", version = "0.9" } derive-new = { workspace = true } derive_more = { workspace = true, features = ["from"] } +enumset = { workspace = true } float-ord = "0.3" fnv = { workspace = true } half = { workspace = true } diff --git a/crates/cubecl-ir/src/arithmetic.rs b/crates/cubecl-ir/src/arithmetic.rs index 088f624dd..f4ff66568 100644 --- a/crates/cubecl-ir/src/arithmetic.rs +++ b/crates/cubecl-ir/src/arithmetic.rs @@ -41,10 +41,11 @@ pub enum Arithmetic { Powf(BinaryOperator), Powi(BinaryOperator), Sqrt(UnaryOperator), - Rsqrt(UnaryOperator), + InverseSqrt(UnaryOperator), Round(UnaryOperator), Floor(UnaryOperator), Ceil(UnaryOperator), + Trunc(UnaryOperator), Erf(UnaryOperator), Recip(UnaryOperator), Clamp(ClampOperator), @@ -95,10 +96,11 @@ impl Display for Arithmetic { Arithmetic::Powf(op) => write!(f, "{}.pow({})", op.lhs, op.rhs), Arithmetic::Powi(op) => write!(f, "{}.powi({})", op.lhs, op.rhs), Arithmetic::Sqrt(op) => write!(f, "{}.sqrt()", op.input), - Arithmetic::Rsqrt(op) => write!(f, "{}.rsqrt()", op.input), + Arithmetic::InverseSqrt(op) => write!(f, "{}.inverse_sqrt()", op.input), Arithmetic::Round(op) => write!(f, "{}.round()", op.input), Arithmetic::Floor(op) => write!(f, "{}.floor()", op.input), Arithmetic::Ceil(op) => write!(f, "{}.ceil()", op.input), + Arithmetic::Trunc(op) => write!(f, "{}.trunc()", op.input), Arithmetic::Erf(op) => write!(f, "{}.erf()", op.input), Arithmetic::Recip(op) => write!(f, "{}.recip()", op.input), Arithmetic::Clamp(op) => { diff --git a/crates/cubecl-ir/src/lib.rs b/crates/cubecl-ir/src/lib.rs index 5477ed58d..fe4810c7e 100644 --- a/crates/cubecl-ir/src/lib.rs +++ b/crates/cubecl-ir/src/lib.rs @@ -10,6 +10,7 @@ mod bitwise; mod branch; mod cmma; mod comparison; +mod marker; mod metadata; mod non_semantic; mod operation; @@ -33,6 +34,7 @@ pub use bitwise::*; pub use branch::*; pub use cmma::*; pub use comparison::*; +pub use marker::*; pub use metadata::*; pub use non_semantic::*; pub use operation::*; diff --git a/crates/cubecl-ir/src/marker.rs b/crates/cubecl-ir/src/marker.rs new file mode 100644 index 000000000..a8e4eba0a --- /dev/null +++ b/crates/cubecl-ir/src/marker.rs @@ -0,0 +1,75 @@ +use core::fmt::Display; + +use enumset::{EnumSet, EnumSetType}; + +use crate::{Instruction, Operation, TypeHash}; + +use crate::{OperationCode, OperationReflect}; + +use super::Variable; + +/// Operations that don't change the semantics of the kernel. In other words, operations that do not +/// perform any computation, if they run at all. i.e. `println`, comments and debug symbols. +/// +/// Can be safely removed or ignored without changing the kernel result. +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationCode)] +#[operation(opcode_name = MarkerOpCode)] +pub enum Marker { + /// Frees a shared memory, allowing reuse in later blocks. + Free(Variable), +} + +impl OperationReflect for Marker { + type OpCode = MarkerOpCode; + + fn op_code(&self) -> Self::OpCode { + self.__match_opcode() + } +} + +impl Display for Marker { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Marker::Free(var) => write!(f, "free({var})"), + } + } +} + +impl From for Instruction { + fn from(value: Marker) -> Self { + Instruction::no_out(Operation::Marker(value)) + } +} + +/// Unchecked optimizations for float operations. May cause precision differences, or undefined +/// behaviour if the relevant conditions are not followed. +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Hash, TypeHash, EnumSetType)] +pub enum FastMath { + /// Assume values are never `NaN`. If they are, the result is considered undefined behaviour. + NotNaN, + /// Assume values are never `Inf`/`-Inf`. If they are, the result is considered undefined + /// behaviour. + NotInf, + /// Ignore sign on zero values. + UnsignedZero, + /// Allow swapping float division with a reciprocal, even if that swap would change precision. + AllowReciprocal, + /// Allow contracting float operations into fewer operations, even if the precision could + /// change. + AllowContraction, + /// Allow reassociation for float operations, even if the precision could change. + AllowReassociation, + /// Allow all mathematical transformations for float operations, including contraction and + /// reassociation, even if the precision could change. + AllowTransform, + /// Allow using lower precision intrinsics + ReducedPrecision, +} + +impl FastMath { + pub const fn all() -> EnumSet { + EnumSet::all() + } +} diff --git a/crates/cubecl-ir/src/operation.rs b/crates/cubecl-ir/src/operation.rs index 83690c050..1ecee54ce 100644 --- a/crates/cubecl-ir/src/operation.rs +++ b/crates/cubecl-ir/src/operation.rs @@ -2,8 +2,8 @@ use core::fmt::Display; use super::{Branch, CoopMma, NonSemantic, Plane, Synchronization, Type, Variable}; use crate::{ - Arithmetic, AtomicOp, Bitwise, Metadata, OperationArgs, OperationReflect, Operator, TmaOps, - comparison::Comparison, + Arithmetic, AtomicOp, Bitwise, InstructionModes, Metadata, OperationArgs, OperationReflect, + Operator, TmaOps, comparison::Comparison, marker::Marker, }; use crate::{BarrierOps, SourceLoc, TypeHash}; use alloc::{ @@ -55,9 +55,9 @@ pub enum Operation { /// Non-semantic instructions (i.e. comments, debug info) #[operation(nested)] NonSemantic(NonSemantic), - /// Frees a shared memory, allowing reuse in later blocks. Only used as a marker for the shared - /// memory analysis, should be ignored by compilers. - Free(Variable), + // Markers used by compilers to update state or modes, but don't emit instructions + #[operation(nested)] + Marker(Marker), } /// An instruction that contains a right hand side [`Operation`] and an optional out variable. @@ -66,6 +66,7 @@ pub enum Operation { pub struct Instruction { pub out: Option, pub source_loc: Option, + pub modes: InstructionModes, pub operation: Operation, } @@ -75,6 +76,7 @@ impl Instruction { out: Some(out), operation: operation.into(), source_loc: None, + modes: Default::default(), } } @@ -83,6 +85,7 @@ impl Instruction { out: None, operation: operation.into(), source_loc: None, + modes: Default::default(), } } @@ -192,7 +195,7 @@ impl Display for Operation { Operation::NonSemantic(non_semantic) => write!(f, "{non_semantic}"), Operation::Barrier(barrier_ops) => write!(f, "{barrier_ops}"), Operation::Tma(tma_ops) => write!(f, "{tma_ops}"), - Operation::Free(var) => write!(f, "free({var})"), + Operation::Marker(marker) => write!(f, "{marker}"), } } } diff --git a/crates/cubecl-ir/src/plane.rs b/crates/cubecl-ir/src/plane.rs index ecc6de5d5..c0ff09156 100644 --- a/crates/cubecl-ir/src/plane.rs +++ b/crates/cubecl-ir/src/plane.rs @@ -18,6 +18,10 @@ pub enum Plane { Any(UnaryOperator), Ballot(UnaryOperator), Broadcast(BinaryOperator), + Shuffle(BinaryOperator), + ShuffleXor(BinaryOperator), + ShuffleUp(BinaryOperator), + ShuffleDown(BinaryOperator), Sum(UnaryOperator), InclusiveSum(UnaryOperator), ExclusiveSum(UnaryOperator), @@ -38,6 +42,18 @@ impl Display for Plane { Plane::Broadcast(op) => { write!(f, "plane_broadcast({}, {})", op.lhs, op.rhs) } + Plane::Shuffle(op) => { + write!(f, "plane_shuffle({}, {})", op.lhs, op.rhs) + } + Plane::ShuffleXor(op) => { + write!(f, "plane_shuffle_xor({}, {})", op.lhs, op.rhs) + } + Plane::ShuffleUp(op) => { + write!(f, "plane_shuffle_up({}, {})", op.lhs, op.rhs) + } + Plane::ShuffleDown(op) => { + write!(f, "plane_shuffle_down({}, {})", op.lhs, op.rhs) + } Plane::Sum(op) => write!(f, "plane_sum({})", op.input), Plane::InclusiveSum(op) => write!(f, "plane_inclusive_sum({})", op.input), Plane::ExclusiveSum(op) => write!(f, "plane_exclusive_sum({})", op.input), diff --git a/crates/cubecl-ir/src/processing.rs b/crates/cubecl-ir/src/processing.rs index fa154c2a0..afc32ea34 100644 --- a/crates/cubecl-ir/src/processing.rs +++ b/crates/cubecl-ir/src/processing.rs @@ -159,7 +159,7 @@ impl ScopeProcessing { Arithmetic::Sqrt(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } - Arithmetic::Rsqrt(op) => { + Arithmetic::InverseSqrt(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } Arithmetic::Round(op) => { @@ -171,6 +171,9 @@ impl ScopeProcessing { Arithmetic::Ceil(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } + Arithmetic::Trunc(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } Arithmetic::Erf(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } @@ -451,7 +454,7 @@ impl ScopeProcessing { Operation::Tma(_) => { // Nothing to do } - Operation::Free(_) => { + Operation::Marker(_) => { // Nothing to do } }); diff --git a/crates/cubecl-ir/src/scope.rs b/crates/cubecl-ir/src/scope.rs index 4566601b3..00d757df1 100644 --- a/crates/cubecl-ir/src/scope.rs +++ b/crates/cubecl-ir/src/scope.rs @@ -1,9 +1,10 @@ use alloc::{borrow::Cow, rc::Rc, string::ToString, vec::Vec}; use core::{any::TypeId, cell::RefCell, fmt::Display}; +use enumset::EnumSet; use hashbrown::{HashMap, HashSet}; use crate::{ - BarrierLevel, CubeFnSource, ExpandElement, Matrix, Processor, SourceLoc, StorageType, + BarrierLevel, CubeFnSource, ExpandElement, FastMath, Matrix, Processor, SourceLoc, StorageType, TargetProperties, TypeHash, }; @@ -38,6 +39,7 @@ pub struct Scope { #[cfg_attr(feature = "serde", serde(skip))] pub typemap: Rc>>, pub runtime_properties: Rc, + pub modes: Rc>, } /// Debug related fields, most of these are global @@ -51,6 +53,13 @@ pub struct DebugInfo { pub entry_loc: Option, } +/// Modes set and reset during expansion +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, TypeHash)] +pub struct InstructionModes { + pub fp_math_mode: EnumSet, +} + impl core::hash::Hash for Scope { fn hash(&self, ra_expand_state: &mut H) { self.depth.hash(ra_expand_state); @@ -104,6 +113,7 @@ impl Scope { }, typemap: Default::default(), runtime_properties: Rc::new(Default::default()), + modes: Default::default(), } } @@ -178,6 +188,7 @@ impl Scope { pub fn register>(&mut self, instruction: T) { let mut inst = instruction.into(); inst.source_loc = self.debug.source_loc.clone(); + inst.modes = *self.modes.borrow(); self.instructions.push(inst) } @@ -213,6 +224,7 @@ impl Scope { debug: self.debug.clone(), typemap: self.typemap.clone(), runtime_properties: self.runtime_properties.clone(), + modes: self.modes.clone(), } } diff --git a/crates/cubecl-ir/src/type_hash.rs b/crates/cubecl-ir/src/type_hash.rs index b1805c06b..040234f6a 100644 --- a/crates/cubecl-ir/src/type_hash.rs +++ b/crates/cubecl-ir/src/type_hash.rs @@ -1,5 +1,6 @@ use alloc::borrow::ToOwned; use core::hash::Hasher; +use enumset::EnumSetType; /// A hash of a type's structure pub trait TypeHash { @@ -121,6 +122,7 @@ impl_type_hash!( portable_atomic::AtomicU64, portable_atomic::AtomicU8, portable_atomic::AtomicUsize, + enumset::EnumSet, ); macro_rules! impl_type_hash_tuple { diff --git a/crates/cubecl-ir/src/variable.rs b/crates/cubecl-ir/src/variable.rs index 5c9d4f56b..499943373 100644 --- a/crates/cubecl-ir/src/variable.rs +++ b/crates/cubecl-ir/src/variable.rs @@ -334,7 +334,7 @@ impl ConstantScalarValue { } } - /// Returns the value of the scalar as a u32. + /// Returns the value of the scalar as a i64. /// /// It will panic if the scalar type is a float or a bool. pub fn as_i64(&self) -> i64 { @@ -342,6 +342,24 @@ impl ConstantScalarValue { .expect("Only Int and UInt kind can be made into i64.") } + /// Returns the value of the scalar as a f64. + /// + /// It will return [None] if the scalar type is an int or a bool. + pub fn try_as_f64(&self) -> Option { + match self { + ConstantScalarValue::Float(val, _) => Some(*val), + _ => None, + } + } + + /// Returns the value of the scalar as a f64. + /// + /// It will panic if the scalar type is an int or a bool. + pub fn as_f64(&self) -> f64 { + self.try_as_f64() + .expect("Only Float kind can be made into f64.") + } + /// Returns the value of the variable as a bool if it actually is a bool. pub fn try_as_bool(&self) -> Option { match self { diff --git a/crates/cubecl-macros/Cargo.toml b/crates/cubecl-macros/Cargo.toml index 4557c1e32..ec28e81c9 100644 --- a/crates/cubecl-macros/Cargo.toml +++ b/crates/cubecl-macros/Cargo.toml @@ -30,4 +30,4 @@ proc-macro2 = { workspace = true } quote = { workspace = true } syn = { workspace = true } -cubecl-common = { path = "../cubecl-common", version = "0.7", default-features = false } +cubecl-common = { path = "../cubecl-common", version = "0.9", default-features = false } diff --git a/crates/cubecl-macros/src/generate/kernel.rs b/crates/cubecl-macros/src/generate/kernel.rs index 6c9ae7adc..9c794405c 100644 --- a/crates/cubecl-macros/src/generate/kernel.rs +++ b/crates/cubecl-macros/src/generate/kernel.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use darling::usage::{CollectLifetimes as _, CollectTypeParams as _, GenericsExt as _, Purpose}; use proc_macro2::TokenStream; use quote::{ToTokens, format_ident, quote, quote_spanned}; @@ -22,36 +24,46 @@ impl KernelFn { }; let name = &self.full_name; - let (debug_source, debug_params) = if cfg!(debug_symbols) || self.debug_symbols { - let debug_source = frontend_type("debug_source_expand"); - let cube_debug = frontend_type("CubeDebug"); - let src_file = self.src_file.as_ref().map(|file| file.value()); - let src_file = src_file.or_else(|| { - let span: proc_macro::Span = self.span.unwrap(); - let source_path = span.local_file(); - let source_file = source_path.as_ref().and_then(|path| path.file_name()); - source_file.map(|file| file.to_string_lossy().into()) - }); - let source_text = match src_file { - Some(file) => quote![include_str!(#file)], - None => quote![""], - }; + let (debug_source, debug_params) = + if cfg!(debug_symbols) || self.args.debug_symbols.is_present() { + let debug_source = frontend_type("debug_source_expand"); + let cube_debug = frontend_type("CubeDebug"); + let src_file = self.args.src_file.as_ref().map(|file| file.value()); + let src_file = src_file.or_else(|| { + let span: proc_macro::Span = self.span.unwrap(); + let source_path = span.local_file(); + let source_file = source_path.as_ref().and_then(|path| path.file_name()); + source_file.map(|file| file.to_string_lossy().into()) + }); + let source_text = match src_file { + Some(file) => quote![include_str!(#file)], + None => quote![""], + }; - let debug_source = quote_spanned! {self.span=> - #debug_source(scope, #name, file!(), #source_text, line!(), column!()) + let debug_source = quote_spanned! {self.span=> + #debug_source(scope, #name, file!(), #source_text, line!(), column!()) + }; + let debug_params = sig + .runtime_params() + .map(|it| &it.name) + .map(|name| { + let name_str = name.to_string(); + quote! [#cube_debug::set_debug_name(&#name, scope, #name_str);] + }) + .collect(); + (debug_source, debug_params) + } else { + (TokenStream::new(), Vec::new()) }; - let debug_params = sig - .runtime_params() - .map(|it| &it.name) - .map(|name| { - let name_str = name.to_string(); - quote! [#cube_debug::set_debug_name(&#name, scope, #name_str);] - }) - .collect(); - (debug_source, debug_params) - } else { - (TokenStream::new(), Vec::new()) - }; + let body = self + .args + .fast_math + .as_ref() + .map(|value| { + let fast_math = frontend_type("fast_math_expand"); + quote![#fast_math(scope, #value, |scope| {#body})] + }) + .unwrap_or_else(|| quote![#body]); let out = quote! { #vis #sig { @@ -225,7 +237,13 @@ impl Launch { fn define_body(&self) -> TokenStream { let kernel_builder = prelude_type("KernelBuilder"); let io_map = self.io_mappings(); - let register_type = self.analysis.register_types(); + let mut mapping = HashMap::new(); + for param in self.func.sig.parameters.iter() { + for define in param.defines.iter() { + mapping.insert(define.clone(), param.name.clone()); + } + } + let register_type = self.analysis.register_types(&mapping); let runtime_args = self.runtime_params().map(|it| &it.name); let comptime_args = self.comptime_params().map(|it| &it.name); let generics = self.analysis.process_generics(&self.func.sig.generics); @@ -355,9 +373,6 @@ impl Launch { if self.args.debug_symbols.is_present() { settings.extend(quote![.debug_symbols()]); } - if let Some(mode) = &self.args.fast_math { - settings.extend(quote![.fp_math_mode((#mode).into())]); - } if let Some(cluster_dim) = &self.args.cluster_dim { settings.extend(quote![.cluster_dim(#cluster_dim)]); } diff --git a/crates/cubecl-macros/src/generate/launch.rs b/crates/cubecl-macros/src/generate/launch.rs index 376a04db1..93af408b0 100644 --- a/crates/cubecl-macros/src/generate/launch.rs +++ b/crates/cubecl-macros/src/generate/launch.rs @@ -15,6 +15,7 @@ impl ToTokens for Launch { let name = &self.func.sig.name; let launch = self.launch(); let launch_unchecked = self.launch_unchecked(); + let aliases = self.create_type_alias(); let dummy = self.create_dummy_kernel(); let kernel = self.kernel_definition(); let mut func = self.func.clone(); @@ -25,6 +26,8 @@ impl ToTokens for Launch { #vis mod #name { use super::*; + #aliases + #[allow(unused, clippy::all)] #func @@ -63,7 +66,7 @@ impl Launch { #[allow(clippy::too_many_arguments)] #[doc = #kernel_doc] pub fn launch #generics( - __client: &#compute_client<__R::Server, __R::Channel>, + __client: &#compute_client<__R::Server>, __cube_count: #cube_count, __cube_dim: #cube_dim, #(#args),* @@ -95,7 +98,7 @@ impl Launch { #[allow(clippy::too_many_arguments)] #[doc = #kernel_doc] pub unsafe fn launch_unchecked #generics( - __client: &#compute_client<__R::Server, __R::Channel>, + __client: &#compute_client<__R::Server>, __cube_count: #cube_count, __cube_dim: #cube_dim, #(#args),* @@ -147,6 +150,22 @@ impl Launch { } } + fn create_type_alias(&self) -> TokenStream { + let mut index = 0u8; + let mut aliases = quote! {}; + + for input in self.func.sig.parameters.iter() { + for define in input.defines.iter() { + aliases.extend(quote! { + /// Type to be used as a generic for launch kernel argument. + pub type #define = NumericExpand<#index>; + }); + index += 1; + } + } + + aliases + } fn create_dummy_kernel(&self) -> TokenStream { if self.args.create_dummy_kernel.is_present() { let cube_count = prelude_type("CubeCount"); diff --git a/crates/cubecl-macros/src/parse/branch.rs b/crates/cubecl-macros/src/parse/branch.rs index d33685c34..356bb5057 100644 --- a/crates/cubecl-macros/src/parse/branch.rs +++ b/crates/cubecl-macros/src/parse/branch.rs @@ -1,5 +1,7 @@ use quote::quote; -use syn::{Expr, ExprForLoop, ExprIf, ExprLoop, ExprMatch, Ident, Lit, Pat, spanned::Spanned}; +use syn::{ + Expr, ExprForLoop, ExprIf, ExprLoop, ExprMatch, Ident, Lit, Pat, parse_quote, spanned::Spanned, +}; use crate::{ expression::{Block, Expression}, @@ -9,13 +11,29 @@ use crate::{ use super::{helpers::Unroll, statement::parse_pat}; -pub fn expand_for_loop(for_loop: ExprForLoop, context: &mut Context) -> syn::Result { +pub fn expand_for_loop( + mut for_loop: ExprForLoop, + context: &mut Context, +) -> syn::Result { let span = for_loop.span(); - let unroll = Unroll::from_attributes(&for_loop.attrs, context)?.map(|it| it.value); + let unroll = Unroll::from_attributes(&for_loop.attrs, context)?; + let var = parse_pat(*for_loop.pat)?; + if let Some(Unroll { + always_true: true, .. + }) = unroll + && var.ident != "_" + { + let var_name = &var.ident; + for_loop.body.stmts.insert( + 0, + parse_quote![let #var_name = #var_name.into_lit_unchecked();], + ); + }; + + let unroll = unroll.map(|it| it.value); let right = Expression::from_expr(*for_loop.expr.clone(), context) .map_err(|_| syn::Error::new(span, "Unsupported for loop expression"))?; - let var = parse_pat(*for_loop.pat)?; if right.is_const() && !matches!(right, Expression::Range { .. }) { return expand_for_in_loop(var.ident, right, for_loop.body, context); diff --git a/crates/cubecl-macros/src/parse/cube_impl.rs b/crates/cubecl-macros/src/parse/cube_impl.rs index e55cfd582..a89814c72 100644 --- a/crates/cubecl-macros/src/parse/cube_impl.rs +++ b/crates/cubecl-macros/src/parse/cube_impl.rs @@ -137,9 +137,11 @@ impl CubeImplItem { body, full_name: func.full_name.clone(), span: func.span, - context: Context::new(func.context.return_type.clone(), func.debug_symbols), - src_file: func.src_file.clone(), - debug_symbols: func.debug_symbols, + context: Context::new( + func.context.return_type.clone(), + func.args.debug_symbols.is_present(), + ), + args: func.args.clone(), } } @@ -194,9 +196,11 @@ impl CubeImplItem { body: KernelBody::Verbatim(body), full_name: func.full_name.clone(), span: func.span, - context: Context::new(func.context.return_type.clone(), func.debug_symbols), - src_file: func.src_file.clone(), - debug_symbols: func.debug_symbols, + context: Context::new( + func.context.return_type.clone(), + func.args.debug_symbols.is_present(), + ), + args: func.args.clone(), } } } diff --git a/crates/cubecl-macros/src/parse/expression.rs b/crates/cubecl-macros/src/parse/expression.rs index aec50e567..bb5763a0e 100644 --- a/crates/cubecl-macros/src/parse/expression.rs +++ b/crates/cubecl-macros/src/parse/expression.rs @@ -482,7 +482,7 @@ fn fn_associated_type(path: &Expression) -> Option<(Path, Option, PathSeg // All supported primitives. Primitives don't start with an uppercase letter const PRIMITIVES: &[&str] = &[ "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f16", "bf16", "f32", "f64", - "flex32", "e2m1", "e2m3", "e3m2", "e4m3", "e5m2", "ue8m0", + "flex32", "e2m1", "e2m1x2", "e2m3", "e3m2", "e4m3", "e5m2", "ue8m0", ]; if !matches!(path, Expression::Path { .. }) { panic!("path: {path:?}"); diff --git a/crates/cubecl-macros/src/parse/helpers.rs b/crates/cubecl-macros/src/parse/helpers.rs index a97de2fb9..5081b5396 100644 --- a/crates/cubecl-macros/src/parse/helpers.rs +++ b/crates/cubecl-macros/src/parse/helpers.rs @@ -8,6 +8,7 @@ use crate::{expression::Expression, paths::prelude_path, scope::Context}; pub struct Unroll { pub value: Expression, + pub always_true: bool, } impl Unroll { @@ -29,16 +30,23 @@ impl Unroll { let res = match &attr.meta { syn::Meta::Path(_) => Self { value: Expression::from_expr(parse_quote![true], context).unwrap(), + always_true: true, }, syn::Meta::List(list) => { let expr = syn::parse2(list.tokens.clone())?; let expr = Expression::from_expr(expr, context)?; - Self { value: expr } + Self { + value: expr, + always_true: false, + } } meta => { let expr = NameVal::from_meta(meta)?; let expr = Expression::from_expr(expr.value, context)?; - Self { value: expr } + Self { + value: expr, + always_true: false, + } } }; Ok(Some(res)) @@ -65,8 +73,12 @@ pub struct RemoveHelpers; impl VisitMut for RemoveHelpers { fn visit_fn_arg_mut(&mut self, i: &mut syn::FnArg) { match i { - syn::FnArg::Receiver(recv) => recv.attrs.retain(|it| !is_comptime_attr(it)), - syn::FnArg::Typed(typed) => typed.attrs.retain(|it| !is_comptime_attr(it)), + syn::FnArg::Receiver(recv) => recv + .attrs + .retain(|it| !is_comptime_attr(it) && !is_define_attribute(it)), + syn::FnArg::Typed(typed) => typed + .attrs + .retain(|it| !is_comptime_attr(it) && !is_define_attribute(it)), } visit_mut::visit_fn_arg_mut(self, i); } @@ -197,6 +209,13 @@ pub fn is_expr_attribute(attr: &Attribute) -> bool { attr.path().is_ident("expr") } +pub fn is_define_attribute(attr: &Attribute) -> bool { + attr.path().is_ident("define") +} + pub fn is_helper(attr: &Attribute) -> bool { - is_comptime_attr(attr) || is_unroll_attr(attr) || is_expr_attribute(attr) + is_comptime_attr(attr) + || is_unroll_attr(attr) + || is_expr_attribute(attr) + || is_define_attribute(attr) } diff --git a/crates/cubecl-macros/src/parse/kernel.rs b/crates/cubecl-macros/src/parse/kernel.rs index 514a09334..f61729d12 100644 --- a/crates/cubecl-macros/src/parse/kernel.rs +++ b/crates/cubecl-macros/src/parse/kernel.rs @@ -16,7 +16,7 @@ use syn::{ use super::{desugar::Desugar, helpers::is_comptime_attr, statement::parse_pat}; -#[derive(Default, FromMeta)] +#[derive(Default, FromMeta, Clone)] pub(crate) struct KernelArgs { pub launch: Flag, pub launch_unchecked: Flag, @@ -33,7 +33,7 @@ pub(crate) struct KernelArgs { pub self_type: SelfType, } -#[derive(Default, FromMeta, PartialEq, Eq)] +#[derive(Default, FromMeta, PartialEq, Eq, Clone, Copy)] pub(crate) enum SelfType { #[default] Owned, @@ -83,14 +83,18 @@ impl GenericAnalysis { } } - pub fn register_types(&self) -> TokenStream { + pub fn register_types(&self, name_mapping: &HashMap) -> TokenStream { let mut output = quote![]; for (name, ty) in self.map.iter() { + let name = match name_mapping.get(name) { + Some(name) => quote! { self.#name.into() }, + None => quote! {#name::as_type_native_unchecked()}, + }; output.extend(quote! { builder .scope - .register_type::<#ty>(#name::as_type_native_unchecked()); + .register_type::<#ty>(#name); }); } @@ -200,10 +204,9 @@ pub struct KernelFn { pub sig: KernelSignature, pub body: KernelBody, pub full_name: String, - pub debug_symbols: bool, pub span: Span, pub context: Context, - pub src_file: Option, + pub args: KernelArgs, } #[allow(clippy::large_enum_variant)] @@ -213,7 +216,7 @@ pub enum KernelBody { Verbatim(TokenStream), } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct KernelSignature { pub name: Ident, pub parameters: Vec, @@ -228,7 +231,7 @@ impl KernelSignature { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum KernelReturns { ExpandType(Type), Plain(Type), @@ -248,6 +251,7 @@ pub struct KernelParam { pub name: Ident, pub ty: Type, pub normalized_ty: Type, + pub defines: Vec, pub is_const: bool, pub is_mut: bool, pub is_ref: bool, @@ -276,6 +280,7 @@ impl KernelParam { name: Ident::new("self", param.span()), ty: *param.ty, normalized_ty, + defines: Vec::new(), is_const: false, is_mut, is_ref, @@ -288,13 +293,27 @@ impl KernelParam { mut is_mut, .. } = parse_pat(*param.pat.clone())?; - let is_const = param.attrs.iter().any(is_comptime_attr); + let mut is_const = false; + let mut defines = Vec::new(); + + for attr in param.attrs.iter() { + if is_comptime_attr(attr) { + is_const = true; + } + if attr.path().is_ident("define") { + let ident: Ident = attr.parse_args().unwrap(); + defines.push(ident); + is_const = true; + } + } + let ty = *param.ty.clone(); let normalized_ty = normalize_kernel_ty(*param.ty, is_const, &mut is_ref, &mut is_mut); Ok(Self { name: ident, ty, + defines, normalized_ty, is_const, is_mut, @@ -393,7 +412,6 @@ impl KernelFn { full_name: String, args: &KernelArgs, ) -> syn::Result { - let src_file = args.src_file.clone(); let debug_symbols = args.debug_symbols.is_present(); let span = Span::call_site(); @@ -413,9 +431,8 @@ impl KernelFn { body: KernelBody::Block(block), full_name, span, - src_file, context, - debug_symbols, + args: args.clone(), }) } @@ -491,18 +508,36 @@ impl Launch { } let mut kernel_generics = func.sig.generics.clone(); + kernel_generics.params.clear(); + + for param in func.sig.generics.params.iter() { + // We remove generic arguments based on defined types. + if let syn::GenericParam::Type(tp) = param + && func + .sig + .parameters + .iter() + .any(|p| p.defines.contains(&tp.ident)) + { + continue; + }; + + kernel_generics.params.push(param.clone()); + } + kernel_generics.params.push(parse_quote![__R: #runtime]); - let mut expand_generics = kernel_generics.clone(); - expand_generics.params = - Punctuated::from_iter(iter::once(parse_quote!['kernel]).chain(expand_generics.params)); + let mut launch_generics = kernel_generics.clone(); + launch_generics.params = + Punctuated::from_iter(iter::once(parse_quote!['kernel]).chain(launch_generics.params)); + let analysis = GenericAnalysis::from_generics(&func.sig.generics); Ok(Launch { args, vis, func, + launch_generics, kernel_generics, - launch_generics: expand_generics, analysis, }) } diff --git a/crates/cubecl-matmul/Cargo.toml b/crates/cubecl-matmul/Cargo.toml index c89c5100c..75cbf144c 100644 --- a/crates/cubecl-matmul/Cargo.toml +++ b/crates/cubecl-matmul/Cargo.toml @@ -75,12 +75,12 @@ matmul_tests_all = [ [dependencies] bytemuck = { workspace = true } -cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false } -cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false } -cubecl-std = { path = "../cubecl-std", version = "0.7.0", default-features = false } -cubecl-reduce = { path = "../cubecl-reduce", version = "0.7.0", default-features = false } -cubecl-random = { path = "../cubecl-random", version = "0.7.0", default-features = false } +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false } +cubecl-core = { path = "../cubecl-core", version = "0.9.0", default-features = false } +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false } +cubecl-std = { path = "../cubecl-std", version = "0.9.0", default-features = false } +cubecl-reduce = { path = "../cubecl-reduce", version = "0.9.0", default-features = false } +cubecl-random = { path = "../cubecl-random", version = "0.9.0", default-features = false } half = { workspace = true, features = ["bytemuck"] } pretty_assertions = { workspace = true, optional = true } serde = { workspace = true } diff --git a/crates/cubecl-matmul/src/base.rs b/crates/cubecl-matmul/src/base.rs index 14c986c3a..301f0b4eb 100644 --- a/crates/cubecl-matmul/src/base.rs +++ b/crates/cubecl-matmul/src/base.rs @@ -1,15 +1,17 @@ +use cubecl_common::quant::scheme::{QuantScheme, QuantStore, QuantValue}; use cubecl_core::{ Runtime, client::ComputeClient, - prelude::{Numeric, TensorHandleRef}, + prelude::{CubePrimitive, Numeric, TensorHandleRef}, }; -use cubecl_std::tensor::TensorHandle; +use cubecl_std::tensor::{TensorHandle, into_contiguous_packed, into_contiguous_pitched}; +use serde::{Deserialize, Serialize}; use crate::{ components::{ AccG, LhsG, MatmulSetupError, RhsG, - tile::{accelerated::AcceleratedMatmul, io::Filled}, + tile::{cmma::CmmaMatmul, io::Filled, mma::MmaMatmul}, }, kernels::layered::{ Selection, @@ -54,21 +56,35 @@ use super::{ /// Most strategies have a selection input that can be overwritten or inferred from minimal information /// Some strategies must have a specified loading strategy pub enum Strategy { - Simple(SyncReadingStrategy, Selection), - SimpleBarrier(AsyncReadingStrategy), - DoubleBuffering(SyncPartialReadingStrategy, Selection), + Simple { + read_strategy: SyncReadingStrategy, + selection: Selection, + tile_kind: AcceleratedTileKind, + }, + SimpleBarrier { + read_strategy: AsyncReadingStrategy, + tile_kind: AcceleratedTileKind, + }, + DoubleBuffering { + read_strategy: SyncPartialReadingStrategy, + selection: Selection, + tile_kind: AcceleratedTileKind, + }, SimpleUnit(Selection), DoubleUnit(Selection), SimpleVecMat(Selection<()>), DoubleVecMat(Selection<()>), - OrderedDoubleBuffering(Selection), + OrderedDoubleBuffering { + selection: Selection, + tile_kind: AcceleratedTileKind, + }, Naive, #[default] /// Tries using a Simple matmul, then a SimpleUnit if the former failed Auto, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] /// Which reader to use in simple algorithms pub enum SyncReadingStrategy { Cyclic, @@ -76,7 +92,7 @@ pub enum SyncReadingStrategy { Tilewise, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] /// Which reader to use in double buffering algorithms pub enum SyncPartialReadingStrategy { Cyclic, @@ -84,7 +100,7 @@ pub enum SyncPartialReadingStrategy { Hybrid, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] /// Which reader to use in barrier algorithm pub enum AsyncReadingStrategy { Cooperative, @@ -94,11 +110,36 @@ pub enum AsyncReadingStrategy { Tma, } -pub enum MatmulInputHandle { +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)] +/// Which tile matmul to use for accelerated algorithms +pub enum AcceleratedTileKind { + #[default] + Cmma, + Mma, +} + +macro_rules! with_tile_kind { + ($kind: expr, $T: ident, $launch: expr) => { + match $kind { + AcceleratedTileKind::Cmma => { + type $T = CmmaMatmul; + ($launch)() + } + AcceleratedTileKind::Mma => { + type $T = MmaMatmul; + ($launch)() + } + } + }; +} + +pub enum MatmulInputHandle { Normal(TensorHandle), Quantized { data: TensorHandle, - scale: TensorHandle, + scale: TensorHandle, + shape: Vec, + scheme: QuantScheme, }, } @@ -106,21 +147,81 @@ impl MatmulInputHandle { pub fn as_ref(&self) -> MatmulInputHandleRef<'_, R> { match self { MatmulInputHandle::Normal(handle) => MatmulInputHandleRef::Normal(handle.as_ref()), - MatmulInputHandle::Quantized { data, scale } => MatmulInputHandleRef::Quantized { + MatmulInputHandle::Quantized { + data, + scale, + shape, + scheme, + } => MatmulInputHandleRef::Quantized { data: data.as_ref(), scale: scale.as_ref(), + shape, + scheme, + }, + } + } + + pub fn from_ref(handle: &MatmulInputHandleRef<'_, R>) -> Self { + match handle { + MatmulInputHandleRef::Normal(handle) => { + MatmulInputHandle::Normal(TensorHandle::from_ref(handle)) + } + MatmulInputHandleRef::Quantized { + data, + scale, + shape, + scheme, + } => MatmulInputHandle::Quantized { + data: TensorHandle::from_ref(data), + scale: TensorHandle::from_ref(scale), + shape: shape.to_vec(), + scheme: **scheme, }, } } + + pub fn data(&self) -> &TensorHandle { + match self { + MatmulInputHandle::Normal(handle) => handle, + MatmulInputHandle::Quantized { data, .. } => data, + } + } + + pub fn swap_dims(&mut self, dim0: usize, dim1: usize) { + match self { + MatmulInputHandle::Normal(handle) => { + handle.shape.swap(dim0, dim1); + handle.strides.swap(dim0, dim1); + } + MatmulInputHandle::Quantized { + data, scale, shape, .. + } => { + data.shape.swap(dim0, dim1); + data.strides.swap(dim0, dim1); + if scale.shape.len() == data.shape.len() { + scale.shape.swap(dim0, dim1); + scale.strides.swap(dim0, dim1); + } + shape.swap(dim0, dim1); + } + } + } } -impl Clone for MatmulInputHandle { +impl Clone for MatmulInputHandle { fn clone(&self) -> Self { match self { Self::Normal(handle) => Self::Normal(handle.clone()), - Self::Quantized { data, scale } => Self::Quantized { + Self::Quantized { + data, + scale, + shape, + scheme, + } => Self::Quantized { data: data.clone(), scale: scale.clone(), + shape: shape.clone(), + scheme: *scheme, }, } } @@ -132,6 +233,9 @@ pub enum MatmulInputHandleRef<'a, R: Runtime> { Quantized { data: TensorHandleRef<'a, R>, scale: TensorHandleRef<'a, R>, + /// Unpacked shape, excluding padding + shape: &'a [usize], + scheme: &'a QuantScheme, }, } @@ -148,8 +252,18 @@ impl<'a, R: Runtime> MatmulInputHandleRef<'a, R> { Self::Normal(data) } - pub fn quantized(data: TensorHandleRef<'a, R>, scale: TensorHandleRef<'a, R>) -> Self { - Self::Quantized { data, scale } + pub fn quantized( + data: TensorHandleRef<'a, R>, + scale: TensorHandleRef<'a, R>, + shape: &'a [usize], + scheme: &'a QuantScheme, + ) -> Self { + Self::Quantized { + data, + scale, + shape, + scheme, + } } pub fn data(&self) -> &TensorHandleRef<'a, R> { @@ -172,12 +286,69 @@ impl<'a, R: Runtime> MatmulInputHandleRef<'a, R> { MatmulInputHandleRef::Quantized { scale, .. } => Some(scale), } } + + pub fn scheme(&self) -> Option<&QuantScheme> { + match self { + MatmulInputHandleRef::Normal(_) => None, + MatmulInputHandleRef::Quantized { scheme, .. } => Some(scheme), + } + } + + pub fn shape(&self) -> &[usize] { + match self { + MatmulInputHandleRef::Normal(handle) => handle.shape, + MatmulInputHandleRef::Quantized { shape, .. } => shape, + } + } + + pub fn into_contiguous( + &self, + client: &ComputeClient, + ) -> MatmulInputHandle { + match self { + MatmulInputHandleRef::Normal(data) => { + MatmulInputHandle::Normal(into_contiguous_pitched::(client, data)) + } + MatmulInputHandleRef::Quantized { + data, + scale, + shape, + scheme, + } => { + let data = match scheme.store { + // e2m1 has native packing (e2m1x2) so also needs to be re-packed + QuantStore::Native if scheme.value == QuantValue::E2M1 => { + let data = into_contiguous_packed::(client, data, shape, 2); + // Unsafely cast to E + TensorHandle::from_ref(&data.as_ref()) + } + QuantStore::U32 => { + let data = into_contiguous_packed::( + client, + data, + shape, + scheme.num_quants() as u32, + ); + // Unsafely cast to E + TensorHandle::from_ref(&data.as_ref()) + } + _ => into_contiguous_pitched::(client, data), + }; + MatmulInputHandle::Quantized { + data, + scale: TensorHandle::from_ref(scale), + shape: shape.to_vec(), + scheme: **scheme, + } + } + } + } } #[allow(clippy::result_large_err)] pub fn launch( strategy: &Strategy, - client: &ComputeClient, + client: &ComputeClient, lhs: MatmulInputHandle>, rhs: MatmulInputHandle>, out: TensorHandle>, @@ -194,15 +365,17 @@ pub fn launch( #[allow(clippy::result_large_err)] pub fn launch_ref( strategy: &Strategy, - client: &ComputeClient, + client: &ComputeClient, lhs: &MatmulInputHandleRef, rhs: &MatmulInputHandleRef, out: &TensorHandleRef, ) -> Result<(), MatmulSetupError> { - type Accelerated = AcceleratedMatmul; - match strategy { - Strategy::Simple(loading_strategy, selection) => match loading_strategy { + Strategy::Simple { + read_strategy, + selection, + tile_kind, + } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy { SyncReadingStrategy::Cyclic => { layered::launch_ref::>( client, lhs, rhs, out, selection, @@ -228,8 +401,11 @@ pub fn launch_ref( >, >(client, lhs, rhs, out, &Default::default()) } - }, - Strategy::SimpleBarrier(loading_strategy) => match loading_strategy { + }), + Strategy::SimpleBarrier { + read_strategy, + tile_kind, + } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy { AsyncReadingStrategy::Cooperative => { layered::launch_ref::< R, @@ -271,17 +447,20 @@ pub fn launch_ref( >(client, lhs, rhs, out, &Default::default()) } AsyncReadingStrategy::Tma => { - layered::matmul_cmma_tma_ref_no_check::>( + layered::launch_ref_tma::>( client, lhs, rhs, out, - (false, false), &Default::default(), ) } - }, - Strategy::DoubleBuffering(loading_strategy, selection) => match loading_strategy { + }), + Strategy::DoubleBuffering { + read_strategy, + selection, + tile_kind, + } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy { SyncPartialReadingStrategy::Cyclic => { layered::launch_ref::>( client, lhs, rhs, out, selection, @@ -297,12 +476,17 @@ pub fn launch_ref( client, lhs, rhs, out, selection, ) } - }, - Strategy::OrderedDoubleBuffering(selection) => { - layered::launch_ref::>( - client, lhs, rhs, out, selection, - ) - } + }), + Strategy::OrderedDoubleBuffering { + selection, + tile_kind, + } => with_tile_kind!(tile_kind, Accelerated, || layered::launch_ref::< + R, + MP, + OrderedDoubleBufferingAlgorithm, + >( + client, lhs, rhs, out, selection, + )), Strategy::SimpleUnit(selection) => { layered::launch_ref::(client, lhs, rhs, out, selection) } @@ -310,11 +494,11 @@ pub fn launch_ref( layered::launch_ref::(client, lhs, rhs, out, selection) } Strategy::Naive => { - naive::launch_ref::, AccG>(client, lhs.data(), rhs.data(), out)?; + naive::launch_ref::, AccG>(client, lhs, rhs, out)?; Ok(()) } Strategy::Auto => { - if let Err(err) = layered::launch_ref::>( + if let Err(err) = layered::launch_ref::>>( client, lhs, rhs, diff --git a/crates/cubecl-matmul/src/components/batch/base.rs b/crates/cubecl-matmul/src/components/batch/base.rs index a137b44a1..789ca113f 100644 --- a/crates/cubecl-matmul/src/components/batch/base.rs +++ b/crates/cubecl-matmul/src/components/batch/base.rs @@ -3,11 +3,10 @@ use crate::components::{ MatmulProblem, MatmulSelection, MatmulSpec, OutputRuntimeArg, RhsG, TilingScheme, batch::{CubeCountInput, CubeCountInputArgs, HypercubeConfig}, error::MatmulSetupError, - global::{self, GlobalConfig as _}, + global::{self, GlobalConfig as _, args::MatmulArgs}, }; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use cubecl_std::{CubeOption, tensor::r#virtual::VirtualTensor}; use std::{fmt::Debug, hash::Hash}; /// A family of [matmuls](BatchMatmul) working with any [precision](MatmulPrecision). @@ -22,7 +21,7 @@ pub trait BatchMatmulFamily: 'static + Send + Sync { /// /// This function may return an error if the configuration cannot be supported on the current runtime. fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, @@ -34,7 +33,7 @@ pub trait BatchMatmulFamily: 'static + Send + Sync { /// /// Out-of-bounds can happen unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>( - client: &ComputeClient<::Server, ::Channel>, + client: &ComputeClient<::Server>, cube_dim: CubeDim, cube_count: CubeCount, input: InputRuntimeArg<'a, MS, R>, @@ -73,11 +72,8 @@ pub trait BatchMatmul: 'static + Send + Sync { type Config: BatchConfig; /// Performs batchwise matrix multiplication over tensors. - fn execute( - a: VirtualTensor>, - b: VirtualTensor>, - c: CubeOption>>, - out: VirtualTensor, ReadWrite>, + fn execute( + state: &mut Args::State, RhsG, AccG>, cube_count_args: CubeCountInput, #[comptime] config: Self::Config, ); diff --git a/crates/cubecl-matmul/src/components/batch/entry_point.rs b/crates/cubecl-matmul/src/components/batch/entry_point.rs index c24956f39..4b426a478 100644 --- a/crates/cubecl-matmul/src/components/batch/entry_point.rs +++ b/crates/cubecl-matmul/src/components/batch/entry_point.rs @@ -1,17 +1,17 @@ +use crate::components::batch::CubeCountInput; use crate::components::batch::base::BatchMatmul; -use crate::components::global::args::{TensorLhs, TensorRhs}; -use crate::components::{batch::CubeCountInput, global::args::TensorAcc}; use crate::components::{ batch::{BatchConfig, BatchMatmulFamily}, - global::args::{MatmulArgs, TensorOutput}, + global::args::MatmulArgs, }; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use cubecl_std::{CubeOption, CubeOptionExpand, tensor::r#virtual::VirtualTensor}; type Input = ::Input; type Output = ::Output; +type GlobalConf = <::Config as BatchConfig>::GlobalConfig; + #[cube(launch_unchecked)] /// Launches the matmul kernel pub(crate) fn matmul< @@ -36,32 +36,14 @@ pub(crate) fn matmul< } } - let mut state = Args::init_state(inputs, output); - - let lhs = TensorLhs::::new(&state); - let rhs = TensorRhs::::new(&state); - let mut out = TensorOutput::::new(&mut state); - - let has_acc = Args::has_acc(&state); - let acc: CubeOption> = match has_acc { - CubeOption::Some(_) => { - let acc = TensorAcc::::new(&state); - let acc = VirtualTensor::::new::>(&acc); - CubeOption::new_Some(acc) - } - CubeOption::None => CubeOption::new_None(), - }; - - let lhs = VirtualTensor::::new::>(&lhs); - let rhs = VirtualTensor::::new::>(&rhs); - let out = - VirtualTensor::::new::>(&mut out); + let mut state = Args::init_state::>( + inputs, + output, + config.global_config(), + ); - BMMF::Matmul::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::execute( - lhs, - rhs, - acc, - out, + BMMF::Matmul::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::execute::( + &mut state, cube_count_args, config, ); diff --git a/crates/cubecl-matmul/src/components/batch/layout.rs b/crates/cubecl-matmul/src/components/batch/layout.rs new file mode 100644 index 000000000..36372be8b --- /dev/null +++ b/crates/cubecl-matmul/src/components/batch/layout.rs @@ -0,0 +1,46 @@ +use cubecl::prelude::*; +use cubecl_core as cubecl; +use cubecl_std::tensor::layout::*; + +/// Slice the layout at a specific batch, and reduce its dimensionality +/// Not general enough to be in cubecl-std +#[derive(CubeType, Clone, Copy)] +pub struct SliceIndex { + offset: u32, + shape: Coords2d, +} + +#[cube] +impl SliceIndex { + pub fn new(offset: u32, shape: Coords3d) -> Self { + let (_, rows, cols) = shape; + SliceIndex { + offset, + shape: (rows, cols), + } + } +} + +#[cube] +impl Layout for SliceIndex { + type Coordinates = Coords2d; + type SourceCoordinates = Coords3d; + + fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates { + let (row, col) = pos; + (self.offset, row, col) + } + + fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool { + // we don't check batch + true.runtime() + } + + fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) { + (self.to_source_pos(pos), self.is_in_bounds(pos)) + } + + fn shape(&self) -> Self::Coordinates { + self.shape + } +} diff --git a/crates/cubecl-matmul/src/components/batch/mod.rs b/crates/cubecl-matmul/src/components/batch/mod.rs index 3843ba129..d4a796da2 100644 --- a/crates/cubecl-matmul/src/components/batch/mod.rs +++ b/crates/cubecl-matmul/src/components/batch/mod.rs @@ -2,7 +2,9 @@ mod base; mod entry_point; +mod layout; mod partitioned_matmul; pub use base::*; +pub use layout::*; pub use partitioned_matmul::*; diff --git a/crates/cubecl-matmul/src/components/batch/partitioned_matmul/matmul.rs b/crates/cubecl-matmul/src/components/batch/partitioned_matmul/matmul.rs index 198f7ce9f..ea15df025 100644 --- a/crates/cubecl-matmul/src/components/batch/partitioned_matmul/matmul.rs +++ b/crates/cubecl-matmul/src/components/batch/partitioned_matmul/matmul.rs @@ -1,15 +1,17 @@ use std::marker::PhantomData; -use crate::components::batch::partitioned_matmul::partition::{ - GlobalPartitionMatmul, PartitionRangeDim, PartitionRanges, -}; use crate::components::batch::{BatchConfig as _, BatchMatmul, CubeCountInput}; use crate::components::global::{self, GlobalMatmul}; use crate::components::{AccG, batch::partitioned_matmul::config::PartitionedBatchConfig}; use crate::components::{LhsG, MatmulPrecision, RhsG}; +use crate::components::{ + batch::partitioned_matmul::partition::{ + GlobalPartitionMatmul, PartitionRangeDim, PartitionRanges, + }, + global::args::MatmulArgs, +}; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use cubecl_std::{CubeOption, tensor::r#virtual::VirtualTensor}; /// Executes matrix multiplication at the batch level, /// assigning each cube to handle multiple global matmuls. @@ -32,17 +34,12 @@ impl, GPMM: GlobalPartitionMatmul> Ba { type Config = PartitionedBatchConfig; - fn execute( - a: VirtualTensor>, - b: VirtualTensor>, - c: CubeOption>>, - out: VirtualTensor, ReadWrite>, + fn execute( + state: &mut Args::State, RhsG, AccG>, cube_count_args: CubeCountInput, #[comptime] config: Self::Config, ) { - let lhs_rank = a.rank(); - - let problem_k = a.shape(lhs_rank - 1); + let (_, _, problem_k) = Args::view_lhs(state).shape(); let k_range = (0, problem_k); let tiling_scheme = config.tiling_scheme(); @@ -70,6 +67,6 @@ impl, GPMM: GlobalPartitionMatmul> Ba let global_config = config.global_config(); let acc = GMM::init_accumulators(global_config); - GPMM::execute::(a, b, c, out, ranges, acc, k_range, global_config); + GPMM::execute::(state, ranges, acc, k_range, global_config); } } diff --git a/crates/cubecl-matmul/src/components/batch/partitioned_matmul/partition/matmul.rs b/crates/cubecl-matmul/src/components/batch/partitioned_matmul/partition/matmul.rs index 3f3cba480..1f9510974 100644 --- a/crates/cubecl-matmul/src/components/batch/partitioned_matmul/partition/matmul.rs +++ b/crates/cubecl-matmul/src/components/batch/partitioned_matmul/partition/matmul.rs @@ -3,9 +3,10 @@ use cubecl_core::prelude::*; use crate::components::{ AccG, LhsG, MatmulPrecision, RhsG, - global::{self, GlobalConfig}, + batch::SliceIndex, + global::{self, GlobalConfig, args::MatmulArgs}, }; -use cubecl_std::{CubeOption, tensor::r#virtual::VirtualTensor}; +use cubecl_std::{CubeOption, CubeOptionExpand}; #[derive(CubeType)] /// Area of a tensor a cube is responsible of performing matmul @@ -27,11 +28,8 @@ pub struct PartitionRangeDim { #[cube] /// Iterates on several global matmul across a global partition pub trait GlobalPartitionMatmul: 'static + Send + Sync { - fn execute>( - a: VirtualTensor>, - b: VirtualTensor>, - c: CubeOption>>, - out: VirtualTensor, ReadWrite>, + fn execute>( + state: &mut Args::State, RhsG, AccG>, partition_ranges: PartitionRanges, acc: GMM::Accumulators, k_range: (u32, u32), @@ -78,11 +76,8 @@ impl PartitionRangeDim { #[cube] impl GlobalPartitionMatmul for RowMajorGlobalPartitionMatmul { - fn execute>( - a: VirtualTensor>, - b: VirtualTensor>, - c: CubeOption>>, - out: VirtualTensor, ReadWrite>, + fn execute>( + state: &mut Args::State, RhsG, AccG>, ranges: PartitionRanges, mut acc: GMM::Accumulators, k_range: (u32, u32), @@ -105,8 +100,8 @@ impl GlobalPartitionMatmul for RowMajorGlobalPartitionMatmul { for col in 0..num_steps_col { let col_offset = ranges.col.start + col * ranges.col.step; - execute_global_matmul::( - a, b, c, out, row_offset, col_offset, batch_iter, &mut acc, k_range, config, + execute_global_matmul::( + state, batch_iter, row_offset, col_offset, &mut acc, k_range, config, ); } } @@ -116,11 +111,8 @@ impl GlobalPartitionMatmul for RowMajorGlobalPartitionMatmul { #[cube] impl GlobalPartitionMatmul for ColMajorGlobalPartitionMatmul { - fn execute>( - a: VirtualTensor>, - b: VirtualTensor>, - c: CubeOption>>, - out: VirtualTensor, ReadWrite>, + fn execute>( + state: &mut Args::State, RhsG, AccG>, ranges: PartitionRanges, mut acc: GMM::Accumulators, k_range: (u32, u32), @@ -143,8 +135,8 @@ impl GlobalPartitionMatmul for ColMajorGlobalPartitionMatmul { for row in 0..num_steps_row { let row_offset = ranges.row.start + row * ranges.row.step; - execute_global_matmul::( - a, b, c, out, row_offset, col_offset, batch_iter, &mut acc, k_range, config, + execute_global_matmul::( + state, batch_iter, row_offset, col_offset, &mut acc, k_range, config, ); } } @@ -155,65 +147,56 @@ impl GlobalPartitionMatmul for ColMajorGlobalPartitionMatmul { #[cube] /// Execute global matmul on lhs, rhs, writing in out. /// m and n offsets are absolute rows and columns -pub(crate) fn execute_global_matmul>( - a: VirtualTensor>, - b: VirtualTensor>, - c: CubeOption>>, - out: VirtualTensor, ReadWrite>, +pub(crate) fn execute_global_matmul< + Args: MatmulArgs, + MP: MatmulPrecision, + GMM: global::GlobalMatmul, +>( + state: &mut Args::State, RhsG, AccG>, + nth_batch: u32, m_offset: u32, n_offset: u32, - nth_batch: u32, acc: &mut GMM::Accumulators, k_range: (u32, u32), #[comptime] config: GMM::Config, ) { - let rank = out.rank(); - - let batch_out = nth_batch * out.stride(rank - 2) * out.shape(rank - 2); - let mut batch_a = 0u32.runtime(); - let mut batch_b = 0u32.runtime(); - for axis in 0..rank - 2 { - let tmp = batch_out / out.stride(axis); - batch_a += tmp % a.shape(axis) * a.stride(axis); - batch_b += tmp % b.shape(axis) * b.stride(axis); - } - let tiling = config.tiling_scheme(); let stage_m = tiling.elements_in_stage_m().runtime(); let stage_n = tiling.elements_in_stage_n().runtime(); let k_size = k_range.1 - k_range.0; + let a = Args::view_lhs(state); + let b = Args::view_rhs(state); + let c = Args::view_acc(state); + let out = Args::view_out(state); + + let a_batch = Args::batch_lhs(state, nth_batch); + let a = a.view(SliceIndex::new(a_batch, a.shape())); + let b_batch = Args::batch_rhs(state, nth_batch); + let b = b.view(SliceIndex::new(b_batch, b.shape())); + let c_batch = Args::batch_acc(state, nth_batch); + let c = match c { + CubeOption::Some(c) => { + let c = c.view(SliceIndex::new(c_batch, c.shape())); + CubeOption::new_Some(c.slice_unchecked((m_offset, n_offset), (stage_m, stage_n))) + } + CubeOption::None => CubeOption::new_None(), + }; + let out_batch = Args::batch_out(state, nth_batch); + let out = out.view_mut(SliceIndex::new(out_batch, out.shape())); + GMM::execute( GMM::init_lhs_global_reader( - a, - batch_a, - (m_offset, k_range.0), - (stage_m, k_size), - nth_batch, + a.slice_unchecked((m_offset, k_range.0), (stage_m, k_size)), config, ), GMM::init_rhs_global_reader( - b, - batch_b, - (k_range.0, n_offset), - (k_size, stage_n), - nth_batch, - config, - ), - GMM::init_acc_global_reader( - c, - batch_out, - (m_offset, n_offset), - (stage_m, stage_n), - nth_batch, + b.slice_unchecked((k_range.0, n_offset), (k_size, stage_n)), config, ), + GMM::init_acc_global_reader(c, config), GMM::init_global_writer( - out, - batch_out, - (m_offset, n_offset), - (stage_m, stage_n), - nth_batch, + out.slice_mut_unchecked((m_offset, n_offset), (stage_m, stage_n)), config, ), acc, diff --git a/crates/cubecl-matmul/src/components/batch/partitioned_matmul/setup.rs b/crates/cubecl-matmul/src/components/batch/partitioned_matmul/setup.rs index 014a3241d..7fc830337 100644 --- a/crates/cubecl-matmul/src/components/batch/partitioned_matmul/setup.rs +++ b/crates/cubecl-matmul/src/components/batch/partitioned_matmul/setup.rs @@ -26,7 +26,7 @@ impl BatchMatmulFamily type Config = PartitionedBatchConfig; fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, @@ -43,7 +43,7 @@ impl BatchMatmulFamily } unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>( - client: &ComputeClient<::Server, ::Channel>, + client: &ComputeClient<::Server>, cube_dim: CubeDim, cube_count: CubeCount, input: InputRuntimeArg<'a, MS, R>, diff --git a/crates/cubecl-matmul/src/components/global/args.rs b/crates/cubecl-matmul/src/components/global/args.rs index e2186eae7..6c16ba57f 100644 --- a/crates/cubecl-matmul/src/components/global/args.rs +++ b/crates/cubecl-matmul/src/components/global/args.rs @@ -1,26 +1,43 @@ use std::any::TypeId; use cubecl::prelude::*; -use cubecl_core::{self as cubecl, intrinsic, server::TensorMapMeta}; +use cubecl_core::{self as cubecl, server::TensorMapMeta, unexpanded}; use cubecl_std::{ CubeOption, CubeOptionArgs, CubeOptionExpand, - tensor::r#virtual::{VirtualTensorOperations, VirtualTensorOperationsExpand}, + tensor::{ + View, + launch::ViewArg, + layout::{Coords1d, Coords3d, VirtualLayout, VirtualLayoutLaunch}, + }, }; use crate::{ MatmulInputHandleRef, - components::{self, MatmulLineSizes, MatmulProblem, MatmulSelection}, + components::{ + self, MatmulIdent, MatmulLineSizes, MatmulProblem, MatmulSelection, + batch::BatchConfig, + global::{ + GlobalConfig, + memory::{ + BatchLayout, BatchLayoutLaunch, GlobalLayout, GlobalLayoutLaunch, + GlobalScaleLayout, NoopLayout, NoopLayoutLaunch, SimpleTmaGlobalLayout, + SimpleTmaGlobalLayoutLaunch, + }, + }, + }, }; /// Create the input runtime arguments for a matmul kernel that works on concrete inputs and /// output (not fused). pub trait ConcreteInputsFactory: LaunchArg { fn create<'a, R: Runtime>( + client: &ComputeClient, lhs: &'a MatmulInputHandleRef<'a, R>, rhs: &'a MatmulInputHandleRef<'a, R>, selection: &MatmulSelection, problem: &MatmulProblem, line_sizes: &MatmulLineSizes, + config: impl BatchConfig, ) -> Self::RuntimeArg<'a, R>; } @@ -28,10 +45,12 @@ pub trait ConcreteInputsFactory: LaunchArg { /// output (not fused). pub trait ConcreteOutputFactory: LaunchArg { fn create<'a, R: Runtime>( + client: &ComputeClient, out: &'a TensorHandleRef<'a, R>, selection: &MatmulSelection, problem: &MatmulProblem, line_sizes: &MatmulLineSizes, + config: impl BatchConfig, ) -> Self::RuntimeArg<'a, R>; } @@ -47,1058 +66,246 @@ pub trait MatmulArgs: Send + Sync + 'static + Clone { type State: CubeType; /// Init the state. - fn init_state( + fn init_state( input: &Self::Input, output: &mut Self::Output, + #[comptime] config: G, ) -> Self::State; - /// Whether the accumulator argument is present. Returns `CubeOption` to allow matching at - /// comptime - fn has_acc( - state: &Self::State, - ) -> CubeOption<()>; - - /// Read the line of the lhs tensor using the state at the given coordinate. - fn read_lhs( - state: &Self::State, - coordinate: u32, - ) -> Line; - /// Read the line of the rhs tensor using the state at the given coordinate. - fn read_rhs( - state: &Self::State, - coordinate: u32, - ) -> Line; - /// Read the line of the acc tensor using the state at the given coordinate. - fn read_acc( - state: &Self::State, - coordinate: u32, - ) -> Line; - - /// Read the line of the lhs tensor using the state at the given coordinate. - fn read_window_lhs( - state: &Self::State, - start: u32, - end: u32, - ) -> Slice>; - - /// Read the line of the rhs tensor using the state at the given coordinate. - fn read_window_rhs( - state: &Self::State, - start: u32, - end: u32, - ) -> Slice>; - - /// Read the line of the acc tensor using the state at the given coordinate. - fn read_window_acc( - state: &Self::State, - start: u32, - end: u32, - ) -> Slice>; - - /// Reinterpret lhs as tensor map - fn as_tensor_map_lhs( - state: &Self::State, - ) -> CubeOption>; - - /// Reinterpret rhs as tensor map - fn as_tensor_map_rhs( - state: &Self::State, - ) -> CubeOption>; - - /// Reinterpret rhs as tensor map - fn as_tensor_map_acc( - state: &Self::State, - ) -> CubeOption>; - - /// Write the line to the output at the given coordinate using the state. - fn write_out( - state: &mut Self::State, - coordinate: u32, - value: Line, - ); - - /// Get the rank of the lhs tensor using the state. - fn rank_lhs(state: &Self::State) -> u32; - /// Get the rank of the rhs tensor using the state. - fn rank_rhs(state: &Self::State) -> u32; - /// Get the rank of the acc tensor using the state. - fn rank_acc(state: &Self::State) -> u32; - /// Get the rank of the out tensor using the state. - fn rank_out(state: &Self::State) -> u32; - - /// Get the length of the lhs tensor using the state. - fn len_lhs(state: &Self::State) -> u32; - /// Get the length of the rhs tensor using the state. - fn len_rhs(state: &Self::State) -> u32; - /// Get the length of the acc tensor using the state. - fn len_acc(state: &Self::State) -> u32; - /// Get the length of the out tensor using the state. - fn len_out(state: &Self::State) -> u32; - - /// Get the buffer length of the lhs tensor using the state. - fn buffer_len_lhs( - state: &Self::State, - ) -> u32; - /// Get the buffer length of the rhs tensor using the state. - fn buffer_len_rhs( - state: &Self::State, - ) -> u32; - /// Get the buffer length of the acc tensor using the state. - fn buffer_len_acc( - state: &Self::State, - ) -> u32; - /// Get the buffer length of the out tensor using the state. - fn buffer_len_out( - state: &Self::State, - ) -> u32; - - /// Get the shape of the lhs tensor using the state. - fn shape_lhs( - state: &Self::State, - axis: u32, - ) -> u32; - /// Get the shape of the rhs tensor using the state. - fn shape_rhs( - state: &Self::State, - axis: u32, - ) -> u32; - /// Get the shape of the acc tensor using the state. - fn shape_acc( - state: &Self::State, - axis: u32, - ) -> u32; - /// Get the shape of the out tensor using the state. - fn shape_out( - state: &Self::State, - axis: u32, - ) -> u32; - - /// Get the stride of the lhs tensor using the state. - fn stride_lhs( - state: &Self::State, - axis: u32, - ) -> u32; - /// Get the stride of the rhs tensor using the state. - fn stride_rhs( - state: &Self::State, - axis: u32, - ) -> u32; - /// Get the stride of the acc tensor using the state. - fn stride_acc( - state: &Self::State, - axis: u32, - ) -> u32; - /// Get the stride of the out tensor using the state. - fn stride_out( - state: &Self::State, - axis: u32, - ) -> u32; - - /// Get the line size of the lhs tensor using the state. - fn line_size_lhs( - state: &Self::State, - ) -> comptime_type!(u32); - /// Get the line size of the rhs tensor using the state. - fn line_size_rhs( - state: &Self::State, - ) -> comptime_type!(u32); - /// Get the line size of the acc tensor using the state. - fn line_size_acc( - state: &Self::State, - ) -> comptime_type!(u32); - /// Get the line size of the out tensor using the state. - fn line_size_out( - state: &Self::State, - ) -> comptime_type!(u32); -} - -#[derive(Clone, Copy)] -/// Identification of the [tensor input](TensorInput). -pub enum TensorInputIdent { - Lhs, - Rhs, -} - -/// Tensor input representation. -/// -/// You can use the tensor input as if it was a pointer to the actually tensor. -pub struct TensorLhs { - state: *const GA::State, -} - -/// Tensor input representation. -/// -/// You can use the tensor input as if it was a pointer to the actually tensor. -pub struct TensorRhs { - state: *const GA::State, -} - -/// Tensor input representation. -/// -/// You can use the tensor input as if it was a pointer to the actually tensor. -pub struct TensorAcc { - state: *const GA::State, -} - -impl VirtualTensorOperations - for TensorLhs -{ -} - -impl VirtualTensorOperations - for TensorRhs -{ -} - -impl VirtualTensorOperations - for TensorAcc -{ -} - -impl VirtualTensorOperations - for TensorOutput -{ -} - -impl VirtualTensorOperationsExpand - for TensorOutputExpand -{ - fn __expand_read_method( - &self, - _scope: &mut Scope, - _index: ExpandElementTyped, - ) -> ExpandElementTyped> { - panic!("Can't read output tensor"); - } - - fn __expand_read_window_method( - &self, - _context: &mut Scope, - _start: ExpandElementTyped, - _end: ExpandElementTyped, - ) -> SliceExpand, ReadOnly> { - panic!("Can't read output tensor"); - } - - fn __expand_write_method( - &self, - scope: &mut Scope, - index: ExpandElementTyped, - value: ExpandElementTyped>, - ) { - TensorOutputExpand::__expand_write_method(self.clone(), scope, index, value) - } - - fn __expand_shape_method( - &self, - scope: &mut Scope, - axis: ExpandElementTyped, - ) -> ExpandElementTyped { - TensorOutputExpand::__expand_shape_method(self.clone(), scope, axis) - } - - fn __expand_stride_method( - &self, - scope: &mut Scope, - axis: ExpandElementTyped, - ) -> ExpandElementTyped { - TensorOutputExpand::__expand_stride_method(self.clone(), scope, axis) - } - - fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped { - TensorOutputExpand::__expand_rank_method(self.clone(), scope) - } - - fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped { - TensorOutputExpand::__expand_len_method(self.clone(), scope) - } - - fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped { - TensorOutputExpand::__expand_buffer_len_method(self.clone(), scope) - } - - fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand> { - CubeOption::__expand_new_None(scope) - } -} - -impl Lined - for TensorOutput -{ -} -impl LinedExpand - for TensorOutputExpand -{ - fn line_size(&self) -> u32 { - let mut scope = Scope::root(false); - TensorOutputExpand::__expand_line_size_method(self.clone(), &mut scope) - } -} - -impl VirtualTensorOperationsExpand - for TensorLhsExpand -{ - fn __expand_read_method( - &self, - scope: &mut Scope, - index: ExpandElementTyped, - ) -> ExpandElementTyped> { - TensorLhsExpand::__expand_read_method(self.clone(), scope, index) - } - fn __expand_read_window_method( - &self, - context: &mut Scope, - start: ExpandElementTyped, - end: ExpandElementTyped, - ) -> SliceExpand, ReadOnly> { - TensorLhsExpand::__expand_read_window_method(self.clone(), context, start, end) - } - - fn __expand_write_method( - &self, - _scope: &mut Scope, - _index: ExpandElementTyped, - _value: ExpandElementTyped>, - ) { - panic!("Can't write to input tensor"); - } - - fn __expand_shape_method( - &self, - scope: &mut Scope, - axis: ExpandElementTyped, - ) -> ExpandElementTyped { - TensorLhsExpand::__expand_shape_method(self.clone(), scope, axis) - } - - fn __expand_stride_method( - &self, - scope: &mut Scope, - axis: ExpandElementTyped, - ) -> ExpandElementTyped { - TensorLhsExpand::__expand_stride_method(self.clone(), scope, axis) - } - - fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped { - TensorLhsExpand::__expand_rank_method(self.clone(), scope) - } - - fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped { - TensorLhsExpand::__expand_len_method(self.clone(), scope) - } - - fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped { - TensorLhsExpand::__expand_buffer_len_method(self.clone(), scope) - } - - fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand> { - TensorLhsExpand::__expand_as_tensor_map_method(self.clone(), scope) - } -} - -impl Lined - for TensorLhs -{ -} -impl LinedExpand - for TensorLhsExpand -{ - fn line_size(&self) -> u32 { - let mut scope = Scope::root(false); - TensorLhsExpand::__expand_line_size_method(self.clone(), &mut scope) - } -} - -impl VirtualTensorOperationsExpand - for TensorRhsExpand -{ - fn __expand_read_method( - &self, - scope: &mut Scope, - index: ExpandElementTyped, - ) -> ExpandElementTyped> { - TensorRhsExpand::__expand_read_method(self.clone(), scope, index) - } - fn __expand_read_window_method( - &self, - context: &mut Scope, - start: ExpandElementTyped, - end: ExpandElementTyped, - ) -> SliceExpand, ReadOnly> { - TensorRhsExpand::__expand_read_window_method(self.clone(), context, start, end) - } - - fn __expand_write_method( - &self, - _scope: &mut Scope, - _index: ExpandElementTyped, - _value: ExpandElementTyped>, - ) { - panic!("Can't write to input tensor"); - } - - fn __expand_shape_method( - &self, - scope: &mut Scope, - axis: ExpandElementTyped, - ) -> ExpandElementTyped { - TensorRhsExpand::__expand_shape_method(self.clone(), scope, axis) - } - - fn __expand_stride_method( - &self, - scope: &mut Scope, - axis: ExpandElementTyped, - ) -> ExpandElementTyped { - TensorRhsExpand::__expand_stride_method(self.clone(), scope, axis) - } - - fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped { - TensorRhsExpand::__expand_rank_method(self.clone(), scope) - } - - fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped { - TensorRhsExpand::__expand_len_method(self.clone(), scope) - } - - fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped { - TensorRhsExpand::__expand_buffer_len_method(self.clone(), scope) - } - - fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand> { - TensorRhsExpand::__expand_as_tensor_map_method(self.clone(), scope) - } -} - -impl Lined - for TensorRhs -{ -} -impl LinedExpand - for TensorRhsExpand -{ - fn line_size(&self) -> u32 { - let mut scope = Scope::root(false); - TensorRhsExpand::__expand_line_size_method(self.clone(), &mut scope) - } -} - -impl VirtualTensorOperationsExpand - for TensorAccExpand -{ - fn __expand_read_method( - &self, - scope: &mut Scope, - index: ExpandElementTyped, - ) -> ExpandElementTyped> { - TensorAccExpand::__expand_read_method(self.clone(), scope, index) - } - fn __expand_read_window_method( - &self, - context: &mut Scope, - start: ExpandElementTyped, - end: ExpandElementTyped, - ) -> SliceExpand, ReadOnly> { - TensorAccExpand::__expand_read_window_method(self.clone(), context, start, end) - } - - fn __expand_write_method( - &self, - _scope: &mut Scope, - _index: ExpandElementTyped, - _value: ExpandElementTyped>, - ) { - panic!("Can't write to input tensor"); - } - - fn __expand_shape_method( - &self, - scope: &mut Scope, - axis: ExpandElementTyped, - ) -> ExpandElementTyped { - TensorAccExpand::__expand_shape_method(self.clone(), scope, axis) - } - - fn __expand_stride_method( - &self, - scope: &mut Scope, - axis: ExpandElementTyped, - ) -> ExpandElementTyped { - TensorAccExpand::__expand_stride_method(self.clone(), scope, axis) - } - - fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped { - TensorAccExpand::__expand_rank_method(self.clone(), scope) - } - - fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped { - TensorAccExpand::__expand_len_method(self.clone(), scope) - } - - fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped { - TensorAccExpand::__expand_buffer_len_method(self.clone(), scope) - } - - fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand> { - TensorAccExpand::__expand_as_tensor_map_method(self.clone(), scope) - } -} - -impl Lined - for TensorAcc -{ -} -impl LinedExpand - for TensorAccExpand -{ - fn line_size(&self) -> u32 { - let mut scope = Scope::root(false); - TensorAccExpand::__expand_line_size_method(self.clone(), &mut scope) - } -} - -/// Tensor output representation. -/// -/// You can use the tensor output as if it was a pointer to the actually tensor. -/// -/// # Warning -/// -/// There is no mutability guarantee. -pub struct TensorOutput { - state: *mut GA::State, -} - -/// Expand type for [tensor lhs](TensorLhs). -pub struct TensorLhsExpand { - state: as CubeType>::ExpandType, -} - -/// Expand type for [tensor rhs](TensorRhs). -pub struct TensorRhsExpand { - state: as CubeType>::ExpandType, -} - -/// Expand type for [tensor rhs](TensorRhs). -pub struct TensorAccExpand { - state: as CubeType>::ExpandType, -} - -/// Expand type for [tensor output](TensorOutput). -pub struct TensorOutputExpand { - state: as CubeType>::ExpandType, -} - -#[cube] -impl TensorLhs { - /// Create a [tensor input](TensorInput) from the state and the [ident](TensorInputIdent). - pub fn new(state: &MA::State) -> TensorLhs { - TensorLhs:: { state } - } - - //// Read the tensor at the given coordinate. - pub fn read_window(&self, start: u32, end: u32) -> Slice> { - unsafe { MA::read_window_lhs(&(*self.state), start, end) } - } - - /// Read the tensor at the given coordinate. - pub fn read(&self, coordinate: u32) -> Line { - unsafe { MA::read_lhs(&(*self.state), coordinate) } - } - - /// Get the shape of the tensor at the given axis. - pub fn shape(&self, axis: u32) -> u32 { - unsafe { MA::shape_lhs(&(*self.state), axis) } - } - - /// Get the stride of the tensor at the given axis. - pub fn stride(&self, axis: u32) -> u32 { - unsafe { MA::stride_lhs(&(*self.state), axis) } - } - - /// Get the rank of the tensor. - pub fn rank(&self) -> u32 { - unsafe { MA::rank_lhs(&(*self.state)) } - } - - /// Get the length of the tensor. - #[allow(clippy::len_without_is_empty)] - pub fn len(&self) -> u32 { - unsafe { MA::len_lhs(&(*self.state)) } - } - - /// Get the buffer length of the tensor. - pub fn buffer_len(&self) -> u32 { - unsafe { MA::buffer_len_lhs(&(*self.state)) } - } - - /// Get the buffer length of the tensor. - pub fn as_tensor_map(&self) -> CubeOption> { - unsafe { MA::as_tensor_map_lhs(&(*self.state)) } - } - - /// Get the line size of the tensor. - pub fn line_size(&self) -> comptime_type!(u32) { - unsafe { MA::line_size_lhs(&(*self.state)) } - } -} - -#[cube] -impl TensorRhs { - /// Create a [tensor input](TensorInput) from the state and the [ident](TensorInputIdent). - pub fn new(state: &MA::State) -> TensorRhs { - TensorRhs:: { state } - } - - //// Read the tensor at the given coordinate. - pub fn read_window(&self, start: u32, end: u32) -> Slice> { - unsafe { MA::read_window_rhs(&(*self.state), start, end) } - } - - /// Read the tensor at the given coordinate. - pub fn read(&self, coordinate: u32) -> Line { - unsafe { MA::read_rhs(&(*self.state), coordinate) } - } - - /// Get the shape of the tensor at the given axis. - pub fn shape(&self, axis: u32) -> u32 { - unsafe { MA::shape_rhs(&(*self.state), axis) } - } - - /// Get the stride of the tensor at the given axis. - pub fn stride(&self, axis: u32) -> u32 { - unsafe { MA::stride_rhs(&(*self.state), axis) } - } - - /// Get the rank of the tensor. - pub fn rank(&self) -> u32 { - unsafe { MA::rank_rhs(&(*self.state)) } - } - - /// Get the length of the tensor. - #[allow(clippy::len_without_is_empty)] - pub fn len(&self) -> u32 { - unsafe { MA::len_rhs(&(*self.state)) } - } - - /// Get the buffer length of the tensor. - pub fn buffer_len(&self) -> u32 { - unsafe { MA::buffer_len_rhs(&(*self.state)) } - } - - /// Get the buffer length of the tensor. - pub fn as_tensor_map(&self) -> CubeOption> { - unsafe { MA::as_tensor_map_rhs(&(*self.state)) } - } - - /// Get the line size of the tensor. - pub fn line_size(&self) -> comptime_type!(u32) { - unsafe { MA::line_size_rhs(&(*self.state)) } - } -} - -#[cube] -impl TensorAcc { - /// Create a [tensor input](TensorInput) from the state and the [ident](TensorInputIdent). - pub fn new(state: &MA::State) -> TensorAcc { - TensorAcc:: { state } - } - - //// Read the tensor at the given coordinate. - pub fn read_window(&self, start: u32, end: u32) -> Slice> { - unsafe { MA::read_window_acc(&(*self.state), start, end) } - } - - /// Read the tensor at the given coordinate. - pub fn read(&self, coordinate: u32) -> Line { - unsafe { MA::read_acc(&(*self.state), coordinate) } - } - - /// Get the shape of the tensor at the given axis. - pub fn shape(&self, axis: u32) -> u32 { - unsafe { MA::shape_acc(&(*self.state), axis) } - } - - /// Get the stride of the tensor at the given axis. - pub fn stride(&self, axis: u32) -> u32 { - unsafe { MA::stride_acc(&(*self.state), axis) } - } - - /// Get the rank of the tensor. - pub fn rank(&self) -> u32 { - unsafe { MA::rank_acc(&(*self.state)) } - } - - /// Get the length of the tensor. - #[allow(clippy::len_without_is_empty)] - pub fn len(&self) -> u32 { - unsafe { MA::len_acc(&(*self.state)) } - } - - /// Get the buffer length of the tensor. - pub fn buffer_len(&self) -> u32 { - unsafe { MA::buffer_len_acc(&(*self.state)) } - } - - /// Get the buffer length of the tensor. - pub fn as_tensor_map(&self) -> CubeOption> { - unsafe { MA::as_tensor_map_acc(&(*self.state)) } - } - - /// Get the line size of the tensor. - pub fn line_size(&self) -> comptime_type!(u32) { - unsafe { MA::line_size_acc(&(*self.state)) } - } -} - -#[cube] -impl TensorOutput { - /// Create a [tensor output](TensorOutput) from the state. - pub fn new(state: &mut GA::State) -> TensorOutput { - TensorOutput:: { state } - } - - /// Write the value to tensor at the given coordinate. - pub fn write(&self, coordinate: u32, value: Line) { - unsafe { GA::write_out(&mut (*self.state), coordinate, value) } - } - - /// Get the shape of the tensor at the given axis. - pub fn shape(&self, axis: u32) -> u32 { - unsafe { GA::shape_out(&(*self.state), axis) } - } - - /// Get the stride of the tensor at the given axis. - pub fn stride(&self, dim: u32) -> u32 { - unsafe { GA::stride_out(&(*self.state), dim) } - } - - /// Get the rank of the tensor. - pub fn rank(&self) -> u32 { - unsafe { GA::rank_out(&(*self.state)) } - } - - /// Get the length of the tensor. - #[allow(clippy::len_without_is_empty)] - pub fn len(&self) -> u32 { - unsafe { GA::len_out(&(*self.state)) } - } - - /// Get the buffer length of the tensor. - pub fn buffer_len(&self) -> u32 { - unsafe { GA::buffer_len_out(&(*self.state)) } - } - - /// Get the buffer length of the tensor. - pub fn line_size(&self) -> comptime_type!(u32) { - unsafe { GA::line_size_out(&(*self.state)) } - } -} - -#[derive(Clone)] -/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensors. -/// -/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels. -pub struct TensorArgs; - -#[derive(CubeLaunch, CubeType)] -/// Input representation for [TensorArgs] implementing [MatmulArgs]. -pub struct TensorInputs { - /// The lhs tensor. - pub lhs: Tensor>, - pub lhs_scale: CubeOption>, - /// The rhs tensor. - pub rhs: Tensor>, - pub rhs_scale: CubeOption>, - /// The tensor for loading the accumulator, if present - pub acc: CubeOption>>, -} - -impl ConcreteInputsFactory - for TensorInputs -{ - fn create<'a, R: Runtime>( - lhs: &'a MatmulInputHandleRef<'a, R>, - rhs: &'a MatmulInputHandleRef<'a, R>, - _selection: &MatmulSelection, - _problem: &MatmulProblem, - line_sizes: &MatmulLineSizes, - ) -> Self::RuntimeArg<'a, R> { - TensorInputsLaunch::new( - lhs.data().as_tensor_arg(line_sizes.lhs), - lhs.scale().map(|it| it.as_tensor_arg(1)).into(), - rhs.data().as_tensor_arg(line_sizes.rhs), - rhs.scale().map(|it| it.as_tensor_arg(1)).into(), - CubeOptionArgs::None, - ) - } -} - -impl ConcreteOutputFactory for Tensor> { - fn create<'a, R: Runtime>( - out: &'a TensorHandleRef<'a, R>, - _selection: &MatmulSelection, - _problem: &MatmulProblem, - line_sizes: &MatmulLineSizes, - ) -> Self::RuntimeArg<'a, R> { - out.as_tensor_arg(line_sizes.out) - } -} - -#[cube] -impl MatmulArgs for TensorArgs { - type Output = Tensor>; - type Input = TensorInputs; - type State = ( - *const Tensor>, - *const Tensor>, - CubeOption<*const Tensor>>, - *mut Tensor>, - CubeOption, - CubeOption, - ); - - fn init_state( - input: &Self::Input, - output: &mut Self::Output, - ) -> Self::State { - let lhs_scale = match &input.lhs_scale { - CubeOption::Some(scale) => CubeOption::new_Some(scale[0]), - CubeOption::None => CubeOption::new_None(), - }; - let rhs_scale = match &input.rhs_scale { - CubeOption::Some(scale) => CubeOption::new_Some(scale[0]), - CubeOption::None => CubeOption::new_None(), - }; - let acc = match &input.acc { - CubeOption::None => CubeOption::new_None(), - CubeOption::Some(acc) => { - let ptr: *const Tensor> = acc; - CubeOption::new_Some(ptr) - } - }; - (&input.lhs, &input.rhs, acc, output, lhs_scale, rhs_scale) - } - - fn has_acc( - state: &Self::State, - ) -> CubeOption<()> { - match state.2 { - CubeOption::None => CubeOption::new_None(), - CubeOption::Some(_) => CubeOption::new_Some(()), - } - } - - fn read_lhs( - state: &Self::State, - coordinate: u32, - ) -> Line { - unsafe { (*state.0)[coordinate] } - } - - fn read_rhs( - state: &Self::State, - coordinate: u32, - ) -> Line { - unsafe { (*state.1)[coordinate] } - } - - fn read_acc( - state: &Self::State, - coordinate: u32, - ) -> Line { - unsafe { (*state.2.unwrap())[coordinate] } - } - - fn read_window_lhs( - state: &Self::State, - start: u32, - end: u32, - ) -> Slice> { - unsafe { (*state.0).slice(start, end) } - } - - /// Read the line of the rhs tensor using the state at the given coordinate. - fn read_window_rhs( - state: &Self::State, - start: u32, - end: u32, - ) -> Slice> { - unsafe { (*state.1).slice(start, end) } - } - - fn read_window_acc( - state: &Self::State, - start: u32, - end: u32, - ) -> Slice> { - unsafe { (*state.2.unwrap()).slice(start, end) } - } - - fn as_tensor_map_lhs( - _state: &Self::State, - ) -> CubeOption> { - CubeOption::new_None() - } - - fn as_tensor_map_rhs( + fn view_lhs( _state: &Self::State, - ) -> CubeOption> { - CubeOption::new_None() + ) -> View, Coords3d> { + unexpanded!() } - - fn as_tensor_map_acc( + fn batch_lhs( _state: &Self::State, - ) -> CubeOption> { - CubeOption::new_None() - } - - fn shape_lhs( - state: &Self::State, - dim: u32, - ) -> u32 { - unsafe { (*state.0).shape(dim) } - } - - fn shape_rhs( - state: &Self::State, - dim: u32, + _batch: u32, ) -> u32 { - unsafe { (*state.1).shape(dim) } + unexpanded!() } - - fn shape_acc( - state: &Self::State, - dim: u32, - ) -> u32 { - unsafe { (*state.2.unwrap()).shape(dim) } + fn view_rhs( + _state: &Self::State, + ) -> View, Coords3d> { + unexpanded!() } - - fn shape_out( - state: &Self::State, - dim: u32, + fn batch_rhs( + _state: &Self::State, + _batch: u32, ) -> u32 { - unsafe { (*state.3).shape(dim) } + unexpanded!() } - - fn stride_lhs( - state: &Self::State, - dim: u32, - ) -> u32 { - unsafe { (*state.0).stride(dim) } + fn view_acc( + _state: &Self::State, + ) -> CubeOption, Coords3d>> { + unexpanded!() } - - fn stride_rhs( - state: &Self::State, - dim: u32, + fn batch_acc( + _state: &Self::State, + _batch: u32, ) -> u32 { - unsafe { (*state.1).stride(dim) } + unexpanded!() } - - fn stride_acc( - state: &Self::State, - dim: u32, - ) -> u32 { - unsafe { (*state.2.unwrap()).stride(dim) } + fn view_out( + _state: &mut Self::State, + ) -> View, Coords3d, ReadWrite> { + unexpanded!() } - - fn stride_out( - state: &Self::State, - dim: u32, + fn batch_out( + _state: &Self::State, + _batch: u32, ) -> u32 { - unsafe { (*state.3).stride(dim) } + unexpanded!() } +} - fn write_out( - state: &mut Self::State, - coordinate: u32, - value: Line, - ) { - unsafe { (*state.3)[coordinate] = value } - } +#[derive(Clone, Copy)] +/// Identification of the [tensor input](TensorInput). +pub enum TensorInputIdent { + Lhs, + Rhs, +} - fn rank_lhs(state: &Self::State) -> u32 { - unsafe { (*state.0).rank() } - } +#[derive(Clone)] +/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensors. +/// +/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels. +pub struct TensorArgs; - fn rank_rhs(state: &Self::State) -> u32 { - unsafe { (*state.1).rank() } - } +#[derive(CubeLaunch, CubeType, Clone, Copy)] +/// Input representation for [TensorArgs] implementing [MatmulArgs]. +pub struct TensorInputs { + /// The lhs tensor. + lhs: View, Coords3d>, + lhs_batch: VirtualLayout, + /// The rhs tensor. + rhs: View, Coords3d>, + rhs_batch: VirtualLayout, + /// The tensor for loading the accumulator, if present + acc: CubeOption, Coords3d>>, + acc_batch: CubeOption>, +} - fn rank_acc(state: &Self::State) -> u32 { - unsafe { (*state.2.unwrap()).rank() } - } +impl ConcreteInputsFactory + for TensorInputs +{ + fn create<'a, R: Runtime>( + client: &ComputeClient, + lhs: &'a MatmulInputHandleRef<'a, R>, + rhs: &'a MatmulInputHandleRef<'a, R>, + _selection: &MatmulSelection, + problem: &MatmulProblem, + line_sizes: &MatmulLineSizes, + config: impl BatchConfig, + ) -> Self::RuntimeArg<'a, R> { + let config = config.global_config(); + let view = |handle: &'a MatmulInputHandleRef<'a, R>, ident, line_size| match handle { + MatmulInputHandleRef::Normal(handle) => { + let layout = GlobalLayoutLaunch::from_handle( + handle, + line_size, + config.global_memory_config(ident).into(), + ); + ViewArg::new::(handle.as_array_arg(line_size), layout) + } + MatmulInputHandleRef::Quantized { + data, + scale, + shape, + scheme, + } => { + let (data_layout, scales_layout) = GlobalLayoutLaunch::from_quantized_handle( + client, + data, + scale, + shape, + problem, + **scheme, + line_size, + config.global_memory_config(ident).into(), + ); + let data_view = + ViewArg::new::(data.as_array_arg(line_size), data_layout); + let scales_view = + ViewArg::new::(scale.as_array_arg(1), scales_layout); + ViewArg::new_quantized(data_view, scales_view, **scheme) + } + }; + let batch_layout = |handle: &'a MatmulInputHandleRef<'a, R>| match handle { + MatmulInputHandleRef::Normal(handle) => { + let layout = BatchLayoutLaunch::from_handle(client, handle, problem); + VirtualLayoutLaunch::new::(layout) + } + MatmulInputHandleRef::Quantized { .. } => { + VirtualLayoutLaunch::new::(NoopLayoutLaunch::new()) + } + }; - fn rank_out(state: &Self::State) -> u32 { - unsafe { (*state.3).rank() } + TensorInputsLaunch::new( + view(lhs, MatmulIdent::Lhs, line_sizes.lhs), + batch_layout(lhs), + view(rhs, MatmulIdent::Rhs, line_sizes.rhs), + batch_layout(rhs), + CubeOptionArgs::None, + CubeOptionArgs::None, + ) } +} - fn len_lhs(state: &Self::State) -> u32 { - unsafe { (*state.0).len() } - } +#[derive(CubeType, CubeLaunch, Clone, Copy)] +pub struct TensorOutput { + view: View, Coords3d, ReadWrite>, + batch: VirtualLayout, +} - fn len_rhs(state: &Self::State) -> u32 { - unsafe { (*state.1).len() } +impl ConcreteOutputFactory for TensorOutput { + fn create<'a, R: Runtime>( + client: &ComputeClient, + out: &'a TensorHandleRef<'a, R>, + _selection: &MatmulSelection, + problem: &MatmulProblem, + line_sizes: &MatmulLineSizes, + config: impl BatchConfig, + ) -> Self::RuntimeArg<'a, R> { + let config = config.global_config(); + let layout = GlobalLayoutLaunch::from_handle( + out, + line_sizes.out, + config.global_memory_config(MatmulIdent::Out).into(), + ); + let batch = BatchLayoutLaunch::from_handle(client, out, problem); + let view = ViewArg::new::(out.as_array_arg(line_sizes.out), layout); + TensorOutputLaunch::new(view, VirtualLayoutLaunch::new::(batch)) } +} - fn len_acc(state: &Self::State) -> u32 { - unsafe { (*state.2.unwrap()).len() } - } +#[cube] +impl MatmulArgs for TensorArgs { + type Output = TensorOutput; + type Input = TensorInputs; + type State = + (TensorInputs, TensorOutput); - fn len_out(state: &Self::State) -> u32 { - unsafe { (*state.3).len() } + fn init_state( + input: &Self::Input, + output: &mut Self::Output, + #[comptime] _config: G, + ) -> Self::State { + (*input, *output) } - fn buffer_len_lhs( + fn view_lhs( state: &Self::State, - ) -> u32 { - unsafe { (*state.0).buffer_len() } + ) -> View, Coords3d> { + state.0.lhs } - fn buffer_len_rhs( + fn batch_lhs( state: &Self::State, + batch: u32, ) -> u32 { - unsafe { (*state.1).buffer_len() } + state.0.lhs_batch.to_source_pos(batch) } - fn buffer_len_acc( + fn view_rhs( state: &Self::State, - ) -> u32 { - unsafe { (*state.2.unwrap()).buffer_len() } + ) -> View, Coords3d> { + state.0.rhs } - fn buffer_len_out( + fn batch_rhs( state: &Self::State, + batch: u32, ) -> u32 { - unsafe { (*state.3).buffer_len() } + state.0.rhs_batch.to_source_pos(batch) } - fn line_size_lhs( + fn view_acc( state: &Self::State, - ) -> comptime_type!(u32) { - unsafe { (*state.0).line_size() } + ) -> CubeOption, Coords3d>> { + state.0.acc } - fn line_size_rhs( + + fn batch_acc( state: &Self::State, - ) -> comptime_type!(u32) { - unsafe { (*state.1).line_size() } + batch: u32, + ) -> u32 { + match state.0.acc_batch { + CubeOption::Some(layout) => layout.to_source_pos(batch), + CubeOption::None => batch, + } } - #[allow(unused_variables)] - fn line_size_acc( - state: &Self::State, - ) -> comptime_type!(u32) { - intrinsic!(|scope| { - match state.2 { - CubeOptionExpand::None => 1, - CubeOptionExpand::Some(t) => t.__expand_line_size_method(scope), - } - }) + fn view_out( + state: &mut Self::State, + ) -> View, Coords3d, ReadWrite> { + state.1.view } - fn line_size_out( + fn batch_out( state: &Self::State, - ) -> comptime_type!(u32) { - unsafe { (*state.3).line_size() } + batch: u32, + ) -> u32 { + state.1.batch.to_source_pos(batch) } } @@ -1108,26 +315,30 @@ impl MatmulArgs for TensorArgs { /// Other types might implement [MatmulArgs] for fused matrix multiplication kernels. pub struct TensorMapArgs; -#[derive(CubeLaunch, CubeType)] +#[derive(CubeLaunch, CubeType, Clone, Copy)] /// Input representation for [TensorArgs] implementing [MatmulArgs]. pub struct TensorMapInputs { /// The lhs tensor. - pub lhs: TensorMap, + pub lhs: View, Coords3d>, /// The rhs tensor. - pub rhs: TensorMap, + pub rhs: View, Coords3d>, /// The accumulator - pub acc: CubeOption>>, + pub acc: CubeOption, Coords3d>>, + /// The accumulator batch layout + pub acc_batch: CubeOption>, } impl ConcreteInputsFactory for TensorMapInputs { fn create<'a, R: Runtime>( + _client: &ComputeClient, lhs_handle: &'a MatmulInputHandleRef<'a, R>, rhs_handle: &'a MatmulInputHandleRef<'a, R>, selection: &MatmulSelection, problem: &MatmulProblem, line_sizes: &MatmulLineSizes, + _config: impl BatchConfig, ) -> Self::RuntimeArg<'a, R> { let lhs = lhs_handle.data(); let rhs = rhs_handle.data(); @@ -1158,37 +369,55 @@ impl ConcreteInputsFactory let lhs_rank = lhs.shape.len(); let mut lhs_shape = vec![ - problem.lhs_batches[0], + problem.lhs_batches.iter().product(), lhs.shape[lhs_rank - 2], lhs.shape[lhs_rank - 1], ]; let mut lhs_strides = if lhs_rank > 2 { lhs.strides[lhs_rank - 3..].to_vec() } else { - vec![1, lhs.strides[lhs_rank - 2], lhs.strides[lhs_rank - 1]] + vec![lhs.strides[0], lhs.strides[1]] }; let rhs_rank = rhs.shape.len(); let mut rhs_shape = vec![ - problem.rhs_batches[0], + problem.rhs_batches.iter().product(), rhs.shape[rhs_rank - 2], rhs.shape[rhs_rank - 1], ]; let mut rhs_strides = if rhs_rank > 2 { rhs.strides[rhs_rank - 3..].to_vec() } else { - vec![1, rhs.strides[rhs_rank - 2], rhs.strides[rhs_rank - 1]] + vec![rhs.strides[0], rhs.strides[1]] }; + let mut lhs_transposed = false; + let mut rhs_transposed = false; + + let lhs_rank = lhs_strides.len(); + let rhs_rank = rhs_strides.len(); + // TMA assumes the last stride is contiguous and won't even take it, so we need to map it // with transposed shape and stride. Tensor metadata still has the normal layout. if matches!(problem.lhs_layout, components::MatrixLayout::ColMajor) { - lhs_shape.swap(lhs_rank - 1, lhs_rank - 2); + lhs_shape.swap(2, 1); lhs_strides.swap(lhs_rank - 1, lhs_rank - 2); + lhs_transposed = true; } if matches!(problem.rhs_layout, components::MatrixLayout::ColMajor) { - rhs_shape.swap(rhs_rank - 1, rhs_rank - 2); + rhs_shape.swap(2, 1); rhs_strides.swap(rhs_rank - 1, rhs_rank - 2); + rhs_transposed = true; + } + + // Insert batch stride after swap so we can easily get the non-contiguous stride + if lhs_rank == 2 { + let stride = lhs_strides[0]; + lhs_strides.insert(0, stride); + } + if rhs_rank == 2 { + let stride = rhs_strides[0]; + rhs_strides.insert(0, stride); } fn prefetch(bytes: usize) -> TensorMapPrefetch { @@ -1221,7 +450,7 @@ impl ConcreteInputsFactory tile_size: stage_size_lhs, }, rank: 3, - shape: lhs_shape, + shape: lhs_shape.clone(), strides: lhs_strides, elem_stride: vec![1, 1, 1], interleave: TensorMapInterleave::None, @@ -1236,7 +465,7 @@ impl ConcreteInputsFactory tile_size: stage_size_rhs, }, rank: 3, - shape: rhs_shape, + shape: rhs_shape.clone(), strides: rhs_strides, elem_stride: vec![1, 1, 1], interleave: TensorMapInterleave::None, @@ -1255,436 +484,99 @@ impl ConcreteInputsFactory metadata: meta_rhs, }; - TensorMapInputsLaunch::new(lhs, rhs, CubeOptionArgs::None) + let view = |buffer, shape: &[usize], transposed| { + let batches = ScalarArg::new(shape[0] as u32); + let (rows, cols) = match transposed { + true => ( + ScalarArg::new(shape[2] as u32), + ScalarArg::new(shape[1] as u32), + ), + false => ( + ScalarArg::new(shape[1] as u32), + ScalarArg::new(shape[2] as u32), + ), + }; + let shape = (batches, rows, cols); + let layout = SimpleTmaGlobalLayoutLaunch::new(transposed, shape); + ViewArg::new_tensor_map::(buffer, layout) + }; + + TensorMapInputsLaunch::new( + view(lhs, &lhs_shape, lhs_transposed), + view(rhs, &rhs_shape, rhs_transposed), + CubeOptionArgs::None, + CubeOptionArgs::None, + ) } } #[cube] impl MatmulArgs for TensorMapArgs { type Input = TensorMapInputs; - type Output = Tensor>; - type State = ( - *const TensorMap, - *const TensorMap, - CubeOption<*const Tensor>>, - *mut Tensor>, - ); + type Output = TensorOutput; + type State = + (TensorMapInputs, TensorOutput); - fn init_state( + fn init_state( input: &Self::Input, output: &mut Self::Output, + #[comptime] _config: G, ) -> Self::State { - let acc = match &input.acc { - CubeOption::None => CubeOption::new_None(), - CubeOption::Some(acc) => { - let ptr: *const Tensor> = acc; - CubeOption::new_Some(ptr) - } - }; - (&input.lhs, &input.rhs, acc, output) - } - - fn has_acc( - state: &Self::State, - ) -> CubeOption<()> { - match state.2 { - CubeOption::None => CubeOption::new_None(), - CubeOption::Some(_) => CubeOption::new_Some(()), - } - } - - fn read_lhs( - _state: &Self::State, - _coordinate: u32, - ) -> Line { - unimplemented!("Can't directly read from TensorMap") - } - - fn read_rhs( - _state: &Self::State, - _coordinate: u32, - ) -> Line { - unimplemented!("Can't directly read from TensorMap") - } - - fn read_acc( - state: &Self::State, - coordinate: u32, - ) -> Line { - unsafe { (*state.2.unwrap())[coordinate] } - } - - #[allow(unused)] - fn read_window_lhs( - state: &Self::State, - start: u32, - end: u32, - ) -> Slice> { - unimplemented!("Can't directly read from TensorMap") - } - - /// Read the line of the rhs tensor using the state at the given coordinate. - #[allow(unused)] - fn read_window_rhs( - state: &Self::State, - start: u32, - end: u32, - ) -> Slice> { - unimplemented!("Can't directly read from TensorMap") + (*input, *output) } - fn read_window_acc( + fn view_lhs( state: &Self::State, - start: u32, - end: u32, - ) -> Slice> { - unsafe { (*state.2.unwrap()).slice(start, end) } + ) -> View, Coords3d> { + state.0.lhs } - fn as_tensor_map_lhs( - state: &Self::State, - ) -> CubeOption> { - CubeOption::new_Some(unsafe { *state.0 }) - } - - fn as_tensor_map_rhs( - state: &Self::State, - ) -> CubeOption> { - CubeOption::new_Some(unsafe { *state.1 }) - } - - fn as_tensor_map_acc( + fn batch_lhs( _state: &Self::State, - ) -> CubeOption> { - CubeOption::new_None() - } - - fn shape_lhs( - state: &Self::State, - dim: u32, - ) -> u32 { - unsafe { (*state.0).shape(dim) } - } - - fn shape_rhs( - state: &Self::State, - dim: u32, - ) -> u32 { - unsafe { (*state.1).shape(dim) } - } - - fn shape_acc( - state: &Self::State, - dim: u32, - ) -> u32 { - unsafe { (*state.2.unwrap()).shape(dim) } - } - - fn shape_out( - state: &Self::State, - dim: u32, + batch: u32, ) -> u32 { - unsafe { &*state.3 }.shape(dim) + batch } - fn stride_lhs( + fn view_rhs( state: &Self::State, - dim: u32, - ) -> u32 { - unsafe { &*state.0 }.stride(dim) + ) -> View, Coords3d> { + state.0.rhs } - fn stride_rhs( - state: &Self::State, - dim: u32, + fn batch_rhs( + _state: &Self::State, + batch: u32, ) -> u32 { - unsafe { &*state.1 }.stride(dim) + batch } - fn stride_acc( + fn view_acc( state: &Self::State, - dim: u32, - ) -> u32 { - unsafe { (*state.2.unwrap()).stride(dim) } + ) -> CubeOption, Coords3d>> { + state.0.acc } - fn stride_out( + fn batch_acc( state: &Self::State, - dim: u32, + batch: u32, ) -> u32 { - unsafe { &*state.3 }.stride(dim) + match state.0.acc_batch { + CubeOption::Some(layout) => layout.to_source_pos(batch), + CubeOption::None => batch, + } } - fn write_out( + fn view_out( state: &mut Self::State, - coordinate: u32, - value: Line, - ) { - unsafe { (*state.3)[coordinate] = value } - } - - fn rank_lhs(state: &Self::State) -> u32 { - unsafe { (*state.0).rank() } - } - - fn rank_rhs(state: &Self::State) -> u32 { - unsafe { (*state.1).rank() } - } - - fn rank_acc(state: &Self::State) -> u32 { - unsafe { (*state.2.unwrap()).rank() } + ) -> View, Coords3d, ReadWrite> { + state.1.view } - fn rank_out(state: &Self::State) -> u32 { - unsafe { (*state.3).rank() } - } - - fn len_lhs(state: &Self::State) -> u32 { - unsafe { (*state.0).len() } - } - - fn len_rhs(state: &Self::State) -> u32 { - unsafe { (*state.1).len() } - } - - fn len_acc(state: &Self::State) -> u32 { - unsafe { (*state.2.unwrap()).len() } - } - - fn len_out(state: &Self::State) -> u32 { - unsafe { (*state.3).len() } - } - - fn buffer_len_lhs( - state: &Self::State, - ) -> u32 { - unsafe { (*state.0).buffer_len() } - } - - fn buffer_len_rhs( - state: &Self::State, - ) -> u32 { - unsafe { (*state.1).buffer_len() } - } - - fn buffer_len_acc( - state: &Self::State, - ) -> u32 { - unsafe { (*state.2.unwrap()).buffer_len() } - } - - fn buffer_len_out( + fn batch_out( state: &Self::State, + batch: u32, ) -> u32 { - unsafe { (*state.3).buffer_len() } - } - - fn line_size_lhs( - _state: &Self::State, - ) -> comptime_type!(u32) { - 1 - } - fn line_size_rhs( - _state: &Self::State, - ) -> comptime_type!(u32) { - 1 - } - #[allow(unused_variables)] - fn line_size_acc( - state: &Self::State, - ) -> comptime_type!(u32) { - intrinsic!(|scope| { - match state.2 { - CubeOptionExpand::None => 1, - CubeOptionExpand::Some(t) => t.__expand_line_size_method(scope), - } - }) - } - fn line_size_out( - state: &Self::State, - ) -> comptime_type!(u32) { - unsafe { (*state.3).line_size() } - } -} - -mod __lhs { - use super::*; - - impl CubeType - for TensorLhs - { - type ExpandType = TensorLhsExpand; - } - - impl Clone - for TensorLhsExpand - { - fn clone(&self) -> Self { - Self { - state: self.state.clone(), - } - } - } - - impl IntoMut - for TensorLhsExpand - { - fn into_mut(mut self, scope: &mut Scope) -> Self { - self.state = self.state.into_mut(scope); - self - } - } - impl CubeDebug - for TensorLhsExpand - { - fn set_debug_name(&self, scope: &mut Scope, name: &'static str) { - self.state.set_debug_name(scope, name); - } - } - impl Clone - for TensorLhs - { - fn clone(&self) -> Self { - *self - } - } - impl Copy for TensorLhs {} -} - -mod __rhs { - use super::*; - - impl CubeType - for TensorRhs - { - type ExpandType = TensorRhsExpand; - } - - impl Clone - for TensorRhsExpand - { - fn clone(&self) -> Self { - Self { - state: self.state.clone(), - } - } - } - - impl IntoMut - for TensorRhsExpand - { - fn into_mut(mut self, scope: &mut Scope) -> Self { - self.state = self.state.into_mut(scope); - self - } - } - impl CubeDebug - for TensorRhsExpand - { - fn set_debug_name(&self, scope: &mut Scope, name: &'static str) { - self.state.set_debug_name(scope, name); - } - } - impl Clone - for TensorRhs - { - fn clone(&self) -> Self { - *self - } - } - impl Copy for TensorRhs {} -} - -mod __acc { - use super::*; - - impl CubeType - for TensorAcc - { - type ExpandType = TensorAccExpand; - } - - impl Clone - for TensorAccExpand - { - fn clone(&self) -> Self { - Self { - state: self.state.clone(), - } - } - } - - impl IntoMut - for TensorAccExpand - { - fn into_mut(mut self, scope: &mut Scope) -> Self { - self.state = self.state.into_mut(scope); - self - } - } - impl CubeDebug - for TensorAccExpand - { - fn set_debug_name(&self, scope: &mut Scope, name: &'static str) { - self.state.set_debug_name(scope, name); - } - } - impl Clone - for TensorAcc - { - fn clone(&self) -> Self { - *self - } - } - impl Copy for TensorAcc {} -} - -mod __output { - use super::*; - - impl CubeType - for TensorOutput - { - type ExpandType = TensorOutputExpand; - } - - impl Clone - for TensorOutput - { - fn clone(&self) -> Self { - *self - } - } - - impl Clone - for TensorOutputExpand - { - fn clone(&self) -> Self { - Self { - state: self.state.clone(), - } - } - } - - impl IntoMut - for TensorOutputExpand - { - fn into_mut(mut self, scope: &mut Scope) -> Self { - self.state = self.state.into_mut(scope); - self - } - } - - impl CubeDebug - for TensorOutputExpand - { - fn set_debug_name(&self, scope: &mut Scope, name: &'static str) { - self.state.set_debug_name(scope, name); - } - } - - impl Copy - for TensorOutput - { + state.1.batch.to_source_pos(batch) } } diff --git a/crates/cubecl-matmul/src/components/global/base.rs b/crates/cubecl-matmul/src/components/global/base.rs index c19c04a6b..cb49b2f02 100644 --- a/crates/cubecl-matmul/src/components/global/base.rs +++ b/crates/cubecl-matmul/src/components/global/base.rs @@ -12,7 +12,7 @@ use crate::components::{LhsG, MatmulIdent, MatmulLineSizes, MatmulSelection, Rhs use crate::components::{global::RoleRuleConfig, stage::StageMemoryConfig}; use cubecl_std::{ CubeOption, - tensor::{layout::Coords2d, r#virtual::VirtualTensor}, + tensor::{View, layout::Coords2d}, }; use std::{fmt::Debug, hash::Hash}; @@ -30,7 +30,7 @@ pub trait GlobalMatmulFamily: Send + Sync + 'static { /// /// This function may return an error if the configuration cannot be supported on the current runtime. fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, matmul_line_sizes: &MatmulLineSizes, @@ -96,31 +96,19 @@ pub trait GlobalMatmul: 'static + Send + Sync { /// Initialize the global reader for Lhs, starting at row m and column k fn init_lhs_global_reader( - lhs: VirtualTensor>, - batch_offset: u32, - offset: Coords2d, - view_shape: Coords2d, - nth_batch: u32, + lhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::LhsGlobalReader; /// Initialize the global reader for Rhs, starting at row k and column n fn init_rhs_global_reader( - rhs: VirtualTensor>, - batch_offset: u32, - offset: Coords2d, - view_shape: Coords2d, - nth_batch: u32, + rhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::RhsGlobalReader; /// Initialize the global reader for Rhs, starting at row k and column n fn init_acc_global_reader( - rhs: CubeOption>>, - batch_offset: u32, - offset: Coords2d, - view_shape: Coords2d, - nth_batch: u32, + acc: CubeOption>, Coords2d>>, #[comptime] config: Self::Config, ) -> Self::AccGlobalReader; @@ -129,11 +117,7 @@ pub trait GlobalMatmul: 'static + Send + Sync { /// Initialize the global writer at row m and column n fn init_global_writer( - out: VirtualTensor, ReadWrite>, - batch_offset: u32, - offset: Coords2d, - view_shape: Coords2d, - nth_batch: u32, + out: View>, Coords2d, ReadWrite>, #[comptime] config: Self::Config, ) -> Self::GlobalWriter; } diff --git a/crates/cubecl-matmul/src/components/global/memory/config.rs b/crates/cubecl-matmul/src/components/global/memory/config.rs index 4789fc3e6..92b214ac2 100644 --- a/crates/cubecl-matmul/src/components/global/memory/config.rs +++ b/crates/cubecl-matmul/src/components/global/memory/config.rs @@ -2,7 +2,7 @@ use std::{fmt::Debug, hash::Hash}; use crate::components::MatrixLayout; -#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Default)] pub struct GlobalMemoryConfig { pub elements_in_tile_row: u32, pub elements_in_tile_col: u32, diff --git a/crates/cubecl-matmul/src/components/global/memory/iterator.rs b/crates/cubecl-matmul/src/components/global/memory/iterator.rs index a92aa9fbf..0db4056a4 100644 --- a/crates/cubecl-matmul/src/components/global/memory/iterator.rs +++ b/crates/cubecl-matmul/src/components/global/memory/iterator.rs @@ -75,4 +75,9 @@ impl GlobalIterator { self.global_view.slice_unchecked(offset, self.view_size) } } + + /// Returns the line size of the global view + pub fn line_size(&self) -> comptime_type!(u32) { + self.global_view.line_size() + } } diff --git a/crates/cubecl-matmul/src/components/global/memory/layout.rs b/crates/cubecl-matmul/src/components/global/memory/layout.rs index 6773e4d62..fe71dd1cc 100644 --- a/crates/cubecl-matmul/src/components/global/memory/layout.rs +++ b/crates/cubecl-matmul/src/components/global/memory/layout.rs @@ -1,54 +1,140 @@ use cubecl::prelude::*; +use cubecl_common::quant::scheme::{QuantLevel, QuantScheme}; use cubecl_core::{self as cubecl}; -use cubecl_std::tensor::{ - layout::{Coords1d, Coords2d, Coords3d, Layout, LayoutExpand}, - r#virtual::VirtualTensor, +use cubecl_std::{ + FastDivmod, FastDivmodArgs, + tensor::layout::{ + Coords1d, Coords2d, Coords3d, Layout, LayoutExpand, VirtualLayout, VirtualLayoutLaunch, + }, }; -use crate::components::{MatrixLayout, global::memory::GlobalMemoryConfig}; +use crate::components::{MatmulProblem, MatrixLayout, global::memory::GlobalMemoryConfig}; /// Global layout that uses the last two dimensions and ignores all others. -#[derive(CubeType, Clone, Copy)] -pub struct SimpleGlobalLayout { +#[derive(CubeType, CubeLaunch, Clone, Copy)] +pub struct SimpleTmaGlobalLayout { + #[cube(comptime)] + transposed: bool, + shape: Coords3d, +} + +#[cube] +impl SimpleTmaGlobalLayout { + /// Creates a new 2D layout with the batch set to `nth_batch`. + pub fn new(shape: Coords3d, #[comptime] layout: MatrixLayout) -> Self { + let transposed = comptime![matches!(layout, MatrixLayout::ColMajor)]; + SimpleTmaGlobalLayout { shape, transposed } + } +} + +#[cube] +impl Layout for SimpleTmaGlobalLayout { + type Coordinates = Coords3d; + type SourceCoordinates = Coords3d; + + fn to_source_pos(&self, coords: Self::Coordinates) -> Coords3d { + let (batch, row, col) = coords; + // Tensor maps are required to have a stride of 1 on the last dim, so their shape is + // transposed for col-major matrices. Need to compensate by swapping the coordinates. + if comptime![self.transposed] { + (batch, col, row) + } else { + (batch, row, col) + } + } + + fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (Coords3d, bool) { + (self.to_source_pos(coords), self.is_in_bounds(coords)) + } + + fn shape(&self) -> Self::Coordinates { + self.shape + } + + fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool { + // No need to bounds check TMA loads + true.runtime() + } +} + +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Default)] +pub struct GlobalLayoutConfig { + pub matrix_layout: MatrixLayout, + pub check_row_bounds: bool, + pub check_col_bounds: bool, +} + +impl From for GlobalLayoutConfig { + fn from(value: GlobalMemoryConfig) -> Self { + GlobalLayoutConfig { + matrix_layout: value.matrix_layout, + check_row_bounds: value.check_row_bounds, + check_col_bounds: value.check_col_bounds, + } + } +} + +/// Global layout that uses the last two dimensions and ignores all others. +#[derive(CubeType, CubeLaunch, Clone)] +pub struct GlobalLayout { + batch_layout: VirtualLayout, rows: u32, + cols: u32, + stride_row: u32, - columns: u32, stride_col: u32, - batch_offset: u32, + + #[cube(comptime)] + line_size: u32, #[cube(comptime)] - config: GlobalMemoryConfig, + packing: u32, + #[cube(comptime)] + config: GlobalLayoutConfig, } #[cube] -impl SimpleGlobalLayout { - /// Creates a new 2D layout starting at `batch_offset`. - pub fn new( - tensor: &VirtualTensor, - batch_offset: u32, - #[comptime] config: GlobalMemoryConfig, +impl GlobalLayout { + /// Create a new batched global layout. `batch_shape` should be based on the output shape. + #[allow(clippy::too_many_arguments)] + pub fn new( + batch_layout: VirtualLayout, + shape_row: u32, + shape_col: u32, + stride_row: u32, + stride_col: u32, + #[comptime] line_size: u32, + #[comptime] packing: u32, + #[comptime] config: GlobalLayoutConfig, ) -> Self { - let rank = tensor.rank(); - - SimpleGlobalLayout { - rows: tensor.shape(rank - 2), - stride_row: tensor.stride(rank - 2), - columns: tensor.shape(rank - 1), - stride_col: tensor.stride(rank - 1), - batch_offset, + GlobalLayout { + batch_layout, + rows: shape_row, + cols: shape_col, + stride_row, + stride_col, + line_size, + packing, config, } } } #[cube] -impl Layout for SimpleGlobalLayout { - type Coordinates = Coords2d; +impl Layout for GlobalLayout { + type Coordinates = Coords3d; type SourceCoordinates = Coords1d; fn to_source_pos(&self, coords: Self::Coordinates) -> u32 { - let line_size = comptime![self.config.global_line_size]; - let (row, col) = coords; - let idx = self.batch_offset + row * self.stride_row + col * self.stride_col; + let line_size = comptime![self.line_size]; + let (batch, row, col) = coords; + let batch_offs = self.batch_layout.to_source_pos(batch); + + let (row, col) = match comptime![self.config.matrix_layout] { + MatrixLayout::RowMajor => (row, col / self.packing), + MatrixLayout::ColMajor => (row / self.packing, col), + }; + + let idx = batch_offs + row * self.stride_row + col * self.stride_col; idx / line_size } @@ -58,68 +144,313 @@ impl Layout for SimpleGlobalLayout { } fn shape(&self) -> Self::Coordinates { - (self.rows, self.columns) + (u32::MAX.runtime(), self.rows, self.cols) } fn is_in_bounds(&self, pos: Self::Coordinates) -> bool { - let (row, col) = pos; + let (_, row, col) = pos; match comptime!((self.config.check_row_bounds, self.config.check_col_bounds)) { - (true, true) => row < self.rows && col < self.columns, + (true, true) => row < self.rows && col < self.cols, (true, false) => row < self.rows, - (false, true) => col < self.columns, + (false, true) => col < self.cols, (false, false) => true, } } } -/// Global layout that uses the last two dimensions and ignores all others. -#[derive(CubeType, Clone, Copy)] -pub struct SimpleTmaGlobalLayout { - nth_batch: u32, - #[cube(comptime)] - transposed: bool, +impl<'a, R: Runtime> GlobalLayoutLaunch<'a, R> { + pub fn from_handle( + handle: &TensorHandleRef<'a, R>, + line_size: u8, + config: GlobalLayoutConfig, + ) -> Self { + let rank = handle.shape.len(); + let rows = handle.shape[rank - 2]; + let cols = handle.shape[rank - 1]; + let stride_row = handle.strides[rank - 2]; + let stride_col = handle.strides[rank - 1]; + + GlobalLayoutLaunch::new( + VirtualLayoutLaunch::new::(NoopLayoutLaunch::new()), + ScalarArg::new(rows as u32), + ScalarArg::new(cols as u32), + ScalarArg::new(stride_row as u32), + ScalarArg::new(stride_col as u32), + line_size as u32, + 1, + config, + ) + } + + pub fn from_handle_batched( + client: &ComputeClient, + handle: &TensorHandleRef<'a, R>, + problem: &MatmulProblem, + line_size: u8, + config: GlobalLayoutConfig, + ) -> Self { + let rank = handle.shape.len(); + let rows = handle.shape[rank - 2]; + let cols = handle.shape[rank - 1]; + let stride_row = handle.strides[rank - 2]; + let stride_col = handle.strides[rank - 1]; + + let batch_layout = BatchLayoutLaunch::from_handle(client, handle, problem); + + GlobalLayoutLaunch::new( + VirtualLayoutLaunch::new::(batch_layout), + ScalarArg::new(rows as u32), + ScalarArg::new(cols as u32), + ScalarArg::new(stride_row as u32), + ScalarArg::new(stride_col as u32), + line_size as u32, + 1, + config, + ) + } + + #[allow(clippy::too_many_arguments)] + pub fn from_quantized_handle( + client: &ComputeClient, + values: &TensorHandleRef<'a, R>, + scales: &TensorHandleRef<'a, R>, + shape: &'a [usize], + problem: &MatmulProblem, + scheme: QuantScheme, + line_size: u8, + config: GlobalLayoutConfig, + ) -> (GlobalLayoutLaunch<'a, R>, GlobalScaleLayoutArgs<'a, R>) { + let rank = values.shape.len(); + let (rows, cols) = (shape[rank - 2], shape[rank - 1]); + let values_layout = { + let (stride_row, stride_col) = (values.strides[rank - 2], values.strides[rank - 1]); + + let batch_layout = BatchLayoutLaunch::from_handle(client, values, problem); + + GlobalLayoutLaunch::new( + VirtualLayoutLaunch::new::(batch_layout), + ScalarArg::new(rows as u32), + ScalarArg::new(cols as u32), + ScalarArg::new(stride_row as u32), + ScalarArg::new(stride_col as u32), + line_size as u32, + scheme.num_quants() as u32, + config, + ) + }; + + let scales_layout = { + let shape = (ScalarArg::new(rows as u32), ScalarArg::new(cols as u32)); + + match scheme.level { + QuantLevel::Tensor => GlobalScaleLayoutArgs::PerTensor { shape }, + QuantLevel::Block(block_size) => { + let [block_row, block_col] = block_size.as_dim(); + // Scales are never vectorized because we require that `block_size >= line_size * num_quants`. + let scales_layout = + GlobalLayoutLaunch::from_handle_batched(client, scales, problem, 1, config); + GlobalScaleLayoutArgs::BlockScaled(BlockScaledLayoutLaunch::new( + shape, + scales_layout, + (block_row as u32, block_col as u32), + )) + } + } + }; + + (values_layout, scales_layout) + } +} + +#[derive(CubeType, CubeLaunch)] +pub struct BatchLayout { + batch_shape: Sequence, + batch_strides: Sequence, } #[cube] -impl SimpleTmaGlobalLayout { - /// Creates a new 2D layout with the batch set to `nth_batch`. - pub fn new(nth_batch: u32, #[comptime] layout: MatrixLayout) -> Self { - let transposed = comptime![matches!(layout, MatrixLayout::ColMajor)]; - SimpleTmaGlobalLayout { - nth_batch, - transposed, +impl BatchLayout { + pub fn new(batch_strides: Sequence, batch_shape: Sequence) -> Self { + BatchLayout { + batch_shape, + batch_strides, } } } #[cube] -impl Layout for SimpleTmaGlobalLayout { - type Coordinates = Coords2d; - type SourceCoordinates = Coords3d; +impl Layout for BatchLayout { + type Coordinates = Coords1d; + type SourceCoordinates = Coords1d; - fn to_source_pos(&self, coords: Self::Coordinates) -> Coords3d { - let (row, col) = coords; - // Tensor maps are required to have a stride of 1 on the last dim, so their shape is - // transposed for col-major matrices. Need to compensate by swapping the coordinates. - if comptime![self.transposed] { - (self.nth_batch, col, row) - } else { - (self.nth_batch, row, col) + fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates { + let mut batch = pos; + let mut batch_offs = 0; + let batch_shape = self.batch_shape.rev(); + let batch_strides = self.batch_strides.rev(); + + #[unroll] + for i in 0..batch_shape.len() { + let (rem, local_pos) = batch_shape.index(i).div_mod(batch); + batch = rem; + batch_offs += local_pos * *batch_strides.index(i); } + + batch_offs } - fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (Coords3d, bool) { - (self.to_source_pos(coords), self.is_in_bounds(coords)) + fn shape(&self) -> Self::Coordinates { + u32::MAX.runtime() + } + + fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool { + true.runtime() + } + + fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) { + (self.to_source_pos(pos), self.is_in_bounds(pos)) + } +} + +/// Layout that passed through the coordinates with no checks or modification. +#[derive(CubeType, CubeLaunch)] +pub struct NoopLayout {} + +#[cube] +impl NoopLayout { + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + NoopLayout {} + } +} + +#[cube] +impl Layout for NoopLayout { + type Coordinates = Coords1d; + type SourceCoordinates = Coords1d; + + fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates { + pos } fn shape(&self) -> Self::Coordinates { - // No need to bounds check TMA loads - (u32::MAX, u32::MAX).runtime() + u32::MAX.runtime() } fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool { - // No need to bounds check TMA loads true.runtime() } + + fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) { + (self.to_source_pos(pos), self.is_in_bounds(pos)) + } +} + +impl<'a, R: Runtime> BatchLayoutLaunch<'a, R> { + pub fn from_handle( + client: &ComputeClient, + handle: &TensorHandleRef<'a, R>, + problem: &MatmulProblem, + ) -> Self { + let rank = handle.shape.len(); + let batch_shape = problem + .out_batches + .iter() + .map(|shape| FastDivmodArgs::new(client, *shape as u32)) + .collect(); + let batch_strides = handle.strides[..rank - 2] + .iter() + .zip(&handle.shape[..rank - 2]) + .map(|(stride, shape)| if *shape == 1 { 0 } else { *stride }) + .map(|stride| ScalarArg::new(stride as u32)) + .collect(); + BatchLayoutLaunch::new(batch_shape, batch_strides) + } +} + +#[derive(CubeType, CubeLaunch)] +pub enum GlobalScaleLayout { + PerTensor { shape: Coords2d }, + BlockScaled(BlockScaledLayout), +} + +/// Workaround for enums not supporting `comptime`, should fix that in the future +#[derive(CubeType, CubeLaunch)] +pub struct BlockScaledLayout { + shape: Coords2d, + scales_layout: GlobalLayout, + #[cube(comptime)] + block_size: Coords2d, +} + +#[cube] +impl BlockScaledLayout { + pub fn new( + shape: Coords2d, + scales_layout: GlobalLayout, + #[comptime] block_size: Coords2d, + ) -> Self { + BlockScaledLayout { + shape, + scales_layout, + block_size, + } + } +} + +#[cube] +impl Layout for GlobalScaleLayout { + type Coordinates = Coords3d; + type SourceCoordinates = Coords1d; + + fn to_source_pos(&self, coords: Self::Coordinates) -> u32 { + match self { + GlobalScaleLayout::PerTensor { .. } => 0u32.runtime(), + GlobalScaleLayout::BlockScaled(layout) => { + let BlockScaledLayout { + scales_layout, + block_size, + .. + } = layout; + + let (batch, row, col) = coords; + let (block_row, block_col) = block_size; + let (row, col) = (row / block_row, col / block_col); + scales_layout.to_source_pos((batch, row, col)) + } + } + } + + fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (u32, bool) { + (self.to_source_pos(coords), self.is_in_bounds(coords)) + } + + fn shape(&self) -> Self::Coordinates { + match self { + GlobalScaleLayout::PerTensor { shape } => (u32::MAX.runtime(), shape.0, shape.1), + GlobalScaleLayout::BlockScaled(layout) => { + let (row, col) = layout.shape; + (u32::MAX.runtime(), row, col) + } + } + } + + fn is_in_bounds(&self, pos: Self::Coordinates) -> bool { + match self { + GlobalScaleLayout::PerTensor { .. } => true.runtime(), + GlobalScaleLayout::BlockScaled(layout) => { + let (_, row, col) = pos; + let l = &layout.scales_layout; + let (rows, cols) = layout.shape; + + match comptime!((l.config.check_row_bounds, l.config.check_col_bounds)) { + (true, true) => row < rows && col < cols, + (true, false) => row < rows, + (false, true) => col < cols, + (false, false) => true, + } + } + } + } } diff --git a/crates/cubecl-matmul/src/components/global/multi_stage/double_buffering/config.rs b/crates/cubecl-matmul/src/components/global/multi_stage/double_buffering/config.rs index 98d0542a5..1010e5d6b 100644 --- a/crates/cubecl-matmul/src/components/global/multi_stage/double_buffering/config.rs +++ b/crates/cubecl-matmul/src/components/global/multi_stage/double_buffering/config.rs @@ -113,7 +113,7 @@ impl DoubleBufferingGlobalConfig { /// - a reader is invalid /// - CubeDim is too big pub fn new( - _client: &ComputeClient, + _client: &ComputeClient, stage_config: S, num_planes: u32, check_m_bounds: bool, diff --git a/crates/cubecl-matmul/src/components/global/multi_stage/double_buffering/matmul.rs b/crates/cubecl-matmul/src/components/global/multi_stage/double_buffering/matmul.rs index 9fe6f4c44..543987ffc 100644 --- a/crates/cubecl-matmul/src/components/global/multi_stage/double_buffering/matmul.rs +++ b/crates/cubecl-matmul/src/components/global/multi_stage/double_buffering/matmul.rs @@ -1,8 +1,8 @@ +use crate::components::global::Specializer; use crate::components::global::multi_stage::double_buffer_execution::{ execute_current_and_read_next, execute_last_and_write_results, read_first, }; use crate::components::global::{GlobalConfig, GlobalWriter}; -use crate::components::global::{Specializer, memory::SimpleGlobalLayout}; use crate::components::{ AccG, global::read::{ @@ -19,7 +19,7 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; use cubecl_std::{ CubeOption, CubeOptionExpand, - tensor::{layout::Coords2d, r#virtual::VirtualTensor}, + tensor::{View, layout::Coords2d}, }; use std::marker::PhantomData; @@ -188,18 +188,12 @@ where } fn init_lhs_global_reader( - lhs: VirtualTensor>, - batch_offset: u32, - offset: Coords2d, - slice_size: Coords2d, - _nth_batch: u32, + lhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::LhsGlobalReader { - let conf = config.global_memory_config(MatmulIdent::Lhs); let k_step = k_step::(config); - let layout = SimpleGlobalLayout::new(&lhs, batch_offset, conf); SyncPartialStageGlobalReader::::new( - lhs.view(layout).slice_unchecked(offset, slice_size), + lhs, k_step, MatmulIdent::Lhs, config, @@ -207,18 +201,12 @@ where } fn init_rhs_global_reader( - rhs: VirtualTensor>, - batch_offset: u32, - offset: Coords2d, - slice_size: Coords2d, - _nth_batch: u32, + rhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::RhsGlobalReader { - let conf = config.global_memory_config(MatmulIdent::Rhs); let k_step = k_step::(config); - let layout = SimpleGlobalLayout::new(&rhs, batch_offset, conf); SyncPartialStageGlobalReader::::new( - rhs.view(layout).slice_unchecked(offset, slice_size), + rhs, k_step, MatmulIdent::Rhs, config, @@ -226,11 +214,7 @@ where } fn init_acc_global_reader( - acc: CubeOption>>, - _batch_offset: u32, - _offset: Coords2d, - _slice_size: Coords2d, - _nth_batch: u32, + acc: CubeOption>, Coords2d>>, #[comptime] _config: Self::Config, ) -> Self::AccGlobalReader { match acc { @@ -240,17 +224,11 @@ where } fn init_global_writer( - out: VirtualTensor, ReadWrite>, - batch_offset: u32, - offset: Coords2d, - size: Coords2d, - _nth_batch: u32, + out: View>, Coords2d, ReadWrite>, #[comptime] config: Self::Config, ) -> Self::GlobalWriter { let conf = config.global_memory_config(MatmulIdent::Out); - let layout = SimpleGlobalLayout::new(&out, batch_offset, conf); - let view = out.view_mut(layout).slice_mut_unchecked(offset, size); - Self::GlobalWriter::init::(view, conf, config.stage_config()) + Self::GlobalWriter::init::(out, conf, config.stage_config()) } fn init_accumulators(#[comptime] config: Self::Config) -> Self::Accumulators { diff --git a/crates/cubecl-matmul/src/components/global/multi_stage/double_buffering/setup.rs b/crates/cubecl-matmul/src/components/global/multi_stage/double_buffering/setup.rs index 20b0c1f1f..89c990969 100644 --- a/crates/cubecl-matmul/src/components/global/multi_stage/double_buffering/setup.rs +++ b/crates/cubecl-matmul/src/components/global/multi_stage/double_buffering/setup.rs @@ -47,7 +47,7 @@ where type Config = DoubleBufferingGlobalConfig; fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-matmul/src/components/global/multi_stage/ordered/config.rs b/crates/cubecl-matmul/src/components/global/multi_stage/ordered/config.rs index 2decb7dd3..f3456758e 100644 --- a/crates/cubecl-matmul/src/components/global/multi_stage/ordered/config.rs +++ b/crates/cubecl-matmul/src/components/global/multi_stage/ordered/config.rs @@ -123,7 +123,7 @@ impl OrderedDoubleBufferingGlobalConfig { /// - There is more than one stage partition in n /// - Lhs is not loaded exclusively by main flow planes pub fn new( - _client: &ComputeClient, + _client: &ComputeClient, stage_config: S, num_planes: u32, check_m_bounds: bool, diff --git a/crates/cubecl-matmul/src/components/global/multi_stage/ordered/matmul.rs b/crates/cubecl-matmul/src/components/global/multi_stage/ordered/matmul.rs index 61e82ab96..bb5d97b0f 100644 --- a/crates/cubecl-matmul/src/components/global/multi_stage/ordered/matmul.rs +++ b/crates/cubecl-matmul/src/components/global/multi_stage/ordered/matmul.rs @@ -1,5 +1,5 @@ +use crate::components::global::Specializer; use crate::components::global::{self, GlobalConfig, GlobalWriter}; -use crate::components::global::{Specializer, memory::SimpleGlobalLayout}; use crate::components::{ AccG, global::read::{ @@ -17,8 +17,8 @@ use crate::components::{ }; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use cubecl_std::tensor::layout::Coords2d; -use cubecl_std::{CubeOption, CubeOptionExpand, tensor::r#virtual::VirtualTensor}; +use cubecl_std::tensor::{View, layout::Coords2d}; +use cubecl_std::{CubeOption, CubeOptionExpand}; use std::marker::PhantomData; use super::OrderedDoubleBufferingGlobalConfig; @@ -190,18 +190,12 @@ where } fn init_lhs_global_reader( - lhs: VirtualTensor>, - batch_offset: u32, - offset: Coords2d, - slice_size: Coords2d, - _nth_batch: u32, + lhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::LhsGlobalReader { - let conf = config.global_memory_config(MatmulIdent::Lhs); let k_step = lhs_k_step::(config); - let layout = SimpleGlobalLayout::new(&lhs, batch_offset, conf); SyncFullStageGlobalReader::::new( - lhs.view(layout).slice_unchecked(offset, slice_size), + lhs, k_step, MatmulIdent::Lhs, config, @@ -209,18 +203,12 @@ where } fn init_rhs_global_reader( - rhs: VirtualTensor>, - batch_offset: u32, - offset: Coords2d, - slice_size: Coords2d, - _nth_batch: u32, + rhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::RhsGlobalReader { - let conf = config.global_memory_config(MatmulIdent::Rhs); let k_step = rhs_k_step::(config); - let layout = SimpleGlobalLayout::new(&rhs, batch_offset, conf); SyncPartialStageGlobalReader::::new( - rhs.view(layout).slice_unchecked(offset, slice_size), + rhs, k_step, MatmulIdent::Rhs, config, @@ -228,11 +216,7 @@ where } fn init_acc_global_reader( - acc: CubeOption>>, - _batch_offset: u32, - _offset: Coords2d, - _slice_size: Coords2d, - _nth_batch: u32, + acc: CubeOption>, Coords2d>>, #[comptime] _config: Self::Config, ) -> Self::AccGlobalReader { match acc { @@ -242,17 +226,11 @@ where } fn init_global_writer( - out: VirtualTensor, ReadWrite>, - batch_offset: u32, - offset: Coords2d, - size: Coords2d, - _nth_batch: u32, + out: View>, Coords2d, ReadWrite>, #[comptime] config: Self::Config, ) -> Self::GlobalWriter { let conf = config.global_memory_config(MatmulIdent::Out); - let layout = SimpleGlobalLayout::new(&out, batch_offset, conf); - let view = out.view_mut(layout).slice_mut_unchecked(offset, size); - Self::GlobalWriter::init::(view, conf, config.stage_config()) + Self::GlobalWriter::init::(out, conf, config.stage_config()) } fn init_accumulators(#[comptime] config: Self::Config) -> Self::Accumulators { diff --git a/crates/cubecl-matmul/src/components/global/multi_stage/ordered/setup.rs b/crates/cubecl-matmul/src/components/global/multi_stage/ordered/setup.rs index c6ab8c315..dc59c96f4 100644 --- a/crates/cubecl-matmul/src/components/global/multi_stage/ordered/setup.rs +++ b/crates/cubecl-matmul/src/components/global/multi_stage/ordered/setup.rs @@ -54,7 +54,7 @@ where type Config = OrderedDoubleBufferingGlobalConfig; fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-matmul/src/components/global/read/reader/sync_full_reader.rs b/crates/cubecl-matmul/src/components/global/read/reader/sync_full_reader.rs index 40339cf20..cb211576d 100644 --- a/crates/cubecl-matmul/src/components/global/read/reader/sync_full_reader.rs +++ b/crates/cubecl-matmul/src/components/global/read/reader/sync_full_reader.rs @@ -34,6 +34,7 @@ pub trait SyncFullLoadingStrategy: /// Returns the job with preliminary calculations done. fn new_job( #[comptime] ident: MatmulIdent, + #[comptime] line_size: u32, #[comptime] config: G, ) -> Self::Job; } @@ -75,7 +76,7 @@ impl let global_iter = GlobalIterator::new(tensor, k_step, ident.view_direction(), false); let loading_job = match config.precompute_job() { - true => CubeOption::new_Some(L::new_job::(ident, config)), + true => CubeOption::new_Some(L::new_job::(ident, tensor.line_size(), config)), false => CubeOption::new_None(), }; @@ -106,15 +107,15 @@ impl pub fn load_stage(&mut self, #[comptime] config: G) { let mut loading_job = match self.loading_job { CubeOption::Some(loading_job) => loading_job, - CubeOption::None => L::new_job::(self.ident, config), + CubeOption::None => { + L::new_job::(self.ident, self.global_iter.line_size(), config) + } }; let len = L::Job::task_count(&loading_job); - let mut task_id = comptime![0u32]; - #[allow(clippy::explicit_counter_loop)] #[unroll] - for _ in 0..len { + for task_id in 0..len { L::Job::::execute_task::( &mut loading_job, task_id, @@ -122,7 +123,6 @@ impl &mut self.stage, config, ); - comptime![task_id += 1]; } } } @@ -138,9 +138,10 @@ impl JobExecut #[comptime] _stage_buffer: StageBuffer, #[comptime] config: G, ) -> Self::JobIterator { + let view = this.global_iter.view(); let job = match this.loading_job { CubeOption::Some(loading_job) => loading_job, - CubeOption::None => L::new_job::(this.ident, config), + CubeOption::None => L::new_job::(this.ident, view.line_size(), config), }; let num_tasks = L::Job::task_count(&job); diff --git a/crates/cubecl-matmul/src/components/global/read/reader/sync_partial_reader.rs b/crates/cubecl-matmul/src/components/global/read/reader/sync_partial_reader.rs index ad2afe8cc..877dbb994 100644 --- a/crates/cubecl-matmul/src/components/global/read/reader/sync_partial_reader.rs +++ b/crates/cubecl-matmul/src/components/global/read/reader/sync_partial_reader.rs @@ -35,6 +35,7 @@ pub trait SyncPartialLoadingStrategy: fn new_job( #[comptime] stage_index: u32, #[comptime] ident: MatmulIdent, + #[comptime] line_size: u32, #[comptime] config: G, ) -> Self::Job; } @@ -77,8 +78,8 @@ impl let loading_job = match config.precompute_job() { true => CubeOption::new_Some(( - L::new_job::(0u32, ident, config), - L::new_job::(1u32, ident, config), + L::new_job::(0u32, ident, tensor.line_size(), config), + L::new_job::(1u32, ident, tensor.line_size(), config), )), false => CubeOption::new_None(), }; @@ -113,18 +114,19 @@ impl StageBuffer::B => job.1, }, CubeOption::None => match stage_buffer { - StageBuffer::A => L::new_job::(0u32, self.ident, config), - StageBuffer::B => L::new_job::(1u32, self.ident, config), + StageBuffer::A => { + L::new_job::(0u32, self.ident, self.global_iter.line_size(), config) + } + StageBuffer::B => { + L::new_job::(1u32, self.ident, self.global_iter.line_size(), config) + } }, }; let len = L::Job::task_count(&loading_job); - let mut task_id = comptime![0u32]; - - #[allow(clippy::explicit_counter_loop)] #[unroll] - for _ in 0..len { + for task_id in 0..len { L::Job::::execute_task::( &mut loading_job, task_id, @@ -132,7 +134,6 @@ impl &mut self.stage_memory, config, ); - comptime![task_id += 1]; } } } @@ -148,14 +149,15 @@ impl JobExe #[comptime] stage_buffer: StageBuffer, #[comptime] config: G, ) -> Self::JobIterator { + let view = this.global_iter.view(); let job = match this.loading_job { CubeOption::Some(job) => match stage_buffer { StageBuffer::A => job.0, StageBuffer::B => job.1, }, CubeOption::None => match stage_buffer { - StageBuffer::A => L::new_job::(0u32, this.ident, config), - StageBuffer::B => L::new_job::(1u32, this.ident, config), + StageBuffer::A => L::new_job::(0u32, this.ident, view.line_size(), config), + StageBuffer::B => L::new_job::(1u32, this.ident, view.line_size(), config), }, }; diff --git a/crates/cubecl-matmul/src/components/global/read/reader/tma_reader.rs b/crates/cubecl-matmul/src/components/global/read/reader/tma_reader.rs index 15063c46c..34fe1fdb5 100644 --- a/crates/cubecl-matmul/src/components/global/read/reader/tma_reader.rs +++ b/crates/cubecl-matmul/src/components/global/read/reader/tma_reader.rs @@ -62,7 +62,7 @@ impl TilingOrder for TmaTilingOrder { #[derive(CubeType)] /// Loads the entire stage memory using TMA (Tensor Memory Accelerator) pub struct TmaGlobalReader { - global_iter: GlobalIterator, + global_iter: GlobalIterator>, stage: StridedStage, #[cube(comptime)] config: StageMemoryConfig, @@ -72,7 +72,7 @@ pub struct TmaGlobalReader { impl TmaGlobalReader { /// Create a TmaGlobalReader pub fn new( - global_view: View, + global_view: View, Coords2d>, k_step: u32, #[comptime] ident: MatmulIdent, #[comptime] config: StageMemoryConfig, diff --git a/crates/cubecl-matmul/src/components/global/read/strategy/async_full_cooperative.rs b/crates/cubecl-matmul/src/components/global/read/strategy/async_full_cooperative.rs index 7af45ed1e..3df6acd5d 100644 --- a/crates/cubecl-matmul/src/components/global/read/strategy/async_full_cooperative.rs +++ b/crates/cubecl-matmul/src/components/global/read/strategy/async_full_cooperative.rs @@ -5,7 +5,7 @@ use crate::components::{ memory::{GlobalIterator, load_window_in_stage}, read::AsyncFullLoadingStrategy, }, - stage::{StridedStage, StridedTilingLayout}, + stage::{StridedStage, StridedTilingLayout, TilingValidation}, }; use cubecl_core::prelude::*; use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel}; @@ -20,7 +20,9 @@ use super::{AsyncLoadingJob, LoadingValidation}; pub struct AsyncFullCooperativeLoading {} impl LoadingValidation for AsyncFullCooperativeLoading { - fn check(_config: &C, _ident: MatmulIdent) -> Result<(), InvalidConfigError> { + fn check(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> { + StridedTilingLayout::check(config.global_memory_config(ident))?; + Ok(()) } } diff --git a/crates/cubecl-matmul/src/components/global/read/strategy/async_full_cyclic.rs b/crates/cubecl-matmul/src/components/global/read/strategy/async_full_cyclic.rs index 3167de388..26b55ab07 100644 --- a/crates/cubecl-matmul/src/components/global/read/strategy/async_full_cyclic.rs +++ b/crates/cubecl-matmul/src/components/global/read/strategy/async_full_cyclic.rs @@ -7,7 +7,7 @@ use crate::components::{ memory::{GlobalIterator, load_window_in_tile}, read::AsyncFullLoadingStrategy, }, - stage::{ContiguousTilingLayout, StridedStage, TilingOrder}, + stage::{ContiguousTilingLayout, StridedStage, TilingOrder, TilingValidation}, }; use cubecl_core::prelude::*; use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel}; @@ -34,6 +34,8 @@ impl LoadingValidation for AsyncFullCyclicLoading { ))); } + ContiguousTilingLayout::::check(config.global_memory_config(ident))?; + Ok(()) } } diff --git a/crates/cubecl-matmul/src/components/global/read/strategy/async_full_maximize_slice_length.rs b/crates/cubecl-matmul/src/components/global/read/strategy/async_full_maximize_slice_length.rs index 94eb4dd85..197dd3dce 100644 --- a/crates/cubecl-matmul/src/components/global/read/strategy/async_full_maximize_slice_length.rs +++ b/crates/cubecl-matmul/src/components/global/read/strategy/async_full_maximize_slice_length.rs @@ -5,7 +5,7 @@ use crate::components::{ memory::{GlobalIterator, load_window_in_stage}, read::AsyncFullLoadingStrategy, }, - stage::{StridedStage, StridedTilingLayout}, + stage::{StridedStage, StridedTilingLayout, TilingValidation}, }; use cubecl_core::prelude::*; use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel}; @@ -18,7 +18,9 @@ use super::{AsyncLoadingJob, LoadingValidation}; pub struct AsyncFullMaximizeSliceLengthLoading {} impl LoadingValidation for AsyncFullMaximizeSliceLengthLoading { - fn check(_config: &C, _ident: MatmulIdent) -> Result<(), InvalidConfigError> { + fn check(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> { + StridedTilingLayout::check(config.global_memory_config(ident))?; + Ok(()) } } diff --git a/crates/cubecl-matmul/src/components/global/read/strategy/async_full_maximize_unit_count.rs b/crates/cubecl-matmul/src/components/global/read/strategy/async_full_maximize_unit_count.rs index dc3a5d6e6..cbec6b1f3 100644 --- a/crates/cubecl-matmul/src/components/global/read/strategy/async_full_maximize_unit_count.rs +++ b/crates/cubecl-matmul/src/components/global/read/strategy/async_full_maximize_unit_count.rs @@ -5,7 +5,7 @@ use crate::components::{ memory::{GlobalIterator, load_window_in_stage}, read::AsyncFullLoadingStrategy, }, - stage::{StageConfig, StridedStage, StridedTilingLayout}, + stage::{StageConfig, StridedStage, StridedTilingLayout, TilingValidation}, }; use cubecl_core::prelude::*; use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel}; @@ -45,6 +45,8 @@ impl LoadingValidation for AsyncFullMaximizeUnitCountLoading { )); } + StridedTilingLayout::check(config.global_memory_config(ident))?; + Ok(()) } } diff --git a/crates/cubecl-matmul/src/components/global/read/strategy/async_partial_maximize_slice_length.rs b/crates/cubecl-matmul/src/components/global/read/strategy/async_partial_maximize_slice_length.rs index adbfef41b..e121f2db7 100644 --- a/crates/cubecl-matmul/src/components/global/read/strategy/async_partial_maximize_slice_length.rs +++ b/crates/cubecl-matmul/src/components/global/read/strategy/async_partial_maximize_slice_length.rs @@ -5,7 +5,7 @@ use crate::components::{ memory::{GlobalIterator, load_window_in_stage}, read::AsyncPartialLoadingStrategy, }, - stage::{StageConfig, StridedStage, StridedTilingLayout}, + stage::{StageConfig, StridedStage, StridedTilingLayout, TilingValidation}, }; use cubecl_core::prelude::*; use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel}; @@ -18,7 +18,9 @@ use super::{AsyncLoadingJob, LoadingValidation}; pub struct AsyncPartialMaximizeSliceLengthLoading {} impl LoadingValidation for AsyncPartialMaximizeSliceLengthLoading { - fn check(_config: &C, _ident: MatmulIdent) -> Result<(), InvalidConfigError> { + fn check(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> { + StridedTilingLayout::check(config.global_memory_config(ident))?; + Ok(()) } } diff --git a/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_cyclic.rs b/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_cyclic.rs index 72ad402ea..392e0eb5c 100644 --- a/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_cyclic.rs +++ b/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_cyclic.rs @@ -1,12 +1,12 @@ use std::marker::PhantomData; -use crate::components::global::memory::GlobalIterator; use crate::components::global::multi_stage::LoadMaxRoundPlaneCount; use crate::components::global::read::{SyncFullLoadingStrategy, tiled::TiledLayout}; use crate::components::global::{GlobalConfig, RoleRule}; use crate::components::stage::{ContiguousTilingLayout, StridedStage, TilingOrder}; use crate::components::{InvalidConfigError, MatmulIdent}; use crate::components::{MatrixPrecision, TilingScheme}; +use crate::components::{global::memory::GlobalIterator, stage::TilingValidation}; use cubecl_core as cubecl; use cubecl_core::prelude::*; @@ -36,6 +36,8 @@ impl LoadingValidation for SyncFullCyclicLoading { } } + ContiguousTilingLayout::::check(config.global_memory_config(ident))?; + Ok(()) } } @@ -59,10 +61,10 @@ impl SyncFullLoadingStrategy for SyncFullCyclicLoading { fn new_job( #[comptime] ident: MatmulIdent, + #[comptime] line_size: u32, #[comptime] config: G, ) -> Self::Job { let tile_num_elements = config.tiling_scheme().elements_in_tile(ident); - let line_size = config.global_line_size(ident); let num_stage_elements = config.tiling_scheme().elements_in_stage(ident); let num_stage_lines = num_stage_elements.div_ceil(line_size); diff --git a/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_ordered.rs b/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_ordered.rs index 95e5202dd..bdb8ed226 100644 --- a/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_ordered.rs +++ b/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_ordered.rs @@ -1,4 +1,3 @@ -use crate::components::global::RoleRule; use crate::components::global::multi_stage::LoadMaxRoundPlaneCount; use crate::components::global::read::SyncFullLoadingStrategy; use crate::components::stage::OrderedTilingOrder; @@ -6,6 +5,7 @@ use crate::components::{ FormattedConfigError, InvalidConfigError, MatmulIdent, MatrixPrecision, TilingScheme, }; use crate::components::{global::GlobalConfig, stage::ContiguousTilingLayout}; +use crate::components::{global::RoleRule, stage::TilingValidation}; use cubecl_core as cubecl; use cubecl_core::prelude::*; @@ -67,6 +67,8 @@ impl LoadingValidation for SyncFullOrderedLoading { })); } + ContiguousTilingLayout::::check(config.global_memory_config(ident))?; + Ok(()) } } @@ -89,9 +91,9 @@ impl SyncFullLoadingStrategy for SyncFullOrderedLoading { fn new_job( #[comptime] ident: MatmulIdent, + #[comptime] line_size: u32, #[comptime] config: G, ) -> Self::Job { - let line_size = config.global_line_size(ident); let num_planes = config.num_loading_planes(ident); let num_tiles = config.tiling_scheme().tiles_in_stage(ident); let plane_dim = config.plane_dim(); diff --git a/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_strided.rs b/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_strided.rs index 691b51322..f2602a8bd 100644 --- a/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_strided.rs +++ b/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_strided.rs @@ -1,10 +1,10 @@ -use crate::components::global::memory::GlobalIterator; use crate::components::global::multi_stage::LoadMaxRoundPlaneCount; use crate::components::global::read::{SyncFullLoadingStrategy, stage::FullStageLayout}; use crate::components::global::{GlobalConfig, RoleRule}; use crate::components::stage::{StridedStage, StridedTilingLayout}; use crate::components::{InvalidConfigError, MatmulIdent}; use crate::components::{MatrixPrecision, TilingScheme}; +use crate::components::{global::memory::GlobalIterator, stage::TilingValidation}; use cubecl_core as cubecl; use cubecl_core::prelude::*; @@ -29,6 +29,8 @@ impl LoadingValidation for SyncFullStridedLoading { )); } + StridedTilingLayout::check(config.global_memory_config(ident))?; + Ok(()) } } @@ -52,9 +54,9 @@ impl SyncFullLoadingStrategy for SyncFullStridedLoading { fn new_job( #[comptime] ident: MatmulIdent, + #[comptime] line_size: u32, #[comptime] config: G, ) -> Self::Job { - let line_size = config.global_line_size(ident); let num_stage_lines = config.tiling_scheme().elements_in_stage(ident) / line_size; let unit_count = config.num_loading_planes(ident) * config.plane_dim(); let num_tasks_per_unit = comptime!(num_stage_lines / unit_count); diff --git a/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_tilewise.rs b/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_tilewise.rs index 9ca5c893b..8276dfbbe 100644 --- a/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_tilewise.rs +++ b/crates/cubecl-matmul/src/components/global/read/strategy/sync_full_tilewise.rs @@ -1,11 +1,11 @@ use std::marker::PhantomData; -use crate::components::global::multi_stage::LoadMaxRoundPlaneCount; use crate::components::global::read::SyncFullLoadingStrategy; use crate::components::global::{RoleRule, read::tiled::TiledLayout}; use crate::components::{ FormattedConfigError, InvalidConfigError, MatmulIdent, MatrixPrecision, TilingScheme, }; +use crate::components::{global::multi_stage::LoadMaxRoundPlaneCount, stage::TilingValidation}; use crate::components::{ global::{GlobalConfig, memory::GlobalIterator}, stage::{ContiguousTilingLayout, StridedStage, TilingOrder}, @@ -69,6 +69,8 @@ impl LoadingValidation for SyncFullTilewiseLoading { })); } + ContiguousTilingLayout::::check(config.global_memory_config(ident))?; + Ok(()) } } @@ -80,9 +82,9 @@ impl SyncFullLoadingStrategy for SyncFullTilewiseLoading { fn new_job( #[comptime] ident: MatmulIdent, + #[comptime] line_size: u32, #[comptime] config: G, ) -> Self::Job { - let line_size = config.global_line_size(ident); let num_planes = config.num_loading_planes(ident); let num_tiles = config.tiling_scheme().tiles_in_stage(ident); let plane_dim = config.plane_dim(); diff --git a/crates/cubecl-matmul/src/components/global/read/strategy/sync_partial_cyclic.rs b/crates/cubecl-matmul/src/components/global/read/strategy/sync_partial_cyclic.rs index 00f6b6d19..fc4553cdb 100644 --- a/crates/cubecl-matmul/src/components/global/read/strategy/sync_partial_cyclic.rs +++ b/crates/cubecl-matmul/src/components/global/read/strategy/sync_partial_cyclic.rs @@ -1,11 +1,11 @@ use std::marker::PhantomData; -use crate::components::global::memory::GlobalIterator; use crate::components::global::multi_stage::LoadMaxRoundPlaneCount; use crate::components::global::read::{SyncPartialLoadingStrategy, tiled::TiledLayout}; use crate::components::global::{GlobalConfig, RoleRule}; use crate::components::stage::{ContiguousTilingLayout, StridedStage, TilingOrder}; use crate::components::{InvalidConfigError, MatmulIdent, MatrixPrecision, TilingScheme}; +use crate::components::{global::memory::GlobalIterator, stage::TilingValidation}; use cubecl_core as cubecl; use cubecl_core::prelude::*; @@ -44,6 +44,8 @@ impl LoadingValidation for SyncPartialCyclicLoading { } } + ContiguousTilingLayout::::check(config.global_memory_config(ident))?; + Ok(()) } } @@ -70,9 +72,9 @@ impl SyncPartialLoadingStrategy for SyncPartialCyclicLoading( #[comptime] stage_index: u32, #[comptime] ident: MatmulIdent, + #[comptime] line_size: u32, #[comptime] config: G, ) -> SyncPartialCyclicJob { - let line_size = config.global_line_size(ident); let num_stage_elements = config.tiling_scheme().elements_in_stage(ident); let tile_size = config.tiling_scheme().elements_in_tile(ident); @@ -165,14 +167,17 @@ pub(crate) fn load_and_store_line>, #[comptime] config: G, ) { - let (line_size, tile_size, tile_count_row, tile_count_col) = comptime! { + let layout = TiledLayout::new(comptime!(config.global_memory_config(job.ident))); + let view = global_iter.view().view(layout); + + let (tile_size, tile_count_row, tile_count_col) = comptime! { ( - config.global_line_size(job.ident), config.tiling_scheme().elements_in_tile(job.ident), config.tiling_scheme().tiles_in_stage_row(job.ident), config.tiling_scheme().tiles_in_stage_col(job.ident), ) }; + let line_size = view.line_size(); let tile_index = unit_position / tile_size; let pos_within_tile = unit_position % tile_size; @@ -208,9 +213,6 @@ pub(crate) fn load_and_store_line comptime!(unreachable!()), }; - let layout = TiledLayout::new(comptime!(config.global_memory_config(job.ident))); - let view = global_iter.view().view(layout); - let line_read = view.read_checked((tile, pos_within_tile)); let nth_tile_in_stage = TO::to_nth_tile( diff --git a/crates/cubecl-matmul/src/components/global/read/strategy/sync_partial_tilewise.rs b/crates/cubecl-matmul/src/components/global/read/strategy/sync_partial_tilewise.rs index 344c1d194..8c06a0c85 100644 --- a/crates/cubecl-matmul/src/components/global/read/strategy/sync_partial_tilewise.rs +++ b/crates/cubecl-matmul/src/components/global/read/strategy/sync_partial_tilewise.rs @@ -1,12 +1,12 @@ use std::marker::PhantomData; -use crate::components::global::multi_stage::LoadMaxRoundPlaneCount; use crate::components::global::read::SyncPartialLoadingStrategy; use crate::components::global::{RoleRule, read::tiled::TiledLayout}; use crate::components::stage::TilingOrderEnum; use crate::components::{ FormattedConfigError, InvalidConfigError, MatmulIdent, MatrixPrecision, TilingScheme, }; +use crate::components::{global::multi_stage::LoadMaxRoundPlaneCount, stage::TilingValidation}; use crate::components::{ global::{GlobalConfig, memory::GlobalIterator}, stage::{ContiguousTilingLayout, StridedStage, TilingOrder}, @@ -85,6 +85,8 @@ impl LoadingValidation for SyncPartialTilewiseLoading { MatmulIdent::Out => unreachable!(), } + ContiguousTilingLayout::::check(config.global_memory_config(ident))?; + Ok(()) } } @@ -97,9 +99,9 @@ impl SyncPartialLoadingStrategy for SyncPartialTilewiseLoading< fn new_job( #[comptime] stage_index: u32, #[comptime] ident: MatmulIdent, + #[comptime] line_size: u32, #[comptime] config: G, ) -> SyncPartialTilewiseJob { - let line_size = config.global_line_size(ident); let num_planes = config.num_loading_planes(ident); let num_tiles = config.tiling_scheme().tiles_in_stage(ident); let plane_dim = config.plane_dim(); diff --git a/crates/cubecl-matmul/src/components/global/single_stage/barrier/config.rs b/crates/cubecl-matmul/src/components/global/single_stage/barrier/config.rs index 602fc2d1e..b4876afbd 100644 --- a/crates/cubecl-matmul/src/components/global/single_stage/barrier/config.rs +++ b/crates/cubecl-matmul/src/components/global/single_stage/barrier/config.rs @@ -115,7 +115,7 @@ impl SimpleBarrierConfig { /// - CubeDim is too big /// - Barriers are not available pub fn new( - client: &ComputeClient, + client: &ComputeClient, stage_config: S, num_planes: u32, check_m_bounds: bool, @@ -151,7 +151,7 @@ impl SimpleBarrierConfig { fn check_availability( self, - client: &ComputeClient, + client: &ComputeClient, ) -> Result { if !client.properties().supports_type(SemanticType::Barrier) { return Err(MatmulSetupError::Unavailable( diff --git a/crates/cubecl-matmul/src/components/global/single_stage/barrier/matmul.rs b/crates/cubecl-matmul/src/components/global/single_stage/barrier/matmul.rs index 92197b4f9..88e197c3d 100644 --- a/crates/cubecl-matmul/src/components/global/single_stage/barrier/matmul.rs +++ b/crates/cubecl-matmul/src/components/global/single_stage/barrier/matmul.rs @@ -4,7 +4,6 @@ use crate::components::RhsG; use crate::components::RhsS; use crate::components::global::GlobalConfig; use crate::components::global::GlobalMatmul; -use crate::components::global::memory::SimpleGlobalLayout; use crate::components::global::read::AsyncFullLoadingStrategy; use crate::components::global::read::AsyncFullStageGlobalReader; use crate::components::global::single_stage::barrier::SimpleBarrierConfig; @@ -19,7 +18,7 @@ use crate::components::{MatmulPrecision, global::GlobalWriter}; use barrier::Barrier; use cubecl_core::prelude::*; use cubecl_core::{self as cubecl}; -use cubecl_std::tensor::r#virtual::VirtualTensor; +use cubecl_std::tensor::View; use cubecl_std::{CubeOption, CubeOptionExpand, tensor::layout::Coords2d}; /// Performs matrix multiplication at the global level @@ -134,47 +133,21 @@ where } fn init_lhs_global_reader( - lhs: VirtualTensor>, - batch_offset: u32, - offset: Coords2d, - slice_size: Coords2d, - _nth_batch: u32, + lhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::LhsGlobalReader { - let conf = config.global_memory_config(MatmulIdent::Lhs); - let layout = SimpleGlobalLayout::new(&lhs, batch_offset, conf); - Self::LhsGlobalReader::new( - lhs.view(layout).slice(offset, slice_size), - config.k_step, - MatmulIdent::Lhs, - config, - ) + Self::LhsGlobalReader::new(lhs, config.k_step, MatmulIdent::Lhs, config) } fn init_rhs_global_reader( - rhs: VirtualTensor>, - batch_offset: u32, - offset: Coords2d, - slice_size: Coords2d, - _nth_batch: u32, + rhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::RhsGlobalReader { - let conf = config.global_memory_config(MatmulIdent::Rhs); - let layout = SimpleGlobalLayout::new(&rhs, batch_offset, conf); - Self::RhsGlobalReader::new( - rhs.view(layout).slice(offset, slice_size), - config.k_step, - MatmulIdent::Rhs, - config, - ) + Self::RhsGlobalReader::new(rhs, config.k_step, MatmulIdent::Rhs, config) } fn init_acc_global_reader( - acc: CubeOption>>, - _batch_offset: u32, - _offset: Coords2d, - _slice_size: Coords2d, - _nth_batch: u32, + acc: CubeOption>, Coords2d>>, #[comptime] _config: Self::Config, ) -> Self::AccGlobalReader { match acc { @@ -184,17 +157,11 @@ where } fn init_global_writer( - out: VirtualTensor, ReadWrite>, - batch_offset: u32, - offset: Coords2d, - size: Coords2d, - _nth_batch: u32, + out: View>, Coords2d, ReadWrite>, #[comptime] config: Self::Config, ) -> Self::GlobalWriter { let conf = config.global_memory_config(MatmulIdent::Out); - let layout = SimpleGlobalLayout::new(&out, batch_offset, conf); - let view = out.view_mut(layout).slice_mut_unchecked(offset, size); - Self::GlobalWriter::init::(view, conf, config.stage_config()) + Self::GlobalWriter::init::(out, conf, config.stage_config()) } fn init_accumulators(#[comptime] config: Self::Config) -> Self::Accumulators { diff --git a/crates/cubecl-matmul/src/components/global/single_stage/barrier/setup.rs b/crates/cubecl-matmul/src/components/global/single_stage/barrier/setup.rs index 196d75fd2..e96a81eeb 100644 --- a/crates/cubecl-matmul/src/components/global/single_stage/barrier/setup.rs +++ b/crates/cubecl-matmul/src/components/global/single_stage/barrier/setup.rs @@ -46,7 +46,7 @@ where type Config = SimpleBarrierConfig; fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-matmul/src/components/global/single_stage/simple/config.rs b/crates/cubecl-matmul/src/components/global/single_stage/simple/config.rs index f5d83033b..8856f9b0a 100644 --- a/crates/cubecl-matmul/src/components/global/single_stage/simple/config.rs +++ b/crates/cubecl-matmul/src/components/global/single_stage/simple/config.rs @@ -115,7 +115,7 @@ impl SimpleConfig { /// - CubeDim is too big /// - Barriers are not available pub fn new( - _client: &ComputeClient, + _client: &ComputeClient, stage_config: S, num_planes: u32, check_m_bounds: bool, diff --git a/crates/cubecl-matmul/src/components/global/single_stage/simple/matmul.rs b/crates/cubecl-matmul/src/components/global/single_stage/simple/matmul.rs index 72fa0599e..6ff4ed8fc 100644 --- a/crates/cubecl-matmul/src/components/global/single_stage/simple/matmul.rs +++ b/crates/cubecl-matmul/src/components/global/single_stage/simple/matmul.rs @@ -2,7 +2,6 @@ use crate::components::{ AccG, AccS, LhsG, LhsS, MatmulIdent, MatmulPrecision, RhsG, RhsS, global::{ GlobalMatmul, GlobalWriter, - memory::SimpleGlobalLayout, read::{SyncFullLoadingStrategy, SyncFullStageGlobalReader, ZeroGlobalReader}, single_stage::simple::SimpleConfig, }, @@ -12,7 +11,7 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; use cubecl_std::{ CubeOption, CubeOptionExpand, - tensor::{layout::Coords2d, r#virtual::VirtualTensor}, + tensor::{View, layout::Coords2d}, }; use std::marker::PhantomData; @@ -120,47 +119,21 @@ where } fn init_lhs_global_reader( - lhs: VirtualTensor>, - batch_offset: u32, - offset: Coords2d, - slice_size: Coords2d, - _nth_batch: u32, + lhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::LhsGlobalReader { - let conf = config.global_memory_config(MatmulIdent::Lhs); - let layout = SimpleGlobalLayout::new(&lhs, batch_offset, conf); - Self::LhsGlobalReader::new( - lhs.view(layout).slice_unchecked(offset, slice_size), - config.k_step, - MatmulIdent::Lhs, - config, - ) + Self::LhsGlobalReader::new(lhs, config.k_step, MatmulIdent::Lhs, config) } fn init_rhs_global_reader( - rhs: VirtualTensor>, - batch_offset: u32, - offset: Coords2d, - slice_size: Coords2d, - _nth_batch: u32, + rhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::RhsGlobalReader { - let conf = config.global_memory_config(MatmulIdent::Rhs); - let layout = SimpleGlobalLayout::new(&rhs, batch_offset, conf); - Self::RhsGlobalReader::new( - rhs.view(layout).slice_unchecked(offset, slice_size), - config.k_step, - MatmulIdent::Rhs, - config, - ) + Self::RhsGlobalReader::new(rhs, config.k_step, MatmulIdent::Rhs, config) } fn init_acc_global_reader( - acc: CubeOption>>, - _batch_offset: u32, - _offset: Coords2d, - _slice_size: Coords2d, - _nth_batch: u32, + acc: CubeOption>, Coords2d>>, #[comptime] _config: Self::Config, ) -> Self::AccGlobalReader { match acc { @@ -170,17 +143,11 @@ where } fn init_global_writer( - out: VirtualTensor, ReadWrite>, - batch_offset: u32, - offset: Coords2d, - size: Coords2d, - _nth_batch: u32, + out: View>, Coords2d, ReadWrite>, #[comptime] config: Self::Config, ) -> Self::GlobalWriter { let conf = config.global_memory_config(MatmulIdent::Out); - let layout = SimpleGlobalLayout::new(&out, batch_offset, conf); - let view = out.view_mut(layout).slice_mut_unchecked(offset, size); - Self::GlobalWriter::init::(view, conf, config.stage_config()) + Self::GlobalWriter::init::(out, conf, config.stage_config()) } fn init_accumulators(#[comptime] config: Self::Config) -> Self::Accumulators { diff --git a/crates/cubecl-matmul/src/components/global/single_stage/simple/setup.rs b/crates/cubecl-matmul/src/components/global/single_stage/simple/setup.rs index a5e2d1577..05ec4ce38 100644 --- a/crates/cubecl-matmul/src/components/global/single_stage/simple/setup.rs +++ b/crates/cubecl-matmul/src/components/global/single_stage/simple/setup.rs @@ -52,7 +52,7 @@ where type Config = SimpleConfig; fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-matmul/src/components/global/single_stage/tma/config.rs b/crates/cubecl-matmul/src/components/global/single_stage/tma/config.rs index 4408ac70f..6cf178752 100644 --- a/crates/cubecl-matmul/src/components/global/single_stage/tma/config.rs +++ b/crates/cubecl-matmul/src/components/global/single_stage/tma/config.rs @@ -118,7 +118,7 @@ impl SimpleTmaConfig { /// - CubeDim is too big /// - TMA is not available pub fn new( - client: &ComputeClient, + client: &ComputeClient, stage_config: S, num_planes: u32, check_m_bounds: bool, @@ -154,7 +154,7 @@ impl SimpleTmaConfig { fn check_availability( self, - client: &ComputeClient, + client: &ComputeClient, ) -> Result { let lhs_g_id = TypeId::of::>(); let lhs_s_id = TypeId::of::>(); diff --git a/crates/cubecl-matmul/src/components/global/single_stage/tma/matmul.rs b/crates/cubecl-matmul/src/components/global/single_stage/tma/matmul.rs index a2dac42c7..6d7e33772 100644 --- a/crates/cubecl-matmul/src/components/global/single_stage/tma/matmul.rs +++ b/crates/cubecl-matmul/src/components/global/single_stage/tma/matmul.rs @@ -1,18 +1,18 @@ +use crate::components::LhsG; +use crate::components::global::GlobalMatmul; use crate::components::global::read::TmaGlobalReader; use crate::components::global::read::arrive_tma; use crate::components::global::single_stage::tma::SimpleTmaConfig; -use crate::components::global::{GlobalMatmul, memory::SimpleTmaGlobalLayout}; use crate::components::stage::StageMatmul; use crate::components::{AccG, RhsG}; use crate::components::{AccS, MatmulIdent}; -use crate::components::{LhsG, global::memory::SimpleGlobalLayout}; use crate::components::{LhsS, global::read::TmaStage, stage::FilledStage}; use crate::components::{MatmulPrecision, global::read::ZeroGlobalReader}; use crate::components::{RhsS, global::GlobalWriter}; use barrier::Barrier; use cubecl_core::prelude::{barrier::BarrierLevel, *}; use cubecl_core::{self as cubecl}; -use cubecl_std::tensor::{AsTensorView, AsTensorViewExpand, r#virtual::VirtualTensor}; +use cubecl_std::tensor::View; use cubecl_std::{CubeOption, CubeOptionExpand, tensor::layout::Coords2d}; use std::marker::PhantomData; @@ -109,17 +109,11 @@ where } fn init_lhs_global_reader( - lhs: VirtualTensor>, - _batch_offset: u32, - offset: Coords2d, - slice_size: Coords2d, - nth_batch: u32, + lhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::LhsGlobalReader { - let layout = SimpleTmaGlobalLayout::new(nth_batch, config.matrix_layout(MatmulIdent::Lhs)); - let lhs = lhs.as_tensor_map().unwrap().view_3d(layout); Self::LhsGlobalReader::new( - lhs.slice(offset, slice_size), + lhs, config.k_step, MatmulIdent::Lhs, config.stage_memory_config(MatmulIdent::Lhs), @@ -127,17 +121,11 @@ where } fn init_rhs_global_reader( - rhs: VirtualTensor>, - _batch_offset: u32, - offset: Coords2d, - slice_size: Coords2d, - nth_batch: u32, + rhs: View>, Coords2d>, #[comptime] config: Self::Config, ) -> Self::RhsGlobalReader { - let layout = SimpleTmaGlobalLayout::new(nth_batch, config.matrix_layout(MatmulIdent::Rhs)); - let rhs = rhs.as_tensor_map().unwrap().view_3d(layout); Self::RhsGlobalReader::new( - rhs.slice(offset, slice_size), + rhs, config.k_step, MatmulIdent::Rhs, config.stage_memory_config(MatmulIdent::Rhs), @@ -145,11 +133,7 @@ where } fn init_acc_global_reader( - acc: CubeOption>>, - _batch_offset: u32, - _offset: Coords2d, - _slice_size: Coords2d, - _nth_batch: u32, + acc: CubeOption>, Coords2d>>, #[comptime] _config: Self::Config, ) -> Self::AccGlobalReader { match acc { @@ -159,17 +143,11 @@ where } fn init_global_writer( - out: VirtualTensor, ReadWrite>, - batch_offset: u32, - offset: Coords2d, - size: Coords2d, - _nth_batch: u32, + out: View>, Coords2d, ReadWrite>, #[comptime] config: Self::Config, ) -> Self::GlobalWriter { let conf = config.global_memory_config(MatmulIdent::Out); - let layout = SimpleGlobalLayout::new(&out, batch_offset, conf); - let view = out.view_mut(layout).slice_mut_unchecked(offset, size); - Self::GlobalWriter::init::(view, conf, config.stage_config()) + Self::GlobalWriter::init::(out, conf, config.stage_config()) } fn init_accumulators(#[comptime] config: Self::Config) -> Self::Accumulators { diff --git a/crates/cubecl-matmul/src/components/global/single_stage/tma/setup.rs b/crates/cubecl-matmul/src/components/global/single_stage/tma/setup.rs index 49f587990..58827849b 100644 --- a/crates/cubecl-matmul/src/components/global/single_stage/tma/setup.rs +++ b/crates/cubecl-matmul/src/components/global/single_stage/tma/setup.rs @@ -41,7 +41,7 @@ where type Config = SimpleTmaConfig; fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-matmul/src/components/global/write/plane.rs b/crates/cubecl-matmul/src/components/global/write/plane.rs index 53e70621b..73a2cdcc3 100644 --- a/crates/cubecl-matmul/src/components/global/write/plane.rs +++ b/crates/cubecl-matmul/src/components/global/write/plane.rs @@ -111,7 +111,7 @@ pub fn plane_write( #[comptime] config: GlobalMemoryConfig, ) { let tile_size = config.elements_in_tile_row * config.elements_in_tile_col; - let output_line_size = config.global_line_size; + let output_line_size = global.line_size(); let unit_step = comptime![plane_dim * output_line_size]; let num_unit_writes = comptime!(tile_size.div_ceil(unit_step)); diff --git a/crates/cubecl-matmul/src/components/global/write/unit.rs b/crates/cubecl-matmul/src/components/global/write/unit.rs index 4f36cc320..d2b049726 100644 --- a/crates/cubecl-matmul/src/components/global/write/unit.rs +++ b/crates/cubecl-matmul/src/components/global/write/unit.rs @@ -11,6 +11,7 @@ use crate::components::{ read::tiled::{TiledCoords, TiledLayout}, }, stage::{StageConfig, StageMemoryConfig, StagePartitioner, UnitPartitioner}, + tile::StridedTile, }; #[derive(CubeType)] @@ -42,20 +43,26 @@ impl UnitWriter { } fn write(&mut self, tile: Coords2d) { - let smem_tile = &self.stage.unit_tile; - let config = comptime![self.config]; - - let tile_size = config.elements_in_tile_row * config.elements_in_tile_col; - let output_line_size = config.global_line_size; - let out_smem_slice = smem_tile.slice.with_line_size(output_line_size); - - let num_lines = tile_size / output_line_size; + unit_write(&mut self.global, &self.stage.unit_tile, tile, self.config) + } +} - for i in 0..num_lines { - let value = out_smem_slice[i]; - self.global - .write_checked((tile, i * output_line_size), Line::cast_from(value)); - } +#[cube] +pub fn unit_write( + global: &mut View, TiledCoords, ReadWrite>, + smem_tile: &StridedTile, + tile_pos: Coords2d, + #[comptime] config: GlobalMemoryConfig, +) { + let tile_size = config.elements_in_tile_row * config.elements_in_tile_col; + let output_line_size = global.line_size(); + let out_smem_slice = smem_tile.slice.with_line_size(output_line_size); + + let num_lines = tile_size / output_line_size; + + for i in 0..num_lines { + let value = out_smem_slice[i]; + global.write_checked((tile_pos, i * output_line_size), Line::cast_from(value)); } } diff --git a/crates/cubecl-matmul/src/components/line_size.rs b/crates/cubecl-matmul/src/components/line_size.rs index ad6e7eb36..200b2dfac 100644 --- a/crates/cubecl-matmul/src/components/line_size.rs +++ b/crates/cubecl-matmul/src/components/line_size.rs @@ -1,9 +1,9 @@ -use cubecl_core::{LineSizeError, Runtime, ir::StorageType, tensor_line_size_parallel}; +use cubecl_core::{LineSizeError, Runtime, tensor_line_size_parallel}; use crate::components::{MatrixLayout, error::MatmulSetupError}; use std::fmt::Debug; -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] /// Line size used for each tensor in global memory accesses. /// Represents the number of elements processed per SIMD load/store. pub struct MatmulLineSizes { @@ -24,11 +24,16 @@ pub struct AvailableLineSizes { } impl AvailableLineSizes { - pub fn from_types( - elem_lhs: &StorageType, - elem_rhs: &StorageType, - elem_out: &StorageType, - ) -> Self { + pub fn from_type_size_tma(elem_out: usize) -> Self { + // TMA requires line size 1 for inputs + AvailableLineSizes { + lhs: vec![1], + rhs: vec![1], + out: R::io_optimized_line_sizes_unchecked(elem_out).collect(), + } + } + + pub fn from_type_sizes(elem_lhs: usize, elem_rhs: usize, elem_out: usize) -> Self { AvailableLineSizes { lhs: R::io_optimized_line_sizes_unchecked(elem_lhs).collect(), rhs: R::io_optimized_line_sizes_unchecked(elem_rhs).collect(), diff --git a/crates/cubecl-matmul/src/components/problem.rs b/crates/cubecl-matmul/src/components/problem.rs index 70f0e2a58..5450d1825 100644 --- a/crates/cubecl-matmul/src/components/problem.rs +++ b/crates/cubecl-matmul/src/components/problem.rs @@ -17,6 +17,8 @@ pub struct MatmulProblem { pub lhs_batches: Vec, /// Batch shape for Rhs tensor pub rhs_batches: Vec, + /// Batch shape for Out tensor + pub out_batches: Vec, /// Memory layout of the Lhs matrix. pub lhs_layout: MatrixLayout, /// Memory layout of the Rhs matrix. @@ -145,10 +147,11 @@ impl From<&MatmulProblem> for MatmulKind { } } -#[derive(CubeType, Copy, Clone, PartialEq, Eq, Hash, Debug)] +#[derive(CubeType, Copy, Clone, PartialEq, Eq, Hash, Debug, Default)] /// Layout of a 2D structure such as a tensor, shared memory or slice, /// used within any matmul kernel level pub enum MatrixLayout { + #[default] RowMajor, ColMajor, } diff --git a/crates/cubecl-matmul/src/components/stage/base.rs b/crates/cubecl-matmul/src/components/stage/base.rs index a613c496a..906f587e3 100644 --- a/crates/cubecl-matmul/src/components/stage/base.rs +++ b/crates/cubecl-matmul/src/components/stage/base.rs @@ -51,7 +51,7 @@ pub trait StageMatmulFamily: Send + Sync + 'static { /// /// This function may return an error if the configuration cannot be supported on the current runtime. fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-matmul/src/components/stage/matmul/partition/fragments.rs b/crates/cubecl-matmul/src/components/stage/matmul/partition/fragments.rs index aeeff4f39..585159c30 100644 --- a/crates/cubecl-matmul/src/components/stage/matmul/partition/fragments.rs +++ b/crates/cubecl-matmul/src/components/stage/matmul/partition/fragments.rs @@ -4,7 +4,7 @@ use crate::components::stage::StageConfig; use crate::components::{AccS, stage::Stage, tile::TileMatmul}; use crate::components::{MatmulPrecision, MatrixPrecision}; use cubecl::prelude::*; -use cubecl_core::{self as cubecl, intrinsic}; +use cubecl_core::{self as cubecl}; #[derive(CubeType)] /// Wrapper over a sequence of Tile Matmul accumulators @@ -62,8 +62,8 @@ impl< for m in 0..size_m { #[unroll] for n in 0..size_n { - let acc = self.get_at_mut(unwrap(m), unwrap(n), config); - let tile = R::tile(stage, (m, n)); + let acc = self.get_at_mut(m, n, config); + let tile = R::tile(stage, (m, n).runtime()); TM::load_acc(&tile, acc, config.tile_config()); } } @@ -100,9 +100,3 @@ pub enum RhsTile { Single(Rhs), Double((Rhs, Rhs)), } - -#[cube] -#[allow(unused)] -fn unwrap(i: u32) -> comptime_type!(u32) { - intrinsic!(|_| i.constant().unwrap().as_u32()) -} diff --git a/crates/cubecl-matmul/src/components/stage/matmul/partition/matmul.rs b/crates/cubecl-matmul/src/components/stage/matmul/partition/matmul.rs index 1ec61e997..941cce204 100644 --- a/crates/cubecl-matmul/src/components/stage/matmul/partition/matmul.rs +++ b/crates/cubecl-matmul/src/components/stage/matmul/partition/matmul.rs @@ -146,7 +146,6 @@ where let n_iterations = config.tiling_scheme().tiles_in_stage_partition_n(); let k_iterations = config.tiling_scheme().tiles_in_stage_partition_k(); - let mut k_iter = comptime![0u32]; let mut lhs_load_counter = comptime![0]; let mut rhs_load_counter = comptime![0]; let mut execute_counter = comptime![0]; @@ -154,15 +153,12 @@ where let rhs_load_total = comptime!(n_iterations * k_iterations); let execute_total = comptime!(m_iterations * n_iterations * k_iterations); - #[allow(clippy::explicit_counter_loop)] #[unroll] - for _ in 0..k_iterations { - let mut m_iter = comptime![0u32]; + for k_iter in 0..k_iterations { let k_load_iter = partition_scheduler.map_k(k_iter); - #[allow(clippy::explicit_counter_loop)] #[unroll] - for _ in 0..m_iterations { + for m_iter in 0..m_iterations { let m_load_iter = partition_scheduler.map_m(m_iter); let tile_lhs = StageLhs::tile(lhs_stage, (m_load_iter, k_load_iter)); @@ -180,15 +176,10 @@ where config, ); comptime!(lhs_load_counter += 1); - - comptime![m_iter += 1]; } - let mut n_iter = comptime![0u32]; - #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..n_iterations { + for n_iter in 0..n_iterations { let n_load_iter = partition_scheduler.map_n(n_iter); let rhs_tile_next = StageRhs::tile(rhs_stage, (k_load_iter, n_load_iter)); @@ -203,11 +194,8 @@ where ); comptime!(rhs_load_counter += 1); - let mut m_iter = comptime![0u32]; - - #[allow(clippy::explicit_counter_loop)] #[unroll] - for _ in 0..m_iterations { + for m_iter in 0..m_iterations { let accumulator = Accumulators::::get_at_mut(acc, m_iter, n_iter, config); TM::execute( @@ -225,14 +213,8 @@ where config, ); comptime!(execute_counter += 1); - - comptime![m_iter += 1]; } - - comptime![n_iter += 1]; } - - comptime![k_iter += 1]; } assert!(lhs_load_counter == lhs_load_total); @@ -261,8 +243,6 @@ where let n_iterations = config.tiling_scheme().tiles_in_stage_partition_n(); let k_iterations = config.tiling_scheme().tiles_in_stage_partition_k(); - let mut k_iter = comptime![0u32]; - let mut lhs_load_counter = comptime![0]; let mut rhs_load_counter = comptime![0]; let mut execute_counter = comptime![0]; @@ -270,15 +250,12 @@ where let rhs_load_total = comptime!(n_iterations * k_iterations); let execute_total = comptime!(m_iterations * n_iterations * k_iterations); - #[allow(clippy::explicit_counter_loop)] #[unroll] - for _ in 0..k_iterations { - let mut m_iter = comptime![0u32]; + for k_iter in 0..k_iterations { let k_load_iter = partition_scheduler.map_k(k_iter); - #[allow(clippy::explicit_counter_loop)] #[unroll] - for _ in 0..m_iterations { + for m_iter in 0..m_iterations { let m_load_iter = partition_scheduler.map_m(m_iter); let tile_lhs = StageLhs::tile(lhs_stage, (m_load_iter, k_load_iter)); @@ -296,8 +273,6 @@ where config, ); comptime!(lhs_load_counter += 1); - - comptime![m_iter += 1]; } let mut n_iter = comptime![0u32]; @@ -337,11 +312,8 @@ where ); comptime!(rhs_load_counter += 1); - let mut m_iter = comptime![0u32]; - - #[allow(clippy::explicit_counter_loop)] #[unroll] - for _ in 0..m_iterations { + for m_iter in 0..m_iterations { let accumulator = Accumulators::::get_at_mut(acc, m_iter, n_iter, config); @@ -360,8 +332,6 @@ where config, ); comptime!(execute_counter += 1); - - comptime![m_iter += 1]; } comptime![n_iter += 1]; @@ -373,11 +343,8 @@ where &mut rhs_fragments.1 }; - let mut m_iter = comptime![0u32]; - - #[allow(clippy::explicit_counter_loop)] #[unroll] - for _ in 0..m_iterations { + for m_iter in 0..m_iterations { let accumulator = Accumulators::::get_at_mut(acc, m_iter, n_iter, config); TM::execute( @@ -395,11 +362,7 @@ where config, ); comptime!(execute_counter += 1); - - comptime![m_iter += 1]; } - - comptime![k_iter += 1]; } assert!(lhs_load_counter == lhs_load_total); diff --git a/crates/cubecl-matmul/src/components/stage/matmul/partitioned_matmul.rs b/crates/cubecl-matmul/src/components/stage/matmul/partitioned_matmul.rs index b01bbdbea..dc8834ba9 100644 --- a/crates/cubecl-matmul/src/components/stage/matmul/partitioned_matmul.rs +++ b/crates/cubecl-matmul/src/components/stage/matmul/partitioned_matmul.rs @@ -173,20 +173,15 @@ where let m_iterations = global_config.tiling_scheme().tiles_in_stage_partition_m(); let n_iterations = global_config.tiling_scheme().tiles_in_stage_partition_n(); - let mut m_iter = comptime![0u32]; - W::on_event(listener, global::WriteEvent::new_Begin()); // Iterate over each tile in the partition #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..comptime![m_iterations] { + for m_iter in 0..m_iterations { let m_load_iter = partition_scheduler.map_m(m_iter); - let mut n_iter = comptime![0u32]; #[unroll] - #[allow(clippy::explicit_counter_loop)] - for _ in 0..comptime![n_iterations] { + for n_iter in 0..n_iterations { let n_load_iter = partition_scheduler.map_n(n_iter); let tile_accumulator = @@ -199,10 +194,7 @@ where // all tiles in the partition TM::write_results(&mut tile, tile_accumulator, stage_config.tile_config()); W::on_event(listener, global::WriteEvent::new_TileStored(tile_pos)); - - comptime![n_iter += 1]; } - comptime![m_iter += 1]; } W::on_event(listener, global::WriteEvent::new_Finish()); diff --git a/crates/cubecl-matmul/src/components/stage/matmul/plane_partitioned/setup.rs b/crates/cubecl-matmul/src/components/stage/matmul/plane_partitioned/setup.rs index f50343f8a..72625c4c9 100644 --- a/crates/cubecl-matmul/src/components/stage/matmul/plane_partitioned/setup.rs +++ b/crates/cubecl-matmul/src/components/stage/matmul/plane_partitioned/setup.rs @@ -66,7 +66,7 @@ impl< type Config = PlanePartitionedStageConfig; fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-matmul/src/components/stage/matmul/unit_partitioned/setup.rs b/crates/cubecl-matmul/src/components/stage/matmul/unit_partitioned/setup.rs index 9e5abf219..bf71bd365 100644 --- a/crates/cubecl-matmul/src/components/stage/matmul/unit_partitioned/setup.rs +++ b/crates/cubecl-matmul/src/components/stage/matmul/unit_partitioned/setup.rs @@ -65,7 +65,7 @@ impl< type Config = UnitPartitionedStageConfig; fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-matmul/src/components/stage/memory/layout.rs b/crates/cubecl-matmul/src/components/stage/memory/layout.rs index 7fadc1049..ff5d2f339 100644 --- a/crates/cubecl-matmul/src/components/stage/memory/layout.rs +++ b/crates/cubecl-matmul/src/components/stage/memory/layout.rs @@ -4,8 +4,10 @@ use cubecl_core::prelude::*; use cubecl_core::{self as cubecl}; use cubecl_std::tensor::layout::Coords2d; -use crate::components::stage::StageMemoryConfig; use crate::components::tile::StridedTile; +use crate::components::{ + InvalidConfigError, global::memory::GlobalMemoryConfig, stage::StageMemoryConfig, +}; use crate::components::{MatrixLayout, StageIdent}; use super::StridedStage; @@ -213,7 +215,7 @@ impl TilingOrder for OrderedTilingOrder { #[cube] /// Describes how tiles are arranged in shared memory. -pub trait TilingLayout: 'static + Send + Sync + Clone + Copy { +pub trait TilingLayout: 'static + Send + Sync + Clone + Copy + TilingValidation { /// Returns the tile at shared memory coordinates fn get_tile( stage: &StridedStage, @@ -224,6 +226,10 @@ pub trait TilingLayout: 'static + Send + Sync + Clone + Copy { ) -> StridedTile; } +pub trait TilingValidation { + fn check(config: GlobalMemoryConfig) -> Result<(), InvalidConfigError>; +} + #[derive(Clone, Copy)] /// Each tile is stored contiguously in shared memory. /// Global memory loads may require remapping to match this layout. @@ -333,6 +339,19 @@ impl TilingLayout for ContiguousTilingLayout { } } +impl TilingValidation for ContiguousTilingLayout { + fn check(config: GlobalMemoryConfig) -> Result<(), InvalidConfigError> { + let tile_width = match config.matrix_layout { + MatrixLayout::RowMajor => config.elements_in_tile_col, + MatrixLayout::ColMajor => config.elements_in_tile_row, + }; + if config.global_line_size > tile_width { + return Err(Box::new("Invalid line size")); + } + Ok(()) + } +} + #[cube] impl StridedTilingLayout { /// Returns the nth slice of the stage @@ -409,6 +428,19 @@ impl TilingLayout for StridedTilingLayout { } } +impl TilingValidation for StridedTilingLayout { + fn check(config: GlobalMemoryConfig) -> Result<(), InvalidConfigError> { + let stage_width = match config.matrix_layout { + MatrixLayout::RowMajor => config.elements_in_stage_col, + MatrixLayout::ColMajor => config.elements_in_stage_row, + }; + if config.global_line_size > stage_width { + return Err(Box::new("Invalid line size")); + } + Ok(()) + } +} + #[derive(Clone, Copy)] /// Dummy tiling layout that panics if it's used. Can be used when the reader is known to be a /// `FillReader` @@ -426,3 +458,9 @@ impl TilingLayout for NoTilingLayout { panic!("Can't get tile of layoutless tiling!") } } + +impl TilingValidation for NoTilingLayout { + fn check(_config: GlobalMemoryConfig) -> Result<(), InvalidConfigError> { + Ok(()) + } +} diff --git a/crates/cubecl-matmul/src/components/tile/base.rs b/crates/cubecl-matmul/src/components/tile/base.rs index 614244481..5e0d920e2 100644 --- a/crates/cubecl-matmul/src/components/tile/base.rs +++ b/crates/cubecl-matmul/src/components/tile/base.rs @@ -1,5 +1,6 @@ -use cubecl_core::prelude::*; use cubecl_core::{self as cubecl}; +use cubecl_core::{ir::StorageType, prelude::*}; +use cubecl_runtime::MmaConfig; use crate::components::error::MatmulSetupError; use crate::components::{ @@ -47,7 +48,7 @@ pub trait TileMatmulFamily: Send + Sync + 'static { /// /// This function may return an error if the configuration cannot be supported on the current runtime. fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, matmul_line_sizes: &MatmulLineSizes, @@ -59,6 +60,21 @@ pub trait TileMatmulFamily: Send + Sync + 'static { fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes { available_line_sizes } + + /// Returns whether a tile configuration is supported + fn is_supported(_client: &ComputeClient, _config: MmaConfig) -> bool { + !Self::requires_accelerator() + } + + /// Returns all sizes supported for these types, if any + fn supported_sizes( + _client: &ComputeClient, + _lhs_ty: StorageType, + _rhs_ty: StorageType, + _acc_ty: StorageType, + ) -> Vec { + Vec::new() + } } /// Provides matrix multiplication operations at the tile level. diff --git a/crates/cubecl-matmul/src/components/tile/accelerated/config.rs b/crates/cubecl-matmul/src/components/tile/cmma/config.rs similarity index 94% rename from crates/cubecl-matmul/src/components/tile/accelerated/config.rs rename to crates/cubecl-matmul/src/components/tile/cmma/config.rs index ac16441de..b5857833e 100644 --- a/crates/cubecl-matmul/src/components/tile/accelerated/config.rs +++ b/crates/cubecl-matmul/src/components/tile/cmma/config.rs @@ -9,7 +9,7 @@ use crate::components::{MatrixLayout, StageIdent, TileSize}; #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] /// Configuration for Accelerated Matmul -pub struct AcceleratedConfig { +pub struct CmmaConfig { tile_size: TileSize, plane_dim: u32, lhs_layout: MatrixLayout, @@ -21,7 +21,7 @@ pub struct AcceleratedConfig { rhs_stage_line_size: u32, } -impl TileConfig for AcceleratedConfig { +impl TileConfig for CmmaConfig { fn plane_dim(&self) -> u32 { self.plane_dim } @@ -58,7 +58,7 @@ impl TileConfig for AcceleratedConfig { } } -impl AcceleratedConfig { +impl CmmaConfig { #[allow(clippy::too_many_arguments)] /// Create a new config for accelerated matmul /// @@ -66,7 +66,7 @@ impl AcceleratedConfig { /// - cmma is unavailable /// - cmma is unavailable for given types pub fn new( - client: &ComputeClient, + client: &ComputeClient, tile_size: TileSize, plane_dim: u32, lhs_layout: MatrixLayout, @@ -93,7 +93,7 @@ impl AcceleratedConfig { fn check_availability( self, - client: &ComputeClient, + client: &ComputeClient, ) -> Result { let lhs = Lhs::as_type_native_unchecked(); let rhs = Rhs::as_type_native_unchecked(); diff --git a/crates/cubecl-matmul/src/components/tile/accelerated/matmul.rs b/crates/cubecl-matmul/src/components/tile/cmma/matmul.rs similarity index 89% rename from crates/cubecl-matmul/src/components/tile/accelerated/matmul.rs rename to crates/cubecl-matmul/src/components/tile/cmma/matmul.rs index bfcb34798..fa06d0f80 100644 --- a/crates/cubecl-matmul/src/components/tile/accelerated/matmul.rs +++ b/crates/cubecl-matmul/src/components/tile/cmma/matmul.rs @@ -1,9 +1,9 @@ use std::marker::PhantomData; -use crate::components::tile::{TileConfig, TileMatmul, accelerated::reader::CmmaFragmentReader}; -use crate::components::tile::{accelerated::writer::CmmaStageWriter, tile_data::StridedTile}; +use crate::components::tile::{TileConfig, TileMatmul, cmma::reader::CmmaFragmentReader}; +use crate::components::tile::{cmma::writer::CmmaStageWriter, tile_data::StridedTile}; use crate::components::tile::{ - accelerated::{config::AcceleratedConfig, reader::CmmaStageReader}, + cmma::{config::CmmaConfig, reader::CmmaStageReader}, io::{Strided, TileKind}, }; use crate::components::{StageIdent, as_cmma_layout}; @@ -12,17 +12,17 @@ use cubecl_core::{cmma, prelude::*}; use cubecl_std::CubeOption; /// Uses one plane to perform a small matmul using accelerated instructions. -pub struct AcceleratedMatmul { +pub struct CmmaMatmul { _ty: PhantomData, } #[cube] impl TileMatmul - for AcceleratedMatmul + for CmmaMatmul where CmmaStageReader: CmmaFragmentReader, { - type Config = AcceleratedConfig; + type Config = CmmaConfig; type LhsFragment = cmma::Matrix; type RhsFragment = cmma::Matrix; type AccFragment = cmma::Matrix; diff --git a/crates/cubecl-matmul/src/components/tile/accelerated/mod.rs b/crates/cubecl-matmul/src/components/tile/cmma/mod.rs similarity index 100% rename from crates/cubecl-matmul/src/components/tile/accelerated/mod.rs rename to crates/cubecl-matmul/src/components/tile/cmma/mod.rs diff --git a/crates/cubecl-matmul/src/components/tile/accelerated/reader.rs b/crates/cubecl-matmul/src/components/tile/cmma/reader.rs similarity index 100% rename from crates/cubecl-matmul/src/components/tile/accelerated/reader.rs rename to crates/cubecl-matmul/src/components/tile/cmma/reader.rs diff --git a/crates/cubecl-matmul/src/components/tile/accelerated/setup.rs b/crates/cubecl-matmul/src/components/tile/cmma/setup.rs similarity index 53% rename from crates/cubecl-matmul/src/components/tile/accelerated/setup.rs rename to crates/cubecl-matmul/src/components/tile/cmma/setup.rs index 7f1550f11..961e99b79 100644 --- a/crates/cubecl-matmul/src/components/tile/accelerated/setup.rs +++ b/crates/cubecl-matmul/src/components/tile/cmma/setup.rs @@ -1,25 +1,26 @@ -use crate::components::tile::accelerated::config::AcceleratedConfig; -use crate::components::tile::accelerated::matmul::AcceleratedMatmul; +use crate::components::tile::cmma::matmul::CmmaMatmul; use crate::components::tile::{ TileMatmulFamily, - accelerated::reader::{CmmaFragmentReader, CmmaStageReader}, + cmma::reader::{CmmaFragmentReader, CmmaStageReader}, }; use crate::components::{InvalidConfigError, MatmulLineSizes, MatmulProblem, MatmulSelection}; +use crate::components::{TileSize, tile::cmma::config::CmmaConfig}; use crate::components::{error::MatmulSetupError, tile::io::Strided}; use crate::components::{resource::ComputeResources, tile::io::TileKind}; -use cubecl_core::prelude::*; +use cubecl_core::{ir::StorageType, prelude::*}; +use cubecl_runtime::MmaConfig; -impl TileMatmulFamily for AcceleratedMatmul +impl TileMatmulFamily for CmmaMatmul where CmmaStageReader: CmmaFragmentReader, { - type Matmul = AcceleratedMatmul; + type Matmul = CmmaMatmul; type LhsTile = Strided; type RhsTile = Strided; type AccTile = Tile; type OutTile = Strided; - type Config = AcceleratedConfig; + type Config = CmmaConfig; fn requires_accelerator() -> bool { true @@ -30,12 +31,12 @@ where } fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, matmul_line_sizes: &MatmulLineSizes, ) -> Result { - AcceleratedConfig::new::( + CmmaConfig::new::( client, selection.tiling_scheme.tile_size, selection.plane_dim, @@ -48,4 +49,24 @@ where matmul_line_sizes.rhs as u32, ) } + + fn is_supported(client: &ComputeClient, config: MmaConfig) -> bool { + client.properties().features.cmma.contains(&config) + } + + fn supported_sizes( + client: &ComputeClient, + lhs_ty: StorageType, + rhs_ty: StorageType, + acc_ty: StorageType, + ) -> Vec { + client + .properties() + .features + .cmma + .iter() + .filter(|it| it.a_type == lhs_ty && it.b_type == rhs_ty && it.cd_type == acc_ty) + .map(|it| (it.m, it.n, it.k).into()) + .collect() + } } diff --git a/crates/cubecl-matmul/src/components/tile/accelerated/writer.rs b/crates/cubecl-matmul/src/components/tile/cmma/writer.rs similarity index 100% rename from crates/cubecl-matmul/src/components/tile/accelerated/writer.rs rename to crates/cubecl-matmul/src/components/tile/cmma/writer.rs diff --git a/crates/cubecl-matmul/src/components/tile/mma/config.rs b/crates/cubecl-matmul/src/components/tile/mma/config.rs index 537cf8d34..d97e757f9 100644 --- a/crates/cubecl-matmul/src/components/tile/mma/config.rs +++ b/crates/cubecl-matmul/src/components/tile/mma/config.rs @@ -66,7 +66,7 @@ impl MmaMatmulConfig { /// - cmma is unavailable /// - cmma is unavailable for given types pub fn new( - client: &ComputeClient, + client: &ComputeClient, tile_size: TileSize, plane_dim: u32, lhs_layout: MatrixLayout, @@ -93,7 +93,7 @@ impl MmaMatmulConfig { fn check_availability( self, - client: &ComputeClient, + client: &ComputeClient, ) -> Result { let lhs = Lhs::as_type_native_unchecked(); let rhs = Rhs::as_type_native_unchecked(); diff --git a/crates/cubecl-matmul/src/components/tile/mma/setup.rs b/crates/cubecl-matmul/src/components/tile/mma/setup.rs index bb51a09fd..697267bc9 100644 --- a/crates/cubecl-matmul/src/components/tile/mma/setup.rs +++ b/crates/cubecl-matmul/src/components/tile/mma/setup.rs @@ -1,15 +1,19 @@ -use crate::components::tile::{ - TileMatmulFamily, - mma::{ - MmaMatmul, - config::MmaMatmulConfig, - reader::{MmaFragmentReader, MmaStageReader}, +use crate::components::{InvalidConfigError, MatmulLineSizes, MatmulProblem, MatmulSelection}; +use crate::components::{ + TileSize, + tile::{ + TileMatmulFamily, + mma::{ + MmaMatmul, + config::MmaMatmulConfig, + reader::{MmaFragmentReader, MmaStageReader}, + }, }, }; -use crate::components::{InvalidConfigError, MatmulLineSizes, MatmulProblem, MatmulSelection}; use crate::components::{error::MatmulSetupError, tile::io::Strided}; use crate::components::{resource::ComputeResources, tile::io::TileKind}; -use cubecl_core::prelude::*; +use cubecl_core::{ir::StorageType, prelude::*}; +use cubecl_runtime::MmaConfig; impl TileMatmulFamily for MmaMatmul where @@ -32,7 +36,7 @@ where } fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, matmul_line_sizes: &MatmulLineSizes, @@ -50,4 +54,24 @@ where matmul_line_sizes.rhs as u32, ) } + + fn is_supported(client: &ComputeClient, config: MmaConfig) -> bool { + client.properties().features.mma.contains(&config) + } + + fn supported_sizes( + client: &ComputeClient, + lhs_ty: StorageType, + rhs_ty: StorageType, + acc_ty: StorageType, + ) -> Vec { + client + .properties() + .features + .mma + .iter() + .filter(|it| it.a_type == lhs_ty && it.b_type == rhs_ty && it.cd_type == acc_ty) + .map(|it| (it.m, it.n, it.k).into()) + .collect() + } } diff --git a/crates/cubecl-matmul/src/components/tile/mod.rs b/crates/cubecl-matmul/src/components/tile/mod.rs index c7aca4636..b68785ff6 100644 --- a/crates/cubecl-matmul/src/components/tile/mod.rs +++ b/crates/cubecl-matmul/src/components/tile/mod.rs @@ -1,7 +1,7 @@ //! Matrix multiplication on register- or shared-memory tiles. //! Optimized for fixed shapes and low-level compute strategies. -pub mod accelerated; +pub mod cmma; pub mod io; pub mod mma; pub mod plane_vec_mat_inner_product; diff --git a/crates/cubecl-matmul/src/components/tile/plane_vec_mat_inner_product/config.rs b/crates/cubecl-matmul/src/components/tile/plane_vec_mat_inner_product/config.rs index aed951517..4aa4bfd63 100644 --- a/crates/cubecl-matmul/src/components/tile/plane_vec_mat_inner_product/config.rs +++ b/crates/cubecl-matmul/src/components/tile/plane_vec_mat_inner_product/config.rs @@ -77,7 +77,7 @@ impl PlaneVecMatInnerProductConfig { /// - Line sizes do not evenly divide tile sizes in the lined axis /// - Types are unavailable pub fn new( - client: &ComputeClient, + client: &ComputeClient, tiling_scheme: TilingScheme, plane_dim: u32, lhs_layout: MatrixLayout, @@ -154,7 +154,7 @@ impl PlaneVecMatInnerProductConfig { fn check_availability( self, - client: &ComputeClient, + client: &ComputeClient, ) -> Result { if !client.properties().features.plane.contains(Plane::Ops) { return Err(MatmulSetupError::Unavailable( diff --git a/crates/cubecl-matmul/src/components/tile/plane_vec_mat_inner_product/setup.rs b/crates/cubecl-matmul/src/components/tile/plane_vec_mat_inner_product/setup.rs index 175ec14ad..6409a8788 100644 --- a/crates/cubecl-matmul/src/components/tile/plane_vec_mat_inner_product/setup.rs +++ b/crates/cubecl-matmul/src/components/tile/plane_vec_mat_inner_product/setup.rs @@ -30,7 +30,7 @@ where } fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, matmul_line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-matmul/src/components/tile/register/config.rs b/crates/cubecl-matmul/src/components/tile/register/config.rs index 9c18a6632..ae71fb820 100644 --- a/crates/cubecl-matmul/src/components/tile/register/config.rs +++ b/crates/cubecl-matmul/src/components/tile/register/config.rs @@ -9,6 +9,7 @@ use crate::components::tile::TileConfig; use crate::components::{MatrixLayout, StageIdent, TileSize}; /// Execution mode for the RegisterMatmul +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] pub enum ProductType { /// Computes the Tile Matmul as m*n inner products of length k. /// @@ -81,7 +82,7 @@ impl RegisterConfig { /// - Line sizes do not evenly divide tile sizes in the lined axis /// - Types are unavailable pub fn new( - client: &ComputeClient, + client: &ComputeClient, tile_size: TileSize, plane_dim: u32, lhs_layout: MatrixLayout, @@ -108,8 +109,25 @@ impl RegisterConfig { } pub fn product_type(&self) -> ProductType { - // TODO: Make it configurable. - ProductType::Outer + let lhs_preferred = match self.lhs_layout { + MatrixLayout::RowMajor => ProductType::Inner, + MatrixLayout::ColMajor => ProductType::Outer, + }; + let rhs_preferred = match self.lhs_layout { + MatrixLayout::RowMajor => ProductType::Outer, + MatrixLayout::ColMajor => ProductType::Inner, + }; + + if lhs_preferred == rhs_preferred { + lhs_preferred + } else if self.tile_size.m() == 1 { + rhs_preferred + } else if self.tile_size.n() == 1 { + lhs_preferred + } else { + // No better solution + ProductType::Outer + } } fn validate(self) -> Result { @@ -125,14 +143,14 @@ impl RegisterConfig { MatrixLayout::RowMajor => { if !k.is_multiple_of(lhs) { return Err(MatmulSetupError::InvalidConfig(Box::new(format!( - "Tile shape in lined axis {k:?} should be divisible by line size {lhs:?}" + "Tile shape in lined axis k({k:?}) should be divisible by line size lhs({lhs:?})" )))); } } MatrixLayout::ColMajor => { if !m.is_multiple_of(lhs) { return Err(MatmulSetupError::InvalidConfig(Box::new(format!( - "Tile shape in lined axis {m:?} should be divisible by line size {lhs:?}" + "Tile shape in lined axis m({m:?}) should be divisible by line size lhs({lhs:?})" )))); } } @@ -141,14 +159,14 @@ impl RegisterConfig { MatrixLayout::RowMajor => { if !n.is_multiple_of(rhs) { return Err(MatmulSetupError::InvalidConfig(Box::new(format!( - "Tile shape in lined axis {n:?} should be divisible by line size {rhs:?}" + "Tile shape in lined axis n({n:?}) should be divisible by line size rhs({rhs:?})" )))); } } MatrixLayout::ColMajor => { if !k.is_multiple_of(rhs) { return Err(MatmulSetupError::InvalidConfig(Box::new(format!( - "Tile shape in lined axis {k:?} should be divisible by line size {rhs:?}" + "Tile shape in lined axis k({k:?}) should be divisible by line size rhs({rhs:?})" )))); } } @@ -156,7 +174,7 @@ impl RegisterConfig { if !n.is_multiple_of(out) { return Err(MatmulSetupError::InvalidConfig(Box::new(format!( - "Tile shape in lined axis {n:?} should be divisible by line size {out:?}" + "Tile shape in lined axis n({n:?}) should be divisible by line size out({out:?})" )))); } @@ -165,7 +183,7 @@ impl RegisterConfig { fn check_availability( self, - client: &ComputeClient, + client: &ComputeClient, ) -> Result { let lhs = Lhs::as_type_native_unchecked(); let rhs = Rhs::as_type_native_unchecked(); diff --git a/crates/cubecl-matmul/src/components/tile/register/matmul.rs b/crates/cubecl-matmul/src/components/tile/register/matmul.rs index d082f26a7..02cf13611 100644 --- a/crates/cubecl-matmul/src/components/tile/register/matmul.rs +++ b/crates/cubecl-matmul/src/components/tile/register/matmul.rs @@ -158,7 +158,7 @@ impl RegisterMatmul { #[unroll(UNROLL)] for line_within_segment in 0..num_lines_per_segment { let line = tile.get_line(segment, line_within_segment); - #[unroll(UNROLL)] + #[unroll] for pos_within_line in 0..line_size { array[segment * segment_size + line_within_segment * line_size @@ -182,7 +182,7 @@ impl RegisterMatmul { #[unroll(UNROLL)] for line_within_segment in 0..num_lines_per_segment { let line = tile.get_line(segment, line_within_segment); - #[unroll(UNROLL)] + #[unroll] for pos_within_line in 0..line_size { array[(line_within_segment * line_size + pos_within_line) * num_segments + segment] = ER::cast_from(line[pos_within_line]); diff --git a/crates/cubecl-matmul/src/components/tile/register/setup.rs b/crates/cubecl-matmul/src/components/tile/register/setup.rs index 7bfbfef8f..cb168a5a3 100644 --- a/crates/cubecl-matmul/src/components/tile/register/setup.rs +++ b/crates/cubecl-matmul/src/components/tile/register/setup.rs @@ -32,7 +32,7 @@ where } fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, matmul_line_sizes: &MatmulLineSizes, @@ -53,8 +53,5 @@ where fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes { available_line_sizes - .filter_lhs(|ls| *ls <= 4) - .filter_rhs(|ls| *ls <= 4) - .filter_out(|ls| *ls <= 4) } } diff --git a/crates/cubecl-matmul/src/components/tile/register/writer.rs b/crates/cubecl-matmul/src/components/tile/register/writer.rs index f6b95121f..a9a823062 100644 --- a/crates/cubecl-matmul/src/components/tile/register/writer.rs +++ b/crates/cubecl-matmul/src/components/tile/register/writer.rs @@ -21,7 +21,7 @@ impl RegisterStageWriter { #[unroll(UNROLL)] for i in 0..comptime!(config.tile_size.mn() / out_line_size) { let mut line = Line::empty(out_line_size); - #[unroll(UNROLL)] + #[unroll] for j in 0..comptime!(out_line_size) { line[j] = acc[i * out_line_size + j]; } diff --git a/crates/cubecl-matmul/src/kernels/layered/algorithm/base.rs b/crates/cubecl-matmul/src/kernels/layered/algorithm/base.rs index 3c90e1dab..07adeca1f 100644 --- a/crates/cubecl-matmul/src/kernels/layered/algorithm/base.rs +++ b/crates/cubecl-matmul/src/kernels/layered/algorithm/base.rs @@ -17,7 +17,7 @@ pub trait Algorithm { type BatchMatmul: BatchMatmulFamily; fn setup( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, @@ -26,7 +26,7 @@ pub trait Algorithm { } fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, line_sizes: &MatmulLineSizes, @@ -42,7 +42,7 @@ pub trait Algorithm { )) } - fn select_plane_dim(client: &ComputeClient) -> u32 { + fn select_plane_dim(client: &ComputeClient) -> u32 { client.properties().hardware.plane_size_max } } diff --git a/crates/cubecl-matmul/src/kernels/layered/algorithm/double_buffering.rs b/crates/cubecl-matmul/src/kernels/layered/algorithm/double_buffering.rs index 9f96cbc78..f08f67e01 100644 --- a/crates/cubecl-matmul/src/kernels/layered/algorithm/double_buffering.rs +++ b/crates/cubecl-matmul/src/kernels/layered/algorithm/double_buffering.rs @@ -69,7 +69,7 @@ where PartitionedBatchMatmulFamily; fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, _line_sizes: &MatmulLineSizes, @@ -120,7 +120,7 @@ where PartitionedBatchMatmulFamily; fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, _line_sizes: &MatmulLineSizes, @@ -170,7 +170,7 @@ where PartitionedBatchMatmulFamily; fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, _line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-matmul/src/kernels/layered/algorithm/double_unit.rs b/crates/cubecl-matmul/src/kernels/layered/algorithm/double_unit.rs index fbe4de0f3..99a870c36 100644 --- a/crates/cubecl-matmul/src/kernels/layered/algorithm/double_unit.rs +++ b/crates/cubecl-matmul/src/kernels/layered/algorithm/double_unit.rs @@ -39,10 +39,10 @@ impl Algorithm for DoubleUnitAlgorithm { PartitionedBatchMatmulFamily; fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, - _line_sizes: &MatmulLineSizes, + line_sizes: &MatmulLineSizes, _elems: MatmulElems, args: &Self::SelectionArgs, ) -> Result { @@ -51,6 +51,7 @@ impl Algorithm for DoubleUnitAlgorithm { problem, plane_dim, true, + line_sizes, UnitMatmulSelectionOptions { tile: args.tile_size, ..Default::default() @@ -58,7 +59,7 @@ impl Algorithm for DoubleUnitAlgorithm { )) } - fn select_plane_dim(client: &ComputeClient) -> u32 { + fn select_plane_dim(client: &ComputeClient) -> u32 { client.properties().hardware.plane_size_min } } diff --git a/crates/cubecl-matmul/src/kernels/layered/algorithm/ordered_double_buffering.rs b/crates/cubecl-matmul/src/kernels/layered/algorithm/ordered_double_buffering.rs index f84329445..6c97c9fb3 100644 --- a/crates/cubecl-matmul/src/kernels/layered/algorithm/ordered_double_buffering.rs +++ b/crates/cubecl-matmul/src/kernels/layered/algorithm/ordered_double_buffering.rs @@ -60,7 +60,7 @@ where PartitionedBatchMatmulFamily; fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, _line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-matmul/src/kernels/layered/algorithm/simple.rs b/crates/cubecl-matmul/src/kernels/layered/algorithm/simple.rs index 6f7d07593..4654d9328 100644 --- a/crates/cubecl-matmul/src/kernels/layered/algorithm/simple.rs +++ b/crates/cubecl-matmul/src/kernels/layered/algorithm/simple.rs @@ -67,7 +67,7 @@ where PartitionedBatchMatmulFamily; fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, _line_sizes: &MatmulLineSizes, @@ -93,20 +93,23 @@ where } fn selection_multi_rows( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, elems: MatmulElems, ) -> Result { let supported = |m: u32, n: u32, k: u32| { - client.properties().features.cmma.contains(&MmaConfig { - a_type: elems.lhs_register, - b_type: elems.rhs_register, - cd_type: elems.acc_register, - m, - n, - k, - }) + TMM::is_supported::( + client, + MmaConfig { + a_type: elems.lhs_register, + b_type: elems.rhs_register, + cd_type: elems.acc_register, + m, + n, + k, + }, + ) }; let cube_count_plan = match client.properties().hardware.num_streaming_multiprocessors { Some(num_sms) => CubeCountPlanSelection::Sm { diff --git a/crates/cubecl-matmul/src/kernels/layered/algorithm/simple_barrier.rs b/crates/cubecl-matmul/src/kernels/layered/algorithm/simple_barrier.rs index a98767faa..cae96888e 100644 --- a/crates/cubecl-matmul/src/kernels/layered/algorithm/simple_barrier.rs +++ b/crates/cubecl-matmul/src/kernels/layered/algorithm/simple_barrier.rs @@ -49,7 +49,7 @@ where PartitionedBatchMatmulFamily; fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, _line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-matmul/src/kernels/layered/algorithm/simple_tma.rs b/crates/cubecl-matmul/src/kernels/layered/algorithm/simple_tma.rs index 91cb11472..859fd7e65 100644 --- a/crates/cubecl-matmul/src/kernels/layered/algorithm/simple_tma.rs +++ b/crates/cubecl-matmul/src/kernels/layered/algorithm/simple_tma.rs @@ -39,7 +39,7 @@ where PartitionedBatchMatmulFamily; fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, _line_sizes: &MatmulLineSizes, diff --git a/crates/cubecl-matmul/src/kernels/layered/algorithm/simple_unit.rs b/crates/cubecl-matmul/src/kernels/layered/algorithm/simple_unit.rs index 002bfe1b0..260a35773 100644 --- a/crates/cubecl-matmul/src/kernels/layered/algorithm/simple_unit.rs +++ b/crates/cubecl-matmul/src/kernels/layered/algorithm/simple_unit.rs @@ -55,10 +55,10 @@ where PartitionedBatchMatmulFamily; fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, - _line_sizes: &MatmulLineSizes, + line_sizes: &MatmulLineSizes, _elems: MatmulElems, args: &Self::SelectionArgs, ) -> Result { @@ -67,6 +67,7 @@ where problem, plane_dim, false, + line_sizes, UnitMatmulSelectionOptions { tile: args.tile_size, stage: match args.tile_size { @@ -81,7 +82,7 @@ where )) } - fn select_plane_dim(client: &ComputeClient) -> u32 { + fn select_plane_dim(client: &ComputeClient) -> u32 { client.properties().hardware.plane_size_min } } diff --git a/crates/cubecl-matmul/src/kernels/layered/algorithm/vecmat.rs b/crates/cubecl-matmul/src/kernels/layered/algorithm/vecmat.rs index 03a355ad6..3c14ef8f9 100644 --- a/crates/cubecl-matmul/src/kernels/layered/algorithm/vecmat.rs +++ b/crates/cubecl-matmul/src/kernels/layered/algorithm/vecmat.rs @@ -47,7 +47,7 @@ impl Algorithm for SimpleVecMatAlgorithm { PartitionedBatchMatmulFamily; fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, line_sizes: &MatmulLineSizes, @@ -84,7 +84,7 @@ impl Algorithm for DoubleVecMatAlgorithm { PartitionedBatchMatmulFamily; fn selection( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, line_sizes: &MatmulLineSizes, @@ -101,7 +101,7 @@ impl Algorithm for DoubleVecMatAlgorithm { } fn selection_vecmat( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, tile_size: TileSize, plane_dim: u32, diff --git a/crates/cubecl-matmul/src/kernels/layered/base.rs b/crates/cubecl-matmul/src/kernels/layered/base.rs index 67179348f..34ba0fe73 100644 --- a/crates/cubecl-matmul/src/kernels/layered/base.rs +++ b/crates/cubecl-matmul/src/kernels/layered/base.rs @@ -1,23 +1,25 @@ -use crate::components::global::args::TensorArgs; use crate::components::{ AccG, AccS, batch::{BatchMatmulFamily, CubeCountInputArgs}, + global::args::{TensorArgs, TensorMapArgs}, }; use crate::components::{ AvailableLineSizes, InputRuntimeArg, LhsG, LhsS, MatmulAvailabilityError, MatmulLineSizes, - MatmulPrecision, MatmulProblem, MatmulSelection, MatmulSetupError, MatmulSpec, MatrixLayout, - OutputRuntimeArg, RhsG, RhsS, + MatmulProblem, MatmulSelection, MatmulSetupError, MatmulSpec, MatrixLayout, OutputRuntimeArg, + RhsG, RhsS, }; -use crate::components::{global::args::TensorMapArgs, tile::TileMatmulFamily}; +use crate::components::{ + InputArg, OutputArg, + global::args::{ConcreteInputsFactory, ConcreteOutputFactory}, +}; +use crate::components::{MatmulPrecision, tile::TileMatmulFamily}; use crate::kernels::layered::selector::launch_kernel_concrete; use crate::{MatmulInputHandle, MatmulInputHandleRef}; use core::any::TypeId; +use cubecl_core::prelude::*; use cubecl_core::{Runtime, client::ComputeClient, frontend::TensorHandleRef}; -use cubecl_core::{prelude::*, try_tensor_line_size_parallel}; use cubecl_runtime::TypeUsage; -use cubecl_std::tensor::{ - MatrixBatchLayout, TensorHandle, into_contiguous_pitched, matrix_batch_layout, -}; +use cubecl_std::tensor::{MatrixBatchLayout, TensorHandle, matrix_batch_layout}; use super::Algorithm; @@ -54,7 +56,7 @@ impl Default for Selection { /// Will fail if unavailable #[allow(clippy::result_large_err)] pub fn launch( - client: &ComputeClient, + client: &ComputeClient, lhs: MatmulInputHandle>, rhs: MatmulInputHandle>, out: TensorHandle>, @@ -80,7 +82,7 @@ pub fn launch( /// otherwise it will fall back on a non-cmma implementation #[allow(clippy::result_large_err)] pub fn launch_ref( - client: &ComputeClient, + client: &ComputeClient, lhs: &MatmulInputHandleRef<'_, R>, rhs: &MatmulInputHandleRef<'_, R>, out: &TensorHandleRef<'_, R>, @@ -101,58 +103,113 @@ pub fn launch_ref( let lhs_owned; let rhs_owned; let lhs = if lhs_make_contiguous { - lhs_owned = match lhs { - MatmulInputHandleRef::Normal(data) => { - MatmulInputHandle::Normal(into_contiguous_pitched::>(client, data)) - } - MatmulInputHandleRef::Quantized { .. } => unimplemented!(), - }; + lhs_owned = lhs.into_contiguous::>(client); &lhs_owned.as_ref() } else { lhs }; let rhs = if rhs_make_contiguous { - rhs_owned = match rhs { - MatmulInputHandleRef::Normal(data) => { - MatmulInputHandle::Normal(into_contiguous_pitched::>(client, data)) - } - MatmulInputHandleRef::Quantized { .. } => unimplemented!(), - }; + rhs_owned = rhs.into_contiguous::>(client); &rhs_owned.as_ref() } else { rhs }; - launch_inner_ref::( + let line_sizes = AvailableLineSizes::from_type_sizes::( + lhs.data().elem_size, + rhs.data().elem_size, + out.elem_size, + ); + + launch_inner_ref::( client, lhs, rhs, out, (lhs_transposed, rhs_transposed), selection, + line_sizes, + ) +} + +/// Launch a matrix multiplication kernel, with TMA restrictions enabled. +/// TMA doesn't support permuted batches, so checks are slightly different. +/// +/// Cmma will be used if available and enabled, +/// otherwise it will fall back on a non-cmma implementation +#[allow(clippy::result_large_err)] +pub fn launch_ref_tma( + client: &ComputeClient, + lhs: &MatmulInputHandleRef<'_, R>, + rhs: &MatmulInputHandleRef<'_, R>, + out: &TensorHandleRef<'_, R>, + selection: &Selection, +) -> Result<(), MatmulSetupError> { + let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_batch_layout(tensor.strides) { + MatrixBatchLayout::Contiguous => (false, false), + MatrixBatchLayout::MildlyPermuted { + transposed, + batch_swap: false, + } => (false, transposed), + _ => (true, false), + }; + + let (lhs_make_contiguous, lhs_transposed) = check_layout(lhs.data()); + let (rhs_make_contiguous, rhs_transposed) = check_layout(rhs.data()); + + let lhs_owned; + let rhs_owned; + let lhs = if lhs_make_contiguous { + lhs_owned = lhs.into_contiguous::>(client); + &lhs_owned.as_ref() + } else { + lhs + }; + let rhs = if rhs_make_contiguous { + rhs_owned = rhs.into_contiguous::>(client); + &rhs_owned.as_ref() + } else { + rhs + }; + + let line_sizes = AvailableLineSizes::from_type_size_tma::(out.elem_size); + + launch_inner_ref::( + client, + lhs, + rhs, + out, + (lhs_transposed, rhs_transposed), + selection, + line_sizes, ) } #[allow(clippy::result_large_err, clippy::too_many_arguments)] -fn launch_inner_ref( - client: &ComputeClient, +fn launch_inner_ref( + client: &ComputeClient, lhs_handle: &MatmulInputHandleRef<'_, R>, rhs_handle: &MatmulInputHandleRef<'_, R>, out: &TensorHandleRef<'_, R>, transposed: (bool, bool), selection: &Selection, -) -> Result<(), MatmulSetupError> { - let lhs = lhs_handle.data(); - let rhs = rhs_handle.data(); - - let rank = lhs.strides.len(); - let lhs_elem = LhsG::::as_type_native().expect("To be a native type"); - let rhs_elem = RhsG::::as_type_native().expect("To be a native type"); - let acc_elem = AccG::::as_type_native().expect("To be a native type"); - - if !LhsG::::supported_uses(client).contains(TypeUsage::Conversion) - || !RhsG::::supported_uses(client).contains(TypeUsage::Conversion) - || !AccG::::supported_uses(client).contains(TypeUsage::Conversion) + line_sizes: AvailableLineSizes, +) -> Result<(), MatmulSetupError> +where + InputArg: ConcreteInputsFactory, + OutputArg: ConcreteOutputFactory, +{ + let lhs_shape = lhs_handle.shape(); + let rhs_shape = rhs_handle.shape(); + + let rank = lhs_shape.len(); + let lhs_elem = LhsG::::as_type_native().expect("To be a native type"); + let rhs_elem = RhsG::::as_type_native().expect("To be a native type"); + let acc_elem = AccG::::as_type_native().expect("To be a native type"); + + if !LhsG::::supported_uses(client).contains(TypeUsage::Conversion) + || !RhsG::::supported_uses(client).contains(TypeUsage::Conversion) + || !AccG::::supported_uses(client).contains(TypeUsage::Conversion) { return Err(MatmulSetupError::Unavailable( MatmulAvailabilityError::TypesUnavailable { @@ -163,9 +220,9 @@ fn launch_inner_ref( )); } - let m = lhs.shape[rank - 2] as u32; - let k = lhs.shape[rank - 1] as u32; - let n = rhs.shape[rank - 1] as u32; + let m = lhs_shape[rank - 2] as u32; + let k = lhs_shape[rank - 1] as u32; + let n = rhs_shape[rank - 1] as u32; let lhs_layout = match transposed.0 { true => MatrixLayout::ColMajor, @@ -181,20 +238,32 @@ fn launch_inner_ref( m: m as usize, n: n as usize, k: k as usize, - lhs_batches: lhs.shape[..lhs.shape.len() - 2].to_vec(), - rhs_batches: rhs.shape[..rhs.shape.len() - 2].to_vec(), + lhs_batches: lhs_shape[..lhs_shape.len() - 2].to_vec(), + rhs_batches: rhs_shape[..rhs_shape.len() - 2].to_vec(), + out_batches: out.shape[..out.shape.len() - 2].to_vec(), lhs_layout, rhs_layout, }; - let line_sizes = AvailableLineSizes::from_types::(&lhs_elem, &rhs_elem, &acc_elem); + let lhs = lhs_handle.data(); + let rhs = rhs_handle.data(); + let line_sizes = A::filter_line_sizes(line_sizes); - let line_sizes = line_sizes + let mut line_sizes = line_sizes .filter_lhs_with_tensor(lhs.strides, lhs.shape, problem.lhs_layout) .filter_rhs_with_tensor(rhs.strides, rhs.shape, problem.rhs_layout) .filter_out_with_tensor(out.strides, out.shape) .pick_max()?; + // The large line size resulting from dequantizing ends up slower due to restrictions on + // algorithms. Use this as a quick and dirty fix. + if lhs_handle.scale().is_some() { + line_sizes.lhs = 1; + } + if rhs_handle.scale().is_some() { + line_sizes.rhs = 1; + } + let fix_plane_dim = |plane_dim: u32| { // Sometimes the GPU doesn't support plane instructions and doesn't report the // plane size, but we can still execute algorithms that don't use plane instructions. @@ -206,14 +275,14 @@ fn launch_inner_ref( let plane_dim = fix_plane_dim(A::select_plane_dim::(client)); - launch_inner_ref_fix_dtype::( + launch_inner_ref_fix_dtype::( client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection, ) } #[allow(clippy::result_large_err, clippy::too_many_arguments)] -fn launch_inner_ref_fix_dtype( - client: &ComputeClient, +fn launch_inner_ref_fix_dtype( + client: &ComputeClient, lhs: &MatmulInputHandleRef<'_, R>, rhs: &MatmulInputHandleRef<'_, R>, out: &TensorHandleRef<'_, R>, @@ -221,16 +290,23 @@ fn launch_inner_ref_fix_dtype( line_sizes: MatmulLineSizes, plane_dim: u32, selection: &Selection, -) -> Result<(), MatmulSetupError> { +) -> Result<(), MatmulSetupError> +where + InputArg: ConcreteInputsFactory, + OutputArg: ConcreteOutputFactory, +{ if ::requires_accelerator() && tf32::supported_uses(client).contains(TypeUsage::Conversion) { match ( - TypeId::of::>() == TypeId::of::(), - TypeId::of::>() == TypeId::of::(), + TypeId::of::>() == TypeId::of::(), + TypeId::of::>() == TypeId::of::(), ) { (true, true) => launch_kernel_concrete::< - ((f32, f32, AccG, tf32, tf32, AccS), TensorArgs), + ( + (LhsG, RhsG, AccG, tf32, tf32, AccS), + MS::Args, + ), R, A, >( @@ -238,8 +314,8 @@ fn launch_inner_ref_fix_dtype( ), (true, false) => launch_kernel_concrete::< ( - (f32, RhsG, AccG, tf32, RhsS, AccS), - TensorArgs, + (LhsG, RhsG, AccG, tf32, RhsS, AccS), + MS::Args, ), R, A, @@ -248,136 +324,28 @@ fn launch_inner_ref_fix_dtype( ), (false, true) => launch_kernel_concrete::< ( - (LhsG, f32, AccG, LhsS, tf32, AccS), - TensorArgs, + (LhsG, RhsG, AccG, LhsS, tf32, AccS), + MS::Args, ), R, A, >( client, lhs, rhs, out, problem, line_sizes, plane_dim, selection, ), - (false, false) => launch_kernel_concrete::<(MP, TensorArgs), R, A>( + (false, false) => launch_kernel_concrete::( client, lhs, rhs, out, problem, line_sizes, plane_dim, selection, ), } } else { - launch_kernel_concrete::<(MP, TensorArgs), R, A>( + launch_kernel_concrete::( client, lhs, rhs, out, problem, line_sizes, plane_dim, selection, ) } } -#[allow(clippy::result_large_err, clippy::too_many_arguments)] -pub fn matmul_cmma_tma_ref_no_check( - client: &ComputeClient, - lhs_handle: &MatmulInputHandleRef<'_, R>, - rhs_handle: &MatmulInputHandleRef<'_, R>, - out: &TensorHandleRef<'_, R>, - transposed: (bool, bool), - selection: &Selection, -) -> Result<(), MatmulSetupError> { - let lhs = lhs_handle.data(); - let rhs = rhs_handle.data(); - - let rank = lhs.strides.len(); - let out_elem = AccG::::as_type_native().expect("To be a native type"); - - let m = lhs.shape[rank - 2] as u32; - let k = lhs.shape[rank - 1] as u32; - let n = rhs.shape[rank - 1] as u32; - - let lhs_layout = match transposed.0 { - true => MatrixLayout::ColMajor, - false => MatrixLayout::RowMajor, - }; - let rhs_layout = match transposed.1 { - true => MatrixLayout::ColMajor, - false => MatrixLayout::RowMajor, - }; - - let line_sizes = MatmulLineSizes { - lhs: 1, - rhs: 1, - out: try_tensor_line_size_parallel( - R::io_optimized_line_sizes_unchecked(&out_elem), - out.shape, - out.strides, - rank - 1, - )?, - }; - - let batch_lhs: usize = lhs.shape[..lhs.shape.len() - 2].iter().product(); - let batch_rhs: usize = rhs.shape[..rhs.shape.len() - 2].iter().product(); - - let problem = MatmulProblem { - m: m as usize, - n: n as usize, - k: k as usize, - lhs_batches: [batch_lhs].to_vec(), - rhs_batches: [batch_rhs].to_vec(), - lhs_layout, - rhs_layout, - }; - - let plane_size = client.properties().hardware.plane_size_max; - - let plane_dim = match plane_size { - 32 | 64 => plane_size, - _ => { - return Err(MatmulSetupError::Unavailable( - MatmulAvailabilityError::PlaneDimUnsupported { - plane_dim: plane_size, - }, - )); - } - }; - - if tf32::supported_uses(client).contains(TypeUsage::Conversion) { - match ( - TypeId::of::>() == TypeId::of::(), - TypeId::of::>() == TypeId::of::(), - ) { - (true, true) => launch_kernel_concrete::< - ((f32, f32, AccG, tf32, tf32, AccS), TensorMapArgs), - R, - A, - >( - client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection, - ), - (true, false) => launch_kernel_concrete::< - ( - (f32, RhsG, AccG, tf32, RhsS, AccS), - TensorMapArgs, - ), - R, - A, - >( - client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection, - ), - (false, true) => launch_kernel_concrete::< - ( - (LhsG, f32, AccG, LhsS, tf32, AccS), - TensorMapArgs, - ), - R, - A, - >( - client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection, - ), - (false, false) => launch_kernel_concrete::<(MP, TensorMapArgs), R, A>( - client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection, - ), - } - } else { - launch_kernel_concrete::<(MP, TensorMapArgs), R, A>( - client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection, - ) - } -} - #[allow(clippy::too_many_arguments, clippy::result_large_err)] pub fn launch_with_config<'a, MS: MatmulSpec, R: Runtime, A: Algorithm>( - client: &ComputeClient, + client: &ComputeClient, cube_dim: CubeDim, cube_count: CubeCount, input: InputRuntimeArg<'a, MS, R>, diff --git a/crates/cubecl-matmul/src/kernels/layered/mod.rs b/crates/cubecl-matmul/src/kernels/layered/mod.rs index d39725de4..5f25b96b3 100644 --- a/crates/cubecl-matmul/src/kernels/layered/mod.rs +++ b/crates/cubecl-matmul/src/kernels/layered/mod.rs @@ -3,7 +3,7 @@ mod base; mod selector; pub use algorithm::*; -pub use base::{Selection, launch, launch_ref, launch_with_config, matmul_cmma_tma_ref_no_check}; +pub use base::{Selection, launch, launch_ref, launch_ref_tma, launch_with_config}; pub use selector::{ NUM_SM_APPROX, NUM_TENSOR_CORES_APPROX, TileSizeSelection, find_instruction_size, launch_kernel_concrete, launch_kernel_virtual, diff --git a/crates/cubecl-matmul/src/kernels/layered/selector/plane.rs b/crates/cubecl-matmul/src/kernels/layered/selector/plane.rs index 86d6197e6..58af2ec2c 100644 --- a/crates/cubecl-matmul/src/kernels/layered/selector/plane.rs +++ b/crates/cubecl-matmul/src/kernels/layered/selector/plane.rs @@ -1,6 +1,5 @@ -use cubecl_core::ir::StorageType; use cubecl_core::{Runtime, client::ComputeClient}; -use cubecl_runtime::{DeviceProperties, MmaConfig}; +use cubecl_runtime::MmaConfig; use crate::components::batch::{ CubeCountPlanSelection, GlobalOrderSelection, HypercubeSelection, SmAllocation, @@ -30,7 +29,7 @@ pub struct PlaneMatmulSelectionOptions { } pub fn plane_matmul_selection( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, elems: MatmulElems, @@ -42,18 +41,7 @@ pub fn plane_matmul_selection( )); } - let tile_size = find_instruction_size( - if TMM::requires_accelerator() { - Some(( - client.properties(), - (elems.lhs_register, elems.rhs_register, elems.acc_register), - )) - } else { - None - }, - problem.m, - problem.n, - ); + let tile_size = find_instruction_size::(client, &elems, problem.m, problem.n); if options.tiny_selection_enabled && is_tiny(problem, &tile_size) { return Ok(selection_tiny::(client, problem, tile_size, plane_dim)); @@ -68,6 +56,12 @@ pub fn plane_matmul_selection( max_plane_per_cube / (4 * precision_factor) }); + if row_count == 0 { + return Err(MatmulSetupError::Unavailable( + MatmulAvailabilityError::PlaneDimUnsupported { plane_dim }, + )); + } + let (rows_per_plane, stage_size_m, partition_shape_n) = select_size( options.multi_row_strategy, row_count as usize, @@ -158,24 +152,24 @@ fn select_size( /// /// Will use 16x16 for balanced matrices, and 32x8 or 8x32 for degenerated ones. #[allow(clippy::type_complexity)] -pub fn find_instruction_size( - properties: Option<(&DeviceProperties, (StorageType, StorageType, StorageType))>, +pub fn find_instruction_size( + client: &ComputeClient, + elems: &MatmulElems, m: usize, n: usize, ) -> TileSize { let supported = |m: u32, n: u32, k: u32| { - properties - .map(|(p, (a_type, b_type, cd_type))| { - p.features.cmma.contains(&MmaConfig { - a_type, - b_type, - cd_type, - m, - n, - k, - }) - }) - .unwrap_or(true) + TMM::is_supported::( + client, + MmaConfig { + a_type: elems.lhs_register, + b_type: elems.rhs_register, + cd_type: elems.acc_register, + m, + n, + k, + }, + ) }; if m >= 4 * n && supported(32, 8, 16) { @@ -187,12 +181,20 @@ pub fn find_instruction_size( } else if supported(8, 8, 8) { (8, 8, 8).into() } else { - (16, 16, 8).into() + TMM::supported_sizes::( + client, + elems.lhs_register, + elems.rhs_register, + elems.acc_register, + ) + .first() + .copied() + .unwrap_or_else(|| (16, 16, 8).into()) } } fn selection_tiny( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, tile_size: TileSize, plane_dim: u32, diff --git a/crates/cubecl-matmul/src/kernels/layered/selector/select_kernel.rs b/crates/cubecl-matmul/src/kernels/layered/selector/select_kernel.rs index f1b83b2e9..53d2d5c63 100644 --- a/crates/cubecl-matmul/src/kernels/layered/selector/select_kernel.rs +++ b/crates/cubecl-matmul/src/kernels/layered/selector/select_kernel.rs @@ -17,7 +17,7 @@ use cubecl_core::{Runtime, client::ComputeClient}; /// Only works for concrete tensor inputs and output. #[allow(clippy::result_large_err, clippy::too_many_arguments)] pub fn launch_kernel_concrete( - client: &ComputeClient, + client: &ComputeClient, lhs: &MatmulInputHandleRef<'_, R>, rhs: &MatmulInputHandleRef<'_, R>, out: &TensorHandleRef<'_, R>, @@ -32,32 +32,48 @@ where { let elems = MatmulElems::new::(); + let mut view_line_sizes = line_sizes; + + if let MatmulInputHandleRef::Quantized { scheme, .. } = lhs { + view_line_sizes.lhs *= scheme.num_quants() as u8; + } + if let MatmulInputHandleRef::Quantized { scheme, .. } = rhs { + view_line_sizes.rhs *= scheme.num_quants() as u8; + } + let selection = match selection { Selection::Forced(selection) => selection.clone(), Selection::Inferred(args) => { - A::selection::(client, &problem, plane_dim, &line_sizes, elems, args)? + A::selection::(client, &problem, plane_dim, &view_line_sizes, elems, args)? } }; - let config = A::setup::(client, &problem, &selection, &line_sizes)?; + let config = A::setup::(client, &problem, &selection, &view_line_sizes)?; let cube_count_plan = config.hypercube_config().cube_count_plan( &problem, client.properties().hardware.max_cube_count.clone(), ); - let line_sizes = config.line_sizes(); - launch_with_config::( client, config.cube_dim(), cube_count_plan.resolve(), as ConcreteInputsFactory>::create( + client, lhs, rhs, &selection, &problem, &line_sizes, + config, + ), + as ConcreteOutputFactory>::create( + client, + out, + &selection, + &problem, + &line_sizes, + config, ), - as ConcreteOutputFactory>::create(out, &selection, &problem, &line_sizes), cube_count_plan.as_args(), config, ) @@ -65,11 +81,11 @@ where /// Select which kernel to launch for the given Algorithm. pub fn launch_kernel_virtual<'a, MS: MatmulSpec, R: Runtime, A: Algorithm>( - client: &ComputeClient, + client: &ComputeClient, input: InputRuntimeArg<'a, MS, R>, output: OutputRuntimeArg<'a, MS, R>, problem: MatmulProblem, - line_sizes: MatmulLineSizes, + view_line_sizes: MatmulLineSizes, plane_dim: u32, selection: &Selection, ) -> Result<(), MatmulSetupError> { @@ -78,10 +94,10 @@ pub fn launch_kernel_virtual<'a, MS: MatmulSpec, R: Runtime, A: Algorithm>( let selection = match selection { Selection::Forced(selection) => selection.clone(), Selection::Inferred(args) => { - A::selection::(client, &problem, plane_dim, &line_sizes, elems, args)? + A::selection::(client, &problem, plane_dim, &view_line_sizes, elems, args)? } }; - let config = A::setup::(client, &problem, &selection, &line_sizes)?; + let config = A::setup::(client, &problem, &selection, &view_line_sizes)?; let cube_count_plan = config.hypercube_config().cube_count_plan( &problem, diff --git a/crates/cubecl-matmul/src/kernels/layered/selector/unit.rs b/crates/cubecl-matmul/src/kernels/layered/selector/unit.rs index 32ed6ba24..01867f265 100644 --- a/crates/cubecl-matmul/src/kernels/layered/selector/unit.rs +++ b/crates/cubecl-matmul/src/kernels/layered/selector/unit.rs @@ -1,10 +1,9 @@ -use cubecl_core::{Runtime, client::ComputeClient}; - use crate::components::{ - MatmulKind, MatmulProblem, MatmulSelection, MatrixLayout, TilingScheme, + MatmulKind, MatmulLineSizes, MatmulProblem, MatmulSelection, MatrixLayout, TilingScheme, batch::{CubeCountPlanSelection, GlobalOrderSelection, HypercubeSelection, SmAllocation}, stage::PartitionBuffering, }; +use cubecl_core::{Runtime, client::ComputeClient}; #[derive(Default, Clone, Copy, Debug)] pub enum TileSizeSelection { @@ -38,37 +37,58 @@ pub struct UnitMatmulSelectionOptions { /// Computes a [MatmulSelection] depending on the problem kind pub fn unit_matmul_selection( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, plane_dim: u32, double_buffering: bool, + line_size: &MatmulLineSizes, options: UnitMatmulSelectionOptions, ) -> MatmulSelection { let kind: MatmulKind = problem.into(); let num_sms = client.properties().hardware.num_streaming_multiprocessors; + let min_tile_size = u8::max(line_size.lhs, line_size.rhs); + let min_tile_size = u8::max(line_size.out, min_tile_size) as u32; + let tile_size = u32::max(min_tile_size, 4); match kind { - MatmulKind::General => { - general_unit_selector(problem, plane_dim, double_buffering, num_sms, options) - } - MatmulKind::MatVec => { - matvec_unit_selector(problem, plane_dim, double_buffering, num_sms, options) - } - MatmulKind::VecMat => vecmat_unit_selector(problem, plane_dim, double_buffering, num_sms), + MatmulKind::General => general_unit_selector( + problem, + plane_dim, + double_buffering, + tile_size, + num_sms, + options, + ), + MatmulKind::MatVec => matvec_unit_selector( + problem, + plane_dim, + double_buffering, + tile_size, + num_sms, + options, + ), + MatmulKind::VecMat => vecmat_unit_selector( + problem, + plane_dim, + double_buffering, + tile_size, + num_sms, + options, + ), MatmulKind::ScalarVec => { - scalarvec_unit_selector(problem, plane_dim, double_buffering, num_sms) + scalarvec_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms) } MatmulKind::VecScalar => { - vecscalar_unit_selector(problem, plane_dim, double_buffering, num_sms) + vecscalar_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms) } MatmulKind::InnerProduct => { - inner_product_unit_selector(problem, plane_dim, double_buffering, num_sms) + inner_product_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms) } MatmulKind::OuterProduct => { - outer_product_unit_selector(problem, plane_dim, double_buffering, num_sms) + outer_product_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms) } MatmulKind::ScalarProduct => { - scalar_product_unit_selector(problem, plane_dim, double_buffering, num_sms) + scalar_product_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms) } } } @@ -78,6 +98,7 @@ fn general_unit_selector( problem: &MatmulProblem, plane_dim: u32, double_buffering: bool, + tile_size: u32, num_sms: Option, options: UnitMatmulSelectionOptions, ) -> MatmulSelection { @@ -87,7 +108,7 @@ fn general_unit_selector( let (tile_size, mut partition_size) = match (problem.lhs_layout, problem.rhs_layout, options.tile) { (RowMajor, _, TileSizeSelection::MinTileSize) => ( - (1, 4, 4), + (1, tile_size, tile_size), ( scale_partition(options.partition, problem.m, 4, 9), 2, @@ -95,11 +116,11 @@ fn general_unit_selector( ), ), (ColMajor, RowMajor, TileSizeSelection::MinTileSize) => ( - (4, 4, 1), + (tile_size, tile_size, 1), (2, 2, scale_partition(options.partition, problem.k, 3, 10)), ), (ColMajor, ColMajor, _) | (_, _, TileSizeSelection::MaxTileSize) => ( - (4, 4, 4), + (tile_size, tile_size, tile_size), ( scale_partition(options.partition, problem.m, 2, 9), 2, @@ -108,9 +129,16 @@ fn general_unit_selector( ), }; - // It seems to be faster, it's not a requirement of the algo. - if double_buffering && partition_size.2 > 2 { - partition_size.2 /= 2; + let mut num_plane = 8; + + if double_buffering { + if partition_size.0 > 2 { + partition_size.0 /= 2; + } + if partition_size.2 > 2 { + partition_size.2 /= 2; + } + num_plane /= 2; } selection( @@ -120,7 +148,7 @@ fn general_unit_selector( plane_dim, StageSelection::WithPlane { plane_dim, - num_plane: 8, + num_plane, }, num_sms, GlobalOrderSelection::SwizzleRow { @@ -136,13 +164,13 @@ fn matvec_unit_selector( problem: &MatmulProblem, plane_dim: u32, _double_buffering: bool, + tile_size: u32, num_sms: Option, - options: UnitMatmulSelectionOptions, + _options: UnitMatmulSelectionOptions, ) -> MatmulSelection { - use MatrixLayout::*; - let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout, options.tile) { - (RowMajor, _, TileSizeSelection::MinTileSize) => ((1, 1, 4), (1, 1, 4)), - _ => ((4, 1, 4), (1, 1, 4)), + let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) { + (MatrixLayout::RowMajor, _) => ((1, 1, tile_size), (1, 1, tile_size * 2)), + _ => ((tile_size, 1, tile_size), (1, 1, 1)), }; selection( @@ -150,44 +178,36 @@ fn matvec_unit_selector( partition_size, PartitionBuffering::Single, plane_dim, - StageSelection::Fixed { m: 8, n: 8 }, + StageSelection::Fixed { + m: plane_dim / 2, + n: 2, + }, num_sms, GlobalOrderSelection::Default, - options.stage, + StageScaling::Disabled, ) } /// (1, K) @ (K, N) → (1, N) fn vecmat_unit_selector( - problem: &MatmulProblem, + _problem: &MatmulProblem, plane_dim: u32, _double_buffering: bool, + tile_size: u32, num_sms: Option, + _options: UnitMatmulSelectionOptions, ) -> MatmulSelection { - use MatrixLayout::*; - let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) { - (RowMajor, RowMajor) => ((1, 4, 4), (1, 1, 4)), - (RowMajor, ColMajor) => ( - (1, 4, 4), - (2, 1, scale_partition(Default::default(), problem.k, 3, 7)), - ), - (ColMajor, RowMajor) => ((1, 4, 4), (1, 1, 4)), - (ColMajor, ColMajor) => ( - (1, 4, 4), - ( - 2, - 1, - scale_partition(PartitionScaling::Enabled, problem.k, 3, 7), - ), - ), - }; + let (tile_size, partition_size) = ((1, tile_size, tile_size), (1, 1, 1)); selection( tile_size, partition_size, PartitionBuffering::Single, plane_dim, - StageSelection::Fixed { m: 8, n: 8 }, + StageSelection::Fixed { + m: 2, + n: plane_dim / 2, + }, num_sms, GlobalOrderSelection::Default, StageScaling::Disabled, @@ -199,14 +219,15 @@ fn scalarvec_unit_selector( problem: &MatmulProblem, plane_dim: u32, _double_buffering: bool, + tile_size: u32, num_sms: Option, ) -> MatmulSelection { use MatrixLayout::*; let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) { - (RowMajor, RowMajor) => ((1, 4, 4), (1, 2, 1)), - (RowMajor, ColMajor) => ((1, 4, 4), (1, 2, 1)), - (ColMajor, RowMajor) => ((1, 4, 4), (1, 2, 1)), - (ColMajor, ColMajor) => ((1, 4, 4), (2, 2, 1)), + (RowMajor, RowMajor) => ((1, tile_size, tile_size), (1, 2, 1)), + (RowMajor, ColMajor) => ((1, tile_size, tile_size), (1, 2, 1)), + (ColMajor, RowMajor) => ((1, tile_size, tile_size), (1, 2, 1)), + (ColMajor, ColMajor) => ((1, tile_size, tile_size), (2, 2, 1)), }; selection( @@ -214,7 +235,10 @@ fn scalarvec_unit_selector( partition_size, PartitionBuffering::Single, plane_dim, - StageSelection::Fixed { m: 4, n: 8 }, + StageSelection::Fixed { + m: 2, + n: plane_dim / 2, + }, num_sms, GlobalOrderSelection::Default, StageScaling::Disabled, @@ -226,16 +250,20 @@ fn vecscalar_unit_selector( _problem: &MatmulProblem, plane_dim: u32, _double_buffering: bool, + tile_size: u32, num_sms: Option, ) -> MatmulSelection { - let (tile_size, partition_size) = ((4, 1, 4), (1, 1, 2)); + let (tile_size, partition_size) = ((tile_size, 1, 1), (1, 1, 1)); selection( tile_size, partition_size, PartitionBuffering::Single, plane_dim, - StageSelection::Fixed { m: 8, n: 4 }, + StageSelection::Fixed { + m: plane_dim / 2, + n: 2, + }, num_sms, GlobalOrderSelection::Default, StageScaling::Disabled, @@ -247,14 +275,15 @@ fn inner_product_unit_selector( problem: &MatmulProblem, plane_dim: u32, _double_buffering: bool, + tile_size: u32, num_sms: Option, ) -> MatmulSelection { use MatrixLayout::*; let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) { - (RowMajor, RowMajor) => ((1, 1, 4), (1, 1, 1)), - (RowMajor, ColMajor) => ((1, 1, 4), (1, 1, 1)), - (ColMajor, RowMajor) => ((1, 1, 4), (1, 1, 1)), - (ColMajor, ColMajor) => ((1, 1, 4), (1, 1, 1)), + (RowMajor, RowMajor) => ((1, 1, tile_size), (1, 1, 1)), + (RowMajor, ColMajor) => ((1, 1, tile_size), (1, 1, 1)), + (ColMajor, RowMajor) => ((1, 1, tile_size), (1, 1, 1)), + (ColMajor, ColMajor) => ((1, 1, tile_size), (1, 1, 1)), }; selection( @@ -262,7 +291,7 @@ fn inner_product_unit_selector( partition_size, PartitionBuffering::Single, plane_dim, - StageSelection::Fixed { m: 4, n: 8 }, + StageSelection::Fixed { m: plane_dim, n: 1 }, // TODO: most planes does nothing. num_sms, GlobalOrderSelection::Default, StageScaling::Disabled, @@ -274,9 +303,10 @@ fn outer_product_unit_selector( _problem: &MatmulProblem, plane_dim: u32, _double_buffering: bool, + tile_size: u32, num_sms: Option, ) -> MatmulSelection { - let (tile_size, partition_size) = ((4, 4, 1), (1, 1, 1)); + let (tile_size, partition_size) = ((tile_size, tile_size, 1), (1, 1, 1)); selection( tile_size, @@ -295,6 +325,7 @@ fn scalar_product_unit_selector( _problem: &MatmulProblem, plane_dim: u32, _double_buffering: bool, + _tile_size: u32, num_sms: Option, ) -> MatmulSelection { let (tile_size, partition_size) = ((1, 1, 1), (1, 1, 1)); @@ -306,7 +337,7 @@ fn scalar_product_unit_selector( plane_dim, StageSelection::WithPlane { plane_dim, - num_plane: 8, + num_plane: 1, }, num_sms, GlobalOrderSelection::Default, diff --git a/crates/cubecl-matmul/src/kernels/naive.rs b/crates/cubecl-matmul/src/kernels/naive.rs index 716373e5d..9f383d220 100644 --- a/crates/cubecl-matmul/src/kernels/naive.rs +++ b/crates/cubecl-matmul/src/kernels/naive.rs @@ -5,67 +5,84 @@ use cubecl::prelude::*; use cubecl_core::{ self as cubecl, ir::{ElemType, IntKind, UIntKind}, + tensor_line_size_parallel, }; -use cubecl_std::tensor::{MatrixBatchLayout, TensorHandle, into_contiguous, matrix_batch_layout}; +use cubecl_std::tensor::{ + MatrixBatchLayout, View, launch::ViewArg, layout::Coords3d, matrix_batch_layout, +}; -use crate::components::{MatmulAvailabilityError, MatmulSetupError}; +use crate::{ + MatmulInputHandle, MatmulInputHandleRef, + components::{ + MatmulAvailabilityError, MatmulProblem, MatmulSetupError, MatrixLayout, + global::memory::{GlobalLayout, GlobalLayoutConfig, GlobalLayoutLaunch, GlobalScaleLayout}, + }, +}; + +#[cube] +fn load_unrolled( + view: &View, Coords3d>, + pos: Coords3d, + #[comptime] layout: MatrixLayout, + #[comptime] line_size: u32, +) -> Line { + comptime![assert!(line_size <= view.line_size())]; + let view_line_size = view.line_size(); + if comptime![view.line_size() == line_size] { + view[pos] + } else { + let (b, row, col) = pos; + let mut out = Line::empty(line_size); + #[unroll] + for i in range_stepped(0, line_size, view_line_size) { + let pos = match layout { + MatrixLayout::RowMajor => (b, row, col + i), + MatrixLayout::ColMajor => (b, row + i, col), + }; + let value = view[pos]; + #[unroll] + for n in 0..view_line_size { + out[i + n] = value[n]; + } + } + out + } +} #[cube(launch_unchecked)] fn matmul_kernel( - lhs: &Tensor>, - rhs: &Tensor>, + lhs: &View, Coords3d>, + rhs: &View, Coords3d>, out: &mut Tensor, - // number of dimensions not involved in the matmul - #[comptime] num_batches: Option, ) { let rank = out.rank(); - let end = num_batches.unwrap_or_else(|| rank - 2); - let unroll = num_batches.is_some(); - let n_rows = lhs.shape(rank - 2); - let n_cols = rhs.shape(rank - 1); - let mut k = rhs.shape(rank - 2); + let (_, _, k) = lhs.shape(); + let size_m = out.shape(rank - 2); + let size_n = out.shape(rank - 1); - let batch_pos = ABSOLUTE_POS_Z; - let row = CUBE_DIM_X * CUBE_POS_X + UNIT_POS_X; - let col = CUBE_DIM_Y * CUBE_POS_Y + UNIT_POS_Y; + let batch = ABSOLUTE_POS_Z; + let m = ABSOLUTE_POS_X; + let n = ABSOLUTE_POS_Y; - if row >= n_rows || col >= n_cols { + if m >= size_m || n >= size_n { terminate!(); } - let line_size = lhs.line_size(); - - let mut offset_lhs = 0; - let mut offset_rhs = 0; - let offset_out = batch_pos * out.stride(rank - 2) * out.shape(rank - 2); - - #[unroll(unroll)] - for i in 0..end { - let ogwl = offset_out / out.stride(i); - - offset_lhs += ogwl % lhs.shape(i) * lhs.stride(i); - offset_rhs += ogwl % rhs.shape(i) * rhs.stride(i); - } - - offset_lhs /= line_size.runtime(); - offset_rhs /= line_size.runtime(); + let offset_out = batch * out.stride(rank - 2) * out.shape(rank - 2); + let line_size = comptime![Ord::max(lhs.line_size(), rhs.line_size())]; let mut sum = Line::empty(line_size).fill(O::from_int(0)); - k /= line_size.runtime(); - - for i in 0..k { - let lhs_index = row * lhs.stride(rank - 2) / line_size + i + offset_lhs; - let rhs_index = col * rhs.stride(rank - 1) / line_size + i + offset_rhs; + for k in range_stepped(0, k, line_size) { + let lhs = load_unrolled(lhs, (batch, m, k), MatrixLayout::RowMajor, line_size); + let rhs = load_unrolled(rhs, (batch, k, n), MatrixLayout::ColMajor, line_size); - sum += Line::cast_from( - Line::::cast_from(lhs[lhs_index]) * Line::::cast_from(rhs[rhs_index]), - ); + sum += Line::cast_from(Line::::cast_from(lhs) * Line::::cast_from(rhs)); } - let mut out_index = row * out.stride(rank - 2) + col; + let mut out_index = m * out.stride(rank - 2) + n; out_index += offset_out; let unroll_sum = line_size != 1; @@ -86,82 +103,95 @@ fn matmul_kernel( /// Matrix multiplication using memory coalescing algorithm with custom cube dimensions #[allow(clippy::result_large_err)] -pub fn launch_ref( - client: &ComputeClient, - lhs: &TensorHandleRef<'_, R>, - rhs: &TensorHandleRef<'_, R>, +pub fn launch( + client: &ComputeClient, + lhs: MatmulInputHandle, + rhs: MatmulInputHandle, out: &TensorHandleRef<'_, R>, ) -> Result<(), MatmulSetupError> { - let lhs = TensorHandle::::from_ref(lhs); - let rhs = TensorHandle::::from_ref(rhs); - - launch::(client, lhs, rhs, out) + launch_ref::(client, &lhs.as_ref(), &rhs.as_ref(), out) } #[allow(clippy::result_large_err)] -pub fn launch( - client: &ComputeClient, - lhs: TensorHandle, - rhs: TensorHandle, +pub fn launch_ref( + client: &ComputeClient, + lhs: &MatmulInputHandleRef<'_, R>, + rhs: &MatmulInputHandleRef<'_, R>, out: &TensorHandleRef<'_, R>, ) -> Result<(), MatmulSetupError> { let (cube_dim_x, cube_dim_y) = (32, 8); - let ndims = lhs.shape.len(); - let dim1 = ndims - 1; - let dim2 = ndims - 2; + let rank = lhs.shape().len(); + let dim1 = rank - 1; + let dim2 = rank - 2; - let lhs_layout = matrix_batch_layout(&lhs.strides); - let rhs_layout = matrix_batch_layout(&rhs.strides); + let lhs_layout = matrix_batch_layout(lhs.data().strides); + let rhs_layout = matrix_batch_layout(rhs.data().strides); let lhs = if !matches!(lhs_layout, MatrixBatchLayout::Contiguous) { - into_contiguous::(client, &lhs.as_ref()) + lhs.into_contiguous::(client) } else { - lhs + MatmulInputHandle::from_ref(lhs) }; + let lhs = lhs.as_ref(); + let rhs = MatmulInputHandle::from_ref(rhs); // we swap the dimensions to achieve memory-coalescing: // consecutive elements of a column in the original rhs tensor will now be stored // consecutively in memory, which allows to fetch them with fewer memory instructions - let correct_rhs_layout = |mut rhs: TensorHandle| { - let rhs_original_shape = rhs.shape.to_vec(); - rhs.strides.swap(dim1, dim2); - rhs.shape.swap(dim1, dim2); - - let mut rhs = into_contiguous::(client, &rhs.as_ref()); + let correct_rhs_layout = |mut rhs: MatmulInputHandle| { + rhs.swap_dims(dim1, dim2); - rhs.strides.swap(dim1, dim2); - rhs.shape.swap(dim1, dim2); + let mut rhs = rhs.as_ref().into_contiguous::(client); - (rhs_original_shape, rhs) + rhs.swap_dims(dim1, dim2); + rhs }; - let (rhs_original_shape, rhs) = match rhs_layout { + let rhs = match rhs_layout { MatrixBatchLayout::Contiguous => correct_rhs_layout(rhs), MatrixBatchLayout::MildlyPermuted { transposed, batch_swap, } => { if transposed && !batch_swap { - let rhs_original_shape = rhs.shape.to_vec(); - (rhs_original_shape, rhs) + rhs } else { correct_rhs_layout(rhs) } } MatrixBatchLayout::HighlyPermuted => correct_rhs_layout(rhs), }; - - let cube_count = simple_cube_count( - &lhs.shape, - &rhs_original_shape, - out.shape, - cube_dim_x, - cube_dim_y, - )?; - - let vectorization_factor = match lhs.shape[ndims - 1] % 4 == 0 { - true => 4, - false => 1, + let rhs = rhs.as_ref(); + + let lhs_shape = lhs.shape(); + let rhs_shape = rhs.shape(); + let out_shape = out.shape; + + let cube_count = simple_cube_count(lhs_shape, rhs_shape, out_shape, cube_dim_x, cube_dim_y)?; + + let elem = EI::as_type_native_unchecked(); + let lhs_line_size = tensor_line_size_parallel( + R::io_optimized_line_sizes(&elem), + lhs.data().shape, + lhs.data().strides, + rank - 1, + ); + let rhs_line_size = tensor_line_size_parallel( + R::io_optimized_line_sizes(&elem), + rhs.data().shape, + rhs.data().strides, + rank - 2, + ); + + let problem = MatmulProblem { + m: out_shape[rank - 2], + n: out_shape[rank - 1], + k: lhs_shape[rank - 1], + lhs_batches: lhs_shape[..rank - 2].to_vec(), + rhs_batches: rhs_shape[..rank - 2].to_vec(), + out_batches: out_shape[..rank - 2].to_vec(), + lhs_layout: MatrixLayout::RowMajor, + rhs_layout: MatrixLayout::ColMajor, }; let launch = match EI::as_type_native_unchecked().elem_type() { @@ -173,15 +203,66 @@ pub fn launch( _ => matmul_kernel::launch_unchecked::, }; + fn view<'a, R: Runtime>( + client: &ComputeClient, + handle: &'a MatmulInputHandleRef<'a, R>, + layout: MatrixLayout, + line_size: u8, + problem: &MatmulProblem, + ) -> ViewArg<'a, Coords3d, R> { + // Checks off, other properties are unused + let config = GlobalLayoutConfig { + matrix_layout: layout, + ..Default::default() + }; + match handle { + MatmulInputHandleRef::Normal(handle) => { + let layout = GlobalLayoutLaunch::from_handle_batched( + client, handle, problem, line_size, config, + ); + ViewArg::new::(handle.as_array_arg(line_size), layout) + } + MatmulInputHandleRef::Quantized { + data, + scale, + shape, + scheme, + } => { + let (data_layout, scales_layout) = GlobalLayoutLaunch::from_quantized_handle( + client, data, scale, shape, problem, **scheme, line_size, config, + ); + let data_view = + ViewArg::new::(data.as_array_arg(line_size), data_layout); + let scales_view = + ViewArg::new::(scale.as_array_arg(1), scales_layout); + ViewArg::new_quantized(data_view, scales_view, **scheme) + } + } + } + + let lhs_view = view( + client, + &lhs, + MatrixLayout::RowMajor, + lhs_line_size, + &problem, + ); + let rhs_view = view( + client, + &rhs, + MatrixLayout::ColMajor, + rhs_line_size, + &problem, + ); + unsafe { launch( client, cube_count, CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1), - lhs.as_arg(vectorization_factor), - rhs.as_arg(vectorization_factor), + lhs_view, + rhs_view, out.as_tensor_arg(1), - Some(ndims as u32 - 2), ); }; diff --git a/crates/cubecl-matmul/src/tests/layered/macros/common/problem/problem_size.rs b/crates/cubecl-matmul/src/tests/layered/macros/common/problem/problem_size.rs index 8537c9fdc..7fe90ea9d 100644 --- a/crates/cubecl-matmul/src/tests/layered/macros/common/problem/problem_size.rs +++ b/crates/cubecl-matmul/src/tests/layered/macros/common/problem/problem_size.rs @@ -16,6 +16,7 @@ macro_rules! testgen_matmul_problem_size { k: 256, lhs_batches: vec![2], rhs_batches: vec![2], + out_batches: vec![2], lhs_layout: $layouts.0, rhs_layout: $layouts.1, } @@ -36,6 +37,7 @@ macro_rules! testgen_matmul_problem_size { k: 100, lhs_batches: vec![2], rhs_batches: vec![2], + out_batches: vec![2], lhs_layout: $layouts.0, rhs_layout: $layouts.1, } @@ -57,6 +59,7 @@ macro_rules! testgen_matmul_problem_size { k: 100, lhs_batches: vec![2], rhs_batches: vec![2], + out_batches: vec![2], lhs_layout: $layouts.0, rhs_layout: $layouts.1, } @@ -78,6 +81,7 @@ macro_rules! testgen_matmul_problem_size { k: 100, lhs_batches: vec![2], rhs_batches: vec![2], + out_batches: vec![2], lhs_layout: $layouts.0, rhs_layout: $layouts.1, } @@ -98,6 +102,7 @@ macro_rules! testgen_matmul_problem_size { k: 17, lhs_batches: vec![2], rhs_batches: vec![2], + out_batches: vec![2], lhs_layout: $layouts.0, rhs_layout: $layouts.1, } @@ -118,6 +123,7 @@ macro_rules! testgen_matmul_problem_size { k: 256, lhs_batches: vec![2], rhs_batches: vec![2], + out_batches: vec![2], lhs_layout: $layouts.0, rhs_layout: $layouts.1, } diff --git a/crates/cubecl-matmul/src/tests/layered/macros/plane_accelerated/mod.rs b/crates/cubecl-matmul/src/tests/layered/macros/plane_accelerated/mod.rs index e3c13ef06..171fe9d59 100644 --- a/crates/cubecl-matmul/src/tests/layered/macros/plane_accelerated/mod.rs +++ b/crates/cubecl-matmul/src/tests/layered/macros/plane_accelerated/mod.rs @@ -8,7 +8,7 @@ macro_rules! testgen_matmul_plane_accelerated { mod matmul_plane_accelerated { use super::*; use cubecl_matmul::components::tile::io::Filled; - type TMM = $crate::components::tile::accelerated::AcceleratedMatmul; + type TMM = $crate::components::tile::cmma::CmmaMatmul; #[cfg(feature = "matmul_tests_plane")] $crate::testgen_matmul_plane_accelerated_algorithm!(); diff --git a/crates/cubecl-matmul/src/tests/layered/macros/tma/mod.rs b/crates/cubecl-matmul/src/tests/layered/macros/tma/mod.rs index 864759828..393672641 100644 --- a/crates/cubecl-matmul/src/tests/layered/macros/tma/mod.rs +++ b/crates/cubecl-matmul/src/tests/layered/macros/tma/mod.rs @@ -8,7 +8,7 @@ macro_rules! testgen_matmul_tma { mod matmul_tma { use super::*; use cubecl_matmul::components::tile::io::Filled; - type TMM = $crate::components::tile::accelerated::AcceleratedMatmul; + type TMM = $crate::components::tile::cmma::CmmaMatmul; #[cfg(feature = "matmul_tests_tma")] $crate::testgen_matmul_tma_algorithm!(); diff --git a/crates/cubecl-matmul/src/tests/layered/matmul_test_launcher.rs b/crates/cubecl-matmul/src/tests/layered/matmul_test_launcher.rs index 19d2d135f..e0a60b9ad 100644 --- a/crates/cubecl-matmul/src/tests/layered/matmul_test_launcher.rs +++ b/crates/cubecl-matmul/src/tests/layered/matmul_test_launcher.rs @@ -4,14 +4,20 @@ use cubecl_core::{ }; use cubecl_core::{prelude::*, server::AllocationDescriptor}; -use crate::components::MatrixLayout; -use crate::components::batch::{BatchConfig, BatchMatmulFamily}; -use crate::components::global::args::TensorInputsLaunch; -use crate::components::{AvailableLineSizes, MatmulIdent}; +use crate::components::global::args::{ConcreteOutputFactory, TensorOutput}; use crate::components::{MatmulProblem, MatmulSelection}; +use crate::components::{MatrixLayout, global::args::ConcreteInputsFactory}; +use crate::components::{ + batch::{BatchConfig, BatchMatmulFamily}, + global::args::TensorInputs, +}; use crate::kernels::layered::Algorithm; use crate::tests::test_utils::Sample; use crate::tests::test_utils::TestPrecision; +use crate::{ + MatmulInputHandleRef, + components::{AccG, AvailableLineSizes, MatmulIdent}, +}; #[derive(Debug)] pub struct TensorRawParts { @@ -25,7 +31,7 @@ pub struct TensorRawParts { /// Test the correctness of the specified Matmul on the given device, /// against a naive CPU implementation over the given problem pub fn test_matmul_algorithm( - client: ComputeClient, + client: ComputeClient, problem: MatmulProblem, selection: MatmulSelection, ) where @@ -47,10 +53,10 @@ pub fn test_matmul_algorithm( let rhs = tensor_raw_parts::(&client, &problem, MatmulIdent::Rhs); let out = tensor_raw_parts::(&client, &problem, MatmulIdent::Out); - let line_sizes = AvailableLineSizes::from_types::( - &P::EG::as_type_native_unchecked(), - &P::EG::as_type_native_unchecked(), - &P::EG::as_type_native_unchecked(), + let line_sizes = AvailableLineSizes::from_type_sizes::( + size_of::(), + size_of::(), + size_of::(), ); let line_sizes = A::filter_line_sizes(line_sizes); let line_sizes = line_sizes @@ -91,39 +97,38 @@ pub fn test_matmul_algorithm( client.properties().hardware.max_cube_count.clone(), ); + let elem_size = size_of::(); + let lhs_handle = MatmulInputHandleRef::Normal(unsafe { + TensorHandleRef::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, elem_size) + }); + let rhs_handle = MatmulInputHandleRef::Normal(unsafe { + TensorHandleRef::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, elem_size) + }); + let out_handle = unsafe { + TensorHandleRef::from_raw_parts(&out.handle, &out.strides, &out.shape, elem_size) + }; + unsafe { A::BatchMatmul::launch_unchecked::( &client, config.cube_dim(), cube_count_plan.resolve(), - TensorInputsLaunch::new( - TensorArg::::from_raw_parts::( - &lhs.handle, - &lhs.strides, - &lhs.shape, - line_sizes.lhs, - ), - lhs.scale - .as_ref() - .map(|it| TensorArg::::from_raw_parts::(it, &[1], &[1], 1)) - .into(), - TensorArg::::from_raw_parts::( - &rhs.handle, - &rhs.strides, - &rhs.shape, - line_sizes.rhs, - ), - rhs.scale - .as_ref() - .map(|it| TensorArg::::from_raw_parts::(it, &[1], &[1], 1)) - .into(), - None.into(), + TensorInputs::create( + &client, + &lhs_handle, + &rhs_handle, + &selection, + &problem, + &line_sizes, + config, ), - TensorArg::::from_raw_parts::( - &out.handle, - &out.strides, - &out.shape, - line_sizes.out, + TensorOutput::>::create( + &client, + &out_handle, + &selection, + &problem, + &line_sizes, + config, ), cube_count_plan.as_args(), config, @@ -142,7 +147,7 @@ pub fn test_matmul_algorithm( } fn tensor_raw_parts( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, ident: MatmulIdent, ) -> TensorRawParts { diff --git a/crates/cubecl-matmul/src/tests/layered/tma_test_launcher.rs b/crates/cubecl-matmul/src/tests/layered/tma_test_launcher.rs index 51b9e7a7a..f3cac3c78 100644 --- a/crates/cubecl-matmul/src/tests/layered/tma_test_launcher.rs +++ b/crates/cubecl-matmul/src/tests/layered/tma_test_launcher.rs @@ -1,10 +1,6 @@ -use std::marker::PhantomData; - use cubecl_core::prelude::*; use cubecl_core::{CubeElement, server::Allocation}; -use crate::MatmulInputHandleRef; -use crate::components::AvailableLineSizes; use crate::components::MatmulIdent; use crate::components::MatmulProblem; use crate::components::MatmulSelection; @@ -13,16 +9,21 @@ use crate::components::batch::BatchConfig; use crate::components::batch::BatchMatmulFamily; use crate::components::global::args::TensorMapArgs; use crate::components::global::args::{ConcreteInputsFactory, TensorMapInputs}; +use crate::components::{AccG, AvailableLineSizes}; use crate::kernels::layered::Algorithm; use crate::tests::test_utils::Sample; use crate::tests::test_utils::TestPrecision; +use crate::{ + MatmulInputHandleRef, + components::global::args::{ConcreteOutputFactory, TensorOutput}, +}; use super::matmul_test_launcher::{TensorRawParts, tensor_size, transpose}; /// Test the correctness of the specified Matmul on the given device, /// against a naive CPU implementation over the given problem pub fn test_tma_matmul_algorithm( - client: ComputeClient, + client: ComputeClient, problem: MatmulProblem, selection: MatmulSelection, ) where @@ -45,25 +46,20 @@ pub fn test_tma_matmul_algorithm( let out = tensor_raw_parts::(&client, &problem, MatmulIdent::Out); let elem_size = size_of::(); - let lhs_handle = MatmulInputHandleRef::Normal(TensorHandleRef { - handle: &lhs.handle, - strides: &lhs.strides, - shape: &lhs.shape, - elem_size, - runtime: PhantomData, + let lhs_handle = MatmulInputHandleRef::Normal(unsafe { + TensorHandleRef::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, elem_size) }); - let rhs_handle = MatmulInputHandleRef::Normal(TensorHandleRef { - handle: &rhs.handle, - strides: &rhs.strides, - shape: &rhs.shape, - elem_size, - runtime: PhantomData, + let rhs_handle = MatmulInputHandleRef::Normal(unsafe { + TensorHandleRef::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, elem_size) }); + let out_handle = unsafe { + TensorHandleRef::from_raw_parts(&out.handle, &out.strides, &out.shape, elem_size) + }; - let line_sizes = AvailableLineSizes::from_types::( - &P::EG::as_type_native_unchecked(), - &P::EG::as_type_native_unchecked(), - &P::EG::as_type_native_unchecked(), + let line_sizes = AvailableLineSizes::from_type_sizes::( + size_of::(), + size_of::(), + size_of::(), ); let line_sizes = A::filter_line_sizes(line_sizes); let line_sizes = line_sizes @@ -92,26 +88,30 @@ pub fn test_tma_matmul_algorithm( let line_sizes = config.line_sizes(); - let inputs = - TensorMapInputs::create(&lhs_handle, &rhs_handle, &selection, &problem, &line_sizes); - let output = unsafe { - TensorArg::::from_raw_parts::( - &out.handle, - &out.strides, - &out.shape, - line_sizes.out, - ) - }; + let inputs = TensorMapInputs::create( + &client, + &lhs_handle, + &rhs_handle, + &selection, + &problem, + &line_sizes, + config, + ); + let output = TensorOutput::>::create( + &client, + &out_handle, + &selection, + &problem, + &line_sizes, + config, + ); let cube_count_plan = config.hypercube_config().cube_count_plan( &problem, client.properties().hardware.max_cube_count.clone(), ); unsafe { - A::BatchMatmul::launch_unchecked::< - ((P::EG, P::EG, P::EG, P::ES, P::ES, P::EA), TensorMapArgs), - R, - >( + A::BatchMatmul::launch_unchecked::<(P::MP, TensorMapArgs), R>( &client, config.cube_dim(), cube_count_plan.resolve(), @@ -134,7 +134,7 @@ pub fn test_tma_matmul_algorithm( } fn tensor_raw_parts( - client: &ComputeClient, + client: &ComputeClient, problem: &MatmulProblem, ident: MatmulIdent, ) -> TensorRawParts { diff --git a/crates/cubecl-matmul/src/tests/naive/macros.rs b/crates/cubecl-matmul/src/tests/naive/macros.rs index 1325af61c..86fd9eac6 100644 --- a/crates/cubecl-matmul/src/tests/naive/macros.rs +++ b/crates/cubecl-matmul/src/tests/naive/macros.rs @@ -21,6 +21,13 @@ macro_rules! testgen_matmul_simple { ) } + #[test] + pub fn test_odd() { + cubecl_matmul::tests::naive::tests::test_odd::( + &Default::default(), + ) + } + #[test] pub fn test_simple_matmul_large() { cubecl_matmul::tests::naive::tests::test_large::( diff --git a/crates/cubecl-matmul/src/tests/naive/tests.rs b/crates/cubecl-matmul/src/tests/naive/tests.rs index bb16e41b2..4a85b661b 100644 --- a/crates/cubecl-matmul/src/tests/naive/tests.rs +++ b/crates/cubecl-matmul/src/tests/naive/tests.rs @@ -3,6 +3,7 @@ use std::fmt::Display; use cubecl_core::{CubeElement, Runtime, prelude::Float}; use crate::{ + MatmulInputHandle, kernels::naive, tests::{ naive::utils::MatmulTestCase, @@ -22,6 +23,17 @@ pub fn test_small(device: test_simple::(case, device); } +pub fn test_odd(device: &R::Device) { + let case = MatmulTestCase { + m: 1, + k: 101, + n: 255, + batch: 1, + }; + + test_simple::(case, device); +} + pub fn test_large(device: &R::Device) { let case = MatmulTestCase { m: 256, @@ -70,7 +82,13 @@ fn test_simple( let expected = case.matmul_cpu::(&lhs, &rhs, &client); let out: TensorHandle = case.empty_out(&client); - naive::launch::(&client, lhs, rhs, &out.as_ref()).unwrap(); + naive::launch::( + &client, + MatmulInputHandle::Normal(lhs), + MatmulInputHandle::Normal(rhs), + &out.as_ref(), + ) + .unwrap(); if let Err(e) = assert_equals_approx::( &client, diff --git a/crates/cubecl-matmul/src/tests/naive/utils.rs b/crates/cubecl-matmul/src/tests/naive/utils.rs index eebe59fcd..f2aebb258 100644 --- a/crates/cubecl-matmul/src/tests/naive/utils.rs +++ b/crates/cubecl-matmul/src/tests/naive/utils.rs @@ -15,7 +15,7 @@ impl MatmulTestCase { &self, lhs: &TensorHandle, rhs: &TensorHandle, - client: &ComputeClient, + client: &ComputeClient, ) -> Vec { let lhs_binding = &client.read_one_tensor(lhs.handle.clone().copy_descriptor( &lhs.shape, @@ -60,28 +60,28 @@ impl MatmulTestCase { pub(crate) fn random_lhs( &self, - client: &ComputeClient, + client: &ComputeClient, ) -> TensorHandle { self.random_tensor(client, vec![self.batch, self.m, self.k]) } pub(crate) fn random_rhs( &self, - client: &ComputeClient, + client: &ComputeClient, ) -> TensorHandle { self.random_tensor(client, vec![self.batch, self.k, self.n]) } pub(crate) fn empty_out( &self, - client: &ComputeClient, + client: &ComputeClient, ) -> TensorHandle { TensorHandle::empty(client, vec![self.batch, self.m, self.n]) } pub(crate) fn random_tensor( &self, - client: &ComputeClient, + client: &ComputeClient, shape: Vec, ) -> TensorHandle { F::sample::(client, &shape, 999) diff --git a/crates/cubecl-matmul/src/tests/test_utils.rs b/crates/cubecl-matmul/src/tests/test_utils.rs index 3895285bb..755f06fe2 100644 --- a/crates/cubecl-matmul/src/tests/test_utils.rs +++ b/crates/cubecl-matmul/src/tests/test_utils.rs @@ -27,7 +27,7 @@ pub trait TestPrecision { lhs: &[Self::EG], rhs: &[Self::EG], problem: &MatmulProblem, - client: &ComputeClient, + client: &ComputeClient, out: server::Handle, shape: &[usize], strides: &[usize], @@ -49,7 +49,7 @@ where lhs: &[EG], rhs: &[EG], problem: &MatmulProblem, - client: &ComputeClient, + client: &ComputeClient, out: server::Handle, shape: &[usize], strides: &[usize], @@ -92,7 +92,7 @@ where /// Compares the content of a handle to a given slice of f32. pub(crate) fn assert_equals_approx( - client: &ComputeClient, + client: &ComputeClient, output: server::Handle, shape: &[usize], strides: &[usize], @@ -228,7 +228,7 @@ impl CastInto for i32 { pub trait Sample: Sized + CubePrimitive { fn sample( - client: &ComputeClient, + client: &ComputeClient, shape: &[usize], seed: u64, ) -> TensorHandle; @@ -239,7 +239,7 @@ macro_rules! sample_float { $( impl Sample for $t { - fn sample(client: &ComputeClient, shape: &[usize], seed: u64) -> TensorHandle:: { + fn sample(client: &ComputeClient, shape: &[usize], seed: u64) -> TensorHandle:: { cubecl_random::seed(seed); let output = TensorHandle::::empty(client, shape.to_vec()); @@ -260,7 +260,7 @@ sample_float!(u8); impl Sample for flex32 { fn sample( - client: &ComputeClient, + client: &ComputeClient, shape: &[usize], seed: u64, ) -> TensorHandle { @@ -280,7 +280,7 @@ impl Sample for flex32 { impl Sample for tf32 { fn sample( - client: &ComputeClient, + client: &ComputeClient, shape: &[usize], seed: u64, ) -> TensorHandle { diff --git a/crates/cubecl-matmul/src/tune_key.rs b/crates/cubecl-matmul/src/tune_key.rs index c9298d539..d82737a7a 100644 --- a/crates/cubecl-matmul/src/tune_key.rs +++ b/crates/cubecl-matmul/src/tune_key.rs @@ -15,6 +15,11 @@ pub struct MatmulAutotuneKey { pub analysis: MatmulAutotuneAnalysis, } +/// Maximum factor relevant for strides. Currently set to 2^5, or 32 since that's the maximum align +/// relevant for CUDA (for interleaved tensors). This can be changed if other platforms or features +/// require more. +const MAX_STRIDE_FACTOR: u32 = 5; + #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] pub struct MatmulProblemDefinition { #[autotune(anchor)] @@ -24,10 +29,14 @@ pub struct MatmulProblemDefinition { #[autotune(anchor)] pub k: usize, pub lhs_pow2_factor: u8, + /// Power of two that lhs strides are aligned to + pub lhs_stride_factor: u8, pub rhs_pow2_factor: u8, - pub elem_lhs: ElemType, - pub elem_rhs: ElemType, - pub elem_out: ElemType, + /// Power of two that rhs strides are aligned to + pub rhs_stride_factor: u8, + pub elem_lhs: MatmulElemType, + pub elem_rhs: MatmulElemType, + pub elem_out: MatmulElemType, pub matrix_layout_lhs: MatrixBatchLayout, pub matrix_layout_rhs: MatrixBatchLayout, } @@ -67,19 +76,25 @@ pub fn should_tune_double_buffering(fused: bool, key: &MatmulAutotuneKey) -> boo } } +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] +pub struct MatmulElemType { + pub elem: ElemType, + pub quantized: bool, +} + impl MatmulAutotuneKey { /// Create the autotune key based on the shape of both lhs and rhs as well as the element type /// used for the calculation. #[allow(clippy::too_many_arguments)] pub fn generate( - _client: &ComputeClient, + _client: &ComputeClient, lhs_shape: &[usize], rhs_shape: &[usize], lhs_strides: &[usize], rhs_strides: &[usize], - elem_lhs: ElemType, - elem_rhs: ElemType, - elem_out: ElemType, + elem_lhs: MatmulElemType, + elem_rhs: MatmulElemType, + elem_out: MatmulElemType, ) -> MatmulAutotuneKey { let ndims = lhs_shape.len(); let m = lhs_shape[ndims - 2]; @@ -112,12 +127,33 @@ impl MatmulAutotuneKey { MatrixBatchLayout::HighlyPermuted => 0, }; + let lhs_stride_factor = match matrix_layout_lhs { + MatrixBatchLayout::Contiguous => stride_align(lhs_strides, ndims - 1, elem_lhs.elem), + // TMA can't handle discontiguous batches because they're all combined into one dim + MatrixBatchLayout::MildlyPermuted { + transposed: true, + batch_swap: false, + } => stride_align(lhs_strides, ndims - 2, elem_lhs.elem), + _ => 0, + }; + let rhs_stride_factor = match matrix_layout_rhs { + MatrixBatchLayout::Contiguous => stride_align(rhs_strides, ndims - 1, elem_rhs.elem), + // TMA can't handle discontiguous batches because they're all combined into one dim + MatrixBatchLayout::MildlyPermuted { + transposed: true, + batch_swap: false, + } => stride_align(rhs_strides, ndims - 2, elem_rhs.elem), + _ => 0, + }; + let definition = MatmulProblemDefinition::new( m, n, k, lhs_pow2_factor, + lhs_stride_factor, rhs_pow2_factor, + rhs_stride_factor, elem_lhs, elem_rhs, elem_out, @@ -133,13 +169,21 @@ impl MatmulAutotuneKey { } } +/// Defines the non-contiguous stride alignment in terms of powers of two +fn stride_align(strides: &[usize], exclude_dim: usize, elem: ElemType) -> u8 { + let max = MAX_STRIDE_FACTOR; + let factor = strides + .iter() + .enumerate() + .filter(|(i, _)| *i != exclude_dim) + .map(|(_, it)| (*it * elem.size_bits()) / 8) + .map(|it| it.trailing_zeros()) + .min() + .unwrap_or(max); + factor.min(max) as u8 +} + /// Defines the potential vectorization. fn pow2_factor(axis: usize) -> u8 { - for i in (1..4).rev() { - if axis.is_multiple_of(2usize.pow(i as u32)) { - return i; - } - } - - 0 + axis.trailing_zeros().min(4) as u8 } diff --git a/crates/cubecl-opt/Cargo.toml b/crates/cubecl-opt/Cargo.toml index b279b431d..bcd746447 100644 --- a/crates/cubecl-opt/Cargo.toml +++ b/crates/cubecl-opt/Cargo.toml @@ -15,9 +15,9 @@ default = ["std", "cubecl-common/default", "cubecl-ir/default"] std = ["cubecl-common/std"] [dependencies] -cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false } -cubecl-ir = { path = "../cubecl-ir", version = "0.7.0", default-features = false } -cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false } +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false } +cubecl-ir = { path = "../cubecl-ir", version = "0.9.0", default-features = false } +cubecl-core = { path = "../cubecl-core", version = "0.9.0", default-features = false } float-ord = "0.3" log = "0.4" diff --git a/crates/cubecl-opt/src/analyses/liveness.rs b/crates/cubecl-opt/src/analyses/liveness.rs index 4a43e174a..379b9a948 100644 --- a/crates/cubecl-opt/src/analyses/liveness.rs +++ b/crates/cubecl-opt/src/analyses/liveness.rs @@ -107,7 +107,7 @@ fn calculate_block_sets(opt: &mut Optimizer, block: NodeIndex) -> BlockSets { /// Shared memory liveness analysis and allocation pub mod shared { - use cubecl_ir::{Operation, Type, Variable, VariableKind}; + use cubecl_ir::{Marker, Operation, Type, Variable, VariableKind}; use crate::Uniformity; @@ -331,10 +331,10 @@ pub mod shared { } }); - if let Operation::Free(Variable { + if let Operation::Marker(Marker::Free(Variable { kind: VariableKind::SharedMemory { id, .. }, .. - }) = &op.operation + })) = &op.operation { kill.insert(*id); generated.remove(id); diff --git a/crates/cubecl-opt/src/analyses/uniformity.rs b/crates/cubecl-opt/src/analyses/uniformity.rs index 272bc459f..0f57e5a3b 100644 --- a/crates/cubecl-opt/src/analyses/uniformity.rs +++ b/crates/cubecl-opt/src/analyses/uniformity.rs @@ -72,6 +72,15 @@ impl Uniformity { self.is_var_uniform(op.lhs) || self.is_var_uniform(op.rhs); self.mark_uniformity(out, input_uniform && block_uniform)?; } + // Shuffle operations: if offset/mask/delta is uniform, output is non-uniform + // (each thread gets a different value). If value is uniform, output is uniform. + Plane::Shuffle(op) + | Plane::ShuffleXor(op) + | Plane::ShuffleUp(op) + | Plane::ShuffleDown(op) => { + let input_uniform = self.is_var_uniform(op.lhs); + self.mark_uniformity(out, input_uniform && block_uniform)?; + } }, Operation::Synchronization(sync) => match sync { Synchronization::SyncCube | Synchronization::SyncStorage => { diff --git a/crates/cubecl-opt/src/control_flow.rs b/crates/cubecl-opt/src/control_flow.rs index 80545ee12..ed1df04c2 100644 --- a/crates/cubecl-opt/src/control_flow.rs +++ b/crates/cubecl-opt/src/control_flow.rs @@ -5,7 +5,7 @@ use std::mem::transmute; use crate::{BasicBlock, BlockUse, NodeIndex, Optimizer}; use cubecl_ir::{ Arithmetic, BinaryOperator, Branch, Comparison, ConstantScalarValue, ElemType, If, IfElse, - Instruction, Loop, Operation, RangeLoop, Switch, Type, Variable, VariableKind, + Instruction, Loop, Marker, Operation, RangeLoop, Switch, Type, Variable, VariableKind, }; use petgraph::{Direction, graph::EdgeIndex, visit::EdgeRef}; use stable_vec::StableVec; @@ -372,7 +372,8 @@ impl Optimizer { } fn split_free_inner(&mut self) -> bool { - let is_free = |inst: &Instruction| matches!(inst.operation, Operation::Free(_)); + let is_free = + |inst: &Instruction| matches!(inst.operation, Operation::Marker(Marker::Free(_))); for block in self.node_ids() { let ops = self.block(block).ops.clone(); diff --git a/crates/cubecl-opt/src/gvn/analysis.rs b/crates/cubecl-opt/src/gvn/analysis.rs index 56fd03f45..ef7955374 100644 --- a/crates/cubecl-opt/src/gvn/analysis.rs +++ b/crates/cubecl-opt/src/gvn/analysis.rs @@ -4,7 +4,7 @@ use std::{ ops::Deref, }; -use crate::{NodeIndex, analyses::Analysis}; +use crate::{ControlFlow, NodeIndex, analyses::Analysis}; use smallvec::SmallVec; use crate::{ @@ -168,8 +168,25 @@ impl GvnState { let successors = opt.successors(current); // Since we have no critical edges, if successors > 1 then they must have only one entry, // So no phi nodes. + // + // Loops are a special case because the conservative nature of PRE normally prevents loop + // invariants from being moved out of the loop. Since only side-effect free values are + // numbered, we can safely treat loops as being executed at least once. The worst case is + // some expressions are executed unnecessarily, but for a loop that never runs, performance + // is likely secondary. #[allow(clippy::comparison_chain)] - if successors.len() > 1 { + if let ControlFlow::Loop { body, .. } | ControlFlow::LoopBreak { body, .. } = + opt.block(current).control_flow.borrow().clone() + { + let antic_in_succ = &self.block_sets[&body].antic_in; + let phi_gen = &self.block_sets[&body].phi_gen; + let result = + phi_translate(opt, phi_gen, antic_in_succ, body, current, &mut self.values); + if self.block_sets[¤t].antic_out != result { + changed = true; + } + self.block_sets.get_mut(¤t).unwrap().antic_out = result; + } else if successors.len() > 1 { let potential_out = &self.block_sets[&successors[0]].antic_in; let mut result = LinkedList::new(); let rest = successors[1..] diff --git a/crates/cubecl-opt/src/gvn/numbering.rs b/crates/cubecl-opt/src/gvn/numbering.rs index 23e2bd9fe..4820884e9 100644 --- a/crates/cubecl-opt/src/gvn/numbering.rs +++ b/crates/cubecl-opt/src/gvn/numbering.rs @@ -128,7 +128,7 @@ impl ValueTable { | Operation::NonSemantic(_) | Operation::Barrier(_) | Operation::Tma(_) - | Operation::Free(_) => Err(None), + | Operation::Marker(_) => Err(None), } } @@ -190,18 +190,7 @@ impl ValueTable { out: Variable, ) -> Result<(Expression, Option), Option> { let (expr, val) = match operator { - Operator::Index(op) | Operator::UncheckedIndex(op) => { - let out_val = value_of_var(&out); - if !op.list.is_immutable() { - Err(out_val)? - } - let item = out.ty; - let lhs = self.lookup_or_add_var(&op.list)?; - let rhs = self.lookup_or_add_var(&op.index)?; - let id = OpCode::Operator(operator.op_code()); - let expr = Instruction::new(id, &[lhs, rhs], item); - (expr.into(), out_val) - } + Operator::Index(_) | Operator::UncheckedIndex(_) => Err(value_of_var(&out))?, Operator::IndexAssign(_) | Operator::UncheckedIndexAssign(_) diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index 2e6d8f663..5576d33e2 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -54,7 +54,7 @@ impl Optimizer { Operation::NonSemantic(non_semantic) => { self.visit_nonsemantic(non_semantic, visit_read) } - Operation::Free(_) => {} + Operation::Marker(_) => {} } } @@ -106,10 +106,11 @@ impl Optimizer { | Arithmetic::Degrees(unary_operator) | Arithmetic::Radians(unary_operator) | Arithmetic::Sqrt(unary_operator) - | Arithmetic::Rsqrt(unary_operator) + | Arithmetic::InverseSqrt(unary_operator) | Arithmetic::Round(unary_operator) | Arithmetic::Floor(unary_operator) | Arithmetic::Ceil(unary_operator) + | Arithmetic::Trunc(unary_operator) | Arithmetic::Erf(unary_operator) | Arithmetic::Recip(unary_operator) | Arithmetic::Neg(unary_operator) @@ -273,7 +274,11 @@ impl Optimizer { fn visit_plane(&mut self, plane: &mut Plane, visit_read: impl FnMut(&mut Self, &mut Variable)) { match plane { Plane::Elect => {} - Plane::Broadcast(binary_operator) => self.visit_binop(binary_operator, visit_read), + Plane::Broadcast(binary_operator) + | Plane::Shuffle(binary_operator) + | Plane::ShuffleXor(binary_operator) + | Plane::ShuffleUp(binary_operator) + | Plane::ShuffleDown(binary_operator) => self.visit_binop(binary_operator, visit_read), Plane::All(unary_operator) | Plane::Any(unary_operator) | Plane::Sum(unary_operator) diff --git a/crates/cubecl-opt/src/passes/constant_prop.rs b/crates/cubecl-opt/src/passes/constant_prop.rs index af135af35..d0d7f5694 100644 --- a/crates/cubecl-opt/src/passes/constant_prop.rs +++ b/crates/cubecl-opt/src/passes/constant_prop.rs @@ -443,20 +443,17 @@ fn try_const_eval_arithmetic(op: &mut Arithmetic) -> Option } } Arithmetic::Sqrt(op) => const_eval_float!(op.input; num::Float::sqrt), - Arithmetic::Rsqrt(op) => { - use ConstantScalarValue::*; - if let Some(input) = op.input.as_const() { - match input { - Float(input, kind) => Some(ConstantScalarValue::Float(1. / input.sqrt(), kind)), - _ => unreachable!(), - } - } else { - None - } + Arithmetic::InverseSqrt(op) => { + let sqrt = const_eval_float!(op.input; num::Float::sqrt)?; + let ConstantScalarValue::Float(val, kind) = sqrt else { + unreachable!() + }; + Some(ConstantScalarValue::Float(1.0 / val, kind)) } Arithmetic::Round(op) => const_eval_float!(op.input; num::Float::round), Arithmetic::Floor(op) => const_eval_float!(op.input; num::Float::floor), Arithmetic::Ceil(op) => const_eval_float!(op.input; num::Float::ceil), + Arithmetic::Trunc(op) => const_eval_float!(op.input; num::Float::trunc), Arithmetic::Recip(op) => const_eval_float!(op.input; num::Float::recip), Arithmetic::Neg(op) => { use ConstantScalarValue::*; diff --git a/crates/cubecl-quant/Cargo.toml b/crates/cubecl-quant/Cargo.toml index 576544d67..fb0cb59cf 100644 --- a/crates/cubecl-quant/Cargo.toml +++ b/crates/cubecl-quant/Cargo.toml @@ -17,12 +17,12 @@ kernels = ["std", "cubecl-core", "cubecl-runtime", "cubecl-std"] std = ["cubecl-core?/std", "cubecl-runtime?/std"] [dependencies] -cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false, features = [ +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false, features = [ "fp8", ] } -cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false, optional = true } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false, optional = true } -cubecl-std = { path = "../cubecl-std", version = "0.7.0", default-features = false, optional = true } +cubecl-core = { path = "../cubecl-core", version = "0.9.0", default-features = false, optional = true } +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false, optional = true } +cubecl-std = { path = "../cubecl-std", version = "0.9.0", default-features = false, optional = true } half.workspace = true serde = { workspace = true } diff --git a/crates/cubecl-quant/src/dequantize.rs b/crates/cubecl-quant/src/dequantize.rs index 653277fdc..26c9ea06f 100644 --- a/crates/cubecl-quant/src/dequantize.rs +++ b/crates/cubecl-quant/src/dequantize.rs @@ -165,7 +165,7 @@ fn dequantize_symmetric_native_kernel( - client: &ComputeClient, + client: &ComputeClient, values: &TensorHandleRef, output: &TensorHandleRef, params: &TensorHandleRef<'_, R>, @@ -238,7 +238,7 @@ pub fn launch_ref( } fn dequantize_packed( - client: &ComputeClient, + client: &ComputeClient, input: &TensorHandleRef, scheme: QuantScheme, scale: &TensorHandleRef<'_, R>, @@ -247,7 +247,7 @@ fn dequantize_packed( let num_elems_input: usize = input.shape.iter().product(); let mut line_size_in = tensor_line_size_parallel( - R::io_optimized_line_sizes_unchecked(&F::as_type_native_unchecked()), + R::io_optimized_line_sizes_unchecked(size_of::()), input.shape, input.strides, input.shape.len() - 1, @@ -288,7 +288,7 @@ fn dequantize_packed( } fn dequantize_native( - client: &ComputeClient, + client: &ComputeClient, input: &TensorHandleRef, scheme: QuantScheme, scale: &TensorHandleRef<'_, R>, @@ -296,7 +296,7 @@ fn dequantize_native( ) { let num_elems: usize = input.shape.iter().product(); let line_size = tensor_line_size_parallel( - R::io_optimized_line_sizes_unchecked(&F::as_type_native_unchecked()), + R::io_optimized_line_sizes_unchecked(size_of::()), input.shape, input.strides, input.shape.len() - 1, diff --git a/crates/cubecl-quant/src/layout/scales.rs b/crates/cubecl-quant/src/layout/scales.rs index e16980973..2f634f390 100644 --- a/crates/cubecl-quant/src/layout/scales.rs +++ b/crates/cubecl-quant/src/layout/scales.rs @@ -1,5 +1,5 @@ use cubecl::prelude::*; -use cubecl_core::{self as cubecl, intrinsic}; +use cubecl_core::{self as cubecl}; use cubecl_std::{ FastDivmod, FastDivmodArgs, tensor::{ @@ -149,7 +149,6 @@ impl Layout for BlockScaledLayout { #[unroll] for i in 0..rank { - let i = unwrap(i); let dim = comptime![rank - i - 1]; let block_size_local = comptime![self.block_size[dim as usize] as u32]; let (rem, offs_local) = self.tensor_shape.index(dim).div_mod(offs); @@ -185,7 +184,6 @@ impl BlockScaledLayout { #[unroll] for i in 0..rank { - let i = unwrap(i); let dim = comptime![rank - i - 1]; let block_size_local = comptime![self.block_size[dim as usize] as u32]; let (rem, offs_local) = self.tensor_shape.index(dim).div_mod(offs); @@ -197,12 +195,6 @@ impl BlockScaledLayout { } } -#[allow(unused_variables)] -#[cube] -fn unwrap(v: u32) -> comptime_type!(u32) { - intrinsic!(|_| v.constant().expect("Must be constant").as_u32()) -} - /// [TensorView] with a linear layout inferred from the shape/strides at launch. /// Useful for elementwise kernels. pub type ScalesView = TypedView; @@ -212,7 +204,7 @@ pub type ScalesViewLaunch<'a, R> = TypedViewLaunch<'a, ScalesLayout, R>; /// Create a scales view from the values and scales handle, line size and quantization scheme. /// `values` should be *the quantized tensor*, and will be adjusted by `num_quants`. pub fn scales_view<'a, R: Runtime>( - client: &ComputeClient, + client: &ComputeClient, values: &'a TensorHandleRef<'a, R>, scales: &'a TensorHandleRef<'a, R>, scales_line_size: u8, @@ -227,7 +219,7 @@ pub fn scales_view<'a, R: Runtime>( } pub fn scales_layout<'a, R: Runtime>( - client: &ComputeClient, + client: &ComputeClient, values: &'a TensorHandleRef<'a, R>, scales: &'a TensorHandleRef<'a, R>, scales_line_size: u8, @@ -253,7 +245,7 @@ pub fn scales_layout<'a, R: Runtime>( } fn shape_divmod_quant<'a, R: Runtime>( - client: &ComputeClient, + client: &ComputeClient, shape: &'a [usize], num_quants: usize, ) -> SequenceArg<'a, R, FastDivmod> { diff --git a/crates/cubecl-quant/src/lib.rs b/crates/cubecl-quant/src/lib.rs index 07b875224..b24289780 100644 --- a/crates/cubecl-quant/src/lib.rs +++ b/crates/cubecl-quant/src/lib.rs @@ -11,7 +11,7 @@ pub mod quantize; #[cfg(feature = "kernels")] pub mod layout; -pub mod scheme; +pub use cubecl_common::quant::scheme; #[cfg(feature = "export_tests")] pub mod tests; diff --git a/crates/cubecl-quant/src/quantize.rs b/crates/cubecl-quant/src/quantize.rs index 814101e17..62cff0f2a 100644 --- a/crates/cubecl-quant/src/quantize.rs +++ b/crates/cubecl-quant/src/quantize.rs @@ -163,7 +163,7 @@ fn quantize_symmetric_packed_kernel( #[allow(clippy::result_large_err)] pub fn launch_ref( - client: &ComputeClient, + client: &ComputeClient, input: &TensorHandleRef, output: &TensorHandleRef, scale: &TensorHandleRef<'_, R>, @@ -237,7 +237,7 @@ pub fn launch_ref( } fn quantize_native( - client: &ComputeClient, + client: &ComputeClient, input: &TensorHandleRef, scheme: &QuantScheme, scale: &TensorHandleRef<'_, R>, @@ -246,7 +246,7 @@ fn quantize_native( ) { let num_elems: usize = input.shape.iter().product(); let line_size = tensor_line_size_parallel( - R::io_optimized_line_sizes_unchecked(&F::as_type_native_unchecked()), + R::io_optimized_line_sizes_unchecked(size_of::()), input.shape, input.strides, input.shape.len() - 1, @@ -303,7 +303,7 @@ fn quantize_native( } fn quantize_packed( - client: &ComputeClient, + client: &ComputeClient, input: &TensorHandleRef, scheme: &QuantScheme, scale: &TensorHandleRef<'_, R>, diff --git a/crates/cubecl-random/Cargo.toml b/crates/cubecl-random/Cargo.toml index 9515a4635..4e1a39da5 100644 --- a/crates/cubecl-random/Cargo.toml +++ b/crates/cubecl-random/Cargo.toml @@ -18,10 +18,10 @@ export_tests = ["pretty_assertions"] std = ["cubecl-runtime/std", "cubecl-core/std"] [dependencies] -cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false } -cubecl-std = { path = "../cubecl-std", version = "0.7.0", default-features = false } -cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false } +cubecl-core = { path = "../cubecl-core", version = "0.9.0", default-features = false } +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false } +cubecl-std = { path = "../cubecl-std", version = "0.9.0", default-features = false } +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false } num-traits = "0.2.19" pretty_assertions = { workspace = true, optional = true } rand = { workspace = true } diff --git a/crates/cubecl-random/src/base.rs b/crates/cubecl-random/src/base.rs index 99053d0de..cb7a0cb1e 100644 --- a/crates/cubecl-random/src/base.rs +++ b/crates/cubecl-random/src/base.rs @@ -23,7 +23,7 @@ pub fn seed(seed: u64) { /// Pseudo-random generator pub(crate) fn random( - client: &ComputeClient, + client: &ComputeClient, prng: F::Runtime, output: TensorHandleRef<'_, R>, ) { diff --git a/crates/cubecl-random/src/bernoulli.rs b/crates/cubecl-random/src/bernoulli.rs index 9770d658b..be2af9634 100644 --- a/crates/cubecl-random/src/bernoulli.rs +++ b/crates/cubecl-random/src/bernoulli.rs @@ -77,7 +77,7 @@ impl PrngArgs for Bernoulli { /// Pseudo-random generator with bernoulli distribution pub fn random_bernoulli( - client: &ComputeClient, + client: &ComputeClient, probability: f32, out: TensorHandleRef, ) { diff --git a/crates/cubecl-random/src/normal.rs b/crates/cubecl-random/src/normal.rs index 6874ff057..cda30edfc 100644 --- a/crates/cubecl-random/src/normal.rs +++ b/crates/cubecl-random/src/normal.rs @@ -96,7 +96,7 @@ impl PrngArgs for Normal { /// Pseudo-random generator with uniform distribution pub fn random_normal( - client: &ComputeClient, + client: &ComputeClient, mean: E, std: E, out: TensorHandleRef, diff --git a/crates/cubecl-random/src/uniform.rs b/crates/cubecl-random/src/uniform.rs index 710c6d1ac..b154491ba 100644 --- a/crates/cubecl-random/src/uniform.rs +++ b/crates/cubecl-random/src/uniform.rs @@ -83,7 +83,7 @@ impl PrngArgs for Uniform { /// Pseudo-random generator with uniform distribution pub fn random_uniform( - client: &ComputeClient, + client: &ComputeClient, lower_bound: E, upper_bound: E, out: TensorHandleRef, diff --git a/crates/cubecl-reduce/Cargo.toml b/crates/cubecl-reduce/Cargo.toml index 2dc6ae7a1..6f31d678c 100644 --- a/crates/cubecl-reduce/Cargo.toml +++ b/crates/cubecl-reduce/Cargo.toml @@ -20,9 +20,9 @@ export_tests = ["pretty_assertions", "rand"] std = ["cubecl-runtime/std", "cubecl-core/std"] [dependencies] -cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false } -cubecl-std = { path = "../cubecl-std", version = "0.7.0", default-features = false } +cubecl-core = { path = "../cubecl-core", version = "0.9.0", default-features = false } +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false } +cubecl-std = { path = "../cubecl-std", version = "0.9.0", default-features = false } num-traits = "0.2.19" pretty_assertions = { workspace = true, optional = true } rand = { workspace = true, optional = true } diff --git a/crates/cubecl-reduce/src/config.rs b/crates/cubecl-reduce/src/config.rs index 8261dbf00..7d132eee9 100644 --- a/crates/cubecl-reduce/src/config.rs +++ b/crates/cubecl-reduce/src/config.rs @@ -1,6 +1,5 @@ use cubecl_core::{ - channel::ComputeChannel, prelude::*, server::ComputeServer, tensor_line_size_parallel, - tensor_line_size_perpendicular, + prelude::*, server::ComputeServer, tensor_line_size_parallel, tensor_line_size_perpendicular, }; use cubecl_std::tensor::is_contiguous; @@ -43,7 +42,7 @@ pub struct ReduceConfig { impl ReduceConfig { pub(crate) fn generate( - client: &ComputeClient, + client: &ComputeClient, input: &TensorHandleRef, output: &TensorHandleRef, axis: usize, @@ -86,8 +85,7 @@ impl ReduceConfig { output: &TensorHandleRef, axis: usize, ) -> Self { - let elem = In::as_type_native_unchecked(); - let supported_line_sizes = R::io_optimized_line_sizes_unchecked(&elem); + let supported_line_sizes = R::io_optimized_line_sizes_unchecked(size_of::()); self.line_size_input = match self.line_mode { LineMode::Parallel => { tensor_line_size_parallel(supported_line_sizes, input.shape, input.strides, axis) @@ -167,18 +165,26 @@ impl ReduceConfig { self } - pub fn generate_cube_dim>( + pub fn generate_cube_dim( mut self, - client: &ComputeClient, + client: &ComputeClient, use_planes: bool, ) -> Self { - self.cube_dim = if use_planes { - let plane_dim = client.properties().hardware.plane_size_min; - CubeDim::new_2d(plane_dim, DEFAULT_PLANE_COUNT) + let hw_properties = &client.properties().hardware; + + let plane_dim = if use_planes { + hw_properties.plane_size_min } else { - let plane_dim = client.properties().hardware.plane_size_max; - CubeDim::new_2d(plane_dim, DEFAULT_PLANE_COUNT) + hw_properties.plane_size_max }; + + let plane_count = if plane_dim * DEFAULT_PLANE_COUNT > hw_properties.max_units_per_cube { + hw_properties.max_units_per_cube / plane_dim + } else { + DEFAULT_PLANE_COUNT + }; + + self.cube_dim = CubeDim::new_2d(plane_dim, plane_count); self } diff --git a/crates/cubecl-reduce/src/launch.rs b/crates/cubecl-reduce/src/launch.rs index fcd43b6d1..5b0e2b6a2 100644 --- a/crates/cubecl-reduce/src/launch.rs +++ b/crates/cubecl-reduce/src/launch.rs @@ -15,7 +15,7 @@ use crate::{LineMode, ReduceConfig, ReduceStrategy}; /// See the main entrypoint `reduce` in `lib.rs` for an example how to call this function /// with the appropriate assumptions. pub(crate) fn launch_reduce( - client: &ComputeClient, + client: &ComputeClient, input: TensorHandleRef, output: TensorHandleRef, axis: u32, diff --git a/crates/cubecl-reduce/src/lib.rs b/crates/cubecl-reduce/src/lib.rs index 647563b90..eb3311a2c 100644 --- a/crates/cubecl-reduce/src/lib.rs +++ b/crates/cubecl-reduce/src/lib.rs @@ -38,6 +38,9 @@ pub use launch::{ReduceParams, reduce_kernel, reduce_kernel_virtual}; #[cfg(feature = "export_tests")] pub mod test; +#[cfg(feature = "export_tests")] +pub mod test_shuffle; + use cubecl_core::prelude::*; /// Reduce the given `axis` of the `input` tensor using the instruction `Inst` and write the result into `output`. @@ -94,7 +97,7 @@ use cubecl_core::prelude::*; /// } /// ``` pub fn reduce( - client: &ComputeClient, + client: &ComputeClient, input: TensorHandleRef, output: TensorHandleRef, axis: usize, diff --git a/crates/cubecl-reduce/src/primitives.rs b/crates/cubecl-reduce/src/primitives.rs index 58bf8918a..66c559136 100644 --- a/crates/cubecl-reduce/src/primitives.rs +++ b/crates/cubecl-reduce/src/primitives.rs @@ -398,3 +398,68 @@ pub fn reduce_tree>( sync_cube(); Inst::SharedAccumulator::read(accumulator, 0) } + +/// Warp-level sum reduction using shuffle operations. +/// +/// All lanes get the sum of all 32 values in the warp using butterfly reduction: +/// ```ignored +/// Step 1 (offset=16): Lane 0 ← Lane 0 + Lane 16, Lane 1 ← Lane 1 + Lane 17, ... +/// Step 2 (offset=8): Lane 0 ← Lane 0 + Lane 8, Lane 1 ← Lane 1 + Lane 9, ... +/// Step 3 (offset=4): Lane 0 ← Lane 0 + Lane 4, Lane 1 ← Lane 1 + Lane 5, ... +/// Step 4 (offset=2): Lane 0 ← Lane 0 + Lane 2, Lane 1 ← Lane 1 + Lane 3, ... +/// Step 5 (offset=1): Lane 0 ← Lane 0 + Lane 1, ... +/// ``` +/// +/// # Performance +/// - ~5 cycles per shuffle × 5 steps = ~25 cycles total +/// - Compare to shared memory: ~110 (write) + ~110 (read) × log2(32) = ~1100+ cycles +/// +/// # Example +/// ```ignored +/// #[cube] +/// fn warp_sum_example(value: f32) -> f32 { +/// reduce_sum_shuffle(value) // All lanes get the sum +/// } +/// ``` +#[cube] +pub fn reduce_sum_shuffle(value: F) -> F { + // Manually unrolled butterfly reduction + let v1 = value + plane_shuffle_xor(value, 16); + let v2 = v1 + plane_shuffle_xor(v1, 8); + let v3 = v2 + plane_shuffle_xor(v2, 4); + let v4 = v3 + plane_shuffle_xor(v3, 2); + v4 + plane_shuffle_xor(v4, 1) +} + +/// Warp-level max reduction using shuffle operations. +/// All lanes get the maximum of all 32 values in the warp. +#[cube] +pub fn reduce_max_shuffle(value: F) -> F { + let v1 = F::max(value, plane_shuffle_xor(value, 16)); + let v2 = F::max(v1, plane_shuffle_xor(v1, 8)); + let v3 = F::max(v2, plane_shuffle_xor(v2, 4)); + let v4 = F::max(v3, plane_shuffle_xor(v3, 2)); + F::max(v4, plane_shuffle_xor(v4, 1)) +} + +/// Warp-level min reduction using shuffle operations. +/// All lanes get the minimum of all 32 values in the warp. +#[cube] +pub fn reduce_min_shuffle(value: F) -> F { + let v1 = F::min(value, plane_shuffle_xor(value, 16)); + let v2 = F::min(v1, plane_shuffle_xor(v1, 8)); + let v3 = F::min(v2, plane_shuffle_xor(v2, 4)); + let v4 = F::min(v3, plane_shuffle_xor(v3, 2)); + F::min(v4, plane_shuffle_xor(v4, 1)) +} + +/// Warp-level product reduction using shuffle operations. +/// All lanes get the product of all 32 values in the warp. +#[cube] +pub fn reduce_prod_shuffle(value: F) -> F { + let v1 = value * plane_shuffle_xor(value, 16); + let v2 = v1 * plane_shuffle_xor(v1, 8); + let v3 = v2 * plane_shuffle_xor(v2, 4); + let v4 = v3 * plane_shuffle_xor(v3, 2); + v4 * plane_shuffle_xor(v4, 1) +} diff --git a/crates/cubecl-reduce/src/shared_sum.rs b/crates/cubecl-reduce/src/shared_sum.rs index 8a5935090..b12e59e8e 100644 --- a/crates/cubecl-reduce/src/shared_sum.rs +++ b/crates/cubecl-reduce/src/shared_sum.rs @@ -53,7 +53,7 @@ use crate::ReduceError; /// } /// ``` pub fn shared_sum( - client: &ComputeClient, + client: &ComputeClient, input: TensorHandleRef, output: TensorHandleRef, cube_count: u32, @@ -71,8 +71,7 @@ pub fn shared_sum( let input_len = input.shape.iter().map(|s| *s as u32).product::(); // Compute the optimal line size. - let elem = N::as_type_native_unchecked(); - let line_size = R::io_optimized_line_sizes_unchecked(&elem) + let line_size = R::io_optimized_line_sizes_unchecked(size_of::()) .filter(|line_size| input_len % *line_size as u32 == 0) .max() .unwrap_or(1) as u32; diff --git a/crates/cubecl-reduce/src/strategy.rs b/crates/cubecl-reduce/src/strategy.rs index 98944ee27..8ef3122b0 100644 --- a/crates/cubecl-reduce/src/strategy.rs +++ b/crates/cubecl-reduce/src/strategy.rs @@ -19,7 +19,7 @@ pub struct ReduceStrategy { impl ReduceStrategy { pub fn validate( self, - client: &ComputeClient, + client: &ComputeClient, ) -> Result { if self.use_planes { if !support_plane::(client) { @@ -33,7 +33,7 @@ impl ReduceStrategy { Ok(self) } - pub fn new(client: &ComputeClient, shared: bool) -> Self { + pub fn new(client: &ComputeClient, shared: bool) -> Self { Self { use_planes: support_plane::(client) && precise_plane_dim::(client), shared, @@ -41,11 +41,11 @@ impl ReduceStrategy { } } -fn support_plane(client: &ComputeClient) -> bool { +fn support_plane(client: &ComputeClient) -> bool { client.properties().features.plane.contains(Plane::Ops) } -fn precise_plane_dim(client: &ComputeClient) -> bool { +fn precise_plane_dim(client: &ComputeClient) -> bool { let hw_props = &client.properties().hardware; hw_props.plane_size_min == hw_props.plane_size_max } diff --git a/crates/cubecl-reduce/src/test_shuffle.rs b/crates/cubecl-reduce/src/test_shuffle.rs new file mode 100644 index 000000000..6efe004b2 --- /dev/null +++ b/crates/cubecl-reduce/src/test_shuffle.rs @@ -0,0 +1,243 @@ +#![allow(missing_docs)] + +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::primitives::{ + reduce_max_shuffle, reduce_min_shuffle, reduce_prod_shuffle, reduce_sum_shuffle, +}; + +/// Test kernel: Each warp sums its lane IDs using shuffle reduction +/// Expected: All 32 lanes in each warp should get 496 (sum of 0..31) +#[cube(launch)] +fn kernel_warp_sum_lanes(output: &mut Tensor) { + let lane_id = UNIT_POS_PLANE; + let my_value: F = F::cast_from(lane_id); + + // Butterfly reduction - all lanes get the sum + let sum: F = reduce_sum_shuffle::(my_value); + + output[ABSOLUTE_POS] = sum; +} + +/// Test kernel: Find max lane ID in each warp (should be 31) +#[cube(launch)] +fn kernel_warp_max_lanes(output: &mut Tensor) { + let lane_id = UNIT_POS_PLANE; + let my_value: F = F::cast_from(lane_id); + + let max_val: F = reduce_max_shuffle::(my_value); + + output[ABSOLUTE_POS] = max_val; +} + +/// Test kernel: Find min lane ID in each warp (should be 0) +#[cube(launch)] +fn kernel_warp_min_lanes(output: &mut Tensor) { + let lane_id = UNIT_POS_PLANE; + let my_value: F = F::cast_from(lane_id); + + let min_val: F = reduce_min_shuffle::(my_value); + + output[ABSOLUTE_POS] = min_val; +} + +/// Test kernel: Product of small values to avoid overflow +/// Each lane contributes (1.0 + lane_id / 100.0) +#[cube(launch)] +fn kernel_warp_prod(output: &mut Tensor) { + let lane_id = UNIT_POS_PLANE; + let my_value: F = F::new(1.0) + F::cast_from(lane_id) / F::new(100.0); + + let prod: F = reduce_prod_shuffle::(my_value); + + output[ABSOLUTE_POS] = prod; +} + +/// Reduce a 32x32 matrix where each warp reduces its row +#[cube(launch)] +fn kernel_matrix_row_reduce(input: &Tensor, output: &mut Tensor) { + let row = CUBE_POS_Y; + let col = UNIT_POS_PLANE; + + let value: F = input[row * 32 + col]; + let row_sum: F = reduce_sum_shuffle::(value); + + // Only lane 0 writes the result + if col == 0 { + output[row] = row_sum; + } +} + +/// Test warp sum reduction +pub fn test_warp_sum(device: &R::Device) { + if !supports_plane_ops::(device) { + return; // Skip if no plane support + } + + let client = R::client(device); + let output_handle = client.create(f32::as_bytes(&vec![0.0f32; 64])); // 2 warps + + unsafe { + kernel_warp_sum_lanes::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(64, 1, 1), // 2 warps of 32 threads + TensorArg::from_raw_parts::(&output_handle, &[1], &[64], 1), + ); + } + + let bytes = client.read_one(output_handle); + let output = f32::from_bytes(&bytes); + + // Sum of 0..31 = 496 + let expected_sum = 496.0f32; + + for (i, &value) in output.iter().enumerate() { + assert!( + (value - expected_sum).abs() < 1e-3, + "Warp sum failed at position {i}: got {value}, expected {expected_sum}" + ); + } +} + +/// Test warp max reduction +pub fn test_warp_max(device: &R::Device) { + if !supports_plane_ops::(device) { + return; + } + + let client = R::client(device); + let output_handle = client.create(f32::as_bytes(&vec![0.0f32; 64])); + + unsafe { + kernel_warp_max_lanes::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(64, 1, 1), + TensorArg::from_raw_parts::(&output_handle, &[1], &[64], 1), + ); + } + + let bytes = client.read_one(output_handle); + let output = f32::from_bytes(&bytes); + + // Max lane ID is 31 + for (i, &value) in output.iter().enumerate() { + assert!( + (value - 31.0).abs() < 1e-3, + "Warp max failed at position {i}: got {value}, expected 31" + ); + } +} + +/// Test warp min reduction +pub fn test_warp_min(device: &R::Device) { + if !supports_plane_ops::(device) { + return; + } + + let client = R::client(device); + let output_handle = client.create(f32::as_bytes(&vec![999.0f32; 64])); + + unsafe { + kernel_warp_min_lanes::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(64, 1, 1), + TensorArg::from_raw_parts::(&output_handle, &[1], &[64], 1), + ); + } + + let bytes = client.read_one(output_handle); + let output = f32::from_bytes(&bytes); + + // Min lane ID is 0 + for (i, &value) in output.iter().enumerate() { + assert!( + value.abs() < 1e-3, + "Warp min failed at position {i}: got {value}, expected 0" + ); + } +} + +/// Test warp product reduction +pub fn test_warp_prod(device: &R::Device) { + if !supports_plane_ops::(device) { + return; + } + + let client = R::client(device); + let output_handle = client.create(f32::as_bytes(&[0.0f32; 32])); + + unsafe { + kernel_warp_prod::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(32, 1, 1), + TensorArg::from_raw_parts::(&output_handle, &[1], &[32], 1), + ); + } + + let bytes = client.read_one(output_handle); + let output = f32::from_bytes(&bytes); + + // Calculate expected product: Π(1 + i/100) for i=0..31 + let mut expected = 1.0f32; + for i in 0..32 { + expected *= 1.0 + (i as f32) / 100.0; + } + + for (i, &value) in output.iter().enumerate() { + let rel_error = ((value - expected) / expected).abs(); + assert!( + rel_error < 0.01, // 1% tolerance + "Warp prod failed at position {i}: got {value}, expected {expected}, rel_error={rel_error}" + ); + } +} + +/// Reduce 32 rows of 32 elements each using warp shuffles +pub fn test_matrix_row_reduce(device: &R::Device) { + if !supports_plane_ops::(device) { + return; + } + + let client = R::client(device); + + // Create a 32x32 matrix where matrix[i][j] = i * 32 + j + let input_data: Vec = (0..1024).map(|x| x as f32).collect(); + let input_handle = client.create(f32::as_bytes(&input_data)); + let output_handle = client.create(f32::as_bytes(&[0.0f32; 32])); + + unsafe { + kernel_matrix_row_reduce::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(32, 32, 1), // 32x32 = 1024 threads, 32 warps + TensorArg::from_raw_parts::(&input_handle, &[1], &[1024], 1), + TensorArg::from_raw_parts::(&output_handle, &[1], &[32], 1), + ); + } + + let bytes = client.read_one(output_handle); + let output = f32::from_bytes(&bytes); + + // Row i should sum to: Σ(i*32 + j) for j=0..31 = i*32*32 + 496 + for (row, &value) in output.iter().enumerate() { + let expected = (row as f32) * 32.0 * 32.0 + 496.0; + assert!( + (value - expected).abs() < 1e-2, + "Matrix row reduce failed at row {row}: got {value}, expected {expected}" + ); + } +} + +fn supports_plane_ops(device: &R::Device) -> bool { + let client = R::client(device); + client + .properties() + .features + .plane + .contains(cubecl_runtime::Plane::Ops) +} diff --git a/crates/cubecl-runtime/Cargo.toml b/crates/cubecl-runtime/Cargo.toml index 5f97a843f..bf9b94e1e 100644 --- a/crates/cubecl-runtime/Cargo.toml +++ b/crates/cubecl-runtime/Cargo.toml @@ -32,8 +32,8 @@ storage-bytes = [] async-channel = { workspace = true } # Assume std bytemuck = { workspace = true } cfg-if = { workspace = true } -cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false } -cubecl-ir = { path = "../cubecl-ir", version = "0.7.0", default-features = false } +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false } +cubecl-ir = { path = "../cubecl-ir", version = "0.9.0", default-features = false } derive-new = { workspace = true } dirs = { workspace = true, optional = true } enumset = { workspace = true } @@ -47,7 +47,7 @@ variadics_please = { workspace = true } # Persistent cache deps - has to match the cfg(std_io) cfg. [target.'cfg(any(target_os = "windows", target_os = "linux", target_os = "macos"))'.dependencies] -cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false, features = [ +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false, features = [ "cache", "serde", ] } diff --git a/crates/cubecl-runtime/benches/dynamic.rs b/crates/cubecl-runtime/benches/dynamic.rs index 719bc221c..ca44aa697 100644 --- a/crates/cubecl-runtime/benches/dynamic.rs +++ b/crates/cubecl-runtime/benches/dynamic.rs @@ -1,7 +1,10 @@ -use std::collections::LinkedList; +use std::{collections::LinkedList, sync::Arc}; use cubecl_runtime::{ - memory_management::{MemoryConfiguration, MemoryDeviceProperties, MemoryManagement}, + logging::ServerLogger, + memory_management::{ + MemoryConfiguration, MemoryDeviceProperties, MemoryManagement, MemoryManagementOptions, + }, storage::BytesStorage, }; @@ -15,7 +18,14 @@ fn main() { max_page_size: 2048 * MB, alignment: 32, }; - let mut mm = MemoryManagement::from_configuration(storage, &mem_props, config); + let logger = Arc::new(ServerLogger::default()); + let mut mm = MemoryManagement::from_configuration( + storage, + &mem_props, + config, + logger, + MemoryManagementOptions::new("test"), + ); let mut handles = LinkedList::new(); for _ in 0..100 * 2048 { if handles.len() >= 4000 { diff --git a/crates/cubecl-runtime/src/base.rs b/crates/cubecl-runtime/src/base.rs deleted file mode 100644 index cc30ae030..000000000 --- a/crates/cubecl-runtime/src/base.rs +++ /dev/null @@ -1,94 +0,0 @@ -use crate::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer}; -use core::ops::DerefMut; -use hashbrown::HashMap; - -/// The compute type has the responsibility to retrieve the correct compute client based on the -/// given device. -pub struct ComputeRuntime { - clients: spin::Mutex>>>, -} - -impl Default for ComputeRuntime -where - Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, - Server: ComputeServer, - Channel: ComputeChannel, -{ - fn default() -> Self { - Self::new() - } -} - -impl ComputeRuntime -where - Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, - Server: ComputeServer, - Channel: ComputeChannel, -{ - /// Create a new compute. - pub const fn new() -> Self { - Self { - clients: spin::Mutex::new(None), - } - } - - /// Get the compute client for the given device. - /// - /// Provide the init function to create a new client if it isn't already initialized. - pub fn client(&self, device: &Device, init: Init) -> ComputeClient - where - Init: Fn() -> ComputeClient, - { - let mut clients = self.clients.lock(); - - if clients.is_none() { - Self::register_inner(device, init(), &mut clients); - } - - match clients.deref_mut() { - Some(clients) => match clients.get(device) { - Some(client) => client.clone(), - None => { - let client = init(); - clients.insert(device.clone(), client.clone()); - client - } - }, - _ => unreachable!(), - } - } - - /// Register the compute client for the given device. - /// - /// # Note - /// - /// This function is mostly useful when the creation of the compute client can't be done - /// synchronously and require special context. - /// - /// # Panics - /// - /// If a client is already registered for the given device. - pub fn register(&self, device: &Device, client: ComputeClient) { - let mut clients = self.clients.lock(); - - Self::register_inner(device, client, &mut clients); - } - - fn register_inner( - device: &Device, - client: ComputeClient, - clients: &mut Option>>, - ) { - if clients.is_none() { - *clients = Some(HashMap::new()); - } - - if let Some(clients) = clients { - if clients.contains_key(device) { - panic!("Client already created for device {device:?}"); - } - - clients.insert(device.clone(), client); - } - } -} diff --git a/crates/cubecl-runtime/src/channel/base.rs b/crates/cubecl-runtime/src/channel/base.rs deleted file mode 100644 index 2dfd758bc..000000000 --- a/crates/cubecl-runtime/src/channel/base.rs +++ /dev/null @@ -1,107 +0,0 @@ -use crate::{ - logging::ServerLogger, - memory_management::MemoryAllocationMode, - server::{ - Allocation, AllocationDescriptor, Binding, Bindings, ComputeServer, CopyDescriptor, - CubeCount, IoError, ProfileError, ProfilingToken, - }, - storage::{BindingResource, ComputeStorage}, -}; -use alloc::sync::Arc; -use alloc::vec::Vec; -use cubecl_common::{ - ExecutionMode, bytes::Bytes, future::DynFut, profile::ProfileDuration, stream_id::StreamId, -}; - -/// The ComputeChannel trait links the ComputeClient to the ComputeServer -/// while ensuring thread-safety -pub trait ComputeChannel: Clone + core::fmt::Debug + Send + Sync { - /// Whether the channel supports changing an allocation from one server to another. - const SERVER_COMM_SUPPORTED: bool; - - /// Retrieve the server logger. - fn logger(&self) -> Arc; - - /// Create a new handle given a set of descriptors - fn create( - &self, - descriptors: Vec>, - stream_id: StreamId, - ) -> Result, IoError>; - - /// Given bindings, returns owned resources as bytes - fn read( - &self, - descriptors: Vec>, - stream_id: StreamId, - ) -> DynFut, IoError>>; - - /// Write bytes to each binding - fn write( - &self, - descriptors: Vec<(CopyDescriptor<'_>, &[u8])>, - stream_id: StreamId, - ) -> Result<(), IoError>; - - /// Moves data from the source server to the destination server on the provided stream IDs. - fn copy( - server_src: &Self, - server_dst: &Self, - src: CopyDescriptor<'_>, - stream_id_src: StreamId, - stream_id_dst: StreamId, - ) -> Result; - - /// Wait for the completion of every task in the server. - fn sync(&self, stream_id: StreamId) -> DynFut<()>; - - /// Given a resource handle, return the storage resource. - fn get_resource( - &self, - binding: Binding, - stream_id: StreamId, - ) -> BindingResource<::Resource>; - - /// Executes the `kernel` over the given `bindings`. - /// - /// Optionally returns some debug information about the compilation to be logged. - /// # Safety - /// - /// When executing with mode [ExecutionMode::Unchecked], out-of-bound reads and writes can happen. - unsafe fn execute( - &self, - kernel: Server::Kernel, - count: CubeCount, - bindings: Bindings, - mode: ExecutionMode, - stream_id: StreamId, - ); - - /// Flush outstanding work of the server. - fn flush(&self, stream_id: StreamId); - - /// Get the current memory usage of the server. - fn memory_usage(&self, stream_id: StreamId) -> crate::memory_management::MemoryUsage; - - /// Change the memory allocation mode. - fn allocation_mode(&self, mode: MemoryAllocationMode, stream_id: StreamId); - - /// Ask the server to release memory that it can release. - fn memory_cleanup(&self, stream_id: StreamId); - - /// Start a profile on the server. This allows you to profile kernels. - /// - /// This will measure execution time either by measuring the 'full' execution time by synchronizing - /// the execution at the start and the end of the profile, or 'device' time by using device timestamps. - /// This function will handle any required synchronization. - fn start_profile(&self, stream_id: StreamId) -> ProfilingToken; - - /// End the profile and return a [`ProfileDuration`]. - /// - /// You can retrieve the Duration of the client profile asynchronously. This function will handle any required synchronization. - fn end_profile( - &self, - stream_id: StreamId, - token: ProfilingToken, - ) -> Result; -} diff --git a/crates/cubecl-runtime/src/channel/cell.rs b/crates/cubecl-runtime/src/channel/cell.rs deleted file mode 100644 index 68e0a0563..000000000 --- a/crates/cubecl-runtime/src/channel/cell.rs +++ /dev/null @@ -1,172 +0,0 @@ -use super::ComputeChannel; -use crate::server::{ - Binding, Bindings, ComputeServer, CopyDescriptor, CubeCount, ProfileError, ProfilingToken, -}; -use crate::storage::{BindingResource, ComputeStorage}; -use crate::{ - logging::ServerLogger, - server::{Allocation, AllocationDescriptor, IoError}, -}; -use alloc::sync::Arc; -use alloc::vec::Vec; -use cubecl_common::ExecutionMode; -use cubecl_common::bytes::Bytes; -use cubecl_common::future::DynFut; -use cubecl_common::profile::ProfileDuration; -use cubecl_common::stream_id::StreamId; - -/// A channel using a [ref cell](core::cell::RefCell) to access the server with mutability. -/// -/// # Important -/// -/// Only use this channel if you don't use any threading in your application, otherwise it will -/// panic or cause undefined behaviors. -/// -/// This is mosly useful for `no-std` environments where threads aren't supported, otherwise prefer -/// the [mutex](super::MutexComputeChannel) or the [mpsc](super::MpscComputeChannel) channels. -#[derive(Debug)] -pub struct RefCellComputeChannel { - server: Arc>, -} - -impl Clone for RefCellComputeChannel { - fn clone(&self) -> Self { - Self { - server: self.server.clone(), - } - } -} - -impl RefCellComputeChannel -where - Server: ComputeServer, -{ - /// Create a new cell compute channel. - pub fn new(server: Server) -> Self { - Self { - server: Arc::new(core::cell::RefCell::new(server)), - } - } -} - -impl ComputeChannel for RefCellComputeChannel -where - Server: ComputeServer + Send, -{ - const SERVER_COMM_SUPPORTED: bool = true; - - fn logger(&self) -> Arc { - todo!(); - } - - fn create( - &self, - descriptors: Vec>, - stream_id: StreamId, - ) -> Result, IoError> { - let mut server = self.server.borrow_mut(); - server.create(descriptors, stream_id) - } - - fn read( - &self, - descriptors: Vec>, - stream_id: StreamId, - ) -> DynFut, IoError>> { - let mut server = self.server.borrow_mut(); - server.read(descriptors, stream_id) - } - - fn write( - &self, - descriptors: Vec<(CopyDescriptor<'_>, &[u8])>, - stream_id: StreamId, - ) -> Result<(), IoError> { - let mut server = self.server.borrow_mut(); - server.write(descriptors, stream_id) - } - - fn sync(&self, stream_id: StreamId) -> DynFut<()> { - let mut server = self.server.borrow_mut(); - server.sync(stream_id) - } - - fn get_resource( - &self, - binding: Binding, - stream_id: StreamId, - ) -> BindingResource<::Resource> { - self.server.borrow_mut().get_resource(binding, stream_id) - } - - unsafe fn execute( - &self, - kernel_description: Server::Kernel, - count: CubeCount, - bindings: Bindings, - kind: ExecutionMode, - stream_id: StreamId, - ) { - unsafe { - self.server - .borrow_mut() - .execute(kernel_description, count, bindings, kind, stream_id) - } - } - - fn flush(&self, stream_id: StreamId) { - self.server.borrow_mut().flush(stream_id) - } - - fn memory_usage(&self, stream_id: StreamId) -> crate::memory_management::MemoryUsage { - self.server.borrow_mut().memory_usage(stream_id) - } - - fn memory_cleanup(&self, stream_id: StreamId) { - self.server.borrow_mut().memory_cleanup(stream_id); - } - - fn start_profile(&self, stream_id: StreamId) -> ProfilingToken { - self.server.borrow_mut().start_profile(stream_id) - } - - fn end_profile( - &self, - stream_id: StreamId, - token: ProfilingToken, - ) -> Result { - self.server.borrow_mut().end_profile(stream_id, token) - } - - fn allocation_mode( - &self, - mode: crate::memory_management::MemoryAllocationMode, - stream_id: StreamId, - ) { - self.server.borrow_mut().allocation_mode(mode, stream_id) - } - - fn copy( - server_src: &Self, - server_dst: &Self, - src: CopyDescriptor<'_>, - stream_id_src: StreamId, - stream_id_dst: StreamId, - ) -> Result { - let mut server_src = server_src.server.borrow_mut(); - let mut server_dst = server_dst.server.borrow_mut(); - - Server::copy( - &mut server_src, - &mut server_dst, - src, - stream_id_src, - stream_id_dst, - ) - } -} - -/// This is unsafe, since no concurrency is supported by the `RefCell` channel. -/// However using this channel should only be done in single threaded environments such as `no-std`. -unsafe impl Send for RefCellComputeChannel {} -unsafe impl Sync for RefCellComputeChannel {} diff --git a/crates/cubecl-runtime/src/channel/mod.rs b/crates/cubecl-runtime/src/channel/mod.rs deleted file mode 100644 index 68a1c372a..000000000 --- a/crates/cubecl-runtime/src/channel/mod.rs +++ /dev/null @@ -1,17 +0,0 @@ -mod base; -pub use base::*; - -#[cfg(feature = "channel-mutex")] -mod mutex; -#[cfg(feature = "channel-mutex")] -pub use mutex::*; - -#[cfg(all(feature = "channel-mpsc", not(target_family = "wasm")))] -mod mpsc; -#[cfg(all(feature = "channel-mpsc", not(target_family = "wasm")))] -pub use mpsc::*; - -#[cfg(feature = "channel-cell")] -mod cell; -#[cfg(feature = "channel-cell")] -pub use cell::*; diff --git a/crates/cubecl-runtime/src/channel/mpsc.rs b/crates/cubecl-runtime/src/channel/mpsc.rs deleted file mode 100644 index f21b2404c..000000000 --- a/crates/cubecl-runtime/src/channel/mpsc.rs +++ /dev/null @@ -1,408 +0,0 @@ -use std::sync::Arc; - -use cubecl_common::{ - ExecutionMode, - bytes::Bytes, - future::{DynFut, spawn_detached_fut}, - profile::ProfileDuration, - stream_id::StreamId, -}; - -use super::ComputeChannel; -use crate::{ - logging::ServerLogger, - memory_management::{MemoryAllocationMode, MemoryUsage}, - server::{ - Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer, - CopyDescriptor, CubeCount, IoError, ProfileError, ProfilingToken, - }, - storage::{BindingResource, ComputeStorage}, -}; - -/// Create a channel using a [multi-producer, single-consumer channel to communicate with -/// the compute server spawn on its own thread. -#[derive(Debug)] -pub struct MpscComputeChannel -where - Server: ComputeServer, -{ - state: Arc>, -} - -#[derive(Debug)] -struct MpscComputeChannelState -where - Server: ComputeServer, -{ - sender: async_channel::Sender>, -} - -type Callback = async_channel::Sender; - -struct AllocationDescriptorOwned { - type_: AllocationKind, - shape: Vec, - elem_size: usize, -} - -impl From> for AllocationDescriptorOwned { - fn from(value: AllocationDescriptor) -> Self { - AllocationDescriptorOwned { - type_: value.kind, - shape: value.shape.to_vec(), - elem_size: value.elem_size, - } - } -} - -impl AllocationDescriptorOwned { - fn as_ref(&self) -> AllocationDescriptor<'_> { - AllocationDescriptor::new(self.type_, &self.shape, self.elem_size) - } -} - -struct CopyDescriptorOwned { - binding: Binding, - shape: Vec, - strides: Vec, - elem_size: usize, -} - -impl From> for CopyDescriptorOwned { - fn from(value: CopyDescriptor<'_>) -> Self { - CopyDescriptorOwned { - binding: value.binding, - shape: value.shape.to_vec(), - strides: value.strides.to_vec(), - elem_size: value.elem_size, - } - } -} - -impl CopyDescriptorOwned { - fn as_ref(&self) -> CopyDescriptor<'_> { - CopyDescriptor::new( - self.binding.clone(), - &self.shape, - &self.strides, - self.elem_size, - ) - } -} - -enum Message -where - Server: ComputeServer, -{ - Create( - Vec, - StreamId, - Callback, IoError>>, - ), - Read( - Vec, - StreamId, - Callback, IoError>>, - ), - Write( - Vec<(CopyDescriptorOwned, Vec)>, - StreamId, - Callback>, - ), - GetResource( - Binding, - StreamId, - Callback::Resource>>, - ), - Logger(Callback>), - ExecuteKernel( - (Server::Kernel, CubeCount, ExecutionMode, StreamId), - Bindings, - ), - Flush(StreamId), - Sync(StreamId, Callback<()>), - MemoryUsage(StreamId, Callback), - MemoryCleanup(StreamId), - AllocationMode(StreamId, MemoryAllocationMode), - StartProfile(StreamId, Callback), - StopMeasure( - Callback>, - StreamId, - ProfilingToken, - ), -} - -impl MpscComputeChannel -where - Server: ComputeServer + 'static, -{ - /// Create a new mpsc compute channel. - pub fn new(mut server: Server) -> Self { - let (sender, receiver) = async_channel::unbounded(); - - spawn_detached_fut(async move { - while let Ok(message) = receiver.recv().await { - match message { - Message::Create(descriptors, stream_id, callback) => { - let descriptors = descriptors.iter().map(|it| it.as_ref()).collect(); - let data = server.create(descriptors, stream_id); - callback.send(data).await.unwrap(); - } - Message::Read(descriptors, stream, callback) => { - let descriptors = descriptors.iter().map(|it| it.as_ref()).collect(); - let data = server.read(descriptors, stream).await; - callback.send(data).await.unwrap(); - } - Message::Logger(callback) => { - callback.send(server.logger()).await.unwrap(); - } - Message::Write(descriptors, stream, callback) => { - let descriptors = descriptors - .iter() - .map(|(desc, data)| (desc.as_ref(), data.as_slice())) - .collect(); - let data = server.write(descriptors, stream); - callback.send(data).await.unwrap(); - } - Message::GetResource(binding, stream, callback) => { - let data = server.get_resource(binding, stream); - callback.send(data).await.unwrap(); - } - Message::ExecuteKernel(kernel, bindings) => unsafe { - server.execute(kernel.0, kernel.1, bindings, kernel.2, kernel.3); - }, - Message::Sync(stream, callback) => { - server.sync(stream).await; - callback.send(()).await.unwrap(); - } - Message::Flush(stream) => { - server.flush(stream); - } - Message::MemoryUsage(stream, callback) => { - callback.send(server.memory_usage(stream)).await.unwrap(); - } - Message::MemoryCleanup(stream) => { - server.memory_cleanup(stream); - } - Message::StartProfile(stream_id, callback) => { - let token = server.start_profile(stream_id); - callback.send(token).await.unwrap(); - } - Message::StopMeasure(callback, stream_id, token) => { - callback - .send(server.end_profile(stream_id, token)) - .await - .unwrap(); - } - Message::AllocationMode(stream, mode) => { - server.allocation_mode(mode, stream); - } - }; - } - }); - - Self { - state: Arc::new(MpscComputeChannelState { sender }), - } - } -} - -impl Clone for MpscComputeChannel { - fn clone(&self) -> Self { - Self { - state: self.state.clone(), - } - } -} - -impl ComputeChannel for MpscComputeChannel -where - Server: ComputeServer + 'static, -{ - const SERVER_COMM_SUPPORTED: bool = false; - - fn logger(&self) -> Arc { - let (callback, response) = async_channel::unbounded(); - - self.state - .sender - .send_blocking(Message::Logger(callback)) - .unwrap(); - - handle_response(response.recv_blocking()) - } - - fn create( - &self, - descriptors: Vec>, - stream_id: StreamId, - ) -> Result, IoError> { - let descriptors = descriptors.into_iter().map(|it| it.into()).collect(); - - let (callback, response) = async_channel::unbounded(); - - self.state - .sender - .send_blocking(Message::Create(descriptors, stream_id, callback)) - .unwrap(); - - handle_response(response.recv_blocking()) - } - - fn read( - &self, - descriptors: Vec>, - stream_id: StreamId, - ) -> DynFut, IoError>> { - let sender = self.state.sender.clone(); - let descriptors = descriptors.into_iter().map(|it| it.into()).collect(); - - Box::pin(async move { - let (callback, response) = async_channel::unbounded(); - sender - .send(Message::Read(descriptors, stream_id, callback)) - .await - .unwrap(); - handle_response(response.recv().await) - }) - } - - fn write( - &self, - descriptors: Vec<(CopyDescriptor<'_>, &[u8])>, - stream_id: StreamId, - ) -> Result<(), IoError> { - let descriptors = descriptors - .into_iter() - .map(|(desc, data)| (desc.into(), data.to_vec())) - .collect(); - - let (callback, response) = async_channel::unbounded(); - - self.state - .sender - .send_blocking(Message::Write(descriptors, stream_id, callback)) - .unwrap(); - - handle_response(response.recv_blocking()) - } - - fn get_resource( - &self, - binding: Binding, - stream_id: StreamId, - ) -> BindingResource<::Resource> { - let (callback, response) = async_channel::unbounded(); - - self.state - .sender - .send_blocking(Message::GetResource(binding, stream_id, callback)) - .unwrap(); - - handle_response(response.recv_blocking()) - } - - unsafe fn execute( - &self, - kernel: Server::Kernel, - count: CubeCount, - bindings: Bindings, - kind: ExecutionMode, - stream_id: StreamId, - ) { - self.state - .sender - .send_blocking(Message::ExecuteKernel( - (kernel, count, kind, stream_id), - bindings, - )) - .unwrap(); - } - - fn flush(&self, stream_id: StreamId) { - self.state - .sender - .send_blocking(Message::Flush(stream_id)) - .unwrap() - } - - fn sync(&self, stream_id: StreamId) -> DynFut<()> { - let sender = self.state.sender.clone(); - - Box::pin(async move { - let (callback, response) = async_channel::unbounded(); - sender - .send(Message::Sync(stream_id, callback)) - .await - .unwrap(); - handle_response(response.recv().await) - }) - } - - fn memory_usage(&self, stream_id: StreamId) -> crate::memory_management::MemoryUsage { - let (callback, response) = async_channel::unbounded(); - self.state - .sender - .send_blocking(Message::MemoryUsage(stream_id, callback)) - .unwrap(); - handle_response(response.recv_blocking()) - } - - fn memory_cleanup(&self, stream_id: StreamId) { - self.state - .sender - .send_blocking(Message::MemoryCleanup(stream_id)) - .unwrap() - } - - fn start_profile(&self, stream_id: StreamId) -> ProfilingToken { - let (callback, response) = async_channel::unbounded(); - - self.state - .sender - .send_blocking(Message::StartProfile(stream_id, callback)) - .unwrap(); - - handle_response(response.recv_blocking()) - } - - fn end_profile( - &self, - stream_id: StreamId, - token: ProfilingToken, - ) -> Result { - let (callback, response) = async_channel::unbounded(); - self.state - .sender - .send_blocking(Message::StopMeasure(callback, stream_id, token)) - .unwrap(); - handle_response(response.recv_blocking()) - } - - fn allocation_mode( - &self, - mode: crate::memory_management::MemoryAllocationMode, - stream_id: StreamId, - ) { - self.state - .sender - .send_blocking(Message::AllocationMode(stream_id, mode)) - .unwrap() - } - fn copy( - _server_src: &Self, - _server_dst: &Self, - _src: CopyDescriptor<'_>, - _stream_id_src: StreamId, - _stream_id_dst: StreamId, - ) -> Result { - panic!("MPSC doesn't support changing the server") - } -} - -fn handle_response(response: Result) -> Response { - match response { - Ok(val) => val, - Err(err) => panic!("Can't connect to the server correctly {err:?}"), - } -} diff --git a/crates/cubecl-runtime/src/channel/mutex.rs b/crates/cubecl-runtime/src/channel/mutex.rs deleted file mode 100644 index 1ac037ae3..000000000 --- a/crates/cubecl-runtime/src/channel/mutex.rs +++ /dev/null @@ -1,157 +0,0 @@ -use super::ComputeChannel; -use crate::memory_management::MemoryAllocationMode; -use crate::server::{ - Binding, Bindings, ComputeServer, CopyDescriptor, CubeCount, ProfileError, ProfilingToken, -}; -use crate::storage::{BindingResource, ComputeStorage}; -use crate::{ - logging::ServerLogger, - server::{Allocation, AllocationDescriptor, IoError}, -}; -use alloc::sync::Arc; -use alloc::vec::Vec; -use cubecl_common::ExecutionMode; -use cubecl_common::bytes::Bytes; -use cubecl_common::future::DynFut; -use cubecl_common::profile::ProfileDuration; -use cubecl_common::stream_id::StreamId; -use spin::Mutex; - -/// The MutexComputeChannel ensures thread-safety by locking the server -/// on every operation -#[derive(Debug)] -pub struct MutexComputeChannel { - server: Arc>, -} - -impl Clone for MutexComputeChannel { - fn clone(&self) -> Self { - Self { - server: self.server.clone(), - } - } -} -impl MutexComputeChannel -where - Server: ComputeServer, -{ - /// Create a new mutex compute channel. - pub fn new(server: Server) -> Self { - Self { - server: Arc::new(Mutex::new(server)), - } - } -} - -impl ComputeChannel for MutexComputeChannel -where - Server: ComputeServer, -{ - const SERVER_COMM_SUPPORTED: bool = true; - - fn logger(&self) -> Arc { - self.server.lock().logger() - } - fn create( - &self, - descriptors: Vec>, - stream_id: StreamId, - ) -> Result, IoError> { - let mut server = self.server.lock(); - server.create(descriptors, stream_id) - } - - fn read( - &self, - descriptors: Vec>, - stream_id: StreamId, - ) -> DynFut, IoError>> { - let mut server = self.server.lock(); - server.read(descriptors, stream_id) - } - - fn write( - &self, - descriptors: Vec<(CopyDescriptor<'_>, &[u8])>, - stream_id: StreamId, - ) -> Result<(), IoError> { - let mut server = self.server.lock(); - server.write(descriptors, stream_id) - } - - fn copy( - server_src: &Self, - server_dst: &Self, - src: CopyDescriptor<'_>, - stream_id_src: StreamId, - stream_id_dst: StreamId, - ) -> Result { - let mut server_src = server_src.server.lock(); - let mut server_dst = server_dst.server.lock(); - - Server::copy( - &mut server_src, - &mut server_dst, - src, - stream_id_src, - stream_id_dst, - ) - } - - fn sync(&self, stream_id: StreamId) -> DynFut<()> { - let mut server = self.server.lock(); - server.sync(stream_id) - } - - fn get_resource( - &self, - binding: Binding, - stream_id: StreamId, - ) -> BindingResource<::Resource> { - self.server.lock().get_resource(binding, stream_id) - } - - unsafe fn execute( - &self, - kernel: Server::Kernel, - count: CubeCount, - handles: Bindings, - kind: ExecutionMode, - stream_id: StreamId, - ) { - unsafe { - self.server - .lock() - .execute(kernel, count, handles, kind, stream_id) - } - } - - fn flush(&self, stream_id: StreamId) { - self.server.lock().flush(stream_id); - } - - fn memory_usage(&self, stream_id: StreamId) -> crate::memory_management::MemoryUsage { - self.server.lock().memory_usage(stream_id) - } - - fn memory_cleanup(&self, stream_id: StreamId) { - self.server.lock().memory_cleanup(stream_id); - } - - fn start_profile(&self, stream_id: StreamId) -> ProfilingToken { - self.server.lock().start_profile(stream_id) - } - - fn end_profile( - &self, - stream_id: StreamId, - token: ProfilingToken, - ) -> Result { - self.server.lock().end_profile(stream_id, token) - } - - fn allocation_mode(&self, mode: MemoryAllocationMode, stream_id: StreamId) { - let mut server = self.server.lock(); - server.allocation_mode(mode, stream_id) - } -} diff --git a/crates/cubecl-runtime/src/client.rs b/crates/cubecl-runtime/src/client.rs index ae2e7cabc..8d90f5012 100644 --- a/crates/cubecl-runtime/src/client.rs +++ b/crates/cubecl-runtime/src/client.rs @@ -1,13 +1,12 @@ use crate::{ DeviceProperties, - channel::ComputeChannel, config::{TypeNameFormatLevel, type_name_format}, kernel::KernelMetadata, - logging::{ProfileLevel, ServerLogger}, + logging::ProfileLevel, memory_management::{MemoryAllocationMode, MemoryUsage}, server::{ Allocation, AllocationDescriptor, AllocationKind, Binding, Bindings, ComputeServer, - CopyDescriptor, CubeCount, Handle, IoError, ProfileError, + CopyDescriptor, CubeCount, Handle, IoError, ProfileError, ServerUtilities, }, storage::{BindingResource, ComputeStorage}, }; @@ -15,7 +14,14 @@ use alloc::format; use alloc::sync::Arc; use alloc::vec; use alloc::vec::Vec; -use cubecl_common::{ExecutionMode, bytes::Bytes, profile::ProfileDuration}; +use core::ops::DerefMut; +use cubecl_common::{ + ExecutionMode, + bytes::Bytes, + device::{Device, DeviceContext}, + future::DynFut, + profile::ProfileDuration, +}; #[allow(unused)] use cubecl_common::profile::TimingMethod; @@ -23,85 +29,56 @@ use cubecl_common::stream_id::StreamId; /// The ComputeClient is the entry point to require tasks from the ComputeServer. /// It should be obtained for a specific device via the Compute struct. -pub struct ComputeClient { - channel: Channel, - state: Arc>, +pub struct ComputeClient { + context: DeviceContext, + utilities: Arc>, stream_id: Option, } -#[derive(new)] -struct ComputeClientState { - #[cfg(feature = "profile-tracy")] - epoch_time: web_time::Instant, - - #[cfg(feature = "profile-tracy")] - gpu_client: tracy_client::GpuContext, - - properties: DeviceProperties, - info: Server::Info, - logger: Arc, - - #[cfg(multi_threading)] - current_profiling: spin::RwLock>, -} - -impl Clone for ComputeClient +impl Clone for ComputeClient where S: ComputeServer, - C: ComputeChannel, { fn clone(&self) -> Self { Self { - channel: self.channel.clone(), - state: self.state.clone(), + context: self.context.clone(), + utilities: self.utilities.clone(), stream_id: self.stream_id, } } } -impl ComputeClient +impl ComputeClient where Server: ComputeServer, - Channel: ComputeChannel, { /// Get the info of the current backend. pub fn info(&self) -> &Server::Info { - &self.state.info + &self.utilities.info } - /// Create a new client. - pub fn new(channel: Channel, properties: DeviceProperties, info: Server::Info) -> Self { - let logger = channel.logger(); + /// Create a new client with a new server. + pub fn init(device: &D, server: Server) -> Self { + let utilities = server.utilities(); - // Start a tracy client if needed. - #[cfg(feature = "profile-tracy")] - let client = tracy_client::Client::start(); - - let state = ComputeClientState { - properties, - logger, - #[cfg(multi_threading)] - current_profiling: spin::RwLock::new(None), - // Create the GPU client if needed. - #[cfg(feature = "profile-tracy")] - gpu_client: client - .clone() - .new_gpu_context( - Some(&format!("{info:?}")), - // In the future should ask the server what makes sense here. 'Invalid' atm is a generic stand-in (Tracy doesn't have CUDA/RocM atm anyway). - tracy_client::GpuContextType::Invalid, - 0, // Timestamps are manually aligned to this epoch so start at 0. - 1.0, // Timestamps are manually converted to be nanoseconds so period is 1. - ) - .unwrap(), - #[cfg(feature = "profile-tracy")] - epoch_time: web_time::Instant::now(), - info, - }; + let context = DeviceContext::::insert(device, server) + .expect("Can't create a new client on an already registered server"); Self { - channel, - state: Arc::new(state), + context, + utilities, + stream_id: None, + } + } + + /// Load the client for the given device. + pub fn load(device: &D) -> Self { + let context = DeviceContext::::locate(device); + let utilities = context.lock().utilities(); + + Self { + context, + utilities, stream_id: None, } } @@ -122,15 +99,16 @@ where self.stream_id = Some(stream_id); } - async fn do_read(&self, descriptors: Vec>) -> Result, IoError> { - self.profile_guard(); - + fn do_read(&self, descriptors: Vec>) -> DynFut, IoError>> { let stream_id = self.stream_id(); - self.channel.read(descriptors, stream_id).await + let mut state = self.context.lock(); + let fut = state.read(descriptors, stream_id); + core::mem::drop(state); + fut } /// Given bindings, returns owned resources as bytes. - pub async fn read_async(&self, handles: Vec) -> Vec { + pub fn read_async(&self, handles: Vec) -> impl Future> + Send { let strides = [1]; let shapes = handles .iter() @@ -146,7 +124,9 @@ where .map(|(binding, shape)| CopyDescriptor::new(binding, shape, &strides, 1)) .collect(); - self.do_read(descriptors).await.unwrap() + let fut = self.do_read(descriptors); + + async move { fut.await.unwrap() } } /// Given bindings, returns owned resources as bytes. @@ -167,8 +147,13 @@ where } /// Given bindings, returns owned resources as bytes. - pub async fn read_tensor_async(&self, descriptors: Vec>) -> Vec { - self.do_read(descriptors).await.unwrap() + pub fn read_tensor_async( + &self, + descriptors: Vec>, + ) -> impl Future> + Send { + let fut = self.do_read(descriptors); + + async move { fut.await.unwrap() } } /// Given bindings, returns owned resources as bytes. @@ -189,8 +174,13 @@ where /// Given a binding, returns owned resource as bytes. /// See [ComputeClient::read_tensor] - pub async fn read_one_tensor_async(&self, descriptor: CopyDescriptor<'_>) -> Bytes { - self.read_tensor_async(vec![descriptor]).await.remove(0) + pub fn read_one_tensor_async( + &self, + descriptor: CopyDescriptor<'_>, + ) -> impl Future + Send { + let fut = self.read_tensor_async(vec![descriptor]); + + async { fut.await.remove(0) } } /// Given a binding, returns owned resource as bytes. @@ -207,10 +197,8 @@ where &self, binding: Binding, ) -> BindingResource<::Resource> { - self.profile_guard(); - let stream_id = self.stream_id(); - self.channel.get_resource(binding, stream_id) + self.context.lock().get_resource(binding, stream_id) } fn do_create( @@ -218,9 +206,8 @@ where descriptors: Vec>, data: Vec<&[u8]>, ) -> Result, IoError> { - self.profile_guard(); - - let allocations = self.channel.create(descriptors.clone(), self.stream_id())?; + let mut state = self.context.lock(); + let allocations = state.create(descriptors.clone(), self.stream_id())?; let descriptors = descriptors .into_iter() .zip(allocations.iter()) @@ -238,7 +225,7 @@ where }) .collect(); let stream_id = self.stream_id(); - self.channel.write(descriptors, stream_id)?; + state.write(descriptors, stream_id)?; Ok(allocations) } @@ -301,9 +288,8 @@ where &self, descriptors: Vec>, ) -> Result, IoError> { - self.profile_guard(); - - self.channel.create(descriptors, self.stream_id()) + let mut state = self.context.lock(); + state.create(descriptors, self.stream_id()) } /// Reserves `size` bytes in the storage, and returns a handle over them. @@ -331,7 +317,7 @@ where let shape = [src.size() as usize]; let src_descriptor = src.copy_descriptor(&shape, &[1], 1); - if Channel::SERVER_COMM_SUPPORTED && Server::SERVER_COMM_ENABLED { + if Server::SERVER_COMM_ENABLED { self.to_client_tensor(src_descriptor, dst_server) } else { let alloc_desc = AllocationDescriptor::new( @@ -351,10 +337,13 @@ where src_descriptor: CopyDescriptor<'_>, dst_server: &Self, ) -> Allocation { - if Channel::SERVER_COMM_SUPPORTED && Server::SERVER_COMM_ENABLED { - Channel::copy( - &self.channel, - &dst_server.channel, + if Server::SERVER_COMM_ENABLED { + let mut server_src = self.context.lock(); + let mut server_dst = dst_server.context.lock(); + + Server::copy( + server_src.deref_mut(), + server_dst.deref_mut(), src_descriptor, self.stream_id(), dst_server.stream_id(), @@ -379,22 +368,18 @@ where mode: ExecutionMode, stream_id: StreamId, ) { - let level = self.state.logger.profile_level(); + let level = self.utilities.logger.profile_level(); match level { None | Some(ProfileLevel::ExecutionOnly) => { - self.profile_guard(); - + let mut state = self.context.lock(); let name = kernel.name(); - unsafe { - self.channel - .execute(kernel, count, bindings, mode, stream_id) - }; + unsafe { state.execute(kernel, count, bindings, mode, stream_id) }; if matches!(level, Some(ProfileLevel::ExecutionOnly)) { let info = type_name_format(name, TypeNameFormatLevel::Balanced); - self.state.logger.register_execution(info); + self.utilities.logger.register_execution(info); } } Some(level) => { @@ -403,8 +388,8 @@ where let profile = self .profile( || unsafe { - self.channel - .execute(kernel, count.clone(), bindings, mode, stream_id) + let mut state = self.context.lock(); + state.execute(kernel, count.clone(), bindings, mode, stream_id) }, name, ) @@ -415,7 +400,7 @@ where } _ => type_name_format(name, TypeNameFormatLevel::Balanced), }; - self.state.logger.register_profiled(info, profile); + self.utilities.logger.register_profiled(info, profile); } } } @@ -463,38 +448,36 @@ where /// Flush all outstanding commands. pub fn flush(&self) { - self.profile_guard(); - let stream_id = self.stream_id(); - self.channel.flush(stream_id); + self.context.lock().flush(stream_id); } /// Wait for the completion of every task in the server. - pub async fn sync(&self) { - self.profile_guard(); - + pub fn sync(&self) -> DynFut<()> { let stream_id = self.stream_id(); - self.channel.sync(stream_id).await; - self.state.logger.profile_summary(); + let mut state = self.context.lock(); + let fut = state.sync(stream_id); + core::mem::drop(state); + self.utilities.logger.profile_summary(); + + fut } /// Get the features supported by the compute server. pub fn properties(&self) -> &DeviceProperties { - &self.state.properties + &self.utilities.properties } /// # Warning /// /// For private use only. pub fn properties_mut(&mut self) -> Option<&mut DeviceProperties> { - Arc::get_mut(&mut self.state).map(|state| &mut state.properties) + Arc::get_mut(&mut self.utilities).map(|state| &mut state.properties) } /// Get the current memory usage of this client. pub fn memory_usage(&self) -> MemoryUsage { - self.profile_guard(); - - self.channel.memory_usage(self.stream_id()) + self.context.lock().memory_usage(self.stream_id()) } /// Change the memory allocation mode. @@ -503,38 +486,33 @@ where /// /// This function isn't thread safe and might create memory leaks. pub unsafe fn allocation_mode(&self, mode: MemoryAllocationMode) { - self.profile_guard(); - - self.channel.allocation_mode(mode, self.stream_id()) + self.context.lock().allocation_mode(mode, self.stream_id()) } - /// Use a static memory strategy to execute the provided function. + /// Use a persistent memory strategy to execute the provided function. /// /// # Notes /// - /// Using that memory strategy is beneficial for weights loading and similar workflows. - /// However make sure to call [Self::memory_cleanup()] if you want to free the allocated - /// memory. - pub fn memory_static_allocation Output>( + /// - Using that memory strategy is beneficial for stating model parameters and similar workflows. + /// - You can call [Self::memory_cleanup()] if you want to free persistent memory. + pub fn memory_persistent_allocation Output>( &self, input: Input, func: Func, ) -> Output { - // We use the same profiling lock to make sure no other task is currently using the current - // device. Meaning that the current static memory strategy will only be used for the - // provided function. + let device_guard = self.context.lock_device(); - #[cfg(multi_threading)] - let stream_id = self.profile_acquire(); + self.context + .lock() + .allocation_mode(MemoryAllocationMode::Persistent, self.stream_id()); - self.channel - .allocation_mode(MemoryAllocationMode::Static, self.stream_id()); let output = func(input); - self.channel + + self.context + .lock() .allocation_mode(MemoryAllocationMode::Auto, self.stream_id()); - #[cfg(multi_threading)] - self.profile_release(stream_id); + core::mem::drop(device_guard); output } @@ -544,9 +522,7 @@ where /// Nb: Results will vary on what the memory allocator deems beneficial, /// so it's not guaranteed any memory is freed. pub fn memory_cleanup(&self) { - self.profile_guard(); - - self.channel.memory_cleanup(self.stream_id()) + self.context.lock().memory_cleanup(self.stream_id()) } /// Measure the execution time of some inner operations. @@ -571,8 +547,7 @@ where 0, ); - #[cfg(multi_threading)] - let stream_id = self.profile_acquire(); + let device_guard = self.context.lock_device(); #[cfg(feature = "profile-tracy")] let gpu_span = if self.state.properties.timing_method == TimingMethod::Device { @@ -586,12 +561,11 @@ where None }; - let token = self.channel.start_profile(self.stream_id()); + let token = self.context.lock().start_profile(self.stream_id()); let out = func(); - #[allow(unused_mut)] - let mut result = self.channel.end_profile(self.stream_id(), token); + let result = self.context.lock().end_profile(self.stream_id(), token); core::mem::drop(out); @@ -614,9 +588,7 @@ where ) }); } - - #[cfg(multi_threading)] - self.profile_release(stream_id); + core::mem::drop(device_guard); result } @@ -634,12 +606,13 @@ where // Allocate destination let alloc = dst_server - .channel + .context + .lock() .create(vec![alloc_descriptor], self.stream_id()) .unwrap() .remove(0); - let read = self.channel.read(vec![src_descriptor], stream_id); + let read = self.context.lock().read(vec![src_descriptor], stream_id); let data = cubecl_common::future::block_on(read).unwrap(); let desc_descriptor = CopyDescriptor { @@ -650,102 +623,11 @@ where }; dst_server - .channel + .context + .lock() .write(vec![(desc_descriptor, &data[0])], stream_id) .unwrap(); alloc } - - #[cfg(not(multi_threading))] - fn profile_guard(&self) {} - - #[cfg(multi_threading)] - fn profile_guard(&self) { - let current = self.state.current_profiling.read(); - - if let Some(current_stream_id) = current.as_ref() { - let stream_id = self.stream_id(); - - if current_stream_id == &stream_id { - return; - } - - core::mem::drop(current); - - loop { - std::thread::sleep(core::time::Duration::from_millis(10)); - - let current = self.state.current_profiling.read(); - match current.as_ref() { - Some(current_stream_id) => { - if current_stream_id == &stream_id { - return; - } - } - None => { - return; - } - } - } - } - } - - #[cfg(multi_threading)] - fn profile_acquire(&self) -> Option { - let stream_id = self.stream_id(); - let mut current = self.state.current_profiling.write(); - - match current.as_mut() { - Some(current_stream_id) => { - if current_stream_id == &stream_id { - return None; - } - - core::mem::drop(current); - - loop { - std::thread::sleep(core::time::Duration::from_millis(10)); - - let mut current = self.state.current_profiling.write(); - - match current.as_mut() { - Some(current_stream_id) => { - if current_stream_id == &stream_id { - return None; - } - } - None => { - *current = Some(stream_id); - return Some(stream_id); - } - } - } - } - None => { - *current = Some(stream_id); - Some(stream_id) - } - } - } - - #[cfg(multi_threading)] - fn profile_release(&self, stream_id: Option) { - let stream_id = match stream_id { - Some(val) => val, - None => return, // No releasing - }; - let mut current = self.state.current_profiling.write(); - - match current.as_mut() { - Some(current_stream_id) => { - if current_stream_id != &stream_id { - panic!("Can't release a different profiling guard."); - } else { - *current = None; - } - } - None => panic!("Can't release an empty profiling guard"), - } - } } diff --git a/crates/cubecl-runtime/src/config/base.rs b/crates/cubecl-runtime/src/config/base.rs index 848926c2c..e9dd3c68c 100644 --- a/crates/cubecl-runtime/src/config/base.rs +++ b/crates/cubecl-runtime/src/config/base.rs @@ -1,3 +1,4 @@ +use crate::config::memory::MemoryConfig; use crate::config::streaming::StreamingConfig; use super::{autotune::AutotuneConfig, compilation::CompilationConfig, profiling::ProfilingConfig}; @@ -26,6 +27,10 @@ pub struct GlobalConfig { /// Configuration for streaming settings. #[serde(default)] pub streaming: StreamingConfig, + + /// Configuration for memory settings. + #[serde(default)] + pub memory: MemoryConfig, } impl GlobalConfig { diff --git a/crates/cubecl-runtime/src/config/logger.rs b/crates/cubecl-runtime/src/config/logger.rs index 5d8d20774..e382396d8 100644 --- a/crates/cubecl-runtime/src/config/logger.rs +++ b/crates/cubecl-runtime/src/config/logger.rs @@ -1,7 +1,7 @@ use super::GlobalConfig; use crate::config::{ - autotune::AutotuneLogLevel, compilation::CompilationLogLevel, profiling::ProfilingLogLevel, - streaming::StreamingLogLevel, + autotune::AutotuneLogLevel, compilation::CompilationLogLevel, memory::MemoryLogLevel, + profiling::ProfilingLogLevel, streaming::StreamingLogLevel, }; use alloc::{string::ToString, sync::Arc, vec::Vec}; use core::fmt::Display; @@ -118,6 +118,9 @@ pub struct Logger { /// Indices of loggers used for streaming logging. streaming_index: Vec, + /// Indices of loggers used for memory logging. + memory_index: Vec, + /// Global configuration for logging settings. pub config: Arc, } @@ -142,6 +145,7 @@ impl Logger { let mut profiling_index = Vec::new(); let mut autotune_index = Vec::new(); let mut streaming_index = Vec::new(); + let mut memory_index = Vec::new(); #[derive(Hash, PartialEq, Eq)] enum LoggerId { @@ -281,12 +285,25 @@ impl Logger { ) } + if let MemoryLogLevel::Disabled = config.memory.logger.level { + } else { + register_logger( + &config.memory.logger, + config.memory.logger.append, + config.memory.logger.log, + &mut memory_index, + &mut loggers, + &mut logger2index, + ) + } + Self { loggers, compilation_index, profiling_index, autotune_index, streaming_index, + memory_index, config, } } @@ -305,6 +322,20 @@ impl Logger { } } + /// Logs a message for memory, directing it to all configured streaming loggers. + pub fn log_memory(&mut self, msg: &S) { + let length = self.memory_index.len(); + if length > 1 { + let msg = msg.to_string(); + for i in 0..length { + let index = self.memory_index[i]; + self.log(&msg, index) + } + } else if let Some(index) = self.memory_index.first() { + self.log(&msg, *index) + } + } + /// Logs a message for compilation, directing it to all configured compilation loggers. pub fn log_compilation(&mut self, msg: &S) { let length = self.compilation_index.len(); diff --git a/crates/cubecl-runtime/src/config/memory.rs b/crates/cubecl-runtime/src/config/memory.rs new file mode 100644 index 000000000..f1d4ed833 --- /dev/null +++ b/crates/cubecl-runtime/src/config/memory.rs @@ -0,0 +1,48 @@ +use super::logger::{LogLevel, LoggerConfig}; + +/// Configuration for memory settings in CubeCL. +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Default)] +pub struct MemoryConfig { + /// Logger configuration for memory-related logs, using specific log levels. + #[serde(default)] + pub logger: LoggerConfig, + /// Configuration for persistent memory pools. + #[serde(default)] + pub persistent_memory: PersistentMemory, +} + +/// Configuration options for persistent memory pools in CubeCL runtimes. +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Default)] +pub enum PersistentMemory { + /// Persistent memory is enabled but used only when explicitly specified. + #[default] + #[serde(rename = "enabled")] + Enabled, + /// Persistent memory is disabled, allowing dynamic allocations. + #[serde(rename = "disabled")] + Disabled, + /// Persistent memory is enforced, preventing dynamic allocations. + /// + /// # Warning + /// + /// Enforcing persistent memory may cause out-of-memory errors if tensors of varying sizes are used. + #[serde(rename = "enforced")] + Enforced, +} + +/// Log levels for memory-related events in CubeCL. +#[derive(Default, Clone, Copy, Debug, serde::Serialize, serde::Deserialize)] +pub enum MemoryLogLevel { + /// No memory-related logging. + #[default] + #[serde(rename = "disabled")] + Disabled, + /// Logs basic memory events, such as creating memory pages and manually cleaning memory. + #[serde(rename = "basic")] + Basic, + /// Logs detailed memory information. + #[serde(rename = "full")] + Full, +} + +impl LogLevel for MemoryLogLevel {} diff --git a/crates/cubecl-runtime/src/config/mod.rs b/crates/cubecl-runtime/src/config/mod.rs index d76346149..f16171816 100644 --- a/crates/cubecl-runtime/src/config/mod.rs +++ b/crates/cubecl-runtime/src/config/mod.rs @@ -5,6 +5,8 @@ pub mod autotune; pub mod cache; /// Compilation config module. pub mod compilation; +/// Memory config module. +pub mod memory; /// Profiling config module. pub mod profiling; /// Streaming config module. diff --git a/crates/cubecl-runtime/src/lib.rs b/crates/cubecl-runtime/src/lib.rs index 2553658ef..4c2d4dfe8 100644 --- a/crates/cubecl-runtime/src/lib.rs +++ b/crates/cubecl-runtime/src/lib.rs @@ -18,8 +18,6 @@ pub mod kernel; #[cfg(feature = "std")] pub mod stream; -/// Compute channel module. -pub mod channel; /// Compute client module. pub mod client; @@ -40,8 +38,6 @@ mod feature_set; /// Runtime features and associated types pub mod features; -mod base; -pub use base::*; pub use cubecl_common::benchmark; pub use feature_set::*; diff --git a/crates/cubecl-runtime/src/logging/server.rs b/crates/cubecl-runtime/src/logging/server.rs index 4d41d7e08..51f016d27 100644 --- a/crates/cubecl-runtime/src/logging/server.rs +++ b/crates/cubecl-runtime/src/logging/server.rs @@ -1,5 +1,6 @@ use core::fmt::Display; +use crate::config::memory::MemoryLogLevel; use crate::config::streaming::StreamingLogLevel; use crate::config::{Logger, compilation::CompilationLogLevel, profiling::ProfilingLogLevel}; use alloc::format; @@ -15,6 +16,7 @@ enum LogMessage { Execution(String), Compilation(String), Streaming(String), + Memory(String), Profile(String, ProfileDuration), ProfileSummary, } @@ -26,6 +28,7 @@ pub struct ServerLogger { log_compile_info: bool, log_streaming: StreamingLogLevel, log_channel: Option>, + log_memory: MemoryLogLevel, } impl Default for ServerLogger { @@ -38,10 +41,11 @@ impl Default for ServerLogger { ) && matches!( logger.config.profiling.logger.level, ProfilingLogLevel::Disabled - ) && matches!( - logger.config.streaming.logger.level, - StreamingLogLevel::Disabled - ); + ) && matches!(logger.config.memory.logger.level, MemoryLogLevel::Disabled) + && matches!( + logger.config.streaming.logger.level, + StreamingLogLevel::Disabled + ); if disabled { return Self { @@ -49,6 +53,7 @@ impl Default for ServerLogger { log_compile_info: false, log_streaming: StreamingLogLevel::Disabled, log_channel: None, + log_memory: MemoryLogLevel::Disabled, }; } let profile_level = match logger.config.profiling.logger.level { @@ -65,6 +70,7 @@ impl Default for ServerLogger { CompilationLogLevel::Full => true, }; let log_streaming = logger.config.streaming.logger.level; + let log_memory = logger.config.memory.logger.level; let (send, rec) = async_channel::unbounded(); @@ -81,6 +87,7 @@ impl Default for ServerLogger { profile_level, log_compile_info, log_streaming, + log_memory, log_channel: Some(send), } } @@ -124,6 +131,20 @@ impl ServerLogger { } } + /// Log the argument to the logger when the memory logger is activated. + pub fn log_memory String, C: FnOnce(MemoryLogLevel) -> bool>( + &self, + cond: C, + format: I, + ) { + if let Some(channel) = &self.log_channel + && cond(self.log_memory) + { + // Channel will never be full, don't care if it's closed. + let _ = channel.try_send(LogMessage::Memory(format())); + } + } + /// Register a profiled task without timing. pub fn register_execution(&self, name: impl Display) { if let Some(channel) = &self.log_channel @@ -171,6 +192,9 @@ impl AsyncLogger { LogMessage::Streaming(msg) => { self.logger.log_streaming(&msg); } + LogMessage::Memory(msg) => { + self.logger.log_memory(&msg); + } LogMessage::Profile(name, profile) => { let duration = profile.resolve().await.duration(); self.profiled.update(&name, duration); diff --git a/crates/cubecl-runtime/src/memory_management/base.rs b/crates/cubecl-runtime/src/memory_management/base.rs index a78a31cea..e54758672 100644 --- a/crates/cubecl-runtime/src/memory_management/base.rs +++ b/crates/cubecl-runtime/src/memory_management/base.rs @@ -1,5 +1,5 @@ #[cfg(not(feature = "std"))] -use alloc::{format, string::String}; +use alloc::string::{String, ToString}; /// Amount of memory in use by this allocator /// and statistics on how much memory is reserved and @@ -7,6 +7,9 @@ use alloc::{format, string::String}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct MemoryUsage { /// The number of allocations currently active. + /// + /// This is not the number of times an actual allocation happens to create a new memory page, + /// but really the number of active slices. pub number_allocs: u64, /// The number of bytes that are currently actually in use. /// @@ -36,26 +39,37 @@ impl MemoryUsage { } } -fn bytes_format(bytes: u64) -> String { - let unit = 1000; +#[derive(new)] +pub(crate) struct BytesFormat { + bytes: u64, +} + +impl core::fmt::Display for BytesFormat { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let unit = 1000; - if bytes < unit { - format!("{bytes} B") - } else { - let size = bytes as f64; - let exp = match size.log(1000.0).floor() as usize { - 0 => 1, - e => e, - }; - let unit_prefix = "KMGTPEZY".as_bytes(); - format!( - "{:.2} {}B", - (size / unit.pow(exp as u32) as f64), - unit_prefix[exp - 1] as char, - ) + if self.bytes < unit { + f.write_fmt(format_args!("{} B", self.bytes)) + } else { + let size = self.bytes as f64; + let exp = match size.log(1000.0).floor() as usize { + 0 => 1, + e => e, + }; + let unit_prefix = "KMGTPEZY".as_bytes(); + f.write_fmt(format_args!( + "{:.2} {}B", + (size / unit.pow(exp as u32) as f64), + unit_prefix[exp - 1] as char, + )) + } } } +fn bytes_format(bytes: u64) -> String { + BytesFormat::new(bytes).to_string() +} + impl core::fmt::Display for MemoryUsage { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { // In the future it'd be nice if MemoryUsage also held some stats about say, diff --git a/crates/cubecl-runtime/src/memory_management/memory_manage.rs b/crates/cubecl-runtime/src/memory_management/memory_manage.rs index 177ccdf64..d3639b52d 100644 --- a/crates/cubecl-runtime/src/memory_management/memory_manage.rs +++ b/crates/cubecl-runtime/src/memory_management/memory_manage.rs @@ -1,15 +1,24 @@ use super::{ MemoryConfiguration, MemoryDeviceProperties, MemoryPoolOptions, MemoryUsage, PoolType, - memory_pool::{ExclusiveMemoryPool, MemoryPool, SlicedPool, StaticPool}, + memory_pool::{ExclusiveMemoryPool, MemoryPool, PersistentPool, SlicedPool}, }; use crate::{ + config::{ + GlobalConfig, + memory::{MemoryLogLevel, PersistentMemory}, + }, + logging::ServerLogger, + memory_management::BytesFormat, server::IoError, storage::{ComputeStorage, StorageHandle}, }; +use alloc::format; +use alloc::string::{String, ToString}; #[cfg(not(exclusive_memory_only))] use alloc::vec; use alloc::vec::Vec; +use cubecl_common::stub::Arc; pub use super::memory_pool::{SliceBinding, handle::*}; @@ -22,6 +31,13 @@ enum DynamicPool { } impl MemoryPool for DynamicPool { + fn accept(&self, size: u64) -> bool { + match self { + DynamicPool::Sliced(pool) => pool.accept(size), + DynamicPool::Exclusive(pool) => pool.accept(size), + } + } + fn get(&self, binding: &SliceBinding) -> Option<&StorageHandle> { match self { DynamicPool::Sliced(m) => m.get(binding), @@ -54,13 +70,6 @@ impl MemoryPool for DynamicPool { } } - fn max_alloc_size(&self) -> u64 { - match self { - DynamicPool::Sliced(m) => m.max_alloc_size(), - DynamicPool::Exclusive(m) => m.max_alloc_size(), - } - } - fn cleanup( &mut self, storage: &mut Storage, @@ -74,24 +83,27 @@ impl MemoryPool for DynamicPool { } } -#[derive(Default, Clone, Copy)] +#[derive(Default, Clone, Copy, Debug)] /// The mode of allocation used. pub enum MemoryAllocationMode { /// Use the automatic memory management strategy for allocation. #[default] Auto, - /// Use a static memory management strategy, meaning that all allocations are for data that is - /// never going to be freed. - Static, + /// Use a persistent memory management strategy, meaning that all allocations are for data that is + /// likely never going to be freed. + Persistent, } /// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks. pub struct MemoryManagement { - static_pool: StaticPool, + name: String, + persistent: PersistentPool, pools: Vec, storage: Storage, alloc_reserve_count: u64, mode: MemoryAllocationMode, + config: PersistentMemory, + logger: Arc, } fn generate_bucket_sizes( @@ -122,12 +134,49 @@ fn generate_bucket_sizes( const DEALLOC_SCALE_MB: u64 = 1024 * 1024 * 1024; const BASE_DEALLOC_PERIOD: u64 = 5000; +/// The options for creating a new [MemoryManagement] instance. +#[derive(Debug)] +pub struct MemoryManagementOptions { + /// The name of the memory management. + name: String, + /// The [MemoryAllocationOption] used by this instance. + memory: MemoryAllocationOption, +} + +impl MemoryManagementOptions { + /// Creates a new [MemoryManagementOptions]. + pub fn new>(name: S) -> Self { + Self { + name: name.into(), + memory: MemoryAllocationOption::FromConfig, + } + } + + /// Forces the [MemoryAllocationMode] during execution to always be the provided one. + pub fn mode(mut self, mode: MemoryAllocationMode) -> Self { + self.memory = MemoryAllocationOption::Provided(mode); + self + } +} + +#[derive(Default, Debug)] +/// Determines which [MemoryAllocationMode] is used during allocations. +enum MemoryAllocationOption { + #[default] + /// Uses the [GlobalConfig] to determine the mode of allocation. + FromConfig, + /// Use the provided [MemoryAllocationMode]. + Provided(MemoryAllocationMode), +} + impl MemoryManagement { /// Creates the options from device limits. pub fn from_configuration( storage: Storage, properties: &MemoryDeviceProperties, config: MemoryConfiguration, + logger: Arc, + options: MemoryManagementOptions, ) -> Self { let pool_options = match config { #[cfg(not(exclusive_memory_only))] @@ -222,9 +271,16 @@ impl MemoryManagement { MemoryConfiguration::Custom { pool_options } => pool_options, }; - for pool in pool_options.iter() { - log::trace!("Using memory pool: \n {pool:?}"); - } + logger.log_memory( + |level| !matches!(level, MemoryLogLevel::Disabled), + || { + let mut msg = String::new(); + for pool in pool_options.iter() { + msg += &format!("[{}] Using memory pool: \n {pool:?}", options.name); + } + msg + }, + ); let pools: Vec<_> = pool_options .iter() @@ -247,23 +303,57 @@ impl MemoryManagement { }) .collect(); + let config = GlobalConfig::get().memory.persistent_memory.clone(); + + let mode = match options.memory { + MemoryAllocationOption::Provided(mode) => mode, + MemoryAllocationOption::FromConfig => match config { + PersistentMemory::Enabled => MemoryAllocationMode::Auto, + PersistentMemory::Disabled => MemoryAllocationMode::Auto, + PersistentMemory::Enforced => MemoryAllocationMode::Persistent, + }, + }; + Self { - static_pool: StaticPool::new(properties.max_page_size), + name: options.name, + persistent: PersistentPool::new(properties.max_page_size, properties.alignment), pools, storage, alloc_reserve_count: 0, - mode: MemoryAllocationMode::Auto, + mode, + config, + logger, } } /// Change the mode of allocation. pub fn mode(&mut self, mode: MemoryAllocationMode) { + // We override the mode based on the cubecl config. + let mode = match self.config { + PersistentMemory::Enabled => mode, + PersistentMemory::Disabled | PersistentMemory::Enforced => return, + }; + + self.logger.log_memory( + |level| !matches!(level, MemoryLogLevel::Disabled), + || { + format!( + "[{}] Setting memory allocation mode: from {:?} => {mode:?}", + self.name, self.mode + ) + }, + ); self.mode = mode; } /// Cleanup allocations in pools that are deemed unnecessary. pub fn cleanup(&mut self, explicit: bool) { - self.static_pool + self.logger.log_memory( + |level| !matches!(level, MemoryLogLevel::Disabled) && explicit, + || "Manual memory cleanup ...".to_string(), + ); + + self.persistent .cleanup(&mut self.storage, self.alloc_reserve_count, explicit); for pool in self.pools.iter_mut() { @@ -273,7 +363,7 @@ impl MemoryManagement { /// Returns the storage from the specified binding pub fn get(&mut self, binding: SliceBinding) -> Option { - if let Some(val) = self.static_pool.get(&binding) { + if let Some(val) = self.persistent.get(&binding) { return Some(val.clone()); } @@ -304,26 +394,73 @@ impl MemoryManagement { /// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it pub fn reserve(&mut self, size: u64) -> Result { - if let MemoryAllocationMode::Static = self.mode { - return self.static_pool.alloc(&mut self.storage, size); - } - // If this happens every nanosecond, counts overflows after 585 years, so not worth thinking too // hard about overflow here. self.alloc_reserve_count += 1; + if let Some(val) = self.persistent.try_reserve(size) { + self.logger.log_memory( + |level| matches!(level, MemoryLogLevel::Full), + || { + format!( + "[{}] Reserved memory {size} using persistent memory", + self.name + ) + }, + ); + return Ok(val); + } + + if matches!(self.mode, MemoryAllocationMode::Persistent) || self.persistent.has_size(size) { + let allocated = self.persistent.alloc(&mut self.storage, size); + + self.logger.log_memory( + |level| !matches!(level, MemoryLogLevel::Disabled), + || { + format!( + "[{}] Allocated a new memory page using persistent memory, \n{}", + self.name, self, + ) + }, + ); + return allocated; + } + + self.logger.log_memory( + |level| matches!(level, MemoryLogLevel::Full), + || { + format!( + "[{}] Reserved memory {} using dynamic pool", + self.name, + BytesFormat::new(size) + ) + }, + ); + // Find first pool that fits this allocation let pool = self .pools .iter_mut() - .find(|p| p.max_alloc_size() >= size) + .find(|p| p.accept(size)) .ok_or(IoError::BufferTooBig(size as usize))?; if let Some(slice) = pool.try_reserve(size) { return Ok(slice); } - pool.alloc(&mut self.storage, size) + let allocated = pool.alloc(&mut self.storage, size); + + self.logger.log_memory( + |level| matches!(level, MemoryLogLevel::Full), + || { + format!( + "[{}], Allocated a new memory page, current usage: \n{}", + self.name, self + ) + }, + ); + + allocated } /// Fetch the storage used by the memory manager. @@ -351,7 +488,7 @@ impl MemoryManagement { }, |m1, m2| m1.combine(m2), ); - memory_usage.combine(self.static_pool.get_memory_usage()) + memory_usage.combine(self.persistent.get_memory_usage()) } /// Print out a report of the current memory usage. @@ -360,6 +497,25 @@ impl MemoryManagement { log::info!("{}", self.memory_usage()); } } +impl core::fmt::Display for MemoryManagement { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str("\n# MemoryManagement\n\n")?; + f.write_fmt(format_args!(" - name: {:?}\n", self.name))?; + f.write_fmt(format_args!("\n## Persistent\n\n{}", self.persistent))?; + f.write_str("\n## Dynamic\n\n")?; + + for pool in self.pools.iter() { + match pool { + DynamicPool::Sliced(pool) => f.write_fmt(format_args!("{pool}\n"))?, + DynamicPool::Exclusive(pool) => f.write_fmt(format_args!("{pool}\n"))?, + } + } + let memory_usage = self.memory_usage(); + f.write_fmt(format_args!("\n## Summary\n\n{memory_usage}"))?; + + Ok(()) + } +} impl core::fmt::Debug for MemoryManagement { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { @@ -383,6 +539,13 @@ mod tests { alignment: 32, }; + fn options() -> MemoryManagementOptions { + MemoryManagementOptions { + name: "test".into(), + memory: MemoryAllocationOption::FromConfig, + } + } + // Test pools with slices. #[test] #[cfg(not(exclusive_memory_only))] @@ -391,6 +554,8 @@ mod tests { BytesStorage::default(), &DUMMY_MEM_PROPS, MemoryConfiguration::SubSlices, + Arc::new(ServerLogger::default()), + options(), ); let handle = memory_management.reserve(10).unwrap(); let other_ref = handle.clone(); @@ -416,6 +581,8 @@ mod tests { dealloc_period: None, }], }, + Arc::new(ServerLogger::default()), + options(), ); let handle = memory_management.reserve(100); let usage = memory_management.memory_usage(); @@ -446,6 +613,8 @@ mod tests { dealloc_period: None, }], }, + Arc::new(ServerLogger::default()), + options(), ); let alloc_size = 512; @@ -475,6 +644,8 @@ mod tests { dealloc_period: None, }], }, + Arc::new(ServerLogger::default()), + options(), ); let alloc_size = 512; @@ -504,6 +675,8 @@ mod tests { dealloc_period: None, }], }, + Arc::new(ServerLogger::default()), + options(), ); let alloc_size = 768; @@ -534,6 +707,8 @@ mod tests { dealloc_period: None, }], }, + Arc::new(ServerLogger::default()), + options(), ); let alloc_size = 40; let _handle = memory_management.reserve(alloc_size); @@ -566,6 +741,8 @@ mod tests { MemoryConfiguration::Custom { pool_options: pools, }, + Arc::new(ServerLogger::default()), + options(), ); // Allocate one thing on each page. let alloc_sizes = [50, 150, 250, 350]; @@ -588,6 +765,8 @@ mod tests { alignment: 32, }, MemoryConfiguration::SubSlices, + Arc::new(ServerLogger::default()), + options(), ); // Allocate a bunch let handles: Vec<_> = (0..5) @@ -617,6 +796,8 @@ mod tests { alignment: 32, }, MemoryConfiguration::SubSlices, + Arc::new(ServerLogger::default()), + options(), ); // Allocate a mix of small and large chunks let sizes = [50, 1000, 100, 5000, 200, 10000, 300]; @@ -648,6 +829,8 @@ mod tests { alignment: 32, }), MemoryConfiguration::ExclusivePages, + Arc::new(ServerLogger::default()), + options(), ); let handle = memory_management.reserve(10).unwrap(); let other_ref = handle.clone(); @@ -669,6 +852,8 @@ mod tests { dealloc_period: None, }], }, + Arc::new(ServerLogger::default()), + options(), ); let alloc_size = 512; @@ -695,6 +880,8 @@ mod tests { dealloc_period: None, }], }, + Arc::new(ServerLogger::default()), + options(), ); let alloc_size = 512; @@ -721,6 +908,8 @@ mod tests { dealloc_period: None, }], }, + Arc::new(ServerLogger::default()), + options(), ); let alloc_size = 768; @@ -748,6 +937,8 @@ mod tests { dealloc_period: None, }], }, + Arc::new(ServerLogger::default()), + options(), ); let alloc_size = 40; let _handle = memory_management.reserve(alloc_size); @@ -778,6 +969,8 @@ mod tests { MemoryConfiguration::Custom { pool_options: pools, }, + Arc::new(ServerLogger::default()), + options(), ); // Allocate one thing on each page. let alloc_sizes = [50, 150, 250, 350]; @@ -796,6 +989,8 @@ mod tests { alignment: 32, }, MemoryConfiguration::ExclusivePages, + Arc::new(ServerLogger::default()), + options(), ); // Allocate a bunch let handles: Vec<_> = (0..5) diff --git a/crates/cubecl-runtime/src/memory_management/memory_pool/base.rs b/crates/cubecl-runtime/src/memory_management/memory_pool/base.rs index 7e9d8bf4c..363dd30ce 100644 --- a/crates/cubecl-runtime/src/memory_management/memory_pool/base.rs +++ b/crates/cubecl-runtime/src/memory_management/memory_pool/base.rs @@ -5,7 +5,56 @@ use crate::{ storage::{ComputeStorage, StorageHandle}, }; +/// Declares how memory is allocated in a reusable pool. +pub trait MemoryPool { + /// Whether the memory pool accepts the given size. + fn accept(&self, size: u64) -> bool; + + /// Retrieves the [storage handle](StorageHandle) using the [slice binding](SliceBinding). + fn get(&self, binding: &SliceBinding) -> Option<&StorageHandle>; + + /// Try to reserve a memory slice of the given size. + /// + /// # Notes + /// + /// It is not guaranteed the `try_reserve` function will reapply the accept function. + /// Therefore it is a good idea to call [MemoryUsage::accept()] before using `try_reserve`. + /// + /// # Returns + /// + /// A [slice handle](StorageHandle) if the current memory pool has enough memory, otherwise it + /// will returns [None]. You can then call [MemoryPool::alloc()] to increase the amount of + /// memory the pool has. + fn try_reserve(&mut self, size: u64) -> Option; + + /// Increases the amount of memory the pool has and returns a [slice handle](StorageHandle) + /// corresponding to the requested size. + /// + /// # Notes + /// + /// The function uses a [ComputeStorage] to perform the allocation. It might return an error + /// if the allocation fails or if the requested size is bigger than the memory pool is + /// configured to handle. + fn alloc( + &mut self, + storage: &mut Storage, + size: u64, + ) -> Result; + + /// Computes the [MemoryUsage] for this pool. + fn get_memory_usage(&self) -> MemoryUsage; + + /// Cleanup the memory pool, maybe freeing some memory using the [ComputeStorage]. + fn cleanup( + &mut self, + storage: &mut Storage, + alloc_nr: u64, + explicit: bool, + ); +} + #[derive(new, Debug)] +/// Slice of data with its associated storage. pub(crate) struct Slice { pub storage: StorageHandle, pub handle: SliceHandle, @@ -13,47 +62,28 @@ pub(crate) struct Slice { } impl Slice { + /// If the slice is free to be reused. pub(crate) fn is_free(&self) -> bool { self.handle.is_free() } + /// The total size of the slice including padding. pub(crate) fn effective_size(&self) -> u64 { self.storage.size() + self.padding } + /// The id of the slice. pub(crate) fn id(&self) -> SliceId { *self.handle.id() } } -pub(crate) fn calculate_padding(size: u64, buffer_alignment: u64) -> u64 { - let remainder = size % buffer_alignment; +/// Calculates the padding required to store the given size in a buffer given the memory alignment. +pub(crate) fn calculate_padding(size: u64, memory_alignment: u64) -> u64 { + let remainder = size % memory_alignment; if remainder != 0 { - buffer_alignment - remainder + memory_alignment - remainder } else { 0 } } - -pub trait MemoryPool { - fn max_alloc_size(&self) -> u64; - - fn get(&self, binding: &SliceBinding) -> Option<&StorageHandle>; - - fn try_reserve(&mut self, size: u64) -> Option; - - fn alloc( - &mut self, - storage: &mut Storage, - size: u64, - ) -> Result; - - fn get_memory_usage(&self) -> MemoryUsage; - - fn cleanup( - &mut self, - storage: &mut Storage, - alloc_nr: u64, - explicit: bool, - ); -} diff --git a/crates/cubecl-runtime/src/memory_management/memory_pool/exclusive_pool.rs b/crates/cubecl-runtime/src/memory_management/memory_pool/exclusive_pool.rs index 583c704cd..e62dc26ad 100644 --- a/crates/cubecl-runtime/src/memory_management/memory_pool/exclusive_pool.rs +++ b/crates/cubecl-runtime/src/memory_management/memory_pool/exclusive_pool.rs @@ -1,5 +1,5 @@ use crate::{ - memory_management::MemoryUsage, + memory_management::{BytesFormat, MemoryUsage}, server::IoError, storage::{ComputeStorage, StorageHandle, StorageUtilization}, }; @@ -22,6 +22,28 @@ pub struct ExclusiveMemoryPool { cur_avg_size: f64, } +impl core::fmt::Display for ExclusiveMemoryPool { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!( + " - Exclusive Pool max_alloc_size={}\n", + BytesFormat::new(self.max_alloc_size) + ))?; + + for page in self.pages.iter() { + let is_free = page.slice.is_free(); + let size = BytesFormat::new(page.slice.effective_size()); + + f.write_fmt(format_args!(" - Page {size} is_free={is_free}\n"))?; + } + + if !self.pages.is_empty() { + f.write_fmt(format_args!("\n{}\n", self.get_memory_usage()))?; + } + + Ok(()) + } +} + const SIZE_AVG_DECAY: f64 = 0.01; // How many times to find the allocation 'free' before deallocating it. @@ -92,6 +114,9 @@ impl ExclusiveMemoryPool { } impl MemoryPool for ExclusiveMemoryPool { + fn accept(&self, size: u64) -> bool { + self.max_alloc_size >= size + } /// Returns the resource from the storage, for the specified handle. fn get(&self, binding: &SliceBinding) -> Option<&StorageHandle> { let binding_id = *binding.id(); @@ -152,10 +177,6 @@ impl MemoryPool for ExclusiveMemoryPool { } } - fn max_alloc_size(&self) -> u64 { - self.max_alloc_size - } - fn cleanup( &mut self, storage: &mut Storage, diff --git a/crates/cubecl-runtime/src/memory_management/memory_pool/index.rs b/crates/cubecl-runtime/src/memory_management/memory_pool/index.rs deleted file mode 100644 index d922a1163..000000000 --- a/crates/cubecl-runtime/src/memory_management/memory_pool/index.rs +++ /dev/null @@ -1,65 +0,0 @@ -use alloc::collections::BTreeMap; -use alloc::vec; -use alloc::vec::Vec; -use core::hash::Hash; -use hashbrown::HashMap; - -/// Data Structure that helps to search items by size efficiently. -pub struct SearchIndex { - items_per_size: BTreeMap>, - sizes_per_item: HashMap, -} - -impl SearchIndex { - /// Create a new item search index. - pub fn new() -> Self { - Self { - items_per_size: BTreeMap::new(), - sizes_per_item: HashMap::new(), - } - } - - /// Insert a new sized item into the search index. - pub fn insert(&mut self, item: T, size: u64) { - self.remove(&item); - - if let Some(values) = self.items_per_size.get_mut(&size) { - values.push(item.clone()) - } else { - self.items_per_size.insert(size, vec![item.clone()]); - } - self.sizes_per_item.insert(item, size); - } - - /// Find the item by size range. - #[allow(unused)] - pub fn find_by_size( - &self, - range: core::ops::Range, - ) -> impl DoubleEndedIterator { - self.items_per_size.range(range).flat_map(|a| a.1) - } - - /// Remove an item from the index. - pub fn remove(&mut self, item: &T) { - let size = match self.sizes_per_item.remove(item) { - Some(size) => size, - None => return, - }; - - if let Some(values) = self.items_per_size.get_mut(&size) { - let mut removed_index = None; - - for (i, v) in values.iter().enumerate() { - if v == item { - removed_index = Some(i); - break; - } - } - - if let Some(index) = removed_index { - values.remove(index); - } - } - } -} diff --git a/crates/cubecl-runtime/src/memory_management/memory_pool/memory_page.rs b/crates/cubecl-runtime/src/memory_management/memory_pool/memory_page.rs new file mode 100644 index 000000000..1c282a73f --- /dev/null +++ b/crates/cubecl-runtime/src/memory_management/memory_pool/memory_page.rs @@ -0,0 +1,684 @@ +use crate::{ + memory_management::{ + BytesFormat, MemoryUsage, SliceHandle, SliceId, + memory_pool::{Slice, calculate_padding}, + }, + storage::{StorageHandle, StorageUtilization}, +}; +use alloc::format; +use alloc::string::String; +use alloc::vec::Vec; +use core::fmt::{Debug, Display}; +use hashbrown::HashMap; + +/// A memory page is responsable to reserve [slices](Slice) of data based on a fixed [storage buffer](StorageHandle). +pub struct MemoryPage { + storage: StorageHandle, + slices: Vec, + slices_map: HashMap, + /// This is a vector to be used temporary to store the updated slices. + /// + /// It avoids allocating a new vector all the time. + slices_tmp: Vec, + /// Memory alignment. + alignment: u64, +} + +impl MemoryPage { + /// Creates a new memory page with the given storage and memory alignment. + pub fn new(storage: StorageHandle, alignment: u64) -> Self { + let mut this = MemoryPage { + storage: storage.clone(), + slices: Vec::new(), + slices_map: HashMap::new(), + slices_tmp: Vec::new(), + alignment, + }; + + let page = Slice { + handle: SliceHandle::new(), + storage, + padding: 0, + }; + let id = *page.handle.id(); + let index = 0; + this.slices.push(page); + this.slices_map.insert(id, index); + + this + } + + /// Gets the [memory usage](MemoryUsage) of the current memory page. + pub fn memory_usage(&self) -> MemoryUsage { + let mut usage = MemoryUsage { + number_allocs: 0, + bytes_in_use: 0, + bytes_padding: 0, + bytes_reserved: 0, + }; + + for slice in self.slices.iter() { + usage.bytes_reserved += slice.effective_size(); + + if !slice.handle.is_free() { + usage.number_allocs += 1; + usage.bytes_in_use += slice.storage.size(); + usage.bytes_padding += slice.padding; + } + } + + usage + } + + /// Gets the [summary](MemoryPageSummary) of the current memory page. + /// + /// # Arguments + /// + /// - `memory_blocks`: whether the memory block details are included in the summary. + pub fn summary(&self, memory_blocks: bool) -> MemoryPageSummary { + let mut summary = MemoryPageSummary::default(); + + for slice in self.slices.iter() { + let is_free = slice.handle.is_free(); + if is_free { + summary.amount_free += slice.effective_size(); + summary.num_free += 1; + } else { + summary.amount_full += slice.effective_size(); + summary.num_full += 1; + } + if memory_blocks { + summary.blocks.push(MemoryBlock { + is_free, + size: slice.effective_size(), + }); + } + } + summary.amount_total = self.storage.size(); + summary.num_total = self.slices.len(); + + summary + } + + /// Reserves a slice of the given size if there is enough place in the page. + /// + /// # Notes + /// + /// If the current memory page is fragmented, meaning multiple contiguous slices of data exist, + /// you can call the [Self::coalesce()] function to merge those. + pub fn try_reserve(&mut self, size: u64) -> Option { + let padding = calculate_padding(size, self.alignment); + let effective_size = size + padding; + + for (index, slice) in self.slices.iter_mut().enumerate() { + let can_use_slice = + slice.storage.utilization.size >= effective_size && slice.handle.is_free(); + if !can_use_slice { + continue; + } + + let can_be_splitted = slice.storage.utilization.size > effective_size; + let handle = slice.handle.clone(); + + let storage_old = slice.storage.clone(); + + // Updates the current storage utilization. + slice.storage.utilization.size = size; + slice.padding = padding; + + if can_be_splitted { + let new_slice = Slice { + handle: SliceHandle::new(), + storage: storage_old.offset_start(effective_size), + padding: 0, + }; + self.add_new_slice(index, size, new_slice); + } + + return Some(handle); + } + + None + } + + /// Gets the [storage handle](SliceHandle) with the correct offset and size using the slice + /// binding. + /// + /// If the handle isn't returned, it means the binding isn't present in the given page. + pub fn get(&self, binding: &super::SliceBinding) -> Option<&StorageHandle> { + let index = self.slices_map.get(binding.id())?; + self.slices.get(*index).map(|slice| &slice.storage) + } + + /// Recompute the memory page metadata to make sure adjacent slices are merged together into a + /// single slice. + /// + /// This is necessary to allow bigger slices to be reserved on the current page. + pub fn coalesce(&mut self) { + let mut job = self.memory_job(); + let mut tasks = job.tasks.drain(..); + + let mut task = match tasks.next() { + Some(task) => Some(task), + None => return, + }; + + self.slices_map.clear(); + + let mut offset = 0; + let mut size = 0; + let mut index_current = 0; + + for (index, slice) in self.slices.drain(..).enumerate() { + let status = match &mut task { + Some(task) => task.on_coalesce(index), + None => MemoryTaskStatus::Ignoring, + }; + + match status { + MemoryTaskStatus::StartMerging => { + offset = slice.storage.utilization.offset; + size = slice.effective_size(); + } + MemoryTaskStatus::Merging => { + size += slice.effective_size(); + } + MemoryTaskStatus::Ignoring => { + let id = *slice.handle.id(); + self.slices_tmp.push(slice); + self.slices_map.insert(id, index_current); + index_current += 1; + } + MemoryTaskStatus::Completed => { + size += slice.effective_size(); + + let mut storage = self.storage.clone(); + storage.utilization = StorageUtilization { offset, size }; + let page = Slice { + handle: SliceHandle::new(), + storage, + padding: 0, + }; + let id = *page.handle.id(); + self.slices_tmp.push(page); + self.slices_map.insert(id, index_current); + index_current += 1; + task = tasks.next(); + } + }; + } + + core::mem::swap(&mut self.slices, &mut self.slices_tmp); + } + + fn add_new_slice( + &mut self, + index_previous: usize, + reserved_size_previous: u64, + new_slice: Slice, + ) { + self.slices_map.clear(); + + let new_id = *new_slice.handle.id(); + let mut new_slice = Some(new_slice); + + let mut index_current = 0; + for mut slice in self.slices.drain(..) { + if index_current == index_previous { + slice.storage.utilization.size = reserved_size_previous; + let id = *slice.handle.id(); + self.slices_tmp.push(slice); + self.slices_map.insert(id, index_current); + index_current += 1; + + // New slice + self.slices_tmp.push(new_slice.take().unwrap()); + self.slices_map.insert(new_id, index_current); + index_current += 1; + } else { + let id = *slice.handle.id(); + self.slices_tmp.push(slice); + self.slices_map.insert(id, index_current); + index_current += 1; + } + } + + core::mem::swap(&mut self.slices, &mut self.slices_tmp); + } + + fn memory_job(&self) -> MemoryJob { + let mut job = MemoryJob::default(); + let mut task = MemoryTask::default(); + + for (index, slice) in self.slices.iter().enumerate() { + if slice.handle.is_free() { + task.size += slice.effective_size(); + task.tag_coalesce(index); + } else { + task = job.add(task); + } + } + job.add(task); + + job + } +} + +#[derive(Debug, PartialEq, Eq)] +struct MemoryBlock { + is_free: bool, + size: u64, +} + +#[derive(Default, PartialEq, Eq)] +pub struct MemoryPageSummary { + blocks: Vec, + pub amount_free: u64, + pub amount_full: u64, + pub amount_total: u64, + pub num_free: usize, + pub num_full: usize, + pub num_total: usize, +} + +impl Display for MemoryBlock { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self.is_free { + true => f.write_fmt(format_args!("Free ({})", BytesFormat::new(self.size))), + false => f.write_fmt(format_args!("Reserved ({})", BytesFormat::new(self.size))), + } + } +} +impl Display for MemoryPageSummary { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("{self:?}")) + } +} + +impl Debug for MemoryPageSummary { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str("\n==== Memory Page Summary ====\n")?; + f.write_str("[Info]\n")?; + + for (tag, num, amount) in [ + ("Free ", self.num_free, self.amount_free), + ("Full ", self.num_full, self.amount_full), + ("Total", self.num_total, self.amount_total), + ] { + f.write_fmt(format_args!( + " - {tag}: {} slices ({})\n", + num, + BytesFormat::new(amount), + ))?; + } + + f.write_str("\n[Blocks]\n")?; + let mut blocks = String::new(); + for (i, b) in self.blocks.iter().enumerate() { + if i == 0 { + blocks += "|"; + } + blocks += format!(" {b} |").as_str(); + } + let size = blocks.len(); + for _ in 0..size { + f.write_str("-")?; + } + f.write_str("\n")?; + f.write_str(&blocks)?; + f.write_str("\n")?; + for _ in 0..size { + f.write_str("-")?; + } + + f.write_str("\n=============================")?; + f.write_str("\n") + } +} + +impl Display for MemoryPage { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("{}", self.summary(true))) + } +} + +#[derive(Default, Debug, PartialEq, Eq)] +struct MemoryJob { + tasks: Vec, +} + +#[derive(Default, Debug, PartialEq, Eq)] +/// The goal of the memory task is to gather contiguous slice indices that can be merged into a single slice. +struct MemoryTask { + /// The first slice index to be merged. + start_index: usize, + /// The number of slices to be merged. + count: usize, + /// Which slice index is being merge right now. + cursor: usize, + /// The total size in bytes in the resulting merged slice. + size: u64, +} + +impl MemoryTask { + /// Tells the task that the given slice index will be coalesced. + fn tag_coalesce(&mut self, index: usize) { + if self.count == 0 { + self.start_index = index; + } + + debug_assert!( + self.start_index + self.count == index, + "Only contiguous index can be coalesced in a single task" + ); + + self.count += 1; + } + /// Tells the task that the given slice index is being coalesce. + fn on_coalesce(&mut self, index: usize) -> MemoryTaskStatus { + let index_current = self.start_index + self.cursor; + + if index_current == index { + self.cursor += 1; + if self.cursor == 1 { + return MemoryTaskStatus::StartMerging; + } + + if self.cursor == self.count { + return MemoryTaskStatus::Completed; + } else { + return MemoryTaskStatus::Merging; + } + } + + MemoryTaskStatus::Ignoring + } +} + +impl MemoryJob { + fn add(&mut self, mut task: MemoryTask) -> MemoryTask { + // A single index can't be merge with anything. + if task.count < 2 { + return MemoryTask::default(); + } + + let mut returned = MemoryTask::default(); + core::mem::swap(&mut task, &mut returned); + self.tasks.push(returned); + task + } +} + +#[derive(Debug)] +enum MemoryTaskStatus { + Merging, + StartMerging, + Ignoring, + Completed, +} + +#[cfg(test)] +#[allow(clippy::bool_assert_comparison, clippy::identity_op)] +mod tests { + use crate::storage::{StorageId, StorageUtilization}; + + use super::*; + + const MB: u64 = 1024 * 1024; + + #[test] + fn test_memory_page() { + let mut page = new_memory_page(32 * MB); + let slice = page + .try_reserve(16 * MB) + .expect("Enough space to allocate a new slice"); + + assert_eq!(slice.is_free(), false); + assert_eq!(slice.can_mut(), true); + + let storage = page + .get(&slice.binding()) + .expect("To find the correct storage"); + + assert_eq!( + storage.utilization, + StorageUtilization { + offset: 0, + size: 16 * MB + }, + "Utilization to be correct" + ); + + let summary = page.summary(true); + + assert_eq!( + summary, + MemoryPageSummary { + blocks: vec![ + MemoryBlock { + is_free: true, + size: 16 * MB + }, + MemoryBlock { + is_free: true, + size: 16 * MB + } + ], + amount_free: 32 * MB, + amount_full: 0, + amount_total: 32 * MB, + num_free: 2, + num_full: 0, + num_total: 2 + }, + "Summary is correct before coalesce", + ); + page.coalesce(); + let summary = page.summary(true); + + assert_eq!( + summary, + MemoryPageSummary { + blocks: vec![MemoryBlock { + is_free: true, + size: 32 * MB + },], + amount_free: 32 * MB, + amount_full: 0, + amount_total: 32 * MB, + num_free: 1, + num_full: 0, + num_total: 1 + }, + "Summary is correct after coalesce", + ); + } + + #[test] + fn test_memory_job() { + let mut page = new_memory_page(32 * MB); + let slice = page + .try_reserve(16 * MB) + .expect("Enough space to allocate a new slice"); + + core::mem::drop(slice); + let job = page.memory_job(); + + assert_eq!( + job, + MemoryJob { + tasks: vec![MemoryTask { + start_index: 0, + count: 2, + cursor: 0, + size: 32 * MB, + }] + } + ); + } + + #[test] + fn test_scenario() { + let mut page = new_memory_page(32 * MB); + + let slice_1 = page + .try_reserve(4 * MB) + .expect("Enough space to allocate a new slice"); + let slice_2 = page + .try_reserve(15 * MB) + .expect("Enough space to allocate a new slice"); + let slice_3 = page + .try_reserve(8 * MB) + .expect("Enough space to allocate a new slice"); + let slice_4 = page + .try_reserve(4 * MB) + .expect("Enough space to allocate a new slice"); + + assert_eq!( + page.summary(true), + MemoryPageSummary { + blocks: vec![ + MemoryBlock { + is_free: false, + size: 4 * MB + }, + MemoryBlock { + is_free: false, + size: 15 * MB + }, + MemoryBlock { + is_free: false, + size: 8 * MB + }, + MemoryBlock { + is_free: false, + size: 4 * MB + }, + MemoryBlock { + is_free: true, + size: 1 * MB + } + ], + amount_free: 1 * MB, + amount_full: 31 * MB, + amount_total: 32 * MB, + num_free: 1, + num_full: 4, + num_total: 5 + }, + ); + + let slice_5 = page.try_reserve(8 * MB); + assert!(slice_5.is_none(), "No more place"); + + core::mem::drop(slice_2); + let slice_5 = page.try_reserve(9 * MB); + assert!(slice_5.is_some(), "Now we have more place"); + + let slice_6 = page.try_reserve(9 * MB); + assert!(slice_6.is_none(), "No more place"); + + core::mem::drop(slice_3); + let slice_6 = page.try_reserve(9 * MB); + assert!(slice_6.is_none(), "No more place"); + + page.coalesce(); + + assert_eq!( + page.summary(true), + MemoryPageSummary { + blocks: vec![ + MemoryBlock { + is_free: false, + size: 4 * MB + }, + MemoryBlock { + is_free: false, + size: 9 * MB + }, + MemoryBlock { + is_free: true, + size: 14 * MB + }, + MemoryBlock { + is_free: false, + size: 4 * MB + }, + MemoryBlock { + is_free: true, + size: 1 * MB + } + ], + amount_free: 15 * MB, + amount_full: 17 * MB, + amount_total: 32 * MB, + num_free: 2, + num_full: 3, + num_total: 5 + }, + ); + + assert_eq!( + page.get(&slice_4.clone().binding()).unwrap().utilization, + StorageUtilization { + offset: 27 * MB, + size: 4 * MB + }, + "Utilization to be correct" + ); + + let slice_6 = page.try_reserve(9 * MB); + assert!(slice_6.is_some(), "Now we have more place"); + core::mem::drop(slice_1); + core::mem::drop(slice_4); + + page.coalesce(); + + assert_eq!( + page.get(&slice_6.clone().unwrap().binding()) + .unwrap() + .utilization, + StorageUtilization { + offset: 13 * MB, + size: 9 * MB + }, + "Utilization to be correct" + ); + + assert_eq!( + page.summary(true), + MemoryPageSummary { + blocks: vec![ + MemoryBlock { + is_free: true, + size: 4 * MB + }, + MemoryBlock { + is_free: false, + size: 9 * MB + }, + MemoryBlock { + is_free: false, + size: 9 * MB + }, + MemoryBlock { + is_free: true, + size: 10 * MB + } + ], + amount_free: 14 * MB, + amount_full: 18 * MB, + amount_total: 32 * MB, + num_free: 2, + num_full: 2, + num_total: 4 + }, + ); + } + + fn new_memory_page(size: u64) -> MemoryPage { + let storage = StorageHandle::new(StorageId::new(), StorageUtilization { offset: 0, size }); + + MemoryPage::new(storage, 4) + } +} diff --git a/crates/cubecl-runtime/src/memory_management/memory_pool/mod.rs b/crates/cubecl-runtime/src/memory_management/memory_pool/mod.rs index c8042883c..b9bf1dcb1 100644 --- a/crates/cubecl-runtime/src/memory_management/memory_pool/mod.rs +++ b/crates/cubecl-runtime/src/memory_management/memory_pool/mod.rs @@ -1,16 +1,14 @@ -mod index; -mod ring; - mod base; mod exclusive_pool; pub(crate) mod handle; +mod memory_page; +mod persistent_pool; mod sliced_pool; -mod static_pool; pub(crate) use base::*; pub(crate) use exclusive_pool::*; -pub(crate) use ring::*; +pub(crate) use memory_page::*; +pub(crate) use persistent_pool::*; pub(crate) use sliced_pool::*; -pub(crate) use static_pool::*; pub use handle::*; diff --git a/crates/cubecl-runtime/src/memory_management/memory_pool/persistent_pool.rs b/crates/cubecl-runtime/src/memory_management/memory_pool/persistent_pool.rs new file mode 100644 index 000000000..174ada512 --- /dev/null +++ b/crates/cubecl-runtime/src/memory_management/memory_pool/persistent_pool.rs @@ -0,0 +1,204 @@ +use super::{MemoryPool, Slice, SliceHandle, SliceId, calculate_padding}; +use crate::memory_management::BytesFormat; +use crate::{memory_management::MemoryUsage, server::IoError}; +use alloc::vec; +use alloc::vec::Vec; +use hashbrown::HashMap; + +pub struct PersistentPool { + slices: HashMap, + sizes: HashMap>, + alignment: u64, + max_alloc_size: u64, +} + +impl core::fmt::Display for PersistentPool { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + for (size, ids) in self.sizes.iter() { + let mut num_free = 0; + let mut num_full = 0; + let total = ids.len(); + + for id in ids { + let slice = self.slices.get(id).unwrap(); + let is_free = slice.is_free(); + if is_free { + num_free += 1; + } else { + num_full += 1; + } + } + + f.write_fmt(format_args!( + " - Slices {} => {num_free} free - {num_full} full - {total} total\n", + BytesFormat::new(*size) + ))?; + } + + if !self.sizes.is_empty() { + f.write_fmt(format_args!("\n{}\n", self.get_memory_usage()))?; + } + + Ok(()) + } +} + +impl PersistentPool { + pub fn new(max_alloc_size: u64, alignment: u64) -> Self { + Self { + slices: HashMap::new(), + sizes: HashMap::new(), + max_alloc_size, + alignment, + } + } + + pub fn has_size(&mut self, size: u64) -> bool { + let padding = calculate_padding(size, self.alignment); + let size_reserve = size + padding; + self.sizes.contains_key(&size_reserve) + } +} + +impl MemoryPool for PersistentPool { + fn accept(&self, size: u64) -> bool { + self.max_alloc_size >= size + } + + fn get(&self, binding: &super::SliceBinding) -> Option<&crate::storage::StorageHandle> { + self.slices.get(binding.id()).map(|slice| &slice.storage) + } + + fn try_reserve(&mut self, size: u64) -> Option { + let padding = calculate_padding(size, self.alignment); + let size_reserve = size + padding; + + if let Some(vals) = self.sizes.get_mut(&size_reserve) { + for id in vals { + let slice = self.slices.get(id).unwrap(); + + if slice.is_free() { + return Some(slice.handle.clone()); + } + } + } + + None + } + + fn alloc( + &mut self, + storage: &mut Storage, + size: u64, + ) -> Result { + let padding = calculate_padding(size, self.alignment); + let size_alloc = size + padding; + + let storage_handle = storage.alloc(size_alloc)?; + let slice_handle = SliceHandle::new(); + let slice = Slice::new(storage_handle, slice_handle.clone(), padding); + + let slice_id = slice.id(); + + match self.sizes.get_mut(&size) { + Some(vals) => { + vals.push(slice_id); + } + None => { + self.sizes.insert(size, vec![slice_id]); + } + } + + self.slices.insert(slice_id, slice); + + Ok(slice_handle) + } + + fn get_memory_usage(&self) -> MemoryUsage { + let used_slices: Vec<_> = self + .slices + .values() + .filter(|slice| !slice.is_free()) + .collect(); + + MemoryUsage { + number_allocs: used_slices.len() as u64, + bytes_in_use: used_slices.iter().map(|slice| slice.storage.size()).sum(), + bytes_padding: used_slices.iter().map(|slice| slice.padding).sum(), + bytes_reserved: self.slices.values().map(|slice| slice.storage.size()).sum(), + } + } + + fn cleanup( + &mut self, + storage: &mut Storage, + _alloc_nr: u64, + explicit: bool, + ) { + if explicit { + let mut removed = Vec::new(); + self.slices.retain(|id, slice| { + if slice.is_free() { + storage.dealloc(slice.storage.id); + removed.push((*id, slice.effective_size())); + false + } else { + true + } + }); + + for (id, size) in removed { + let ids = self.sizes.get_mut(&size).expect("The size should match"); + ids.retain(|id_| *id_ != id); + } + + storage.flush(); + } + } +} + +#[cfg(test)] +mod tests { + use crate::storage::BytesStorage; + + use super::*; + + #[test] + fn persistent_pool() { + let mut storage = BytesStorage::default(); + let mut pool = PersistentPool::new(1024 * 1024, 4); + + let result = pool.try_reserve(1024); + assert!(result.is_none(), "No alloc yet"); + + let alloc1 = pool.alloc(&mut storage, 1024); + let result = pool.try_reserve(1024); + assert!(result.is_none(), "No free slice yet, handle1 is alive"); + + core::mem::drop(alloc1); + let result = pool.try_reserve(1024); + assert!(result.is_some(), "Handle1 is free to be reused."); + core::mem::drop(result); + + let result = pool.try_reserve(1025); + assert!(result.is_none(), "Not the same size."); + + let alloc2 = pool.alloc(&mut storage, 1024); + let usage = pool.get_memory_usage(); + assert_eq!(usage.bytes_in_use, 1024); + assert_eq!(usage.bytes_reserved, 2048); + + let result = pool.try_reserve(1024); + let usage = pool.get_memory_usage(); + assert!(result.is_some(), "Handle1 is free to be reused."); + assert_eq!(usage.bytes_in_use, 2048); + assert_eq!(usage.bytes_reserved, 2048); + + core::mem::drop(alloc2); + core::mem::drop(result); + + let usage = pool.get_memory_usage(); + assert_eq!(usage.bytes_in_use, 0); + assert_eq!(usage.bytes_reserved, 2048); + } +} diff --git a/crates/cubecl-runtime/src/memory_management/memory_pool/ring.rs b/crates/cubecl-runtime/src/memory_management/memory_pool/ring.rs deleted file mode 100644 index d78d4a837..000000000 --- a/crates/cubecl-runtime/src/memory_management/memory_pool/ring.rs +++ /dev/null @@ -1,305 +0,0 @@ -use alloc::vec::Vec; -use hashbrown::HashMap; - -use crate::storage::StorageId; - -use super::{MemoryPage, Slice, SliceId}; - -#[derive(Debug)] -pub struct RingBuffer { - queue: Vec, - chunk_positions: HashMap, - cursor_slice: u64, - cursor_chunk: usize, - buffer_alignment: u64, -} - -impl RingBuffer { - pub fn new(buffer_alignment: u64) -> Self { - Self { - queue: Vec::new(), - chunk_positions: HashMap::new(), - cursor_slice: 0, - cursor_chunk: 0, - buffer_alignment, - } - } - - pub fn push_page(&mut self, storage_id: StorageId) { - self.queue.push(storage_id); - self.chunk_positions - .insert(storage_id, self.queue.len() - 1); - } - - pub fn find_free_slice( - &mut self, - size: u64, - pages: &mut HashMap, - slices: &mut HashMap, - ) -> Option { - let max_second = self.cursor_chunk; - let result = self.find_free_slice_in_all_chunks(size, pages, slices, self.queue.len()); - - if result.is_some() { - return result; - } - - self.cursor_chunk = 0; - self.cursor_slice = 0; - self.find_free_slice_in_all_chunks(size, pages, slices, max_second) - } - - fn find_free_slice_in_chunk( - &mut self, - size: u64, - page: &mut MemoryPage, - slices: &mut HashMap, - mut slice_index: u64, - ) -> Option<(u64, SliceId)> { - while let Some(slice_id) = page.find_slice(slice_index) { - //mutable borrow scope - { - let slice = slices.get_mut(&slice_id).unwrap(); - - let is_big_enough = slice.effective_size() >= size; - let is_free = slice.is_free(); - - if is_big_enough && is_free { - if slice.effective_size() > size - && let Some(new_slice) = slice.split(size, self.buffer_alignment) - { - let new_slice_id = new_slice.id(); - page.insert_slice(slice.next_slice_position(), new_slice_id); - slices.insert(new_slice_id, new_slice); - } - return Some((slice_index, slice_id)); - } - } - { - let slice = slices.get_mut(&slice_id).unwrap(); - let is_free = slice.is_free(); - if is_free && page.merge_with_next_slice(slice_index, slices) { - continue; - } - } - - if let Some(slice) = slices.get(&slice_id) { - slice_index = slice.next_slice_position(); - } else { - panic!("current slice_id should still be valid after potential merge"); - } - } - - None - } - - fn find_free_slice_in_all_chunks( - &mut self, - size: u64, - pages: &mut HashMap, - slices: &mut HashMap, - max_cursor_position: usize, - ) -> Option { - let start = self.cursor_chunk; - let end = usize::min(self.queue.len(), max_cursor_position); - let mut slice_index = self.cursor_slice; - - for chunk_index in start..end { - if chunk_index > start { - slice_index = 0; - } - - if let Some(id) = self.queue.get(chunk_index) { - let chunk = pages.get_mut(id).unwrap(); - let result = self.find_free_slice_in_chunk(size, chunk, slices, slice_index); - - if let Some((_cursor_slice, slice)) = result { - let slice = slices.get(&slice).unwrap(); - self.cursor_slice = slice.next_slice_position(); - self.cursor_chunk = chunk_index; - return Some(slice.id()); - } - } - self.cursor_chunk = chunk_index; - self.cursor_slice = 0; - } - - None - } -} - -#[cfg(test)] -mod tests { - use crate::{ - memory_management::memory_pool::{MemoryPage, SliceHandle}, - storage::StorageHandle, - }; - - use super::*; - - #[test] - fn simple_1() { - let mut ring = RingBuffer::new(1); - - let (storage_id, slice_ids, mut slices, chunk) = new_chunk(&[100, 200]); - - ring.push_page(storage_id); - let mut chunks = HashMap::from([(storage_id, chunk)]); - - let slice = ring.find_free_slice(50, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slice, slice_ids[0]); - assert_eq!(slices.get(&slice).unwrap().effective_size(), 50); - assert_eq!(slices.len(), 3); - assert_eq!(chunks.values().last().unwrap().slices.len(), 3); - } - - #[test] - fn simple_2() { - let mut ring = RingBuffer::new(1); - - let (storage_id, slice_ids, mut slices, chunk) = new_chunk(&[100, 200]); - - ring.push_page(storage_id); - let mut chunks = HashMap::from([(storage_id, chunk)]); - - let slice = ring.find_free_slice(150, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slice, slice_ids[0]); - assert_eq!(slices.get(&slice).unwrap().effective_size(), 150); - assert_eq!(slices.len(), 2); - assert_eq!(chunks.values().last().unwrap().slices.len(), 2); - } - - #[test] - fn multiple_chunks() { - let mut ring = RingBuffer::new(1); - - let (storage_id_1, mut slice_ids, mut slices, chunk_1) = new_chunk(&[100, 200]); - let (storage_id_2, slice_ids_2, slices_2, chunk_2) = new_chunk(&[200, 200]); - - ring.push_page(storage_id_1); - ring.push_page(storage_id_2); - - let mut chunks = HashMap::from([(storage_id_1, chunk_1), (storage_id_2, chunk_2)]); - - slice_ids.extend(slice_ids_2); - slices.extend(slices_2); - - // Clone references to control what slice is free: - let _slice_1 = slices.get(&slice_ids[1]).unwrap().handle.clone(); - let _slice_3 = slices.get(&slice_ids[3]).unwrap().handle.clone(); - - let slice = ring.find_free_slice(200, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slice, slice_ids[2]); - - let slice = ring.find_free_slice(100, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slice, slice_ids[0]); - } - - #[test] - fn find_free_slice_with_exact_fit() { - let mut ring = RingBuffer::new(1); - - let (storage_id, slice_ids, mut slices, chunk) = new_chunk(&[100, 200]); - - ring.push_page(storage_id); - let mut chunks = HashMap::from([(storage_id, chunk)]); - - // Clone reference to control what slice is free: - let _slice_1 = slices.get(&slice_ids[0]).unwrap().handle.clone(); - - let slice = ring.find_free_slice(200, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slice, slice_ids[1]); - assert_eq!(slices.get(&slice).unwrap().effective_size(), 200); - assert_eq!(slices.len(), 2); - assert_eq!(chunks.values().last().unwrap().slices.len(), 2); - } - - #[test] - fn find_free_slice_with_merging() { - let mut ring = RingBuffer::new(1); - - let (storage_id, slice_ids, mut slices, chunk) = new_chunk(&[100, 50, 100]); - - ring.push_page(storage_id); - let mut chunks = HashMap::from([(storage_id, chunk)]); - - let slice = ring.find_free_slice(250, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slice, slice_ids[0]); - assert_eq!(slices.get(&slice).unwrap().effective_size(), 250); - assert_eq!(slices.len(), 1); - assert_eq!(chunks.values().last().unwrap().slices.len(), 1); - } - - #[test] - fn find_free_slice_with_multiple_chunks_and_merging() { - let mut ring = RingBuffer::new(1); - - let (storage_id_1, mut slice_ids, mut slices, page_1) = new_chunk(&[50, 50]); - let (storage_id_2, slice_ids_2, slices_2, page_2) = new_chunk(&[100, 50]); - slice_ids.extend(slice_ids_2); - slices.extend(slices_2); - - ring.push_page(storage_id_1); - ring.push_page(storage_id_2); - - let mut pages = HashMap::from([(storage_id_1, page_1), (storage_id_2, page_2)]); - - let slice = ring.find_free_slice(150, &mut pages, &mut slices).unwrap(); - - assert_eq!(slices.get(&slice).unwrap().effective_size(), 150); - assert_eq!(slices.len(), 2); - assert_eq!(pages.values().last().unwrap().slices.len(), 1); - } - - fn new_chunk( - slice_sizes: &[u64], - ) -> (StorageId, Vec, HashMap, MemoryPage) { - let offsets: Vec<_> = slice_sizes - .iter() - .scan(0, |state, size| { - let offset = *state; - *state += *size; - Some(offset) - }) - .collect(); - - let storage_id = StorageId::new(); - - let slices: Vec<_> = offsets - .iter() - .zip(slice_sizes) - .map(|(&offset, &size)| Slice { - storage: StorageHandle { - id: storage_id, - utilization: crate::storage::StorageUtilization { offset, size }, - }, - handle: SliceHandle::new(), - padding: 0, - }) - .collect(); - - let mem_page = MemoryPage { - slices: slices - .iter() - .zip(offsets) - .map(|(slice, offset)| (offset, slice.id())) - .collect(), - }; - - ( - storage_id, - slices.iter().map(|slice| slice.id()).collect(), - slices - .into_iter() - .map(|slice| (slice.id(), slice)) - .collect(), - mem_page, - ) - } -} diff --git a/crates/cubecl-runtime/src/memory_management/memory_pool/sliced_pool.rs b/crates/cubecl-runtime/src/memory_management/memory_pool/sliced_pool.rs index a1f83c53b..83b1b25c7 100644 --- a/crates/cubecl-runtime/src/memory_management/memory_pool/sliced_pool.rs +++ b/crates/cubecl-runtime/src/memory_management/memory_pool/sliced_pool.rs @@ -1,277 +1,151 @@ -use super::index::SearchIndex; -use super::{MemoryPool, RingBuffer, Slice, SliceBinding, SliceHandle, SliceId}; -use crate::memory_management::MemoryUsage; -use crate::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; -use crate::{memory_management::memory_pool::calculate_padding, server::IoError}; +use crate::{ + memory_management::{ + BytesFormat, MemoryUsage, + memory_pool::{MemoryPage, MemoryPool}, + }, + storage::StorageId, +}; use alloc::vec::Vec; +use core::fmt::Display; use hashbrown::HashMap; -/// A memory pool that allocates buffers in a range of sizes and reuses them to minimize allocations. -/// -/// - Each 'page' allocation will contain a number of sub slices. -/// - The pool uses a ring buffer to efficiently manage and reuse pages. -pub(crate) struct SlicedPool { +pub struct SlicedPool { pages: HashMap, - slices: HashMap, - storage_index: SearchIndex, - ring: RingBuffer, - recently_added_pages: Vec, - recently_allocated_size: u64, page_size: u64, + aligment: u64, max_alloc_size: u64, - alignment: u64, } -// TODO: consider using generic trait and decouple from Slice -#[derive(new, Debug)] -pub(crate) struct MemoryPage { - pub(crate) slices: HashMap, -} - -impl MemoryPage { - /// merge slice at first_slice_address with the next slice (if there is one and if it's free) - /// return a boolean representing if a merge happened - pub(crate) fn merge_with_next_slice( - &mut self, - first_slice_address: u64, - slices: &mut HashMap, - ) -> bool { - let first_slice_id = self.find_slice(first_slice_address).expect( - "merge_with_next_slice shouldn't be called with a nonexistent first_slice address", - ); - - let next_slice_address = - first_slice_address + slices.get(&first_slice_id).unwrap().effective_size(); - - if let Some(next_slice_id) = self.find_slice(next_slice_address) { - let (next_slice_eff_size, next_slice_is_free) = { - let next_slice = slices.get(&next_slice_id).unwrap(); - (next_slice.effective_size(), next_slice.is_free()) - }; - if next_slice_is_free { - let first_slice = slices.get_mut(&first_slice_id).unwrap(); - let first_slice_eff_size = first_slice.effective_size(); - let first_slice_offset = first_slice.storage.offset(); - - let merged_size = first_slice_eff_size + next_slice_eff_size; - first_slice.storage.utilization = StorageUtilization { - size: merged_size, - offset: first_slice_offset, - }; - first_slice.padding = 0; - - // Cleanup of the extra slice - self.slices.remove(&next_slice_address); - slices.remove(&next_slice_id); - return true; - } - return false; +impl SlicedPool { + pub fn new(page_size: u64, max_slice_size: u64, aligment: u64) -> Self { + Self { + pages: HashMap::new(), + page_size, + aligment, + max_alloc_size: max_slice_size, } - false - } - - pub(crate) fn find_slice(&self, address: u64) -> Option { - let slice_id = self.slices.get(&address); - slice_id.copied() - } - - pub(crate) fn insert_slice(&mut self, address: u64, slice: SliceId) { - self.slices.insert(address, slice); } } impl MemoryPool for SlicedPool { - fn max_alloc_size(&self) -> u64 { - self.max_alloc_size + fn accept(&self, size: u64) -> bool { + self.max_alloc_size >= size + || + // If the size is close to the page size so it doesn't create much fragmentation with + // unused space. + match self.page_size.checked_sub(size) { + Some(diff) => diff * 5 < self.page_size, // 20 % unused space is the max allowed. + None => false, + } } - /// Returns the resource from the storage, for the specified handle. - fn get(&self, binding: &SliceBinding) -> Option<&StorageHandle> { - self.slices.get(binding.id()).map(|s| &s.storage) - } + fn get(&self, binding: &super::SliceBinding) -> Option<&crate::storage::StorageHandle> { + for (_, page) in self.pages.iter() { + if let Some(handle) = page.get(binding) { + return Some(handle); + } + } - /// Reserves memory of specified size using the reserve algorithm, and return - /// a handle to the reserved memory. - /// - /// Also clean ups, merging free slices together if permitted by the merging strategy - fn try_reserve(&mut self, size: u64) -> Option { - let padding = calculate_padding(size, self.alignment); - let effective_size = size + padding; - let slice_id = - self.ring - .find_free_slice(effective_size, &mut self.pages, &mut self.slices)?; + None + } - let slice = self.slices.get_mut(&slice_id).unwrap(); - let old_slice_size = slice.effective_size(); - let offset = slice.storage.utilization.offset; - slice.storage.utilization = StorageUtilization { offset, size }; - let new_padding = old_slice_size - size; - slice.padding = new_padding; - assert_eq!( - slice.effective_size(), - old_slice_size, - "new and old slice should have the same size" - ); + fn try_reserve(&mut self, size: u64) -> Option { + for (_, page) in self.pages.iter_mut() { + page.coalesce(); + if let Some(handle) = page.try_reserve(size) { + return Some(handle); + } + } - Some(slice.handle.clone()) + None } - fn alloc( + fn alloc( &mut self, storage: &mut Storage, size: u64, - ) -> Result { - let storage_id = self.create_page(storage)?; - self.recently_added_pages.push(storage_id); - self.recently_allocated_size += self.page_size; - - let slice = self.create_slice(0, size, storage_id); - - let effective_size = slice.effective_size(); - - let extra_slice = if effective_size < self.page_size { - Some(self.create_slice(effective_size, self.page_size - effective_size, storage_id)) - } else { - None - }; - - let handle_slice = slice.handle.clone(); - let storage_id = slice.storage.id; - let slice_id = slice.id(); - let slice_offset = slice.storage.offset(); - - self.slices.insert(slice_id, slice); - let page = self.pages.get_mut(&storage_id).unwrap(); - page.slices.insert(slice_offset, slice_id); - - if let Some(extra_slice) = extra_slice { - let extra_slice_id = extra_slice.id(); - let extra_slice_offset = extra_slice.storage.offset(); - self.slices.insert(extra_slice_id, extra_slice); - page.slices.insert(extra_slice_offset, extra_slice_id); - } + ) -> Result { + let storage = storage.alloc(self.page_size)?; + let storage_id = storage.id; + let mut page = MemoryPage::new(storage, self.aligment); + let returned = page.try_reserve(size); + self.pages.insert(storage_id, page); - Ok(handle_slice) + Ok(returned.expect("effectice_size to be smaller than page_size")) } fn get_memory_usage(&self) -> MemoryUsage { - let used_slices: Vec<_> = self - .slices - .values() - .filter(|slice| !slice.is_free()) - .collect(); + let mut usage = MemoryUsage { + number_allocs: 0, + bytes_in_use: 0, + bytes_padding: 0, + bytes_reserved: 0, + }; - MemoryUsage { - number_allocs: used_slices.len() as u64, - bytes_in_use: used_slices.iter().map(|s| s.storage.size()).sum(), - bytes_padding: used_slices.iter().map(|s| s.padding).sum(), - bytes_reserved: (self.pages.len() as u64) * self.page_size, + for (_, page) in self.pages.iter() { + let current = page.memory_usage(); + usage = usage.combine(current); } + + usage } - fn cleanup( + fn cleanup( &mut self, - _storage: &mut Storage, + storage: &mut Storage, _alloc_nr: u64, - _explicit: bool, + explicit: bool, ) { - // This pool doesn't do any shrinking currently. - } -} - -impl SlicedPool { - pub(crate) fn new(page_size: u64, max_alloc_size: u64, alignment: u64) -> Self { - // Pages should be allocated to be aligned. - assert_eq!(page_size % alignment, 0); - Self { - pages: HashMap::new(), - slices: HashMap::new(), - storage_index: SearchIndex::new(), - ring: RingBuffer::new(alignment), - recently_added_pages: Vec::new(), - recently_allocated_size: 0, - alignment, - page_size, - max_alloc_size, + if !explicit { + return; } - } - - /// Creates a slice of size `size` upon the given page with the given offset. - fn create_slice(&self, offset: u64, size: u64, storage_id: StorageId) -> Slice { - assert_eq!( - offset % self.alignment, - 0, - "slice with offset {offset} needs to be a multiple of {}", - self.alignment - ); - let handle = SliceHandle::new(); - - let storage = StorageHandle { - id: storage_id, - utilization: StorageUtilization { offset, size }, - }; - - let padding = calculate_padding(size, self.alignment); - - Slice::new(storage, handle, padding) - } - - /// Creates a page of given size by allocating on the storage. - fn create_page( - &mut self, - storage: &mut Storage, - ) -> Result { - let storage = storage.alloc(self.page_size)?; - - let id = storage.id; - self.ring.push_page(id); + let mut to_clean = Vec::new(); - self.pages.insert(id, MemoryPage::new(HashMap::new())); - self.storage_index.insert(id, self.page_size); + for (id, page) in self.pages.iter_mut() { + page.coalesce(); + let summary = page.summary(false); + if summary.amount_free == summary.amount_total { + to_clean.push(*id); + } + } - Ok(id) + for id in to_clean { + self.pages.remove(&id); + storage.dealloc(id); + } } } -impl Slice { - pub(crate) fn split(&mut self, offset_slice: u64, buffer_alignment: u64) -> Option { - let size_new = self.effective_size() - offset_slice; - let offset_new = self.storage.offset() + offset_slice; - let old_size = self.effective_size(); - - let storage_new = StorageHandle { - id: self.storage.id, - utilization: StorageUtilization { - offset: offset_new, - size: size_new, - }, - }; - - self.storage.utilization = StorageUtilization { - offset: self.storage.offset(), - size: offset_slice, - }; - - if !offset_new.is_multiple_of(buffer_alignment) { - panic!("slice with offset {offset_new} needs to be a multiple of {buffer_alignment}"); +impl Display for SlicedPool { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + if self.pages.is_empty() { + return Ok(()); } - let handle = SliceHandle::new(); - if size_new < buffer_alignment { - self.padding = old_size - offset_slice; - assert_eq!(self.effective_size(), old_size); - return None; + + f.write_fmt(format_args!( + " - Sliced Pool page_size={} max_alloc_size={}\n", + BytesFormat::new(self.page_size), + BytesFormat::new(self.max_alloc_size) + ))?; + + for (id, page) in self.pages.iter() { + let summary = page.summary(false); + f.write_fmt(format_args!( + " - Page {id} num_slices={} =>", + summary.num_total + ))?; + + let size_free = BytesFormat::new(summary.amount_free); + let size_full = BytesFormat::new(summary.amount_full); + let size_total = BytesFormat::new(summary.amount_total); + + f.write_fmt(format_args!( + " {size_free} free - {size_full} full - {size_total} total\n" + ))?; } - assert!( - size_new >= buffer_alignment, - "Size new > {buffer_alignment}" - ); - self.padding = 0; - let padding = calculate_padding(size_new - buffer_alignment, buffer_alignment); - Some(Slice::new(storage_new, handle, padding)) - } + f.write_fmt(format_args!("\n{}\n", self.get_memory_usage()))?; - pub(crate) fn next_slice_position(&self) -> u64 { - self.storage.offset() + self.effective_size() + Ok(()) } } diff --git a/crates/cubecl-runtime/src/memory_management/memory_pool/static_pool.rs b/crates/cubecl-runtime/src/memory_management/memory_pool/static_pool.rs deleted file mode 100644 index 41b722c41..000000000 --- a/crates/cubecl-runtime/src/memory_management/memory_pool/static_pool.rs +++ /dev/null @@ -1,85 +0,0 @@ -use crate::{memory_management::MemoryUsage, server::IoError}; -use alloc::vec::Vec; -use hashbrown::HashMap; - -use super::{MemoryPool, Slice, SliceHandle, SliceId, calculate_padding}; - -pub struct StaticPool { - slices: HashMap, - max_alloc_size: u64, -} - -impl StaticPool { - pub fn new(max_alloc_size: u64) -> Self { - Self { - slices: HashMap::new(), - max_alloc_size, - } - } -} - -impl MemoryPool for StaticPool { - fn max_alloc_size(&self) -> u64 { - self.max_alloc_size - } - - fn get(&self, binding: &super::SliceBinding) -> Option<&crate::storage::StorageHandle> { - self.slices.get(binding.id()).map(|slice| &slice.storage) - } - - fn try_reserve(&mut self, _size: u64) -> Option { - None - } - - fn alloc( - &mut self, - storage: &mut Storage, - size: u64, - ) -> Result { - let padding = calculate_padding(size, storage.alignment() as u64); - let size_alloc = size + padding; - - let storage_handle = storage.alloc(size_alloc)?; - let slice_handle = SliceHandle::new(); - let slice = Slice::new(storage_handle, slice_handle.clone(), padding); - - self.slices.insert(slice.id(), slice); - - Ok(slice_handle) - } - - fn get_memory_usage(&self) -> MemoryUsage { - let used_slices: Vec<_> = self - .slices - .values() - .filter(|slice| !slice.is_free()) - .collect(); - - MemoryUsage { - number_allocs: used_slices.len() as u64, - bytes_in_use: used_slices.iter().map(|slice| slice.storage.size()).sum(), - bytes_padding: used_slices.iter().map(|slice| slice.padding).sum(), - bytes_reserved: self.slices.values().map(|slice| slice.storage.size()).sum(), - } - } - - fn cleanup( - &mut self, - storage: &mut Storage, - _alloc_nr: u64, - explicit: bool, - ) { - if explicit { - self.slices.retain(|_, slice| { - if slice.is_free() { - storage.dealloc(slice.storage.id); - false - } else { - true - } - }); - - storage.flush(); - } - } -} diff --git a/crates/cubecl-runtime/src/server.rs b/crates/cubecl-runtime/src/server.rs index 43e057c48..1a0196e79 100644 --- a/crates/cubecl-runtime/src/server.rs +++ b/crates/cubecl-runtime/src/server.rs @@ -1,4 +1,5 @@ use crate::{ + DeviceProperties, kernel::KernelMetadata, logging::ServerLogger, memory_management::{ @@ -15,7 +16,8 @@ use alloc::vec; use alloc::vec::Vec; use core::fmt::Debug; use cubecl_common::{ - ExecutionMode, bytes::Bytes, future::DynFut, profile::ProfileDuration, stream_id::StreamId, + ExecutionMode, bytes::Bytes, device, future::DynFut, profile::ProfileDuration, + stream_id::StreamId, }; use cubecl_ir::StorageType; use thiserror::Error; @@ -29,11 +31,58 @@ pub enum ProfileError { NotRegistered, } +#[derive(Debug)] +/// Contains many different types that are useful for server implementations and compute clients. +pub struct ServerUtilities { + /// The time when `profile-tracy` is activated. + #[cfg(feature = "profile-tracy")] + pub epoch_time: web_time::Instant, + /// The GPU client when `profile-tracy` is activated. + #[cfg(feature = "profile-tracy")] + pub gpu_client: tracy_client::GpuContext, + /// Information shared between all servers. + pub properties: DeviceProperties, + /// Information specific to the current server. + pub info: Server::Info, + /// The logger based on global cubecl configs. + pub logger: Arc, +} + +impl ServerUtilities { + /// Creates a new server utilities. + pub fn new(properties: DeviceProperties, logger: Arc, info: S::Info) -> Self { + // Start a tracy client if needed. + #[cfg(feature = "profile-tracy")] + let client = tracy_client::Client::start(); + + Self { + properties, + logger, + // Create the GPU client if needed. + #[cfg(feature = "profile-tracy")] + gpu_client: client + .clone() + .new_gpu_context( + Some(&format!("{info:?}")), + // In the future should ask the server what makes sense here. 'Invalid' atm is a generic stand-in (Tracy doesn't have CUDA/RocM atm anyway). + tracy_client::GpuContextType::Invalid, + 0, // Timestamps are manually aligned to this epoch so start at 0. + 1.0, // Timestamps are manually converted to be nanoseconds so period is 1. + ) + .unwrap(), + #[cfg(feature = "profile-tracy")] + epoch_time: web_time::Instant::now(), + info, + } + } +} + /// The compute server is responsible for handling resources and computations over resources. /// /// Everything in the server is mutable, therefore it should be solely accessed through the /// [compute channel](crate::channel::ComputeChannel) for thread safety. -pub trait ComputeServer: Send + core::fmt::Debug + ServerCommunication +pub trait ComputeServer: + Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static where Self: Sized, { @@ -54,6 +103,9 @@ where /// Retrieve the server logger. fn logger(&self) -> Arc; + /// Retrieve the server utilities. + fn utilities(&self) -> Arc>; + /// Utility to create a new buffer and immediately copy contiguous data into it fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result { let alloc = self diff --git a/crates/cubecl-runtime/src/storage/base.rs b/crates/cubecl-runtime/src/storage/base.rs index 7d80a2aae..fbf7fc74d 100644 --- a/crates/cubecl-runtime/src/storage/base.rs +++ b/crates/cubecl-runtime/src/storage/base.rs @@ -8,8 +8,14 @@ use crate::{ // This ID is used to map a handle to its actual data. storage_id_type!(StorageId); +impl core::fmt::Display for StorageId { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("StorageId({})", self.value)) + } +} + /// Defines if data uses a full memory chunk or a slice of it. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct StorageUtilization { /// The offset in bytes from the chunk start. pub offset: u64, diff --git a/crates/cubecl-runtime/src/stream/scheduler.rs b/crates/cubecl-runtime/src/stream/scheduler.rs index a75ff7fdd..38b131ae4 100644 --- a/crates/cubecl-runtime/src/stream/scheduler.rs +++ b/crates/cubecl-runtime/src/stream/scheduler.rs @@ -18,6 +18,8 @@ pub trait SchedulerStreamBackend { /// Enqueues a task onto a given stream for execution. fn enqueue(task: Self::Task, stream: &mut Self::Stream); + /// Flush the inner stream queue to ensure ordering between different streams. + fn flush(stream: &mut Self::Stream); /// Returns a mutable reference to the stream factory. fn factory(&mut self) -> &mut Self::Factory; } @@ -218,11 +220,28 @@ impl SchedulerMultiStream { // Enqueue each task on the stream. B::enqueue(task, &mut stream.stream); } + + // Makes sure the tasks are ordered on the compute queue. + B::flush(&mut stream.stream); } } - /// Executes schedules in an interleaved manner, alternating tasks across streams. + //// Executes schedules in an interleaved manner, alternating tasks from different streams. + /// + /// We chose the first stream as the one executing the tasks, ensuring proper ordering by + /// flushing all other streams first and flushing the execution stream at the end. + /// This way, we ensure that most tasks are actually interleaved on the real compute queue + /// shared across all streams. fn execute_schedules_interleave(&mut self, mut schedules: Vec>) { + // Makes sure the tasks are ordered on the compute queue. + for schedule in schedules.iter_mut().skip(1) { + let stream = unsafe { self.pool.get_mut_index(schedule.stream_index) }; + B::flush(&mut stream.stream); + } + + let execution_index = schedules.first().expect("At least one stream").stream_index; + let stream = unsafe { self.pool.get_mut_index(execution_index) }; + // Find the maximum number of tasks across all schedules. let num_tasks_max = schedules .iter() @@ -235,12 +254,13 @@ impl SchedulerMultiStream { for schedule in schedules.iter_mut() { // If there are tasks remaining in the schedule, enqueue the next one. if let Some(task) = schedule.tasks.next() { - // Note: `unsafe` usage assumes valid index. - let stream = unsafe { self.pool.get_mut_index(schedule.stream_index) }; B::enqueue(task, &mut stream.stream); } } } + + // Making sure all tasks are registered to the queue. + B::flush(&mut stream.stream); } } diff --git a/crates/cubecl-runtime/src/tune/base.rs b/crates/cubecl-runtime/src/tune/base.rs index 25867a194..d162a2445 100644 --- a/crates/cubecl-runtime/src/tune/base.rs +++ b/crates/cubecl-runtime/src/tune/base.rs @@ -24,7 +24,12 @@ impl Tunable { } /// Tag the current tunable as part of the given [group](TuneGroup). - pub fn group u8 + 'static>(mut self, group: &TuneGroup, priority: F) -> Self { + /// `group` is a tuning group with a corresponding priority function. + /// `priority` is the intra-group priority, applied after the group priority to further sort entries + /// + /// Groups are tuned in order of priority, and then each entry in the group is tuned based on the + /// intra-group priority. Negative priorities ensure the entry is never tuned for this key. + pub fn group i8 + 'static>(mut self, group: &TuneGroup, priority: F) -> Self { self.groups.push((group.clone(), Arc::new(priority))); self } @@ -54,7 +59,7 @@ impl Clone for TuneGroup { impl TuneGroup { /// Create a new group based on a priority function. - pub fn new u8 + 'static>(f: F) -> Self { + pub fn new i8 + 'static>(f: F) -> Self { let id = GROUP_COUNTER.fetch_add(1, Ordering::Relaxed); Self { @@ -67,27 +72,27 @@ impl TuneGroup { #[derive(Debug)] /// A group plan dictates which [tunables](Tunable) should be executed, and in what order. pub(crate) struct TunePlan { - priorities: Vec, + priorities: Vec, no_groups: Vec, - groups: HashMap, + groups: HashMap, } #[derive(Default, Debug)] struct GroupPlan { - priorities: Vec, - indices: HashMap>, + priorities: Vec, + indices: HashMap>, } struct Cleanup { - groups: Vec, - tunables: Vec<(u8, u8)>, + groups: Vec, + tunables: Vec<(i8, i8)>, } impl TunePlan { pub fn new(key: &K, tunables: &[Tunable]) -> Self { - let mut priorities = Vec::::new(); + let mut priorities = Vec::::new(); let mut no_groups = Vec::new(); - let mut groups = HashMap::::new(); + let mut groups = HashMap::::new(); for (index, tunable) in tunables.iter().enumerate() { if tunable.groups.is_empty() { @@ -144,13 +149,16 @@ impl TunePlan { let priority = self.priorities.last(); let priority = match priority { - Some(val) => val, + Some(val) => *val, None => return indices, }; - let (mut group_indices, cleanup) = self.group_plan_next(*priority); + let (mut group_indices, cleanup) = self.group_plan_next(priority); self.cleanup(cleanup); - indices.append(&mut group_indices); + + if priority >= 0 { + indices.append(&mut group_indices); + } indices } @@ -182,10 +190,10 @@ impl TunePlan { } } - fn group_plan_next(&mut self, priority: u8) -> (Vec, Cleanup) { + fn group_plan_next(&mut self, priority: i8) -> (Vec, Cleanup) { let plan = self.groups.get_mut(&priority).expect("To be filled"); let within_group_prio = plan.priorities.pop().unwrap(); - let next_indices = plan.indices.remove(&within_group_prio).unwrap(); + let mut next_indices = plan.indices.remove(&within_group_prio).unwrap(); let mut cleanup_groups = Vec::new(); let mut cleanup_tunables = Vec::new(); @@ -213,6 +221,11 @@ impl TunePlan { } } + if within_group_prio < 0 { + // Discard algorithms with negative priority + next_indices.clear(); + } + ( next_indices, Cleanup { @@ -223,7 +236,7 @@ impl TunePlan { } } -type PriorityFunc = Arc u8>; +type PriorityFunc = Arc i8>; static GROUP_COUNTER: AtomicU32 = AtomicU32::new(0); diff --git a/crates/cubecl-runtime/src/tune/local.rs b/crates/cubecl-runtime/src/tune/local.rs index df733bfbf..684036af2 100644 --- a/crates/cubecl-runtime/src/tune/local.rs +++ b/crates/cubecl-runtime/src/tune/local.rs @@ -1,7 +1,5 @@ use super::{AutotuneKey, AutotuneOutput, TunableSet, Tuner}; -use crate::{ - channel::ComputeChannel, client::ComputeClient, server::ComputeServer, tune::TuneCacheResult, -}; +use crate::{client::ComputeClient, server::ComputeServer, tune::TuneCacheResult}; use alloc::boxed::Box; use alloc::sync::Arc; use core::{ @@ -114,16 +112,15 @@ where } /// Execute the best operation in the provided [tunable set](TunableSet) - pub fn execute( + pub fn execute( &self, id: &ID, - client: &ComputeClient, + client: &ComputeClient, operations: Arc>, inputs: In, ) -> Out where S: ComputeServer + 'static, - C: ComputeChannel + 'static, In: Clone + Send + 'static, Out: AutotuneOutput, { diff --git a/crates/cubecl-runtime/src/tune/tune_benchmark.rs b/crates/cubecl-runtime/src/tune/tune_benchmark.rs index fc2c7189e..33b8a7526 100644 --- a/crates/cubecl-runtime/src/tune/tune_benchmark.rs +++ b/crates/cubecl-runtime/src/tune/tune_benchmark.rs @@ -3,7 +3,6 @@ use alloc::sync::Arc; use alloc::vec::Vec; use cubecl_common::profile::{ProfileDuration, TimingMethod}; -use crate::channel::ComputeChannel; use crate::client::ComputeClient; use crate::server::ComputeServer; @@ -11,10 +10,10 @@ use super::{AutotuneError, TuneFn}; /// A benchmark that runs on server handles #[derive(new)] -pub struct TuneBenchmark { +pub struct TuneBenchmark { operation: Arc>, inputs: In, - client: ComputeClient, + client: ComputeClient, } /// The trait to be implemented by an autotune output. @@ -32,12 +31,8 @@ impl AutotuneOutput for () { } } -impl< - S: ComputeServer + 'static, - C: ComputeChannel + 'static, - In: Clone + Send + 'static, - Out: AutotuneOutput, -> TuneBenchmark +impl + TuneBenchmark { /// Benchmark how long this operation takes for a number of samples. /// diff --git a/crates/cubecl-runtime/src/tune/tuner.rs b/crates/cubecl-runtime/src/tune/tuner.rs index f68e8e6ba..a41a999eb 100644 --- a/crates/cubecl-runtime/src/tune/tuner.rs +++ b/crates/cubecl-runtime/src/tune/tuner.rs @@ -11,7 +11,6 @@ use core::time::Duration; use alloc::string::{String, ToString}; use cubecl_common::benchmark::{BenchmarkComputations, BenchmarkDurations}; -use crate::channel::ComputeChannel; use crate::client::ComputeClient; use crate::config::{Logger, autotune::AutotuneLogLevel}; use crate::server::ComputeServer; @@ -191,7 +190,6 @@ impl Tuner { /// Execute benchmarks to find out what the fastest operation is. pub fn prepare_autotune< S: ComputeServer + 'static, - C: ComputeChannel + 'static, In: Clone + Send + 'static, Out: AutotuneOutput, >( @@ -199,7 +197,7 @@ impl Tuner { key: K, inputs: &In, tunables: &TunableSet, - client: &ComputeClient, + client: &ComputeClient, ) -> Box { log::info!("Tuning {key}"); @@ -285,10 +283,9 @@ impl Tuner { In: Clone + Send + 'static, Out: AutotuneOutput, S: ComputeServer + 'static, - C: ComputeChannel + 'static, >( key: K, - client: &ComputeClient, + client: &ComputeClient, mut plan: TunePlan, autotunables: Vec + 'static>>, test_inputs: In, @@ -331,9 +328,8 @@ impl Tuner { In: Clone + Send + 'static, Out: AutotuneOutput, S: ComputeServer + 'static, - C: ComputeChannel + 'static, >( - client: &ComputeClient, + client: &ComputeClient, plan: &mut TunePlan, autotunables: Vec + 'static>>, test_inputs: &In, diff --git a/crates/cubecl-runtime/tests/dummy/compute.rs b/crates/cubecl-runtime/tests/dummy/compute.rs index a8be6527b..1cf8221d1 100644 --- a/crates/cubecl-runtime/tests/dummy/compute.rs +++ b/crates/cubecl-runtime/tests/dummy/compute.rs @@ -1,61 +1,60 @@ +use std::sync::Arc; + use super::DummyServer; -use cubecl_common::CubeDim; -use cubecl_common::profile::TimingMethod; +use cubecl_common::device::{Device, DeviceState}; use cubecl_runtime::client::ComputeClient; +use cubecl_runtime::logging::ServerLogger; use cubecl_runtime::memory_management::{ - MemoryConfiguration, MemoryDeviceProperties, MemoryManagement, + MemoryConfiguration, MemoryDeviceProperties, MemoryManagement, MemoryManagementOptions, }; -use cubecl_runtime::server::CubeCount; use cubecl_runtime::storage::BytesStorage; -use cubecl_runtime::{ComputeRuntime, DeviceProperties}; -use cubecl_runtime::{channel::MutexComputeChannel, memory_management::HardwareProperties}; /// The dummy device. -#[derive(Clone, Debug, Hash, PartialEq, Eq)] +#[derive(Clone, Debug, Hash, PartialEq, Eq, Default)] pub struct DummyDevice; -pub type DummyChannel = MutexComputeChannel; -pub type DummyClient = ComputeClient; +impl Device for DummyDevice { + fn from_id(_device_id: cubecl_common::device::DeviceId) -> Self { + Self + } + + fn to_id(&self) -> cubecl_common::device::DeviceId { + cubecl_common::device::DeviceId { + type_id: 0, + index_id: 0, + } + } + + fn device_count(_type_id: u16) -> usize { + 1 + } +} + +pub type DummyClient = ComputeClient; -static RUNTIME: ComputeRuntime = ComputeRuntime::new(); +impl DeviceState for DummyServer { + fn init(_device_id: cubecl_common::device::DeviceId) -> Self { + init_server() + } +} -pub fn init_client() -> ComputeClient> { +fn init_server() -> DummyServer { let storage = BytesStorage::default(); let mem_properties = MemoryDeviceProperties { max_page_size: 1024 * 1024 * 512, alignment: 32, }; - let topology = HardwareProperties { - plane_size_min: 32, - plane_size_max: 32, - max_bindings: 32, - max_shared_memory_size: 48000, - max_cube_count: CubeCount::new_3d(u16::MAX as u32, u16::MAX as u32, u16::MAX as u32), - max_units_per_cube: 1024, - max_cube_dim: CubeDim::new_3d(1024, 1024, 64), - num_streaming_multiprocessors: None, - num_tensor_cores: None, - min_tensor_cores_dim: None, - }; + let memory_management = MemoryManagement::from_configuration( storage, &mem_properties, MemoryConfiguration::default(), + Arc::new(ServerLogger::default()), + MemoryManagementOptions::new("Main CPU Memory"), ); - let server = DummyServer::new(memory_management); - let channel = MutexComputeChannel::new(server); - ComputeClient::new( - channel, - DeviceProperties::new( - Default::default(), - mem_properties, - topology, - TimingMethod::System, - ), - (), - ) + DummyServer::new(memory_management, mem_properties) } pub fn test_client(device: &DummyDevice) -> DummyClient { - RUNTIME.client(device, init_client) + ComputeClient::load(device) } diff --git a/crates/cubecl-runtime/tests/dummy/server.rs b/crates/cubecl-runtime/tests/dummy/server.rs index 7870ebb6a..da18eea1d 100644 --- a/crates/cubecl-runtime/tests/dummy/server.rs +++ b/crates/cubecl-runtime/tests/dummy/server.rs @@ -1,13 +1,14 @@ -use cubecl_common::ExecutionMode; use cubecl_common::bytes::Bytes; use cubecl_common::future::DynFut; use cubecl_common::profile::ProfileDuration; use cubecl_common::stream_id::StreamId; +use cubecl_common::{CubeDim, ExecutionMode}; use cubecl_runtime::logging::ServerLogger; use cubecl_runtime::server::{ - Bindings, CopyDescriptor, ProfileError, ProfilingToken, ServerCommunication, + Bindings, CopyDescriptor, ProfileError, ProfilingToken, ServerCommunication, ServerUtilities, }; use cubecl_runtime::timestamp_profiler::TimestampProfiler; +use cubecl_runtime::{DeviceProperties, Features}; use cubecl_runtime::{id::KernelId, server::IoError}; use cubecl_runtime::{ kernel::KernelMetadata, @@ -16,7 +17,9 @@ use cubecl_runtime::{ use std::sync::Arc; use super::DummyKernel; -use cubecl_runtime::memory_management::{MemoryAllocationMode, MemoryUsage}; +use cubecl_runtime::memory_management::{ + HardwareProperties, MemoryAllocationMode, MemoryDeviceProperties, MemoryUsage, +}; use cubecl_runtime::server::CubeCount; use cubecl_runtime::storage::{BindingResource, BytesResource, ComputeStorage}; use cubecl_runtime::{ @@ -31,7 +34,7 @@ use cubecl_runtime::{ pub struct DummyServer { memory_management: MemoryManagement, timestamps: TimestampProfiler, - logger: Arc, + utilities: Arc>, } #[derive(Debug, Clone)] @@ -71,7 +74,11 @@ impl ComputeServer for DummyServer { type Info = (); fn logger(&self) -> Arc { - self.logger.clone() + self.utilities.logger.clone() + } + + fn utilities(&self) -> Arc> { + self.utilities.clone() } fn create( @@ -225,10 +232,31 @@ impl ComputeServer for DummyServer { } impl DummyServer { - pub fn new(memory_management: MemoryManagement) -> Self { + pub fn new( + memory_management: MemoryManagement, + mem_props: MemoryDeviceProperties, + ) -> Self { + let hardware = HardwareProperties { + plane_size_min: 32, + plane_size_max: 32, + max_bindings: 32, + max_shared_memory_size: 48000, + max_cube_count: CubeCount::new_3d(u16::MAX as u32, u16::MAX as u32, u16::MAX as u32), + max_units_per_cube: 1024, + max_cube_dim: CubeDim::new_3d(1024, 1024, 64), + num_streaming_multiprocessors: None, + num_tensor_cores: None, + min_tensor_cores_dim: None, + }; + let features = Features::default(); + let timing_method = cubecl_common::profile::TimingMethod::System; + let props = DeviceProperties::new(features, mem_props, hardware, timing_method); + let logger = Arc::new(ServerLogger::default()); + + let utilities = Arc::new(ServerUtilities::new(props, logger, ())); Self { - logger: Arc::new(ServerLogger::default()), memory_management, + utilities, timestamps: TimestampProfiler::default(), } } diff --git a/crates/cubecl-runtime/tests/dummy/tune/autotune_operations.rs b/crates/cubecl-runtime/tests/dummy/tune/autotune_operations.rs index f93c69802..c6f8b9e62 100644 --- a/crates/cubecl-runtime/tests/dummy/tune/autotune_operations.rs +++ b/crates/cubecl-runtime/tests/dummy/tune/autotune_operations.rs @@ -5,14 +5,14 @@ use cubecl_runtime::{ }; use derive_new::new; -use crate::dummy::{DummyChannel, DummyServer, KernelTask}; +use crate::dummy::{DummyServer, KernelTask}; #[derive(new, Clone)] /// Extended kernel that accounts for additional parameters, i.e. needed /// information that does not count as an input/output. pub struct OneKernelAutotuneOperation { kernel: KernelTask, - client: ComputeClient, + client: ComputeClient, } impl TuneFn for OneKernelAutotuneOperation { diff --git a/crates/cubecl-spirv/Cargo.toml b/crates/cubecl-spirv/Cargo.toml index 03655e823..f8488b8c6 100644 --- a/crates/cubecl-spirv/Cargo.toml +++ b/crates/cubecl-spirv/Cargo.toml @@ -21,9 +21,9 @@ std = ["cubecl-common/std", "cubecl-core/std", "cubecl-runtime/std"] [dependencies] bitflags = { workspace = true } -cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false } -cubecl-core = { path = "../cubecl-core", version = "0.7.0" } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false, features = [ +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false } +cubecl-core = { path = "../cubecl-core", version = "0.9.0" } +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false, features = [ "channel-mutex", ] } half = { workspace = true } @@ -33,9 +33,9 @@ hashbrown = { workspace = true } rspirv = { workspace = true } # Optimizer -cubecl-opt = { path = "../cubecl-opt", version = "0.7.0" } +cubecl-opt = { path = "../cubecl-opt", version = "0.9.0" } [dev-dependencies] -cubecl-random = { path = "../cubecl-random", version = "0.7.0", features = [ +cubecl-random = { path = "../cubecl-random", version = "0.9.0", features = [ "export_tests", ] } diff --git a/crates/cubecl-spirv/src/arithmetic.rs b/crates/cubecl-spirv/src/arithmetic.rs index 42d18cdee..eda20d204 100644 --- a/crates/cubecl-spirv/src/arithmetic.rs +++ b/crates/cubecl-spirv/src/arithmetic.rs @@ -3,7 +3,7 @@ use crate::{ item::{Elem, Item}, variable::ConstVal, }; -use cubecl_core::ir::{self as core, Arithmetic}; +use cubecl_core::ir::{self as core, Arithmetic, InstructionModes}; use rspirv::spirv::{Capability, Decoration, FPEncoding}; impl SpirvCompiler { @@ -11,6 +11,7 @@ impl SpirvCompiler { &mut self, op: Arithmetic, out: Option, + modes: InstructionModes, uniform: bool, ) { let out = out.unwrap(); @@ -19,9 +20,13 @@ impl SpirvCompiler { self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| { match out_ty.elem() { Elem::Int(_, _) => b.i_add(ty, Some(out), lhs, rhs).unwrap(), - Elem::Float(..) => b.f_add(ty, Some(out), lhs, rhs).unwrap(), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + b.f_add(ty, Some(out), lhs, rhs).unwrap() + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); b.f_add(ty, Some(out), lhs, rhs).unwrap() } _ => unreachable!(), @@ -35,9 +40,13 @@ impl SpirvCompiler { self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| { match out_ty.elem() { Elem::Int(_, _) => b.i_sub(ty, Some(out), lhs, rhs).unwrap(), - Elem::Float(..) => b.f_sub(ty, Some(out), lhs, rhs).unwrap(), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + b.f_sub(ty, Some(out), lhs, rhs).unwrap() + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); b.f_sub(ty, Some(out), lhs, rhs).unwrap() } _ => unreachable!(), @@ -51,9 +60,13 @@ impl SpirvCompiler { self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| { match out_ty.elem() { Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(), - Elem::Float(..) => b.f_mul(ty, Some(out), lhs, rhs).unwrap(), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + b.f_mul(ty, Some(out), lhs, rhs).unwrap() + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); b.f_mul(ty, Some(out), lhs, rhs).unwrap() } _ => unreachable!(), @@ -76,9 +89,13 @@ impl SpirvCompiler { match out_ty.elem() { Elem::Int(_, false) => b.u_div(ty, Some(out), lhs, rhs).unwrap(), Elem::Int(_, true) => b.s_div(ty, Some(out), lhs, rhs).unwrap(), - Elem::Float(..) => b.f_div(ty, Some(out), lhs, rhs).unwrap(), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + b.f_div(ty, Some(out), lhs, rhs).unwrap() + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); b.f_div(ty, Some(out), lhs, rhs).unwrap() } _ => unreachable!(), @@ -90,9 +107,13 @@ impl SpirvCompiler { match out_ty.elem() { Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(), Elem::Int(_, true) => b.s_mod(ty, Some(out), lhs, rhs).unwrap(), - Elem::Float(..) => b.f_mod(ty, Some(out), lhs, rhs).unwrap(), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + b.f_mod(ty, Some(out), lhs, rhs).unwrap() + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); b.f_mod(ty, Some(out), lhs, rhs).unwrap() } _ => unreachable!(), @@ -104,9 +125,13 @@ impl SpirvCompiler { match out_ty.elem() { Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(), Elem::Int(_, true) => b.s_rem(ty, Some(out), lhs, rhs).unwrap(), - Elem::Float(..) => b.f_rem(ty, Some(out), lhs, rhs).unwrap(), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + b.f_rem(ty, Some(out), lhs, rhs).unwrap() + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); b.f_rem(ty, Some(out), lhs, rhs).unwrap() } _ => unreachable!(), @@ -118,9 +143,13 @@ impl SpirvCompiler { self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| { match out_ty.elem() { Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(), - Elem::Float(..) => b.f_mul(ty, Some(out), lhs, rhs).unwrap(), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + b.f_mul(ty, Some(out), lhs, rhs).unwrap() + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); b.f_mul(ty, Some(out), lhs, rhs).unwrap() } _ => unreachable!(), @@ -193,7 +222,9 @@ impl SpirvCompiler { let mul = self.f_mul(ty, None, a_id, b_id).unwrap(); self.mark_uniformity(mul, uniform); + self.declare_math_mode(modes, mul); self.f_add(ty, Some(out_id), mul, c_id).unwrap(); + self.declare_math_mode(modes, out_id); if relaxed { self.decorate(mul, Decoration::RelaxedPrecision, []); self.decorate(out_id, Decoration::RelaxedPrecision, []); @@ -203,6 +234,7 @@ impl SpirvCompiler { Arithmetic::Recip(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty); + b.declare_math_mode(modes, out); b.f_div(ty, Some(out), one, input).unwrap(); }); } @@ -210,9 +242,13 @@ impl SpirvCompiler { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { match out_ty.elem() { Elem::Int(_, true) => b.s_negate(ty, Some(out), input).unwrap(), - Elem::Float(..) => b.f_negate(ty, Some(out), input).unwrap(), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + b.f_negate(ty, Some(out), input).unwrap() + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); b.f_negate(ty, Some(out), input).unwrap() } _ => unreachable!(), @@ -226,6 +262,7 @@ impl SpirvCompiler { // Extension functions Arithmetic::Normalize(op) => { self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::normalize(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -234,6 +271,7 @@ impl SpirvCompiler { } Arithmetic::Magnitude(op) => { self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::magnitude(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -244,9 +282,13 @@ impl SpirvCompiler { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { match out_ty.elem() { Elem::Int(_, _) => T::s_abs(b, ty, input, out), - Elem::Float(..) => T::f_abs(b, ty, input, out), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + T::f_abs(b, ty, input, out) + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); T::f_abs(b, ty, input, out) } _ => unreachable!(), @@ -255,6 +297,7 @@ impl SpirvCompiler { } Arithmetic::Exp(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::exp(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -263,6 +306,7 @@ impl SpirvCompiler { } Arithmetic::Log(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::log(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -275,7 +319,10 @@ impl SpirvCompiler { let relaxed = matches!(out_ty.elem(), Elem::Relaxed); let add = match out_ty.elem() { Elem::Int(_, _) => b.i_add(ty, None, input, one).unwrap(), - Elem::Float(..) | Elem::Relaxed => b.f_add(ty, None, input, one).unwrap(), + Elem::Float(..) | Elem::Relaxed => { + b.declare_math_mode(modes, out); + b.f_add(ty, None, input, one).unwrap() + } _ => unreachable!(), }; b.mark_uniformity(add, uniform); @@ -283,11 +330,13 @@ impl SpirvCompiler { b.decorate(add, Decoration::RelaxedPrecision, []); b.decorate(out, Decoration::RelaxedPrecision, []); } + b.declare_math_mode(modes, out); T::log(b, ty, add, out) }); } Arithmetic::Cos(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::cos(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -296,6 +345,7 @@ impl SpirvCompiler { } Arithmetic::Sin(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::sin(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -304,6 +354,7 @@ impl SpirvCompiler { } Arithmetic::Tan(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::tan(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -312,6 +363,7 @@ impl SpirvCompiler { } Arithmetic::Tanh(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::tanh(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -320,6 +372,7 @@ impl SpirvCompiler { } Arithmetic::Sinh(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::sinh(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -328,6 +381,7 @@ impl SpirvCompiler { } Arithmetic::Cosh(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::cosh(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -336,6 +390,7 @@ impl SpirvCompiler { } Arithmetic::ArcCos(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::acos(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -344,6 +399,7 @@ impl SpirvCompiler { } Arithmetic::ArcSin(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::asin(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -352,6 +408,7 @@ impl SpirvCompiler { } Arithmetic::ArcTan(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::atan(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -360,6 +417,7 @@ impl SpirvCompiler { } Arithmetic::ArcSinh(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::asinh(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -368,6 +426,7 @@ impl SpirvCompiler { } Arithmetic::ArcCosh(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::acosh(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -376,6 +435,7 @@ impl SpirvCompiler { } Arithmetic::ArcTanh(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::atanh(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -384,6 +444,7 @@ impl SpirvCompiler { } Arithmetic::Degrees(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::degrees(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -392,6 +453,7 @@ impl SpirvCompiler { } Arithmetic::Radians(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::radians(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -400,6 +462,7 @@ impl SpirvCompiler { } Arithmetic::ArcTan2(op) => { self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| { + b.declare_math_mode(modes, out); T::atan2(b, ty, lhs, rhs, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -419,19 +482,29 @@ impl SpirvCompiler { let one = out_ty.const_u32(b, 1); let two = out_ty.const_u32(b, 2); let modulo = b.f_rem(ty, None, rhs, two).unwrap(); + b.declare_math_mode(modes, modulo); let is_zero = b.f_ord_equal(bool, None, modulo, zero).unwrap(); + b.declare_math_mode(modes, is_zero); let abs = b.id(); + b.declare_math_mode(modes, abs); T::f_abs(b, ty, lhs, abs); let even = b.id(); + b.declare_math_mode(modes, even); T::pow(b, ty, abs, rhs, even); let cond2_0 = b.f_ord_equal(bool, None, modulo, one).unwrap(); + b.declare_math_mode(modes, cond2_0); let cond2_1 = b.f_ord_less_than(bool, None, lhs, zero).unwrap(); + b.declare_math_mode(modes, cond2_1); let cond2 = b.logical_and(bool, None, cond2_0, cond2_1).unwrap(); let neg_lhs = b.f_negate(ty, None, lhs).unwrap(); + b.declare_math_mode(modes, neg_lhs); let pow2 = b.id(); + b.declare_math_mode(modes, pow2); T::pow(b, ty, neg_lhs, rhs, pow2); let pow2_neg = b.f_negate(ty, None, pow2).unwrap(); + b.declare_math_mode(modes, pow2_neg); let default = b.id(); + b.declare_math_mode(modes, default); T::pow(b, ty, lhs, rhs, default); let ids = [ modulo, is_zero, abs, even, cond2_0, cond2_1, neg_lhs, pow2, pow2_neg, @@ -450,18 +523,17 @@ impl SpirvCompiler { } Arithmetic::Sqrt(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::sqrt(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); } }) } - Arithmetic::Rsqrt(op) => { + Arithmetic::InverseSqrt(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { - let sqrt = b.id(); - T::sqrt(b, ty, input, sqrt); - let one = out_ty.const_u32(b, 1); - b.f_div(ty, Some(out), one, sqrt).unwrap(); + b.declare_math_mode(modes, out); + T::inverse_sqrt(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); } @@ -477,6 +549,7 @@ impl SpirvCompiler { } Arithmetic::Floor(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::floor(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); @@ -485,12 +558,22 @@ impl SpirvCompiler { } Arithmetic::Ceil(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); T::ceil(b, ty, input, out); if matches!(out_ty.elem(), Elem::Relaxed) { b.decorate(out, Decoration::RelaxedPrecision, []); } }) } + Arithmetic::Trunc(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + b.declare_math_mode(modes, out); + T::trunc(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } Arithmetic::Clamp(op) => { let input = self.compile_variable(op.input); let min = self.compile_variable(op.min_value); @@ -509,9 +592,13 @@ impl SpirvCompiler { match out_ty.elem() { Elem::Int(_, false) => T::u_clamp(self, ty, input, min, max, out_id), Elem::Int(_, true) => T::s_clamp(self, ty, input, min, max, out_id), - Elem::Float(..) => T::f_clamp(self, ty, input, min, max, out_id), + Elem::Float(..) => { + self.declare_math_mode(modes, out_id); + T::f_clamp(self, ty, input, min, max, out_id) + } Elem::Relaxed => { self.decorate(out_id, Decoration::RelaxedPrecision, []); + self.declare_math_mode(modes, out_id); T::f_clamp(self, ty, input, min, max, out_id) } _ => unreachable!(), @@ -526,9 +613,13 @@ impl SpirvCompiler { |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() { Elem::Int(_, false) => T::u_max(b, ty, lhs, rhs, out), Elem::Int(_, true) => T::s_max(b, ty, lhs, rhs, out), - Elem::Float(..) => T::f_max(b, ty, lhs, rhs, out), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + T::f_max(b, ty, lhs, rhs, out) + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); T::f_max(b, ty, lhs, rhs, out) } _ => unreachable!(), @@ -541,9 +632,13 @@ impl SpirvCompiler { |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() { Elem::Int(_, false) => T::u_min(b, ty, lhs, rhs, out), Elem::Int(_, true) => T::s_min(b, ty, lhs, rhs, out), - Elem::Float(..) => T::f_min(b, ty, lhs, rhs, out), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + T::f_min(b, ty, lhs, rhs, out) + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); T::f_min(b, ty, lhs, rhs, out) } _ => unreachable!(), diff --git a/crates/cubecl-spirv/src/atomic.rs b/crates/cubecl-spirv/src/atomic.rs index f7f7d8cc2..2b0a41675 100644 --- a/crates/cubecl-spirv/src/atomic.rs +++ b/crates/cubecl-spirv/src/atomic.rs @@ -1,10 +1,15 @@ -use cubecl_core::ir::{AtomicOp, Variable}; +use cubecl_core::ir::{AtomicOp, InstructionModes, Variable}; use rspirv::spirv::{Capability, MemorySemantics, Scope}; use crate::{SpirvCompiler, SpirvTarget, item::Elem}; impl SpirvCompiler { - pub fn compile_atomic(&mut self, atomic: AtomicOp, out: Option) { + pub fn compile_atomic( + &mut self, + atomic: AtomicOp, + out: Option, + modes: InstructionModes, + ) { let out = out.unwrap(); match atomic { AtomicOp::Load(op) => { @@ -151,6 +156,7 @@ impl SpirvCompiler { _ => unreachable!(), }; let negated = self.f_negate(ty, None, rhs_id).unwrap(); + self.declare_math_mode(modes, negated); self.atomic_f_add_ext(ty, Some(out_id), lhs_id, memory, semantics, negated) .unwrap() } diff --git a/crates/cubecl-spirv/src/bitwise.rs b/crates/cubecl-spirv/src/bitwise.rs index 34f973b59..43ed3e34d 100644 --- a/crates/cubecl-spirv/src/bitwise.rs +++ b/crates/cubecl-spirv/src/bitwise.rs @@ -5,7 +5,7 @@ use cubecl_core::{ use cubecl_core::{comptime, ir as core, prelude::*}; use cubecl_core::{cube, ir::Bitwise}; -use crate::{SpirvCompiler, SpirvTarget}; +use crate::{SpirvCompiler, SpirvTarget, item::Elem}; impl SpirvCompiler { pub fn compile_bitwise(&mut self, op: Bitwise, out: Option, uniform: bool) { @@ -42,8 +42,14 @@ impl SpirvCompiler { }) } Bitwise::ShiftRight(op) => { - self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| { - b.shift_right_logical(ty, Some(out), lhs, rhs).unwrap(); + self.compile_binary_op(op, out, uniform, |b, item, ty, lhs, rhs, out| { + match item.elem() { + // Match behaviour on most compilers + Elem::Int(_, true) => { + b.shift_right_arithmetic(ty, Some(out), lhs, rhs).unwrap() + } + _ => b.shift_right_logical(ty, Some(out), lhs, rhs).unwrap(), + }; }) } diff --git a/crates/cubecl-spirv/src/compiler.rs b/crates/cubecl-spirv/src/compiler.rs index d085520f2..9466530f2 100644 --- a/crates/cubecl-spirv/src/compiler.rs +++ b/crates/cubecl-spirv/src/compiler.rs @@ -1,6 +1,7 @@ use cubecl_common::ExecutionMode; use cubecl_core::{ - Metadata, WgpuCompilationOptions, ir as core, + Metadata, WgpuCompilationOptions, + ir::{self as core, InstructionModes}, post_processing::{ checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor, unroll::UnrollProcessor, @@ -23,7 +24,7 @@ use std::{ use cubecl_core::{Compiler, compute::KernelDefinition}; use rspirv::{ dr::{Builder, InsertPoint, Instruction, Module, Operand}, - spirv::{self, BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word}, + spirv::{BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word}, }; use crate::{ @@ -43,7 +44,6 @@ pub struct SpirvCompiler { pub mode: ExecutionMode, pub debug_symbols: bool, - pub fp_math_mode: FPFastMathMode, global_invocation_id: Word, num_workgroups: Word, pub setup_block: usize, @@ -54,7 +54,6 @@ pub struct SpirvCompiler { pub visited: HashSet, pub capabilities: HashSet, - pub float_controls: bool, pub state: LookupTables, pub ext_meta_pos: Vec, pub metadata: Metadata, @@ -78,12 +77,9 @@ impl Clone for SpirvCompiler { uniformity: self.uniformity.clone(), shared_liveness: self.shared_liveness.clone(), current_block: self.current_block, - capabilities: self.capabilities.clone(), - float_controls: self.float_controls, state: self.state.clone(), debug_symbols: self.debug_symbols, - fp_math_mode: self.fp_math_mode, visited: self.visited.clone(), metadata: self.metadata.clone(), debug_info: self.debug_info.clone(), @@ -109,7 +105,6 @@ impl Default for SpirvCompiler { global_invocation_id: Default::default(), num_workgroups: Default::default(), capabilities: Default::default(), - float_controls: Default::default(), state: Default::default(), setup_block: Default::default(), opt: Default::default(), @@ -117,7 +112,6 @@ impl Default for SpirvCompiler { shared_liveness: Default::default(), current_block: Default::default(), debug_symbols: debug_symbols_activated(), - fp_math_mode: FPFastMathMode::NONE, visited: Default::default(), metadata: Default::default(), debug_info: Default::default(), @@ -212,15 +206,6 @@ impl SpirvCompiler { let options = kernel.options.clone(); self.debug_symbols = debug_symbols_activated() || options.debug_symbols; - self.fp_math_mode = match self.compilation_options.supports_fp_fast_math { - true => convert_math_mode(options.fp_math_mode), - false => FPFastMathMode::NONE, - }; - self.float_controls = self.fp_math_mode != FPFastMathMode::NONE; - - if self.float_controls { - self.capabilities.insert(Capability::FloatControls2); - } self.set_version(1, 6); @@ -461,18 +446,17 @@ impl SpirvCompiler { } } - pub fn declare_float_execution_modes(&mut self, main: Word) { - let mode = self.const_u32(self.fp_math_mode.bits()); - - let types = self.builder.module_ref().types_global_values.clone(); - let scalars = types - .iter() - .filter(|inst| inst.class.opcode == Op::TypeFloat) - .map(|it| it.result_id.expect("OpTypeFloat always has result ID")) - .collect::>(); - for ty in scalars { - self.execution_mode(main, spirv::ExecutionMode::FPFastMathDefault, [ty, mode]); + pub fn declare_math_mode(&mut self, modes: InstructionModes, out_id: Word) { + if !self.compilation_options.supports_fp_fast_math || modes.fp_math_mode.is_empty() { + return; } + let mode = convert_math_mode(modes.fp_math_mode); + self.capabilities.insert(Capability::FloatControls2); + self.decorate( + out_id, + Decoration::FPFastMathMode, + [Operand::FPFastMathMode(mode)], + ); } pub fn is_uniform_block(&self) -> bool { @@ -481,7 +465,7 @@ impl SpirvCompiler { } } -fn convert_math_mode(math_mode: EnumSet) -> FPFastMathMode { +pub(crate) fn convert_math_mode(math_mode: EnumSet) -> FPFastMathMode { let mut flags = FPFastMathMode::NONE; for mode in math_mode.iter() { @@ -490,12 +474,12 @@ fn convert_math_mode(math_mode: EnumSet) -> FPFastMathMode { FastMath::NotInf => flags |= FPFastMathMode::NOT_INF, FastMath::UnsignedZero => flags |= FPFastMathMode::NSZ, FastMath::AllowReciprocal => flags |= FPFastMathMode::ALLOW_RECIP, - FastMath::AllowContraction => flags |= FPFastMathMode::from_bits_retain(0x10000), - FastMath::AllowReassociation => flags |= FPFastMathMode::from_bits_retain(0x20000), + FastMath::AllowContraction => flags |= FPFastMathMode::ALLOW_CONTRACT, + FastMath::AllowReassociation => flags |= FPFastMathMode::ALLOW_REASSOC, FastMath::AllowTransform => { - flags |= FPFastMathMode::from_bits_retain(0x10000) - | FPFastMathMode::from_bits_retain(0x20000) - | FPFastMathMode::from_bits_retain(0x40000) + flags |= FPFastMathMode::ALLOW_CONTRACT + | FPFastMathMode::ALLOW_REASSOC + | FPFastMathMode::ALLOW_TRANSFORM } _ => {} } diff --git a/crates/cubecl-spirv/src/extensions.rs b/crates/cubecl-spirv/src/extensions.rs index c2093b0fd..c8dda6b17 100644 --- a/crates/cubecl-spirv/src/extensions.rs +++ b/crates/cubecl-spirv/src/extensions.rs @@ -10,6 +10,7 @@ pub trait TargetExtensions { fn s_abs(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn floor(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn ceil(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn trunc(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn sin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn cos(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn tan(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); @@ -29,6 +30,7 @@ pub trait TargetExtensions { fn exp(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn log(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn sqrt(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn inverse_sqrt(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn f_min(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word); fn u_min(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word); fn s_min(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word); @@ -46,6 +48,7 @@ pub trait TargetExtensions { } pub mod glcompute { + use super::*; impl TargetExtensions for GLCompute { @@ -69,6 +72,10 @@ pub mod glcompute { b.gl_ceil_id(ty, Some(out), input).unwrap(); } + fn trunc(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.gl_trunc_id(ty, Some(out), input).unwrap(); + } + fn sin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { b.gl_sin_id(ty, Some(out), input).unwrap(); } @@ -145,6 +152,10 @@ pub mod glcompute { b.gl_sqrt_id(ty, Some(out), input).unwrap(); } + fn inverse_sqrt(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.gl_inverse_sqrt_id(ty, Some(out), input).unwrap(); + } + fn f_min(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word) { b.gl_f_min_id(ty, Some(out), lhs, rhs).unwrap(); } diff --git a/crates/cubecl-spirv/src/instruction.rs b/crates/cubecl-spirv/src/instruction.rs index 33f5761de..6e8583ac2 100644 --- a/crates/cubecl-spirv/src/instruction.rs +++ b/crates/cubecl-spirv/src/instruction.rs @@ -1,5 +1,6 @@ use cubecl_core::ir::{ - self as core, BinaryOperator, Comparison, Instruction, Operation, Operator, UnaryOperator, + self as core, BinaryOperator, Comparison, Instruction, InstructionModes, Operation, Operator, + UnaryOperator, }; use rspirv::spirv::{Capability, Decoration, Word}; @@ -29,11 +30,15 @@ impl SpirvCompiler { self.mark_uniformity(out_id, uniform); self.write(&out, out_id); } - Operation::Arithmetic(operator) => self.compile_arithmetic(operator, inst.out, uniform), - Operation::Comparison(operator) => self.compile_cmp(operator, inst.out, uniform), + Operation::Arithmetic(operator) => { + self.compile_arithmetic(operator, inst.out, inst.modes, uniform) + } + Operation::Comparison(operator) => { + self.compile_cmp(operator, inst.out, inst.modes, uniform) + } Operation::Bitwise(operator) => self.compile_bitwise(operator, inst.out, uniform), Operation::Operator(operator) => self.compile_operator(operator, inst.out, uniform), - Operation::Atomic(atomic) => self.compile_atomic(atomic, inst.out), + Operation::Atomic(atomic) => self.compile_atomic(atomic, inst.out, inst.modes), Operation::Branch(_) => unreachable!("Branches shouldn't exist in optimized IR"), Operation::Metadata(meta) => self.compile_meta(meta, inst.out, uniform), Operation::Plane(plane) => self.compile_plane(plane, inst.out, uniform), @@ -42,11 +47,17 @@ impl SpirvCompiler { Operation::NonSemantic(debug) => self.compile_debug(debug), Operation::Barrier(_) => panic!("Barrier not supported in SPIR-V"), Operation::Tma(_) => panic!("TMA not supported in SPIR-V"), - Operation::Free(_) => {} + Operation::Marker(_) => {} } } - pub fn compile_cmp(&mut self, op: Comparison, out: Option, uniform: bool) { + pub fn compile_cmp( + &mut self, + op: Comparison, + out: Option, + modes: InstructionModes, + uniform: bool, + ) { let out = out.unwrap(); match op { Comparison::Equal(op) => { @@ -54,9 +65,13 @@ impl SpirvCompiler { match lhs_ty.elem() { Elem::Bool => b.logical_equal(ty, Some(out), lhs, rhs), Elem::Int(_, _) => b.i_equal(ty, Some(out), lhs, rhs), - Elem::Float(..) => b.f_ord_equal(ty, Some(out), lhs, rhs), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + b.f_ord_equal(ty, Some(out), lhs, rhs) + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); b.f_ord_equal(ty, Some(out), lhs, rhs) } Elem::Void => unreachable!(), @@ -69,9 +84,13 @@ impl SpirvCompiler { match lhs_ty.elem() { Elem::Bool => b.logical_not_equal(ty, Some(out), lhs, rhs), Elem::Int(_, _) => b.i_not_equal(ty, Some(out), lhs, rhs), - Elem::Float(..) => b.f_ord_not_equal(ty, Some(out), lhs, rhs), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + b.f_ord_not_equal(ty, Some(out), lhs, rhs) + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); b.f_ord_not_equal(ty, Some(out), lhs, rhs) } Elem::Void => unreachable!(), @@ -84,9 +103,13 @@ impl SpirvCompiler { match lhs_ty.elem() { Elem::Int(_, false) => b.u_less_than(ty, Some(out), lhs, rhs), Elem::Int(_, true) => b.s_less_than(ty, Some(out), lhs, rhs), - Elem::Float(..) => b.f_ord_less_than(ty, Some(out), lhs, rhs), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + b.f_ord_less_than(ty, Some(out), lhs, rhs) + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); b.f_ord_less_than(ty, Some(out), lhs, rhs) } _ => unreachable!(), @@ -99,9 +122,13 @@ impl SpirvCompiler { match lhs_ty.elem() { Elem::Int(_, false) => b.u_less_than_equal(ty, Some(out), lhs, rhs), Elem::Int(_, true) => b.s_less_than_equal(ty, Some(out), lhs, rhs), - Elem::Float(..) => b.f_ord_less_than_equal(ty, Some(out), lhs, rhs), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + b.f_ord_less_than_equal(ty, Some(out), lhs, rhs) + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); b.f_ord_less_than_equal(ty, Some(out), lhs, rhs) } _ => unreachable!(), @@ -114,9 +141,13 @@ impl SpirvCompiler { match lhs_ty.elem() { Elem::Int(_, false) => b.u_greater_than(ty, Some(out), lhs, rhs), Elem::Int(_, true) => b.s_greater_than(ty, Some(out), lhs, rhs), - Elem::Float(..) => b.f_ord_greater_than(ty, Some(out), lhs, rhs), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + b.f_ord_greater_than(ty, Some(out), lhs, rhs) + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); b.f_ord_greater_than(ty, Some(out), lhs, rhs) } _ => unreachable!(), @@ -129,9 +160,13 @@ impl SpirvCompiler { match lhs_ty.elem() { Elem::Int(_, false) => b.u_greater_than_equal(ty, Some(out), lhs, rhs), Elem::Int(_, true) => b.s_greater_than_equal(ty, Some(out), lhs, rhs), - Elem::Float(..) => b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs), + Elem::Float(..) => { + b.declare_math_mode(modes, out); + b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs) + } Elem::Relaxed => { b.decorate(out, Decoration::RelaxedPrecision, []); + b.declare_math_mode(modes, out); b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs) } _ => unreachable!(), diff --git a/crates/cubecl-spirv/src/subgroup.rs b/crates/cubecl-spirv/src/subgroup.rs index 182988a66..07e965b13 100644 --- a/crates/cubecl-spirv/src/subgroup.rs +++ b/crates/cubecl-spirv/src/subgroup.rs @@ -183,6 +183,34 @@ impl SpirvCompiler { .unwrap(); }); } + Plane::Shuffle(op) => { + self.capabilities.insert(Capability::GroupNonUniformShuffle); + self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| { + b.group_non_uniform_shuffle(ty, Some(out), subgroup, lhs, rhs) + .unwrap(); + }); + } + Plane::ShuffleXor(op) => { + self.capabilities.insert(Capability::GroupNonUniformShuffle); + self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| { + b.group_non_uniform_shuffle_xor(ty, Some(out), subgroup, lhs, rhs) + .unwrap(); + }); + } + Plane::ShuffleUp(op) => { + self.capabilities.insert(Capability::GroupNonUniformShuffle); + self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| { + b.group_non_uniform_shuffle_up(ty, Some(out), subgroup, lhs, rhs) + .unwrap(); + }); + } + Plane::ShuffleDown(op) => { + self.capabilities.insert(Capability::GroupNonUniformShuffle); + self.compile_binary_op_no_cast(op, out, uniform, |b, _, ty, lhs, rhs, out| { + b.group_non_uniform_shuffle_down(ty, Some(out), subgroup, lhs, rhs) + .unwrap(); + }); + } } } diff --git a/crates/cubecl-spirv/src/target.rs b/crates/cubecl-spirv/src/target.rs index c927d471a..170ee9a15 100644 --- a/crates/cubecl-spirv/src/target.rs +++ b/crates/cubecl-spirv/src/target.rs @@ -109,7 +109,7 @@ impl SpirvTarget for GLCompute { b.extension("SPV_EXT_float8"); } - if b.float_controls { + if caps.contains(&Capability::FloatControls2) { b.extension("SPV_KHR_float_controls2"); } @@ -125,10 +125,6 @@ impl SpirvTarget for GLCompute { interface, ); b.execution_mode(main, spirv::ExecutionMode::LocalSize, cube_dims); - - if b.float_controls { - b.declare_float_execution_modes(main); - } } fn generate_binding( diff --git a/crates/cubecl-std/Cargo.toml b/crates/cubecl-std/Cargo.toml index 9886117e4..0e3100a42 100644 --- a/crates/cubecl-std/Cargo.toml +++ b/crates/cubecl-std/Cargo.toml @@ -19,8 +19,9 @@ export_tests = [] [dependencies] -cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false } +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false } +cubecl-core = { path = "../cubecl-core", version = "0.9.0", default-features = false } +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false } half.workspace = true paste = { workspace = true } serde = { workspace = true } diff --git a/crates/cubecl-std/src/fast_math.rs b/crates/cubecl-std/src/fast_math.rs index 61f97b846..6d2d55948 100644 --- a/crates/cubecl-std/src/fast_math.rs +++ b/crates/cubecl-std/src/fast_math.rs @@ -29,7 +29,7 @@ impl Clone for FastDivmodArgs<'_, R> { impl Copy for FastDivmodArgs<'_, R> {} impl FastDivmodArgs<'_, R> { - pub fn new(client: &ComputeClient, divisor: u32) -> Self { + pub fn new(client: &ComputeClient, divisor: u32) -> Self { debug_assert!(divisor != 0); if !u64::supported_uses(client).contains(TypeUsage::Arithmetic) { diff --git a/crates/cubecl-std/src/lib.rs b/crates/cubecl-std/src/lib.rs index 604be5075..dd38ccff0 100644 --- a/crates/cubecl-std/src/lib.rs +++ b/crates/cubecl-std/src/lib.rs @@ -12,6 +12,8 @@ pub use trigonometry::*; mod option; pub use option::*; +/// Quantization functionality required in views +pub mod quant; pub mod tensor; #[cfg(feature = "export_tests")] diff --git a/crates/cubecl-std/src/quant/base.rs b/crates/cubecl-std/src/quant/base.rs new file mode 100644 index 000000000..f8a24cfaf --- /dev/null +++ b/crates/cubecl-std/src/quant/base.rs @@ -0,0 +1,9 @@ +use cubecl_core::prelude::CubePrimitive; + +/// Run an arbitrary function with the quantization types from the scheme. +/// Useful when concrete types aren't available. +pub trait RunWithQuantType { + type Output; + + fn execute(self) -> Self::Output; +} diff --git a/crates/cubecl-std/src/quant/dequantize.rs b/crates/cubecl-std/src/quant/dequantize.rs new file mode 100644 index 000000000..eb69dae48 --- /dev/null +++ b/crates/cubecl-std/src/quant/dequantize.rs @@ -0,0 +1,100 @@ +use cubecl::prelude::*; +use cubecl_common::quant::scheme::*; +use cubecl_common::{e2m1x2, e4m3, e5m2}; +use cubecl_core as cubecl; + +/// Dequantize a line of values, where `line_size * num_quants` is a power of two. +/// Unaligned values can't be dequantized in place. +#[cube] +pub fn dequantize_aligned( + value: Line, + scale: S, + #[comptime] scheme: QuantScheme, +) -> Line { + let q_values = match scheme.store { + QuantStore::Native => Line::::cast_from(value), + QuantStore::U32 => unpack_cast_u32::(Line::cast_from(value), scheme), + }; + let scale = Line::::cast_from(scale); + + match scheme.mode { + QuantMode::Symmetric => q_values * scale, + } +} + +/// Unpack a set of values from u32, and convert to the specified floating point format. +#[cube] +pub fn unpack_cast_u32(value: Line, #[comptime] scheme: QuantScheme) -> Line { + let num_quants = comptime![scheme.num_quants() as u32]; + let native_packing = comptime![scheme.native_packing() as u32]; + let out_line_size = comptime![value.line_size() * num_quants]; + let size_bits = comptime![scheme.size_bits_value() as u32]; + let mask = comptime![packing_mask(scheme)]; + + let mut out = Line::::empty(out_line_size); + + #[unroll] + for line_idx in 0..value.line_size() { + let packed_val = value[line_idx]; + let out_offset = comptime![line_idx * num_quants]; + #[unroll] + for packed_idx in range_stepped(0, num_quants, native_packing) { + let shift = packed_idx * size_bits; + let value = (packed_val >> shift) & mask; + + let float_value = cast_masked::(value, scheme); + + #[unroll] + for native_idx in 0..native_packing { + let out_offset = comptime![out_offset + packed_idx + native_idx]; + out[out_offset] = float_value[native_idx]; + } + } + } + + out +} + +/// The mask required for each packed value, taking into account the native packing required for +/// `e2m1`. +fn packing_mask(scheme: QuantScheme) -> u32 { + let bits = match scheme.value { + QuantValue::E2M1 => 8, // Packed conversion + other => other.size_bits(), + }; + (1u32 << bits) - 1 +} + +/// Cast a masked-out value in the low `n` bits of a `u32` to the specified float type. +/// Applies sign conversion for integer quantization before casting to the float type, +/// while minifloats are simply truncated to `u8`, reinterpreted and then cast. +/// For `e2m1`, casting is done on the packed `e2m1x2` representation. +/// +/// # Returns +/// Two floating point numbers for `e2m1`, one for all other formats. +#[cube] +fn cast_masked(value: u32, #[comptime] scheme: QuantScheme) -> Line { + match scheme.value { + // For minifloat we can assume if they're supported then u8 is supported + QuantValue::E5M2 => Line::::cast_from(e5m2::reinterpret(value as u8)), + QuantValue::E4M3 => Line::::cast_from(e4m3::reinterpret(value as u8)), + QuantValue::E2M1 => Line::::cast_from(e2m1x2::reinterpret(value as u8)), + QuantValue::Q8F + | QuantValue::Q4F + | QuantValue::Q2F + | QuantValue::Q8S + | QuantValue::Q4S + | QuantValue::Q2S => { + let size_quant = comptime!(scheme.size_bits_value() as u32); + let sign_bit = comptime!(1u32 << (size_quant - 1)); + let two_pow_n = comptime!(1 << size_quant); + + // Branchless two's complement conversion + // If raw >= 2^(n-1), then result = raw - 2^n + let raw_i32 = value as i32; + let is_negative = (value >= sign_bit) as i32; // 1 if negative, 0 if positive + let signed_value = raw_i32 - (is_negative * two_pow_n); + Line::::cast_from(signed_value) + } + } +} diff --git a/crates/cubecl-std/src/quant/mod.rs b/crates/cubecl-std/src/quant/mod.rs new file mode 100644 index 000000000..a89e8a115 --- /dev/null +++ b/crates/cubecl-std/src/quant/mod.rs @@ -0,0 +1,6 @@ +mod base; +mod dequantize; +pub mod view; + +pub use base::*; +pub use dequantize::*; diff --git a/crates/cubecl-std/src/quant/view.rs b/crates/cubecl-std/src/quant/view.rs new file mode 100644 index 000000000..6438c162b --- /dev/null +++ b/crates/cubecl-std/src/quant/view.rs @@ -0,0 +1,321 @@ +use std::marker::PhantomData; + +use super::*; +use crate::{ + CubeOption, CubeOptionExpand, + tensor::{ + View, ViewExpand, ViewOperations, ViewOperationsExpand, launch::ViewCompilationArg, + layout::Coordinates, + }, +}; +use cubecl::prelude::*; +use cubecl_common::{ + e2m1x2, e4m3, e5m2, + quant::scheme::{QuantParam, QuantScheme, QuantStore, QuantValue}, + ue8m0, +}; +use cubecl_core::{ + self as cubecl, + ir::{ElemType, FloatKind, StorageType}, + prelude::barrier::BarrierExpand, + unexpanded, +}; +use half::{bf16, f16}; + +/// View that dequantizes after loads. Scales layout should take values coordinates and map them +/// to the corresponding scale. +/// +/// # Warning +/// Assumes only one scale maps to a single load. Adjust line size of values or block size to ensure +/// this. +/// Must ensure `block_size.is_multiple_of(line_size * scheme.num_quants())`. +#[expect(dead_code, reason = "only used in expand")] +#[derive(CubeType, CubeLaunch, Clone, Copy)] +pub struct QuantizedView { + values: View, C>, + scales: View, + #[cube(comptime)] + scheme: QuantScheme, + #[cube(comptime)] + _ty: PhantomData, +} + +#[cube] +impl + QuantizedView +{ + pub fn new( + values: View, C>, + scales: View, + #[comptime] scheme: QuantScheme, + ) -> Self { + QuantizedView:: { + values, + scales, + scheme, + _ty: PhantomData, + } + } +} + +impl + QuantizedView +{ + pub fn view(self) -> View, C> { + unexpanded!() + } + + pub fn __expand_view( + scope: &mut Scope, + this: QuantizedViewExpand, + ) -> ViewExpand, C, ReadOnly> { + this.__expand_view_method(scope) + } +} + +impl + QuantizedViewExpand +{ + pub fn new( + values: ViewExpand, C>, + scales: ViewExpand, + scheme: QuantScheme, + ) -> Self { + QuantizedViewExpand:: { + values, + scales, + scheme, + _ty: PhantomData, + } + } + + pub fn __expand_view_method(self, _scope: &mut Scope) -> ViewExpand, C, ReadOnly> { + ViewExpand::new(self) + } +} + +impl Lined + for QuantizedView +{ +} +impl LinedExpand + for QuantizedViewExpand +{ + fn line_size(&self) -> u32 { + self.values.line_size() * self.scheme.num_quants() as u32 + } +} + +impl + ViewOperations, C> for QuantizedView +{ +} + +impl + ViewOperationsExpand, C> for QuantizedViewExpand +{ + fn __expand_read_method( + &self, + scope: &mut Scope, + pos: ::ExpandType, + ) -> ExpandElementTyped> { + let value = self.values.clone().__expand_read_method(scope, pos.clone()); + let scale = self.scales.clone().__expand_read_method(scope, pos); + + dequantize_aligned::expand::(scope, value, scale, self.scheme) + } + + fn __expand_read_checked_method( + &self, + scope: &mut Scope, + pos: ::ExpandType, + ) -> ExpandElementTyped> { + let value = self + .values + .clone() + .__expand_read_checked_method(scope, pos.clone()); + let scale = self + .scales + .clone() + .__expand_read_checked_method(scope, pos.clone()); + + dequantize_aligned::expand::(scope, value, scale, self.scheme) + } + + fn __expand_read_masked_method( + &self, + scope: &mut Scope, + pos: ::ExpandType, + mask_value: ExpandElementTyped>, + ) -> ExpandElementTyped> { + let value = self + .values + .clone() + .__expand_read_checked_method(scope, pos.clone()); + let scale = self + .scales + .clone() + .__expand_read_checked_method(scope, pos.clone()); + let in_bounds = self.__expand_is_in_bounds_method(scope, pos); + + let value = dequantize_aligned::expand::(scope, value, scale, self.scheme); + select::expand::>(scope, in_bounds, value, mask_value) + } + + fn __expand_read_unchecked_method( + &self, + scope: &mut Scope, + pos: ::ExpandType, + ) -> ExpandElementTyped> { + let value = self + .values + .clone() + .__expand_read_unchecked_method(scope, pos.clone()); + let scale = self + .scales + .clone() + .__expand_read_unchecked_method(scope, pos); + + dequantize_aligned::expand::(scope, value, scale, self.scheme) + } + + fn __expand_to_linear_slice_method( + &self, + _scope: &mut Scope, + _pos: ::ExpandType, + _end: ::ExpandType, + ) -> SliceExpand, ReadOnly> { + panic!("Can't create raw slice for quantized view") + } + + fn __expand_as_tensor_map_method( + &self, + scope: &mut Scope, + ) -> CubeOptionExpand>> { + CubeOption::__expand_new_None(scope) + } + + fn __expand_shape_method(&self, scope: &mut Scope) -> ::ExpandType { + self.values.clone().__expand_shape_method(scope) + } + + fn __expand_is_in_bounds_method( + &self, + scope: &mut Scope, + pos: C::ExpandType, + ) -> ExpandElementTyped { + self.values.clone().__expand_is_in_bounds_method(scope, pos) + } + + fn __expand_tensor_map_load_method( + &self, + _scope: &mut Scope, + _barrier: BarrierExpand, + _shared_memory: SliceExpand, ReadWrite>, + _pos: C::ExpandType, + ) { + panic!("Can't use tensor map functions on quantized view"); + } +} + +struct ExpandDynamic<'a, E: Numeric, C: Coordinates + 'static> { + values: &'a ViewCompilationArg, + scales: &'a ViewCompilationArg, + scheme: QuantScheme, + builder: &'a mut KernelBuilder, + _ty: PhantomData, +} + +impl<'a, E: Numeric, C: Coordinates + 'static> RunWithQuantType for ExpandDynamic<'a, E, C> { + type Output = ViewExpand, C>; + + fn execute(self) -> Self::Output { + let values = View::, C>::expand(self.values, self.builder); + let scales = View::::expand(self.scales, self.builder); + let view = QuantizedViewExpand::new(values, scales, self.scheme); + ViewExpand::new(view) + } +} + +/// Run a function with the quantization storage type and scale. Useful when concrete types are +/// required but aren't available, and only the dynamic schema is known. +pub fn run_with_quant_type(func: F, scheme: QuantScheme) -> F::Output { + fn run_with_q( + func: F, + scheme: QuantScheme, + ) -> F::Output { + match scheme.param { + QuantParam::F32 => func.execute::(), + QuantParam::F16 => func.execute::(), + QuantParam::BF16 => func.execute::(), + QuantParam::UE8M0 => func.execute::(), + QuantParam::UE4M3 => func.execute::(), + } + } + + let run_q = match scheme.store { + QuantStore::Native => match scheme.value { + QuantValue::Q8F => run_with_q::, + QuantValue::Q8S => run_with_q::, + QuantValue::E5M2 => run_with_q::, + QuantValue::E4M3 => run_with_q::, + QuantValue::E2M1 => run_with_q::, + QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => { + panic!("Sub-byte quantization can't be native") + } + }, + QuantStore::U32 => run_with_q::, + }; + run_q(func, scheme) +} + +/// Dynamically expand based on the quantization scheme. Ugly, but the only way to fully hide the +/// quantization from the kernel using the view. +pub(crate) fn expand_dynamic( + values: &ViewCompilationArg, + scales: &ViewCompilationArg, + scheme: QuantScheme, + builder: &mut KernelBuilder, +) -> ViewExpand { + use core::mem::transmute as t; + + // To specify tighter trait bounds + fn expand_dynamic_f( + values: &ViewCompilationArg, + scales: &ViewCompilationArg, + scheme: QuantScheme, + builder: &mut KernelBuilder, + ) -> ViewExpand, C> { + let func = ExpandDynamic { + values, + scales, + scheme, + builder, + _ty: PhantomData::, + }; + run_with_quant_type(func, scheme) + } + + #[allow(clippy::missing_transmute_annotations)] + unsafe { + match E::as_type(&builder.scope) { + StorageType::Scalar(ElemType::Float(ty)) => match ty { + FloatKind::F16 => t(expand_dynamic_f::(values, scales, scheme, builder)), + FloatKind::BF16 => t(expand_dynamic_f::(values, scales, scheme, builder)), + FloatKind::Flex32 => t(expand_dynamic_f::( + values, scales, scheme, builder, + )), + FloatKind::F32 => t(expand_dynamic_f::(values, scales, scheme, builder)), + FloatKind::TF32 => t(expand_dynamic_f::(values, scales, scheme, builder)), + FloatKind::F64 => t(expand_dynamic_f::(values, scales, scheme, builder)), + FloatKind::E2M1 + | FloatKind::E2M3 + | FloatKind::E3M2 + | FloatKind::E4M3 + | FloatKind::E5M2 + | FloatKind::UE8M0 => unreachable!("Minifloats don't implement `Float` ops"), + }, + _ => unreachable!("Quantized view should only be used with floats"), + } + } +} diff --git a/crates/cubecl-std/src/tensor/contiguous.rs b/crates/cubecl-std/src/tensor/contiguous.rs index ad65689a7..3cce91211 100644 --- a/crates/cubecl-std/src/tensor/contiguous.rs +++ b/crates/cubecl-std/src/tensor/contiguous.rs @@ -1,5 +1,5 @@ use crate::{ - FastDivmod, + FastDivmod, FastDivmodArgs, tensor::layout::{ Layout, LayoutExpand, linear::{LinearLayout, LinearLayoutArgs, LinearView, linear_view}, @@ -149,9 +149,97 @@ fn into_contiguous_kernel_pack( } } +/// Fetch all values required contained in a given position, unpack them, then repack them to their +/// new position. +#[cube] +fn index_packed( + tensor: &Tensor, + pos: u32, + in_shape: &Sequence, + #[comptime] packed_dim: u32, + #[comptime] packing: u32, + #[comptime] rank: u32, +) -> N { + let bits_per_elem = comptime![N::elem_size_bits() / packing]; + let mask = comptime![(1u32 << bits_per_elem) - 1]; + let mask = N::cast_from(mask); + + let elem_pos = pos * packing; + + let mut out = N::new(0); + for n in 0..packing { + let mut remainder = elem_pos + n; + let mut offset = 0; + let mut packing_offset = 0; + + #[unroll] + for i in 0..rank { + let dim = comptime![rank - i - 1]; + let (rem, mut local_pos) = in_shape.index(dim).div_mod(remainder); + remainder = rem; + if comptime![dim == packed_dim] { + packing_offset = local_pos % packing; + local_pos /= packing; + } + offset += local_pos * tensor.stride(dim); + } + let packed_val = tensor[offset]; + let shift_in = packing_offset * bits_per_elem; + let shift_out = n * bits_per_elem; + let value = (packed_val >> N::cast_from(shift_in)) & mask; + + out |= value << N::cast_from(shift_out); + } + out +} + +#[cube(launch)] +fn into_contiguous_kernel_packed( + input: &Tensor, + output: &mut Tensor>, + out_layout: LinearLayout, + in_shape: Sequence, + #[comptime] packed_dim: u32, + #[comptime] packing: u32, + #[comptime] rank: u32, + #[comptime] elems_per_thread: u32, +) { + let line_size = output.line_size(); + let lines_per_thread = comptime![elems_per_thread / line_size]; + + let offset_output = ABSOLUTE_POS * lines_per_thread; + let offset_input = offset_output * line_size; + + if offset_output >= output.len() { + terminate!() + } + + let mut registers = Array::>::vectorized(lines_per_thread, line_size); + + #[unroll] + for i in 0..lines_per_thread { + let offset = i * line_size; + let mut reg = Line::::empty(line_size); + #[unroll] + for k in 0..line_size { + let offset_input = offset_input + offset + k; + + reg[k] = index_packed(input, offset_input, &in_shape, packed_dim, packing, rank); + } + registers[i] = reg; + } + + let offset_output = out_layout.to_source_pos(offset_output); + + #[unroll] + for i in 0..lines_per_thread { + output[offset_output + i] = registers[i]; + } +} + /// Make a jit tensor contiguous. pub fn into_contiguous( - client: &ComputeClient, + client: &ComputeClient, input: &TensorHandleRef<'_, R>, ) -> TensorHandle { let num_elems: usize = input.shape.iter().product(); @@ -167,7 +255,7 @@ pub fn into_contiguous( /// Make a jit tensor contiguous, using the pitched allocator if available. /// See [create_tensor](cubecl_runtime::client::ComputeClient::create_tensor). pub fn into_contiguous_pitched( - client: &ComputeClient, + client: &ComputeClient, input: &TensorHandleRef<'_, R>, ) -> TensorHandle { if input.shape.len() <= 1 { @@ -181,9 +269,39 @@ pub fn into_contiguous_pitched( output } +/// Make a jit tensor contiguous, using the pitched allocator if available. +/// See [create_tensor](cubecl_runtime::client::ComputeClient::create_tensor). +/// Handles unpacking and repacking packed tensors (i.e. quantized values). +/// `shape` refers to the actual (unpacked) shape of the tensor, while `packing` specifies the +/// number of elements in each storage element. +/// +/// # Warning +/// This assumes `u32` or `u8` packing. +pub fn into_contiguous_packed( + client: &ComputeClient, + input: &TensorHandleRef<'_, R>, + shape: &[usize], + packing: u32, +) -> TensorHandle { + let rank = shape.len(); + if rank <= 1 { + return into_contiguous(client, input); + } + + let mut out_shape = shape.to_vec(); + out_shape[rank - 1] = out_shape[rank - 1].div_ceil(packing as usize); + let output = TensorHandle::::empty(client, out_shape); + + // Should reinterpret as u8 if possible at some point, but requires modifying shape/strides so + // keep it simple for now + into_contiguous_packed_ref::(client, input, &output.as_ref(), shape, packing); + + output +} + /// Make a jit tensor contiguous. pub fn into_contiguous_ref( - client: &ComputeClient, + client: &ComputeClient, input: &TensorHandleRef<'_, R>, output: &TensorHandleRef<'_, R>, ) { @@ -191,13 +309,13 @@ pub fn into_contiguous_ref( // Vectorization is only enabled when the last dimension is contiguous. let rank = input.strides.len(); - let vectorization_factor = tensor_line_size_parallel( + let line_size = tensor_line_size_parallel( R::supported_line_sizes().iter().cloned(), input.shape, input.strides, rank - 1, ); - let num_vecs = num_elems / vectorization_factor as usize; + let num_vecs = num_elems / line_size as usize; let num_sm = client .properties() .hardware @@ -211,19 +329,18 @@ pub fn into_contiguous_ref( 8.. => 8, }; - let mut num_elems_per_unit = vectorization_factor as u32 * elems_per_unit; + let mut num_elems_per_unit = line_size as u32 * elems_per_unit; let last_dim = output.shape[rank - 1]; - let is_padded = rank > 1 && last_dim != output.strides[rank - 2]; // If tensor is strided, elems_per_unit must be compatible with last dim - while is_padded && !last_dim.is_multiple_of(num_elems_per_unit as usize) { + while !last_dim.is_multiple_of(num_elems_per_unit as usize) { elems_per_unit /= 2; num_elems_per_unit /= 2; } - let out_vec = if vectorization_factor > 1 { - vectorization_factor + let out_vec = if line_size > 1 { + line_size } else { *R::supported_line_sizes() .iter() @@ -232,14 +349,14 @@ pub fn into_contiguous_ref( .unwrap_or(&1) }; - let input = linear_view(client, input, vectorization_factor); + let input = linear_view(client, input, line_size); let out_layout = LinearLayoutArgs::from_handle(client, output, out_vec); let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim); - let launch = if vectorization_factor != out_vec && out_vec > 1 { + let launch = if line_size != out_vec && out_vec > 1 { into_contiguous_kernel_pack::launch:: } else { into_contiguous_kernel::launch:: @@ -256,6 +373,82 @@ pub fn into_contiguous_ref( ); } +/// Make a jit tensor contiguous. +pub fn into_contiguous_packed_ref( + client: &ComputeClient, + input: &TensorHandleRef<'_, R>, + output: &TensorHandleRef<'_, R>, + shape: &[usize], + packing: u32, +) { + let num_elems: usize = input.shape.iter().product(); + + // Vectorization is only enabled when the last dimension is contiguous. + let rank = input.strides.len(); + let line_size = tensor_line_size_parallel( + R::io_optimized_line_sizes(&E::as_type_native_unchecked()), + output.shape, + output.strides, + rank - 1, + ); + let num_vecs = num_elems / line_size as usize; + let num_sm = client + .properties() + .hardware + .num_streaming_multiprocessors + .unwrap_or(NUM_SM_APPROX); + let simul_vecs = num_sm * CubeDim::default().num_elems(); + let mut elems_per_unit = match num_vecs as u32 / simul_vecs { + 0..2 => 1, + 2..4 => 2, + 4..8 => 4, + 8.. => 8, + }; + + let mut num_elems_per_unit = line_size as u32 * elems_per_unit; + + let last_dim = output.shape[rank - 1]; + let packed_dim = input + .strides + .iter() + .enumerate() + .rev() + .find(|(_, s)| **s == 1) + .expect("At least one stride should be 1") + .0; + + // If tensor is strided, elems_per_unit must be compatible with last dim + while !last_dim.is_multiple_of(num_elems_per_unit as usize) { + elems_per_unit /= 2; + num_elems_per_unit /= 2; + } + + let out_layout = LinearLayoutArgs::from_handle(client, output, line_size); + + let cube_dim = CubeDim::default(); + let cube_count = + calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim); + + let in_shape = shape + .iter() + .map(|s| FastDivmodArgs::new(client, *s as u32)) + .collect(); + + into_contiguous_kernel_packed::launch::( + client, + cube_count, + cube_dim, + input.as_tensor_arg(1), + output.as_tensor_arg(line_size), + out_layout, + in_shape, + packed_dim as u32, + packing, + rank as u32, + elems_per_unit, + ); +} + /// Checks if the tensor associated with the given shape and strides is contiguous. pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool { if shape.is_empty() { diff --git a/crates/cubecl-std/src/tensor/handle.rs b/crates/cubecl-std/src/tensor/handle.rs index 33cd8527f..a7d34b307 100644 --- a/crates/cubecl-std/src/tensor/handle.rs +++ b/crates/cubecl-std/src/tensor/handle.rs @@ -66,7 +66,7 @@ where } } - pub fn empty(client: &ComputeClient, shape: Vec) -> Self { + pub fn empty(client: &ComputeClient, shape: Vec) -> Self { let elem_size = E::size().expect("To be a native type"); let Allocation { handle, strides } = client.empty_tensor(&shape, elem_size); @@ -153,7 +153,7 @@ where R: Runtime, E: Numeric, { - pub fn zeros(client: &ComputeClient, shape: Vec) -> Self { + pub fn zeros(client: &ComputeClient, shape: Vec) -> Self { let num_elements: usize = shape.iter().product(); let rank = shape.len(); let output = Self::empty(client, shape); diff --git a/crates/cubecl-std/src/tensor/identity.rs b/crates/cubecl-std/src/tensor/identity.rs index 6c9b375a7..2516d5a46 100644 --- a/crates/cubecl-std/src/tensor/identity.rs +++ b/crates/cubecl-std/src/tensor/identity.rs @@ -31,7 +31,7 @@ fn identity_kernel(output: &mut Tensor>, gap: u32) { /// Ensure output is a [`TensorHandle`] containing a square matrix. /// output will contain the identity matrix. pub fn launch( - client: &ComputeClient, + client: &ComputeClient, output: &TensorHandle, ) { launch_ref::(client, &output.as_ref()); @@ -41,7 +41,7 @@ pub fn launch( /// Ensure output is a [`TensorHandleRef`] containing a square matrix. /// output will contain the identity matrix. pub fn launch_ref( - client: &ComputeClient, + client: &ComputeClient, output: &TensorHandleRef, ) { assert_eq!(2, output.shape.len(), "input should be a matrix"); diff --git a/crates/cubecl-std/src/tensor/layout/as_dyn.rs b/crates/cubecl-std/src/tensor/layout/as_dyn.rs new file mode 100644 index 000000000..65eda91d8 --- /dev/null +++ b/crates/cubecl-std/src/tensor/layout/as_dyn.rs @@ -0,0 +1,98 @@ +use cubecl::prelude::*; +use cubecl_core::{self as cubecl, unexpanded}; +use variadics_please::all_tuples; + +use crate::tensor::layout::*; + +/// Coordinates that can be converted to a dynamic sequence of signed coordinates. +/// Can be used to convert any set of coordinates to a comptime-sized sequence for use with TMA. +#[cube] +pub trait IntoDyn: Coordinates + LaunchArg { + fn into_dyn(self) -> Sequence { + unexpanded!() + } +} + +macro_rules! as_ty { + ($T: ident, $dummy: ident) => { + $T + }; +} + +macro_rules! impl_tuple { + ($ty: ident, $($t: ident),*) => { + impl IntoDyn for ($(as_ty!($ty, $t)),*) {} + + impl IntoDynExpand for ($(ExpandElementTyped),*) { + fn __expand_into_dyn_method(self, scope: &mut Scope) -> SequenceExpand { + let mut seq = Sequence::__expand_new(scope); + let ($($t),*) = self; + let ($($t),*) = ($(i32::__expand_cast_from(scope, $t)),*); + $(seq.__expand_push_method(scope, $t);)* + seq + } + } + }; +} + +macro_rules! impl_tuples { + ($($t: ident),*) => { + impl_tuple!(u32, $($t),*); + impl_tuple!(i32, $($t),*); + }; +} + +all_tuples!(impl_tuples, 2, 12, t); + +#[cube] +impl IntoDyn for Sequence { + fn into_dyn(self) -> Sequence { + self + } +} + +#[cube] +impl IntoDyn for Sequence { + fn into_dyn(self) -> Sequence { + let mut seq = Sequence::new(); + for x in self { + seq.push(i32::cast_from(x)); + } + seq + } +} + +#[derive(CubeType, CubeLaunch)] +pub struct IntoDynLayout + LaunchArg> { + layout: L, +} + +impl + LaunchArg> IntoDynLayout { + pub fn new(layout: L) -> Self { + IntoDynLayout { layout } + } +} + +#[cube] +impl + LaunchArg> Layout for IntoDynLayout { + type Coordinates = L::Coordinates; + type SourceCoordinates = Sequence; + + fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates { + let pos = self.layout.to_source_pos(pos); + pos.into_dyn() + } + + fn is_in_bounds(&self, pos: Self::Coordinates) -> bool { + self.layout.is_in_bounds(pos) + } + + fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) { + let (pos, in_bounds) = self.layout.to_source_pos_checked(pos); + (pos.into_dyn(), in_bounds) + } + + fn shape(&self) -> Self::Coordinates { + self.layout.shape() + } +} diff --git a/crates/cubecl-std/src/tensor/layout/linear.rs b/crates/cubecl-std/src/tensor/layout/linear.rs index 80088ad12..87f4f4962 100644 --- a/crates/cubecl-std/src/tensor/layout/linear.rs +++ b/crates/cubecl-std/src/tensor/layout/linear.rs @@ -3,7 +3,7 @@ use cubecl_core::{self as cubecl, unexpanded}; use crate::tensor::{ View, is_contiguous, is_contiguous_pitched, - launch::ViewLaunch, + launch::ViewArg, layout::{ Coords1d, Layout, LayoutExpand, VirtualLayoutOperationsExpand, permuted::{PermutedLayout, PermutedLayoutLaunch}, @@ -51,7 +51,7 @@ impl LinearLayoutExpand { impl<'a, R: Runtime> LinearLayoutArgs<'a, R> { /// Construct a linear layout from shapes, strides and line size of the tensor pub fn from_shape_strides( - client: &ComputeClient, + client: &ComputeClient, shape: &[usize], strides: &[usize], line_size: u8, @@ -71,7 +71,7 @@ impl<'a, R: Runtime> LinearLayoutArgs<'a, R> { /// Construct a possibly broadcast linear layout from shapes/strides and a reference shape pub fn from_shape_strides_with_reference( - client: &ComputeClient, + client: &ComputeClient, shape: &[usize], reference_shape: &[usize], strides: &[usize], @@ -93,7 +93,7 @@ impl<'a, R: Runtime> LinearLayoutArgs<'a, R> { /// Construct a linear layout from a tensor handle pub fn from_handle( - client: &ComputeClient, + client: &ComputeClient, handle: &TensorHandleRef<'a, R>, line_size: u8, ) -> Self { @@ -102,7 +102,7 @@ impl<'a, R: Runtime> LinearLayoutArgs<'a, R> { /// Construct a possibly broadcast linear layout from a tensor handle and reference handle pub fn from_handle_with_reference( - client: &ComputeClient, + client: &ComputeClient, handle: &TensorHandleRef<'a, R>, reference: &TensorHandleRef<'a, R>, line_size: u8, @@ -143,11 +143,11 @@ impl Layout for LinearLayout { /// Useful for elementwise kernels. pub type LinearView = View; /// Launch type for [LinearTensorView]. -pub type LinearViewLaunch<'a, R> = ViewLaunch<'a, Coords1d, R>; +pub type LinearViewLaunch<'a, R> = ViewArg<'a, Coords1d, R>; /// Create a linear tensor view from a handle and line size pub fn linear_view<'a, R: Runtime>( - client: &ComputeClient, + client: &ComputeClient, handle: &'a TensorHandleRef<'a, R>, line_size: u8, ) -> LinearViewLaunch<'a, R> { @@ -161,7 +161,7 @@ pub fn linear_view<'a, R: Runtime>( /// Create a possibly broadcast linear tensor view from a handle, reference handle and line size pub fn linear_view_with_reference<'a, R: Runtime>( - client: &ComputeClient, + client: &ComputeClient, handle: &'a TensorHandleRef<'a, R>, reference: &'a TensorHandleRef<'a, R>, line_size: u8, @@ -175,7 +175,7 @@ pub fn linear_view_with_reference<'a, R: Runtime>( } pub fn linear_view_alias<'a, R: Runtime>( - client: &ComputeClient, + client: &ComputeClient, handle: &'a TensorHandleRef<'a, R>, line_size: u8, pos: usize, diff --git a/crates/cubecl-std/src/tensor/layout/mod.rs b/crates/cubecl-std/src/tensor/layout/mod.rs index 31da4e033..da74037b7 100644 --- a/crates/cubecl-std/src/tensor/layout/mod.rs +++ b/crates/cubecl-std/src/tensor/layout/mod.rs @@ -6,6 +6,7 @@ pub use base::*; pub use coordinates::*; pub use r#virtual::*; +pub mod as_dyn; pub mod chain; pub mod linear; pub mod permuted; diff --git a/crates/cubecl-std/src/tensor/layout/permuted.rs b/crates/cubecl-std/src/tensor/layout/permuted.rs index 6826f3a1a..e5eb52940 100644 --- a/crates/cubecl-std/src/tensor/layout/permuted.rs +++ b/crates/cubecl-std/src/tensor/layout/permuted.rs @@ -24,25 +24,21 @@ impl<'a, R: Runtime> PermutedLayoutLaunch<'a, R> { /// Create a new permuted layout for a possibly broadcast tensor, with a reference shape to be /// broadcast to. pub fn from_shape_strides( - client: &ComputeClient, + client: &ComputeClient, shape: &[usize], strides: &[usize], line_size: u8, ) -> Self { let len = shape.iter().product::() / line_size as usize; - let shape = SequenceArg { - values: shape - .iter() - .map(|it| FastDivmodArgs::new(client, *it as u32)) - .collect(), - }; - let strides = SequenceArg { - values: strides - .iter() - .map(|it| ScalarArg::new(*it as u32)) - .collect(), - }; + let shape = shape + .iter() + .map(|it| FastDivmodArgs::new(client, *it as u32)) + .collect(); + let strides = strides + .iter() + .map(|it| ScalarArg::new(*it as u32)) + .collect(); Self::new(shape, strides, ScalarArg::new(len as u32), line_size as u32) } @@ -50,7 +46,7 @@ impl<'a, R: Runtime> PermutedLayoutLaunch<'a, R> { /// Create a new permuted layout for a possibly broadcast tensor, with a reference shape to be /// broadcast to. pub fn from_shapes_strides_ref( - client: &ComputeClient, + client: &ComputeClient, shape: &[usize], reference_shape: &[usize], strides: &[usize], @@ -78,7 +74,7 @@ impl<'a, R: Runtime> PermutedLayoutLaunch<'a, R> { } pub fn from_handles_ref( - client: &ComputeClient, + client: &ComputeClient, handle: &TensorHandleRef<'_, R>, reference_handle: &TensorHandleRef<'_, R>, line_size: u8, @@ -93,7 +89,7 @@ impl<'a, R: Runtime> PermutedLayoutLaunch<'a, R> { } pub fn from_handle( - client: &ComputeClient, + client: &ComputeClient, handle: &TensorHandleRef<'_, R>, line_size: u8, ) -> Self { diff --git a/crates/cubecl-std/src/tensor/layout/strided.rs b/crates/cubecl-std/src/tensor/layout/strided.rs index c3ced3288..7ad33bf95 100644 --- a/crates/cubecl-std/src/tensor/layout/strided.rs +++ b/crates/cubecl-std/src/tensor/layout/strided.rs @@ -20,7 +20,7 @@ pub struct StridedLayout { impl<'a, R: Runtime> StridedLayoutLaunch<'a, R> { pub fn from_shape_strides( - client: &ComputeClient, + client: &ComputeClient, shape: &[usize], strides: &[usize], line_size: u8, @@ -36,7 +36,7 @@ impl<'a, R: Runtime> StridedLayoutLaunch<'a, R> { } pub fn from_handle( - client: &ComputeClient, + client: &ComputeClient, handle: &TensorHandleRef<'_, R>, line_size: u8, ) -> Self { diff --git a/crates/cubecl-std/src/tensor/layout/virtual.rs b/crates/cubecl-std/src/tensor/layout/virtual.rs index 21fb317b0..00826ef06 100644 --- a/crates/cubecl-std/src/tensor/layout/virtual.rs +++ b/crates/cubecl-std/src/tensor/layout/virtual.rs @@ -229,7 +229,7 @@ mod launch { ) -> Self { // Hash ahead of time so we don't need to store the actual data, which would be far // more complex - let state = foldhash::fast::RandomState::default(); + let state = foldhash::fast::FixedState::default(); let hash = state.hash_one(arg); Self { type_name: core::any::type_name::().to_string(), diff --git a/crates/cubecl-std/src/tensor/view/base.rs b/crates/cubecl-std/src/tensor/view/base.rs index 7124128e4..b9c826f10 100644 --- a/crates/cubecl-std/src/tensor/view/base.rs +++ b/crates/cubecl-std/src/tensor/view/base.rs @@ -93,11 +93,7 @@ impl View { view: V::ExpandType, layout: VirtualLayoutExpand, ) -> ViewExpand { - let virt = VirtualView::::__expand_new(scope, view, layout); - ViewExpand:: { - inner: ViewType::Read(Arc::new(virt)), - _io: PhantomData, - } + ViewExpand::new(VirtualView::::__expand_new(scope, view, layout)) } } @@ -126,6 +122,20 @@ impl ViewExpand ) -> ViewExpand { View::__expand_new::, C>(scope, self, layout) } + + pub fn new + 'static>(view: V) -> Self { + ViewExpand { + inner: ViewType::Read(Arc::new(view)), + _io: PhantomData, + } + } + + pub fn new_mut + 'static>(view: V) -> Self { + ViewExpand { + inner: ViewType::ReadWrite(Arc::new(view)), + _io: PhantomData, + } + } } impl View { @@ -174,11 +184,9 @@ impl View { view: V::ExpandType, layout: VirtualLayoutExpand, ) -> ViewExpand { - let virt = VirtualViewMut::::__expand_new(scope, view, layout); - ViewExpand:: { - inner: ViewType::ReadWrite(Arc::new(virt)), - _io: PhantomData, - } + ViewExpand::new_mut(VirtualViewMut::::__expand_new( + scope, view, layout, + )) } } @@ -338,6 +346,15 @@ impl View View { + unexpanded!() + } + pub fn __expand_slice( scope: &mut Scope, this: ViewExpand, @@ -346,32 +363,51 @@ impl View ViewExpand { this.__expand_slice_method(scope, pos, size) } -} -#[cube] -impl View { - /// Create a slice starting from `pos`, with `size`. - /// The layout handles translation into concrete indices. - /// #Safety - /// Size is not checked and may exceed bounds! - pub fn slice_unchecked(&self, pos: C, size: C) -> View { - let layout = SliceLayout::new(pos, size, false); - self.view(layout) + pub fn __expand_slice_unchecked( + scope: &mut Scope, + this: ViewExpand, + pos: C::ExpandType, + size: C::ExpandType, + ) -> ViewExpand { + this.__expand_slice_unchecked_method(scope, pos, size) } } +#[cube] +impl View {} + impl ViewExpand { pub fn __expand_slice_method( &self, scope: &mut Scope, pos: C::ExpandType, size: C::ExpandType, + ) -> ViewExpand { + self.slice(scope, pos, size, true) + } + + pub fn __expand_slice_unchecked_method( + &self, + scope: &mut Scope, + pos: C::ExpandType, + size: C::ExpandType, + ) -> ViewExpand { + self.slice(scope, pos, size, false) + } + + fn slice( + &self, + scope: &mut Scope, + pos: C::ExpandType, + size: C::ExpandType, + checked: bool, ) -> ViewExpand { let shape = self.__expand_shape_method(scope); let pos = C::__expand_min(scope, pos, shape.clone()); let max_size = C::__expand_sub(scope, shape, pos.clone()); let size = C::__expand_min(scope, size, max_size); - let layout = SliceLayout::__expand_new(scope, pos, size, true); + let layout = SliceLayout::__expand_new(scope, pos, size, checked); self.clone().__expand_view_method(scope, layout.into()) } } @@ -456,6 +492,16 @@ impl View { unexpanded!() } + /// Create a mutable slice starting from `pos`, with `size`. + /// The layout handles translation into concrete indices. + /// Size and pos will be clamped to the current layout size. + /// + /// # Safety + /// Access is always unchecked. + pub fn slice_mut_unchecked(&self, _pos: C, _size: C) -> View { + unexpanded!() + } + pub fn __expand_slice_mut( scope: &mut Scope, this: ViewExpand, @@ -464,33 +510,51 @@ impl View { ) -> ViewExpand { this.__expand_slice_mut_method(scope, pos, size) } -} -#[cube] -impl View { - /// Create a mutable slice starting from `pos`, with `size`. - /// The layout handles translation into concrete indices. - /// - /// # Safety - /// Size is unchecked and may exceed bounds - pub fn slice_mut_unchecked(&self, pos: C, size: C) -> View { - let layout = SliceLayout::new(pos, size, false); - self.view_mut(layout) + pub fn __expand_slice_mut_unchecked( + scope: &mut Scope, + this: ViewExpand, + pos: C::ExpandType, + size: C::ExpandType, + ) -> ViewExpand { + this.__expand_slice_mut_unchecked_method(scope, pos, size) } } +#[cube] +impl View {} + impl ViewExpand { pub fn __expand_slice_mut_method( &self, scope: &mut Scope, pos: C::ExpandType, size: C::ExpandType, + ) -> ViewExpand { + self.slice_mut(scope, pos, size, true) + } + + pub fn __expand_slice_mut_unchecked_method( + &self, + scope: &mut Scope, + pos: C::ExpandType, + size: C::ExpandType, + ) -> ViewExpand { + self.slice_mut(scope, pos, size, false) + } + + fn slice_mut( + &self, + scope: &mut Scope, + pos: C::ExpandType, + size: C::ExpandType, + checked: bool, ) -> ViewExpand { let shape = self.__expand_shape_method(scope); let pos = C::__expand_min(scope, pos, shape.clone()); let max_size = C::__expand_sub(scope, shape, pos.clone()); let size = C::__expand_min(scope, size, max_size); - let layout = SliceLayout::__expand_new(scope, pos, size, true); + let layout = SliceLayout::__expand_new(scope, pos, size, checked); self.clone().__expand_view_mut_method(scope, layout.into()) } } diff --git a/crates/cubecl-std/src/tensor/view/launch.rs b/crates/cubecl-std/src/tensor/view/launch.rs index 9b1ed1460..9e9aaac2b 100644 --- a/crates/cubecl-std/src/tensor/view/launch.rs +++ b/crates/cubecl-std/src/tensor/view/launch.rs @@ -230,96 +230,272 @@ impl< } mod dynamic { - use crate::tensor::layout::{VirtualLayout, VirtualLayoutCompilationArg, VirtualLayoutLaunch}; + use cubecl_common::quant::scheme::QuantScheme; + + use crate::{ + quant, + tensor::layout::{ + VirtualLayout, VirtualLayoutCompilationArg, VirtualLayoutLaunch, + as_dyn::{IntoDyn, IntoDynLayout, IntoDynLayoutLaunch}, + }, + }; use super::*; - pub struct ViewLaunch<'a, C: Coordinates, R: Runtime> { - _phantom_runtime: PhantomData, - _phantom_a: PhantomData<&'a ()>, - buffer: ArrayArg<'a, R>, - layout: VirtualLayoutLaunch<'a, C, Coords1d, R>, + pub enum ViewArg<'a, C: Coordinates, R: Runtime> { + Array(ArrayArg<'a, R>, VirtualLayoutLaunch<'a, C, Coords1d, R>), + TensorMap( + TensorMapArg<'a, R>, + VirtualLayoutLaunch<'a, C, Sequence, R>, + ), + Quantized { + values: Box>, + scales: Box>, + scheme: QuantScheme, + }, } - impl<'a, C: Coordinates, R: Runtime> ViewLaunch<'a, C, R> { + impl<'a, C: Coordinates, R: Runtime> ViewArg<'a, C, R> { pub fn new + LaunchArg>( buffer: ArrayArg<'a, R>, layout: L::RuntimeArg<'a, R>, ) -> Self { - Self { - _phantom_runtime: core::marker::PhantomData, - _phantom_a: core::marker::PhantomData, - buffer, - layout: VirtualLayoutLaunch::new::(layout), + ViewArg::Array(buffer, VirtualLayoutLaunch::new::(layout)) + } + + pub fn new_tensor_map< + L: Layout + LaunchArg, + >( + buffer: TensorMapArg<'a, R>, + layout: L::RuntimeArg<'a, R>, + ) -> Self { + let layout = IntoDynLayoutLaunch::new(layout); + ViewArg::TensorMap(buffer, VirtualLayoutLaunch::new::>(layout)) + } + + /// Create a new view arg that dequantizes on read. + /// The scales layout should take values indices and map them to the corresponding scale. + pub fn new_quantized(values: Self, scales: Self, scheme: QuantScheme) -> Self { + Self::Quantized { + values: Box::new(values), + scales: Box::new(scales), + scheme, } } } - impl<'a, C: Coordinates, R: Runtime> ArgSettings for ViewLaunch<'a, C, R> { + impl<'a, C: Coordinates, R: Runtime> ArgSettings for ViewArg<'a, C, R> { fn register(&self, launcher: &mut KernelLauncher) { - self.buffer.register(launcher); - self.layout.register(launcher); + match self { + ViewArg::Array(buffer, layout) => { + buffer.register(launcher); + layout.register(launcher); + } + ViewArg::TensorMap(buffer, layout) => { + buffer.register(launcher); + layout.register(launcher); + } + ViewArg::Quantized { values, scales, .. } => { + values.register(launcher); + scales.register(launcher); + } + } } } #[derive(Clone)] - pub struct ViewCompilationArg { - buffer: ArrayCompilationArg, - layout: VirtualLayoutCompilationArg, + pub enum ViewCompilationArg { + Array { + buffer: ArrayCompilationArg, + layout: VirtualLayoutCompilationArg, + }, + TensorMap { + buffer: TensorMapCompilationArg, + layout: VirtualLayoutCompilationArg>, + }, + Quantized { + values: Box>, + scales: Box>, + scheme: QuantScheme, + }, } impl CompilationArg for ViewCompilationArg {} impl Eq for ViewCompilationArg {} impl PartialEq for ViewCompilationArg { fn eq(&self, other: &Self) -> bool { - self.buffer == other.buffer && self.layout == other.layout + match (self, other) { + ( + ViewCompilationArg::Array { buffer, layout }, + ViewCompilationArg::Array { + buffer: buffer_other, + layout: layout_other, + }, + ) => buffer == buffer_other && layout == layout_other, + ( + ViewCompilationArg::TensorMap { buffer, layout }, + ViewCompilationArg::TensorMap { + buffer: buffer_other, + layout: layout_other, + }, + ) => buffer == buffer_other && layout == layout_other, + ( + ViewCompilationArg::Quantized { + values, + scales, + scheme, + }, + ViewCompilationArg::Quantized { + values: values_other, + scales: scales_other, + scheme: scheme_other, + }, + ) => values == values_other && scales == scales_other && scheme == scheme_other, + _ => false, + } } } impl core::hash::Hash for ViewCompilationArg { fn hash(&self, ra_expand_state: &mut H) { - self.buffer.hash(ra_expand_state); - self.layout.hash(ra_expand_state); + match self { + ViewCompilationArg::Array { buffer, layout } => { + buffer.hash(ra_expand_state); + layout.hash(ra_expand_state); + } + ViewCompilationArg::TensorMap { buffer, layout } => { + buffer.hash(ra_expand_state); + layout.hash(ra_expand_state); + } + ViewCompilationArg::Quantized { + values, + scales, + scheme, + } => { + values.hash(ra_expand_state); + scales.hash(ra_expand_state); + scheme.hash(ra_expand_state); + } + } } } impl core::fmt::Debug for ViewCompilationArg { fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { - f.debug_struct("ViewCompilationArg") - .field("buffer", &self.buffer) - .field("layout", &self.layout) - .finish() + match self { + ViewCompilationArg::Array { buffer, layout } => f + .debug_struct("ArrayView") + .field("buffer", &buffer) + .field("layout", &layout) + .finish(), + ViewCompilationArg::TensorMap { buffer, layout } => f + .debug_struct("TensorMapView") + .field("buffer", &buffer) + .field("layout", &layout) + .finish(), + ViewCompilationArg::Quantized { + values, + scales, + scheme, + } => f + .debug_struct("QuantizedView") + .field("values", &values) + .field("scales", &scales) + .field("scheme", &scheme) + .finish(), + } } } impl LaunchArg for View { - type RuntimeArg<'a, R: Runtime> = ViewLaunch<'a, C, R>; + type RuntimeArg<'a, R: Runtime> = ViewArg<'a, C, R>; type CompilationArg = ViewCompilationArg; fn compilation_arg<'a, R: Runtime>( runtime_arg: &Self::RuntimeArg<'a, R>, ) -> Self::CompilationArg { - let buffer = Array::::compilation_arg(&runtime_arg.buffer); - let layout = VirtualLayout::::compilation_arg(&runtime_arg.layout); - ViewCompilationArg { buffer, layout } + match runtime_arg { + ViewArg::Array(buffer, layout) => { + let buffer = Array::::compilation_arg(buffer); + let layout = VirtualLayout::::compilation_arg(layout); + ViewCompilationArg::Array { buffer, layout } + } + ViewArg::TensorMap(buffer, layout) => { + let buffer = TensorMap::::compilation_arg(buffer); + let layout = VirtualLayout::>::compilation_arg(layout); + ViewCompilationArg::TensorMap { buffer, layout } + } + ViewArg::Quantized { + values, + scales, + scheme, + } => { + // Type isn't real, but doesn't matter for compilation arg + let values = View::::compilation_arg(values); + let scales = View::::compilation_arg(scales); + ViewCompilationArg::Quantized { + values: Box::new(values), + scales: Box::new(scales), + scheme: *scheme, + } + } + } } fn expand( arg: &Self::CompilationArg, builder: &mut KernelBuilder, ) -> ::ExpandType { - let buffer = Array::::expand(&arg.buffer, builder); - let layout = VirtualLayout::::expand(&arg.layout, builder); - let view = VirtualViewMutExpand::>::new(buffer, layout); - ViewExpand:: { - inner: ViewType::ReadWrite(Arc::new(view)), - _io: PhantomData, + match arg { + ViewCompilationArg::Array { buffer, layout } => { + let buffer = Array::::expand(buffer, builder); + let layout = VirtualLayout::::expand(layout, builder); + let view = + VirtualViewMutExpand::>::new(buffer, layout); + ViewExpand:: { + inner: ViewType::ReadWrite(Arc::new(view)), + _io: PhantomData, + } + } + ViewCompilationArg::TensorMap { buffer, layout } => { + let buffer = TensorMap::::expand(buffer, builder); + let layout = VirtualLayout::>::expand(layout, builder); + let view = VirtualViewMutExpand::, TensorMap>::new( + buffer, layout, + ); + ViewExpand:: { + inner: ViewType::ReadWrite(Arc::new(view)), + _io: PhantomData, + } + } + ViewCompilationArg::Quantized { + values, + scales, + scheme, + } => quant::view::expand_dynamic(values, scales, *scheme, builder), } } fn expand_output( arg: &Self::CompilationArg, builder: &mut KernelBuilder, ) -> ::ExpandType { - let buffer = Array::::expand_output(&arg.buffer, builder); - let layout = VirtualLayout::::expand_output(&arg.layout, builder); - let view = VirtualViewMutExpand::>::new(buffer, layout); - ViewExpand:: { - inner: ViewType::ReadWrite(Arc::new(view)), - _io: PhantomData, + match arg { + ViewCompilationArg::Array { buffer, layout } => { + let buffer = Array::::expand_output(buffer, builder); + let layout = VirtualLayout::::expand_output(layout, builder); + let view = + VirtualViewMutExpand::>::new(buffer, layout); + ViewExpand:: { + inner: ViewType::ReadWrite(Arc::new(view)), + _io: PhantomData, + } + } + ViewCompilationArg::TensorMap { buffer, layout } => { + let buffer = TensorMap::::expand_output(buffer, builder); + let layout = VirtualLayout::>::expand_output(layout, builder); + let view = VirtualViewMutExpand::, TensorMap>::new( + buffer, layout, + ); + ViewExpand:: { + inner: ViewType::ReadWrite(Arc::new(view)), + _io: PhantomData, + } + } + ViewCompilationArg::Quantized { .. } => panic!("Quantized views must be readonly"), } } } diff --git a/crates/cubecl-std/src/tensor/view/operations/virtual_tensor.rs b/crates/cubecl-std/src/tensor/view/operations/virtual_tensor.rs index b375f828f..2a126f3b0 100644 --- a/crates/cubecl-std/src/tensor/view/operations/virtual_tensor.rs +++ b/crates/cubecl-std/src/tensor/view/operations/virtual_tensor.rs @@ -24,11 +24,8 @@ impl ViewOperationsExpand, Coords1d> for VirtualT scope: &mut Scope, pos: ExpandElementTyped, ) -> as CubeType>::ExpandType { - let len = self.clone().__expand_buffer_len_method(scope); - let in_bounds = lt::expand(scope, pos.clone(), len); - let slice = self.clone().__expand_to_slice_method(scope); let zero = Line::__expand_cast_from(scope, 0.into()); - read_masked::expand::>(scope, in_bounds, slice, pos, zero) + self.__expand_read_masked_method(scope, pos, zero) } fn __expand_read_masked_method( @@ -37,8 +34,7 @@ impl ViewOperationsExpand, Coords1d> for VirtualT pos: ExpandElementTyped, mask_value: as CubeType>::ExpandType, ) -> as CubeType>::ExpandType { - let len = self.__expand_len_method(scope); - let in_bounds = lt::expand(scope, pos.clone(), len); + let in_bounds = self.__expand_is_in_bounds_method(scope, pos.clone()); let slice = self.clone().__expand_to_slice_method(scope); read_masked::expand::>(scope, in_bounds, slice, pos, mask_value) } diff --git a/crates/cubecl-std/src/tests/mod.rs b/crates/cubecl-std/src/tests/mod.rs index bcd799585..e31c4e0bf 100644 --- a/crates/cubecl-std/src/tests/mod.rs +++ b/crates/cubecl-std/src/tests/mod.rs @@ -1,6 +1,7 @@ pub mod reinterpret_slice; pub mod tensor; pub mod trigonometry; +pub mod view; #[macro_export] macro_rules! testgen { diff --git a/crates/cubecl-std/src/tests/reinterpret_slice.rs b/crates/cubecl-std/src/tests/reinterpret_slice.rs index 90a207992..ce244a677 100644 --- a/crates/cubecl-std/src/tests/reinterpret_slice.rs +++ b/crates/cubecl-std/src/tests/reinterpret_slice.rs @@ -11,10 +11,7 @@ fn kernel_read_global(input: &Array>, output: &mut Array) { output[UNIT_POS] = list.read(UNIT_POS); } -pub fn run_test_read_global( - client: ComputeClient, - line_size: usize, -) { +pub fn run_test_read_global(client: ComputeClient, line_size: usize) { if !client.properties().features.dynamic_line_size { return; // can't run test } @@ -47,10 +44,7 @@ fn kernel_write_global(output: &mut Array>, input: &Array) { list.write(UNIT_POS, input[UNIT_POS]); } -pub fn run_test_write_global( - client: ComputeClient, - line_size: usize, -) { +pub fn run_test_write_global(client: ComputeClient, line_size: usize) { if !client.properties().features.dynamic_line_size { return; // can't run test } @@ -92,7 +86,7 @@ fn kernel_read_shared_memory(output: &mut Array) { output[UNIT_POS] = list.read(UNIT_POS); } -pub fn run_test_read_shared_memory(client: ComputeClient) { +pub fn run_test_read_shared_memory(client: ComputeClient) { if !client.properties().features.dynamic_line_size { return; // can't run test } @@ -125,7 +119,7 @@ fn kernel_write_shared_memory(output: &mut Array>, input: &Array) output[2 * UNIT_POS + 1] = mem[2 * UNIT_POS + 1]; } -pub fn run_test_write_shared_memory(client: ComputeClient) { +pub fn run_test_write_shared_memory(client: ComputeClient) { if !client.properties().features.dynamic_line_size { return; // can't run test } diff --git a/crates/cubecl-std/src/tests/trigonometry.rs b/crates/cubecl-std/src/tests/trigonometry.rs index 97c0e3af1..9090e461b 100644 --- a/crates/cubecl-std/src/tests/trigonometry.rs +++ b/crates/cubecl-std/src/tests/trigonometry.rs @@ -11,7 +11,7 @@ fn kernel_to_degrees(input: &Array, output: &mut Array) { } } -pub fn test_to_degrees(client: ComputeClient) { +pub fn test_to_degrees(client: ComputeClient) { let input_data = vec![0.0, PI / 6.0, PI / 4.0, PI / 2.0, PI, TAU]; let expected = vec![0.0, 30.0, 45.0, 90.0, 180.0, 360.0]; @@ -49,7 +49,7 @@ fn kernel_to_radians(input: &Array, output: &mut Array) { } } -pub fn test_to_radians(client: ComputeClient) { +pub fn test_to_radians(client: ComputeClient) { let input_data = vec![0.0, 30.0, 45.0, 90.0, 180.0, 360.0]; let expected = vec![0.0, PI / 6.0, PI / 4.0, PI / 2.0, PI, TAU]; @@ -87,7 +87,7 @@ fn kernel_hypot(x: &Array, y: &Array, output: &mut Array) { } } -pub fn test_hypot(client: ComputeClient) { +pub fn test_hypot(client: ComputeClient) { let x_data = vec![3.0, 0.0, 1.0, 5.0, 0.0]; let y_data = vec![4.0, 1.0, 1.0, 12.0, 0.0]; let expected = vec![5.0, 1.0, 1.4142135623730951, 13.0, 0.0]; diff --git a/crates/cubecl-std/src/tests/view/mod.rs b/crates/cubecl-std/src/tests/view/mod.rs new file mode 100644 index 000000000..72eebb62d --- /dev/null +++ b/crates/cubecl-std/src/tests/view/mod.rs @@ -0,0 +1 @@ +pub mod quantized; diff --git a/crates/cubecl-std/src/tests/view/quantized.rs b/crates/cubecl-std/src/tests/view/quantized.rs new file mode 100644 index 000000000..2371783c2 --- /dev/null +++ b/crates/cubecl-std/src/tests/view/quantized.rs @@ -0,0 +1,214 @@ +use cubecl::prelude::*; +use cubecl_common::{ + e2m1, e2m1x2, + quant::scheme::{QuantScheme, QuantValue}, +}; +use cubecl_core::{self as cubecl}; + +use crate::tensor::{ + View, + launch::ViewArg, + layout::{ + plain::{PlainLayout, PlainLayoutLaunch}, + *, + }, +}; + +#[derive(CubeType, CubeLaunch)] +struct TestPerTensorScaleLayout { + length: u32, +} + +#[cube] +impl Layout for TestPerTensorScaleLayout { + type Coordinates = Coords1d; + type SourceCoordinates = Coords1d; + + fn to_source_pos(&self, _pos: Self::Coordinates) -> Self::SourceCoordinates { + 0u32.runtime() + } + + fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) { + (self.to_source_pos(pos), true.runtime()) + } + + fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool { + true.runtime() + } + + fn shape(&self) -> Self::Coordinates { + self.length + } +} + +#[cube(launch_unchecked)] +pub fn kernel_quantized_view(lhs: View, Coords1d>, output: &mut Array>) { + if UNIT_POS < lhs.shape() { + output[UNIT_POS] = lhs[UNIT_POS]; + } +} + +#[allow(clippy::needless_range_loop)] +pub fn test_quantized_per_tensor_int( + client: ComputeClient, + line_size_values: u8, +) { + let line_size_float = 8 * line_size_values; + let values_lines = 2 / line_size_values as u32; + + let scheme = QuantScheme::default().with_value(QuantValue::Q4F); + let float_data = (-8..=7) + .map(|it| F::new(it as f32 * 3.4)) + .collect::>(); + + let output = client.empty(16 * size_of::()); + let values = client.create(u32::as_bytes(&[0xFEDCBA98, 0x76543210])); + let scales = client.create(f32::as_bytes(&[3.4])); + + let float_values = client.create(F::as_bytes(&float_data)); + let float_output = client.empty(16 * size_of::()); + + let values_layout = PlainLayoutLaunch::::new(ScalarArg::new(values_lines)); + let scales_layout = TestPerTensorScaleLayoutLaunch::::new(ScalarArg::new(16)); + let float_layout = PlainLayoutLaunch::::new(ScalarArg::new(values_lines)); + + let values_view = ViewArg::new::( + unsafe { ArrayArg::from_raw_parts::(&values, 2, line_size_values) }, + values_layout, + ); + let scales_view = ViewArg::new::( + unsafe { ArrayArg::from_raw_parts::(&scales, 1, 1) }, + scales_layout, + ); + let quantized_view = ViewArg::new_quantized(values_view, scales_view, scheme); + let float_view = ViewArg::new::( + unsafe { ArrayArg::from_raw_parts::(&float_values, 16, line_size_float) }, + float_layout, + ); + + unsafe { + kernel_quantized_view::launch_unchecked::( + &client, + CubeCount::new_single(), + CubeDim::new_1d(2), + quantized_view, + ArrayArg::from_raw_parts::(&output, 16, line_size_float), + ); + kernel_quantized_view::launch_unchecked::( + &client, + CubeCount::new_single(), + CubeDim::new_1d(2), + float_view, + ArrayArg::from_raw_parts::(&float_output, 16, line_size_float), + ); + } + + let actual = client.read_one(output); + let actual_float = client.read_one(float_output); + let actual = F::from_bytes(&actual); + let actual_float = F::from_bytes(&actual_float); + + assert_eq!(&actual, &float_data); + assert_eq!(&actual_float, &float_data); +} + +#[allow(clippy::needless_range_loop)] +pub fn test_quantized_per_tensor_fp4( + client: ComputeClient, + line_size_values: u8, +) { + if !client.properties().supports_type(e2m1x2::cube_type()) { + return; + } + + let line_size_float = 8 * line_size_values; + let values_lines = 2 / line_size_values as u32; + + let scheme = QuantScheme::default().with_value(QuantValue::E2M1); + let float_data = (0..16) + .map(e2m1::from_bits) + .map(|it| F::new(it.to_f32() * 3.4)) + .collect::>(); + + let output = client.empty(16 * size_of::()); + let values = client.create(u32::as_bytes(&[0x76543210, 0xFEDCBA98])); + let scales = client.create(f32::as_bytes(&[3.4])); + + let float_values = client.create(F::as_bytes(&float_data)); + let float_output = client.empty(16 * size_of::()); + + let values_layout = PlainLayoutLaunch::::new(ScalarArg::new(values_lines)); + let scales_layout = TestPerTensorScaleLayoutLaunch::::new(ScalarArg::new(16)); + let float_layout = PlainLayoutLaunch::::new(ScalarArg::new(values_lines)); + + let values_view = ViewArg::new::( + unsafe { ArrayArg::from_raw_parts::(&values, 2, line_size_values) }, + values_layout, + ); + let scales_view = ViewArg::new::( + unsafe { ArrayArg::from_raw_parts::(&scales, 1, 1) }, + scales_layout, + ); + let quantized_view = ViewArg::new_quantized(values_view, scales_view, scheme); + let float_view = ViewArg::new::( + unsafe { ArrayArg::from_raw_parts::(&float_values, 16, line_size_float) }, + float_layout, + ); + + unsafe { + kernel_quantized_view::launch_unchecked::( + &client, + CubeCount::new_single(), + CubeDim::new_1d(2), + quantized_view, + ArrayArg::from_raw_parts::(&output, 16, line_size_float), + ); + kernel_quantized_view::launch_unchecked::( + &client, + CubeCount::new_single(), + CubeDim::new_1d(2), + float_view, + ArrayArg::from_raw_parts::(&float_output, 16, line_size_float), + ); + } + + let actual = client.read_one(output); + let actual_float = client.read_one(float_output); + let actual = F::from_bytes(&actual); + let actual_float = F::from_bytes(&actual_float); + + assert_eq!(&actual, &float_data); + assert_eq!(&actual_float, &float_data); +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_quantized_view { + ($ty: ty) => { + use super::*; + + #[test] + fn test_quantized_view_per_tensor_int() { + let client = TestRuntime::client(&Default::default()); + cubecl_std::tests::view::quantized::test_quantized_per_tensor_int::( + client.clone(), + 1, + ); + cubecl_std::tests::view::quantized::test_quantized_per_tensor_int::( + client, 2, + ); + } + + #[test] + fn test_quantized_view_per_tensor_fp4() { + let client = TestRuntime::client(&Default::default()); + cubecl_std::tests::view::quantized::test_quantized_per_tensor_fp4::( + client.clone(), + 1, + ); + cubecl_std::tests::view::quantized::test_quantized_per_tensor_fp4::( + client, 2, + ); + } + }; +} diff --git a/crates/cubecl-wgpu/Cargo.toml b/crates/cubecl-wgpu/Cargo.toml index 044afea82..0c86da648 100644 --- a/crates/cubecl-wgpu/Cargo.toml +++ b/crates/cubecl-wgpu/Cargo.toml @@ -86,9 +86,9 @@ matmul_tests_unit = ["cubecl-matmul/matmul_tests_unit"] matmul_tests_vecmat = ["cubecl-matmul/matmul_tests_vecmat"] [dependencies] -cubecl-common = { path = "../cubecl-common", version = "0.7.0", default-features = false } -cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false, features = [ +cubecl-common = { path = "../cubecl-common", version = "0.9.0", default-features = false } +cubecl-core = { path = "../cubecl-core", version = "0.9.0", default-features = false } +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false, features = [ "channel-mutex", ] } derive_more = { workspace = true } @@ -98,11 +98,11 @@ tracy-client = { workspace = true, optional = true } # SPIR-V ash = { workspace = true, optional = true } -cubecl-spirv = { path = "../cubecl-spirv", version = "0.7.0", optional = true } +cubecl-spirv = { path = "../cubecl-spirv", version = "0.9.0", optional = true } tracel-ash = { workspace = true, optional = true } # Metal -cubecl-cpp = { path = "../cubecl-cpp", version = "0.7.0", features = [ +cubecl-cpp = { path = "../cubecl-cpp", version = "0.9.0", features = [ "metal", ], optional = true } @@ -135,29 +135,29 @@ wgpu = { version = "26.0.0", features = [ # wgpu = { path = "../../../../tracel/wgpu/wgpu", features = ["vulkan-portability", "fragile-send-sync-non-atomic-wasm"]} [dev-dependencies] -cubecl-attention = { path = "../cubecl-attention", version = "0.7.0", features = [ +cubecl-attention = { path = "../cubecl-attention", version = "0.9.0", features = [ "export_tests", ] } -cubecl-convolution = { path = "../cubecl-convolution", version = "0.7.0", features = [ +cubecl-convolution = { path = "../cubecl-convolution", version = "0.9.0", features = [ "export_tests", ] } -cubecl-core = { path = "../cubecl-core", version = "0.7.0", features = [ +cubecl-core = { path = "../cubecl-core", version = "0.9.0", features = [ "export_tests", ] } -cubecl-matmul = { path = "../cubecl-matmul", version = "0.7.0", features = [ +cubecl-matmul = { path = "../cubecl-matmul", version = "0.9.0", features = [ "export_tests", ] } -cubecl-quant = { path = "../cubecl-quant", version = "0.7.0", features = [ +cubecl-quant = { path = "../cubecl-quant", version = "0.9.0", features = [ "export_tests", "kernels", ] } -cubecl-random = { path = "../cubecl-random", version = "0.7.0", features = [ +cubecl-random = { path = "../cubecl-random", version = "0.9.0", features = [ "export_tests", ] } -cubecl-reduce = { path = "../cubecl-reduce", version = "0.7.0", features = [ +cubecl-reduce = { path = "../cubecl-reduce", version = "0.9.0", features = [ "export_tests", ] } -cubecl-std = { path = "../cubecl-std", version = "0.7.0", features = [ +cubecl-std = { path = "../cubecl-std", version = "0.9.0", features = [ "export_tests", ] } half = { workspace = true } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 492fc5e85..acb7ac8b8 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -469,7 +469,7 @@ impl WgslCompiler { panic!("Barrier isn't supported on wgpu.") } cube::Operation::Tma(_) => panic!("TMA isn't supported on wgpu."), - cube::Operation::Free(_) => {} + cube::Operation::Marker(_) => {} } } @@ -498,15 +498,18 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(out), }, + cube::Plane::Broadcast(op) => Subgroup::Broadcast { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), out: self.compile_variable(out), }, + cube::Plane::Sum(op) => Subgroup::Sum { input: self.compile_variable(op.input), out: self.compile_variable(out), }, + cube::Plane::ExclusiveSum(op) => Subgroup::ExclusiveSum { input: self.compile_variable(op.input), out: self.compile_variable(out), @@ -535,6 +538,26 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(out), }, + cube::Plane::Shuffle(op) => Subgroup::Shuffle { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(out), + }, + cube::Plane::ShuffleXor(op) => Subgroup::ShuffleXor { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(out), + }, + cube::Plane::ShuffleUp(op) => Subgroup::ShuffleUp { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(out), + }, + cube::Plane::ShuffleDown(op) => Subgroup::ShuffleDown { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(out), + }, }; instructions.push(wgsl::Instruction::Subgroup(op)); @@ -824,10 +847,12 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(out), }), - cube::Arithmetic::Rsqrt(op) => instructions.push(wgsl::Instruction::Rsqrt { - input: self.compile_variable(op.input), - out: self.compile_variable(out), - }), + cube::Arithmetic::InverseSqrt(op) => { + instructions.push(wgsl::Instruction::InverseSqrt { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }) + } cube::Arithmetic::Round(op) => instructions.push(wgsl::Instruction::Round { input: self.compile_variable(op.input), out: self.compile_variable(out), @@ -840,6 +865,10 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(out), }), + cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), cube::Arithmetic::Erf(op) => { let mut scope = scope.child(); expand_erf(&mut scope, op.input, out); diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/extension.rs b/crates/cubecl-wgpu/src/compiler/wgsl/extension.rs index c46f980ca..4ccfa92a3 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/extension.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/extension.rs @@ -121,12 +121,16 @@ pub fn call_powf( rhs: &Variable, out: &Variable, ) -> core::fmt::Result { - let (rhs, base_name) = if should_use_scalar_powf(rhs) { + let (lhs, rhs, base_name) = if should_use_scalar_powf(rhs) { let rhs = rhs.fmt_cast_to(Item::Scalar(lhs.elem())); - (rhs, POWF_SCALAR) + let lhs = lhs.to_string(); + (lhs, rhs, POWF_SCALAR) } else { - let rhs = rhs.fmt_cast_to(lhs.item()); - (rhs, POWF) + // When vecotized, we make sure the function inputs shared the same vectorization factor as + // the output. + let rhs = rhs.fmt_cast_to(out.item()); + let lhs = lhs.fmt_cast_to(out.item()); + (lhs, rhs, POWF) }; let function_name = construct_vectorized_name(base_name, out.item()); diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index 350e74702..b5fe4a393 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -179,7 +179,7 @@ pub enum Instruction { input: Variable, out: Variable, }, - Rsqrt { + InverseSqrt { input: Variable, out: Variable, }, @@ -318,6 +318,10 @@ pub enum Instruction { input: Variable, out: Variable, }, + Trunc { + input: Variable, + out: Variable, + }, Remainder { lhs: Variable, rhs: Variable, @@ -652,9 +656,9 @@ impl Display for Instruction { let out = out.fmt_left(); writeln!(f, "{out} = sqrt({input});") } - Instruction::Rsqrt { input, out } => { + Instruction::InverseSqrt { input, out } => { let out = out.fmt_left(); - writeln!(f, "{out} = rsqrt({input});") + writeln!(f, "{out} = inverseSqrt({input});") } Instruction::Log1p { input, out } => { let out = out.fmt_left(); @@ -959,6 +963,10 @@ for (var {i}: {i_ty} = {start}; {i} {cmp} {end}; {increment}) {{ let out = out.fmt_left(); writeln!(f, "{out} = ceil({input});") } + Instruction::Trunc { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = trunc({input});") + } Instruction::Subgroup(op) => write!(f, "{op}"), Instruction::Bitcast { input, out } => { let elem = out.item(); diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/subgroup.rs b/crates/cubecl-wgpu/src/compiler/wgsl/subgroup.rs index d934b9d14..29dcd27db 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/subgroup.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/subgroup.rs @@ -56,6 +56,26 @@ pub enum Subgroup { input: Variable, out: Variable, }, + Shuffle { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + ShuffleXor { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + ShuffleUp { + lhs: Variable, + rhs: Variable, + out: Variable, + }, + ShuffleDown { + lhs: Variable, + rhs: Variable, + out: Variable, + }, } impl Display for Subgroup { @@ -159,6 +179,22 @@ impl Display for Subgroup { let out = out.fmt_left(); writeln!(f, "{out} = subgroupMax({input});") } + Subgroup::Shuffle { lhs, rhs, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = subgroupShuffle({lhs}, {rhs});") + } + Subgroup::ShuffleXor { lhs, rhs, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = subgroupShuffleXor({lhs}, {rhs});") + } + Subgroup::ShuffleUp { lhs, rhs, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = subgroupShuffleUp({lhs}, {rhs});") + } + Subgroup::ShuffleDown { lhs, rhs, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = subgroupShuffleDown({lhs}, {rhs});") + } } } } diff --git a/crates/cubecl-wgpu/src/compute/mem_manager.rs b/crates/cubecl-wgpu/src/compute/mem_manager.rs index d6e246ca1..06ff2140f 100644 --- a/crates/cubecl-wgpu/src/compute/mem_manager.rs +++ b/crates/cubecl-wgpu/src/compute/mem_manager.rs @@ -1,12 +1,14 @@ use crate::{WgpuResource, WgpuStorage}; -use cubecl_common::stream_id::StreamId; +use cubecl_common::{stream_id::StreamId, stub::Arc}; use cubecl_core::{ MemoryConfiguration, server::{Binding, Handle, IoError}, }; use cubecl_runtime::{ + logging::ServerLogger, memory_management::{ - MemoryDeviceProperties, MemoryHandle, MemoryManagement, SliceBinding, SliceHandle, + MemoryAllocationMode, MemoryDeviceProperties, MemoryHandle, MemoryManagement, + MemoryManagementOptions, SliceBinding, SliceHandle, }, storage::ComputeStorage, }; @@ -25,6 +27,7 @@ impl WgpuMemManager { device: wgpu::Device, memory_properties: MemoryDeviceProperties, memory_config: MemoryConfiguration, + logger: Arc, ) -> Self { // Allocate storage & memory management for the main memory buffers. Any calls // to empty() or create() with a small enough size will be allocated from this @@ -40,6 +43,8 @@ impl WgpuMemManager { ), &memory_properties, memory_config, + logger.clone(), + MemoryManagementOptions::new("Main GPU Memory"), ); let memory_staging = MemoryManagement::from_configuration( @@ -52,6 +57,8 @@ impl WgpuMemManager { // Unfortunately, we can't reuse a different part of a buffer for different reads, so we // can't have a single binding with multiple slices allocated. MemoryConfiguration::ExclusivePages, + logger.clone(), + MemoryManagementOptions::new("Staging CPU Memory").mode(MemoryAllocationMode::Auto), ); // TODO: In the future this should not need STORAGE, if cube writes out all @@ -64,6 +71,8 @@ impl WgpuMemManager { ), &memory_properties, MemoryConfiguration::ExclusivePages, + logger, + MemoryManagementOptions::new("Uniform GPU Memory").mode(MemoryAllocationMode::Auto), ); Self { @@ -137,7 +146,7 @@ impl WgpuMemManager { self.memory_pool.cleanup(explicit); } - pub(crate) fn mode(&mut self, mode: cubecl_runtime::memory_management::MemoryAllocationMode) { + pub(crate) fn mode(&mut self, mode: MemoryAllocationMode) { self.memory_pool.mode(mode); } diff --git a/crates/cubecl-wgpu/src/compute/schedule.rs b/crates/cubecl-wgpu/src/compute/schedule.rs index 08c07ecde..50498ef7a 100644 --- a/crates/cubecl-wgpu/src/compute/schedule.rs +++ b/crates/cubecl-wgpu/src/compute/schedule.rs @@ -7,6 +7,7 @@ use cubecl_core::{ server::{MetadataBinding, ScalarBinding}, }; use cubecl_runtime::{ + logging::ServerLogger, memory_management::MemoryDeviceProperties, stream::{StreamFactory, scheduler::SchedulerStreamBackend}, }; @@ -60,6 +61,7 @@ pub struct WgpuStreamFactory { memory_config: MemoryConfiguration, timing_method: TimingMethod, tasks_max: usize, + logger: Arc, } impl StreamFactory for WgpuStreamFactory { @@ -73,6 +75,7 @@ impl StreamFactory for WgpuStreamFactory { self.memory_config.clone(), self.timing_method, self.tasks_max, + self.logger.clone(), ) } } @@ -86,6 +89,7 @@ impl ScheduledWgpuBackend { memory_config: MemoryConfiguration, timing_method: TimingMethod, tasks_max: usize, + logger: Arc, ) -> Self { Self { factory: WgpuStreamFactory { @@ -95,6 +99,7 @@ impl ScheduledWgpuBackend { memory_config, timing_method, tasks_max, + logger, }, } } @@ -130,6 +135,10 @@ impl SchedulerStreamBackend for ScheduledWgpuBackend { stream.enqueue_task(task); } + fn flush(stream: &mut Self::Stream) { + stream.flush(); + } + fn factory(&mut self) -> &mut Self::Factory { &mut self.factory } diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index 1ec880bb2..115f61c62 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -6,7 +6,7 @@ use cubecl_common::bytes::Bytes; use cubecl_common::profile::{ProfileDuration, TimingMethod}; use cubecl_common::stream_id::StreamId; use cubecl_core::future::DynFut; -use cubecl_core::server::{ProfileError, ProfilingToken, ServerCommunication}; +use cubecl_core::server::{ProfileError, ProfilingToken, ServerCommunication, ServerUtilities}; use cubecl_core::{ MemoryConfiguration, WgpuCompilationOptions, prelude::*, @@ -36,6 +36,7 @@ pub struct WgpuServer { scheduler: SchedulerMultiStream, pub compilation_options: WgpuCompilationOptions, pub(crate) backend: wgpu::Backend, + pub(crate) utilities: Arc>, } impl ServerCommunication for WgpuServer { @@ -54,7 +55,7 @@ impl WgpuServer { tasks_max: usize, backend: wgpu::Backend, timing_method: TimingMethod, - logger: Arc, + utilities: ServerUtilities, ) -> Self { let backend_scheduler = ScheduledWgpuBackend::new( device.clone(), @@ -63,6 +64,7 @@ impl WgpuServer { memory_config, timing_method, tasks_max, + utilities.logger.clone(), ); let config = GlobalConfig::get(); @@ -73,7 +75,7 @@ impl WgpuServer { device, pipelines: HashMap::new(), scheduler: SchedulerMultiStream::new( - logger, + utilities.logger.clone(), backend_scheduler, SchedulerMultiStreamOptions { max_streams, @@ -82,6 +84,7 @@ impl WgpuServer { }, ), backend, + utilities: Arc::new(utilities), } } @@ -164,6 +167,10 @@ impl ComputeServer for WgpuServer { self.scheduler.logger.clone() } + fn utilities(&self) -> Arc> { + self.utilities.clone() + } + fn create( &mut self, descriptors: Vec>, diff --git a/crates/cubecl-wgpu/src/compute/stream.rs b/crates/cubecl-wgpu/src/compute/stream.rs index bde0edd7f..e0e0511e5 100644 --- a/crates/cubecl-wgpu/src/compute/stream.rs +++ b/crates/cubecl-wgpu/src/compute/stream.rs @@ -11,7 +11,8 @@ use cubecl_core::{ server::{Handle, IoError, ProfileError, ProfilingToken}, }; use cubecl_runtime::{ - memory_management::MemoryDeviceProperties, timestamp_profiler::TimestampProfiler, + logging::ServerLogger, memory_management::MemoryDeviceProperties, + timestamp_profiler::TimestampProfiler, }; use std::{future::Future, num::NonZero, pin::Pin, sync::Arc}; use wgpu::ComputePipeline; @@ -45,6 +46,7 @@ impl WgpuStream { memory_config: MemoryConfiguration, timing_method: TimingMethod, tasks_max: usize, + logger: Arc, ) -> Self { let timings = if timing_method == TimingMethod::Device { Timings::Device(QueryProfiler::new(&queue, &device)) @@ -62,7 +64,8 @@ impl WgpuStream { let poll = WgpuPoll::new(device.clone()); #[allow(unused_mut)] - let mut mem_manage = WgpuMemManager::new(device.clone(), memory_properties, memory_config); + let mut mem_manage = + WgpuMemManager::new(device.clone(), memory_properties, memory_config, logger); Self { mem_manage, diff --git a/crates/cubecl-wgpu/src/lib.rs b/crates/cubecl-wgpu/src/lib.rs index 608888530..2cf966bdc 100644 --- a/crates/cubecl-wgpu/src/lib.rs +++ b/crates/cubecl-wgpu/src/lib.rs @@ -33,6 +33,7 @@ mod tests { cubecl_core::testgen_all!(); cubecl_std::testgen!(); cubecl_std::testgen_tensor_identity!([flex32, f32, u32]); + cubecl_std::testgen_quantized_view!(f32); cubecl_matmul::testgen_matmul_simple!([flex32, f32]); cubecl_matmul::testgen_matmul_plane_vecmat!(); cubecl_matmul::testgen_matmul_unit!(); @@ -53,6 +54,7 @@ mod tests_spirv { cubecl_core::testgen_all!(f32: [f16, flex32, f32], i32: [i8, i16, i32, i64], u32: [u8, u16, u32, u64]); cubecl_std::testgen!(); cubecl_std::testgen_tensor_identity!([f16, flex32, f32, u32]); + cubecl_std::testgen_quantized_view!(f16); cubecl_convolution::testgen_conv2d_accelerated!([f16: f16]); cubecl_matmul::testgen_matmul_simple!([f32]); cubecl_matmul::testgen_matmul_plane_accelerated!(); @@ -73,6 +75,7 @@ mod tests_msl { cubecl_core::testgen_all!(f32: [f16, f32], i32: [i16, i32], u32: [u16, u32]); cubecl_std::testgen!(); cubecl_std::testgen_tensor_identity!([f16, flex32, f32, u32]); + cubecl_std::testgen_quantized_view!(f16); cubecl_convolution::testgen_conv2d_accelerated!([f16: f16]); cubecl_matmul::testgen_matmul_simple!([f16, f32]); cubecl_matmul::testgen_matmul_plane_accelerated!(); diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index 86f11f299..da4d6aebf 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -2,17 +2,17 @@ use crate::{ AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer, contiguous_strides, }; +use cubecl_common::device::{Device, DeviceState}; use cubecl_common::{future, profile::TimingMethod}; - +use cubecl_core::server::ServerUtilities; use cubecl_core::{CubeCount, CubeDim, Runtime, ir::TargetProperties}; pub use cubecl_runtime::memory_management::MemoryConfiguration; use cubecl_runtime::memory_management::MemoryDeviceProperties; +use cubecl_runtime::{DeviceProperties, memory_management::HardwareProperties}; use cubecl_runtime::{ - ComputeRuntime, channel, client::ComputeClient, logging::{ProfileLevel, ServerLogger}, }; -use cubecl_runtime::{DeviceProperties, memory_management::HardwareProperties}; use wgpu::{InstanceFlags, RequestAdapterOptions}; /// Runtime that uses the [wgpu] crate with the wgsl compiler. This is used in the Wgpu backend. @@ -21,29 +21,24 @@ use wgpu::{InstanceFlags, RequestAdapterOptions}; #[derive(Debug)] pub struct WgpuRuntime; -type Server = WgpuServer; -type Channel = channel::MutexComputeChannel; -// type Channel = channel::MpscComputeChannel; - -/// The compute instance is shared across all [wgpu runtimes](WgpuRuntime). -static RUNTIME: ComputeRuntime = ComputeRuntime::new(); +impl DeviceState for WgpuServer { + fn init(device_id: cubecl_common::device::DeviceId) -> Self { + let device = WgpuDevice::from_id(device_id); + let setup = future::block_on(create_setup_for_device(&device, AutoGraphicsApi::backend())); + create_server(setup, RuntimeOptions::default()) + } +} impl Runtime for WgpuRuntime { type Compiler = AutoCompiler; type Server = WgpuServer; - - type Channel = Channel; type Device = WgpuDevice; - fn client(device: &Self::Device) -> ComputeClient { - RUNTIME.client(device, move || { - let setup = - future::block_on(create_setup_for_device(device, AutoGraphicsApi::backend())); - create_client_on_setup(setup, RuntimeOptions::default()) - }) + fn client(device: &Self::Device) -> ComputeClient { + ComputeClient::load(device) } - fn name(client: &ComputeClient) -> &'static str { + fn name(client: &ComputeClient) -> &'static str { match client.info() { wgpu::Backend::Vulkan => { #[cfg(feature = "spirv")] @@ -74,6 +69,10 @@ impl Runtime for WgpuRuntime { } } + fn max_global_line_size() -> u8 { + 4 + } + fn max_cube_count() -> (u32, u32, u32) { let max_dim = u16::MAX as u32; (max_dim, max_dim, max_dim) @@ -167,8 +166,8 @@ pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice { } let device_id = WgpuDevice::Existing(device_id); - let client = create_client_on_setup(setup, options); - RUNTIME.register(&device_id, client); + let server = create_server(setup, options); + let _ = ComputeClient::init(&device_id, server); device_id } @@ -194,15 +193,12 @@ pub async fn init_setup_async( ) -> WgpuSetup { let setup = create_setup_for_device(device, G::backend()).await; let return_setup = setup.clone(); - let client = create_client_on_setup(setup, options); - RUNTIME.register(device, client); + let server = create_server(setup, options); + let _ = ComputeClient::init(device, server); return_setup } -pub(crate) fn create_client_on_setup( - setup: WgpuSetup, - options: RuntimeOptions, -) -> ComputeClient { +pub(crate) fn create_server(setup: WgpuSetup, options: RuntimeOptions) -> WgpuServer { let limits = setup.device.limits(); let mut adapter_limits = setup.adapter.limits(); @@ -282,7 +278,9 @@ pub(crate) fn create_client_on_setup( backend::register_features(&setup.adapter, &mut device_props, &mut compilation_options); - let server = WgpuServer::new( + let logger = alloc::sync::Arc::new(ServerLogger::default()); + + WgpuServer::new( mem_props, options.memory_config, compilation_options, @@ -291,22 +289,8 @@ pub(crate) fn create_client_on_setup( options.tasks_max, setup.backend, time_measurement, - alloc::sync::Arc::new(ServerLogger::default()), - ); - let channel = Channel::new(server); - - #[cfg(not(all(target_os = "macos", feature = "msl")))] - if features.contains(wgpu::Features::SHADER_FLOAT32_ATOMIC) { - use cubecl_core::ir::{ElemType, FloatKind, StorageType}; - use cubecl_runtime::TypeUsage; - - device_props.register_type_usage( - StorageType::Atomic(ElemType::Float(FloatKind::F32)), - TypeUsage::AtomicLoadStore | TypeUsage::AtomicAdd, - ); - } - - ComputeClient::new(channel, device_props, setup.backend) + ServerUtilities::new(device_props, logger, setup.backend), + ) } /// Select the wgpu device and queue based on the provided [device](WgpuDevice) and diff --git a/crates/cubecl/Cargo.toml b/crates/cubecl/Cargo.toml index 4ac37b1d3..7be289a0e 100644 --- a/crates/cubecl/Cargo.toml +++ b/crates/cubecl/Cargo.toml @@ -47,17 +47,17 @@ wgpu-spirv = ["wgpu", "cubecl-wgpu/spirv"] spirv-dump = ["cubecl-wgpu/spirv-dump"] [dependencies] -cubecl-convolution = { path = "../cubecl-convolution", version = "0.7.0", default-features = false, optional = true } -cubecl-core = { path = "../cubecl-core", version = "0.7.0", default-features = false } -cubecl-cpu = { path = "../cubecl-cpu", version = "0.7.0", default-features = false, optional = true } -cubecl-cuda = { path = "../cubecl-cuda", version = "0.7.0", default-features = false, optional = true } -cubecl-hip = { path = "../cubecl-hip", version = "0.7.0", default-features = false, optional = true } -cubecl-matmul = { path = "../cubecl-matmul", version = "0.7.0", default-features = false, optional = true } -cubecl-random = { path = "../cubecl-random", version = "0.7.0", default-features = false, optional = true } -cubecl-reduce = { path = "../cubecl-reduce", version = "0.7.0", default-features = false, optional = true } -cubecl-runtime = { path = "../cubecl-runtime", version = "0.7.0", default-features = false } -cubecl-std = { path = "../cubecl-std", version = "0.7.0", optional = true } -cubecl-wgpu = { path = "../cubecl-wgpu", version = "0.7.0", default-features = false, optional = true } +cubecl-convolution = { path = "../cubecl-convolution", version = "0.9.0", default-features = false, optional = true } +cubecl-core = { path = "../cubecl-core", version = "0.9.0", default-features = false } +cubecl-cpu = { path = "../cubecl-cpu", version = "0.9.0", default-features = false, optional = true } +cubecl-cuda = { path = "../cubecl-cuda", version = "0.9.0", default-features = false, optional = true } +cubecl-hip = { path = "../cubecl-hip", version = "0.9.0", default-features = false, optional = true } +cubecl-matmul = { path = "../cubecl-matmul", version = "0.9.0", default-features = false, optional = true } +cubecl-random = { path = "../cubecl-random", version = "0.9.0", default-features = false, optional = true } +cubecl-reduce = { path = "../cubecl-reduce", version = "0.9.0", default-features = false, optional = true } +cubecl-runtime = { path = "../cubecl-runtime", version = "0.9.0", default-features = false } +cubecl-std = { path = "../cubecl-std", version = "0.9.0", optional = true } +cubecl-wgpu = { path = "../cubecl-wgpu", version = "0.9.0", default-features = false, optional = true } half = { workspace = true } [[bench]] diff --git a/crates/cubecl/benches/conv2d.rs b/crates/cubecl/benches/conv2d.rs index 65435879e..455e36c3f 100644 --- a/crates/cubecl/benches/conv2d.rs +++ b/crates/cubecl/benches/conv2d.rs @@ -105,7 +105,7 @@ pub struct Conv2dBench { bias_shape: usize, args: ConvolutionArgs<2>, device: R::Device, - client: ComputeClient, + client: ComputeClient, _phantom: PhantomData, } diff --git a/crates/cubecl/benches/matmul.rs b/crates/cubecl/benches/matmul.rs index e40fb458b..e596bd916 100644 --- a/crates/cubecl/benches/matmul.rs +++ b/crates/cubecl/benches/matmul.rs @@ -1,5 +1,6 @@ use core::marker::PhantomData; use cubecl::prelude::*; +use cubecl_matmul::AsyncReadingStrategy; use cubecl_matmul::components::batch::HypercubeSelection; use cubecl_matmul::components::stage::PartitionBuffering; use cubecl_matmul::components::{ @@ -11,17 +12,10 @@ use cubecl_matmul::kernels::layered::double_unit::DoubleUnitSelectionArgs; use cubecl_matmul::kernels::layered::ordered_double_buffering::OrderedSelectionArgs; use cubecl_matmul::kernels::layered::simple::SimpleArgs; use cubecl_matmul::kernels::layered::simple_unit::SimpleUnitSelectionArgs; -use cubecl_matmul::kernels::layered::{ - MatmulSelection, MultiRowStrategy, Selection, TileSizeSelection, closest_factor_pair, -}; use cubecl_matmul::kernels::layered::{Selection, TileSizeSelection}; -use cubecl_matmul::{self as matmul}; use cubecl_matmul::{ self as matmul, MatmulInputHandle, SyncPartialReadingStrategy, SyncReadingStrategy, }; -use cubecl_matmul::{self as matmul, SyncPartialReadingStrategy, SyncReadingStrategy}; -use cubecl_matmul::{AsyncReadingStrategy, components::MatmulPrecision}; -use cubecl_matmul::{SyncPartialReadingStrategy, SyncReadingStrategy}; use std::collections::BTreeMap; use std::time::Duration; @@ -98,8 +92,8 @@ impl Benchmark for MatmulBench { matmul_elems.rhs_global, matmul_elems.rhs_stage, matmul_elems.rhs_register, - matmul_elems.acc, - matmul_elems.out, + matmul_elems.acc_register, + matmul_elems.acc_global, self.strategy ) .to_lowercase() @@ -127,7 +121,7 @@ struct MatmulBench { tr: bool, strategy: matmul::Strategy, device: R::Device, - client: ComputeClient, + client: ComputeClient, _mp: PhantomData, } @@ -144,14 +138,14 @@ fn entry(m: usize, n: usize, k: usize) -> (usize, usize, usize, usize) { #[allow(dead_code)] fn run(device: R::Device, strategy: matmul::Strategy) { - for tl in [false] { - for tr in [true] { + for tl in [true, false] { + for tr in [true, false] { for (b, m, n, k) in [ // entry(8192, 8192, 8192), // entry(6144, 6144, 6144), // entry(4096, 4096, 4096), // entry(2048, 2048, 2048), - entry(1024, 1024, 1024), + // (2, 1024, 1024, 1024), // entry(512, 512, 512), // entry(64, 1024, 64), // entry(32, 1024, 32), @@ -165,7 +159,11 @@ fn run(device: R::Device, strategy: matmul::Str // (16, 1, 2048, 8192), // (16, 1, 4096, 4096), // (16, 1, 512, 4096), + // (2, 8192, 8192, 1), // Outer + // (2, 8192, 1, 8192), // MatVec + (2, 1, 8192, 8192), // VecMat ] { + println!("-------------------"); let _ = run_one::(device.clone(), strategy.clone(), (b, m, n, k), (tl, tr)); } } @@ -334,6 +332,7 @@ fn run_algos_unit() { tile_size: TileSizeSelection::MinTileSize, })), ); + println!("Double Unit Max"); run::( Default::default(), diff --git a/crates/cubecl/benches/memcpy_async.rs b/crates/cubecl/benches/memcpy_async.rs index 19b1276e5..d303a4dc3 100644 --- a/crates/cubecl/benches/memcpy_async.rs +++ b/crates/cubecl/benches/memcpy_async.rs @@ -563,7 +563,7 @@ enum ComputeTaskEnum { fn launch_ref( strategy: CopyStrategyEnum, - client: &ComputeClient, + client: &ComputeClient, input: &TensorHandleRef, output: &TensorHandleRef, smem_size: u32, @@ -778,7 +778,7 @@ struct MemcpyAsyncBench { strategy: CopyStrategyEnum, double_buffering: bool, device: R::Device, - client: ComputeClient, + client: ComputeClient, _e: PhantomData, } diff --git a/crates/cubecl/benches/unary.rs b/crates/cubecl/benches/unary.rs index 505e1728e..acd72b861 100644 --- a/crates/cubecl/benches/unary.rs +++ b/crates/cubecl/benches/unary.rs @@ -82,7 +82,7 @@ struct UnaryBench { shape: Vec, vectorization: u8, device: R::Device, - client: ComputeClient, + client: ComputeClient, _e: PhantomData, } diff --git a/cubecl-book/src/SUMMARY.md b/cubecl-book/src/SUMMARY.md index 4535c5be5..ed4cd43a7 100644 --- a/cubecl-book/src/SUMMARY.md +++ b/cubecl-book/src/SUMMARY.md @@ -19,6 +19,4 @@ - [Struct](./language-support/struct.md) - [Advanced Usage](./advanced-usage/summary.md) - [Configuration](./advanced-usage/config.md) -- [Algorithm reference](./algorithms/summary.md) - - [Quantized matrix multiplication](./algorithms/quantized_matmul.md) - - [Pseudo Random Number Generator]() + - [Math Optimizations](./advanced-usage/math_optimizations.md) diff --git a/cubecl-book/src/advanced-usage/math_optimizations.md b/cubecl-book/src/advanced-usage/math_optimizations.md new file mode 100644 index 000000000..4784dc36c --- /dev/null +++ b/cubecl-book/src/advanced-usage/math_optimizations.md @@ -0,0 +1,112 @@ +# Math Optimizations + +## Fast Math Options + +Floating point operations have a lot of restrictions required to follow the specification, +especially around special values (`Inf`/`NaN`) and signed zero that are rarely used. CubeCL allows +marking functions with loosened restrictions to accelerate math operations, while trading off some +special handling or precision. + +The effect is backend-dependent, but uses a unified API of flags specifying acceptable +optimizations. These `FastMath` flags can be applied per-function, so they can be applied only to +performance-critical sections of the code. + +**Example:** + +```rust +/// Only the inverse square root has reduced precision/no special handling. Everything else is full +/// precision. +#[cube(launch_unchecked)] +fn run_on_array(input: &Array, alpha: F, epsilon: F, output: &mut Array) { + if ABSOLUTE_POS < input.len() { + output[ABSOLUTE_POS] = alpha * fast_rsqrt::(input[ABSOLUTE_POS]) + epsilon; + } +} + +#[cube(fast_math = FastMath::all())] +fn fast_rsqrt(x: F) -> F { + F::inverse_sqrt(x) +} +``` + +### Backend Implementation + +#### WGPU with Vulkan Compiler + +Vulkan supports each flag as a modifier for all floating point operations. The compiler applies all +enabled flags, but the implementation is driver-specific. + +#### CUDA/HIP + +These targets only expose specific intrinsics. These intrinsics are used when all their required +flags are present. Only `f32` is supported for these intrinsics, other float types are not affected +by math flags on CUDA/HIP. Note that some of these are guesswork, because CUDA lacks documentation +on special value handling. + +| CubeCL Function | Intrinsic | Required Flags | +| ----------------- | ------------------- | --------------------------------------------------------------- | +| `a / b` | `__fdividef(a, b)` | `AllowReciprocal \| ReducedPrecision \| UnsignedZero \| NotInf` | +| `exp(a)` | `__expf(a)` | `ReducedPrecision \| NotNaN \| NotInf` | +| `log(a)` | `__logf(a)` | `ReducedPrecision \| NotNaN \| NotInf` | +| `sin(a)` | `__sinf(a)` | `ReducedPrecision \| NotNaN \| NotInf` | +| `cos(a)` | `__cosf(a)` | `ReducedPrecision \| NotNaN \| NotInf` | +| `tanh(a)` | `__tanhf(a)` | `ReducedPrecision \| NotNaN \| NotInf` | +| `powf(a)` | `__powf(a)` | `ReducedPrecision \| NotNaN \| NotInf` | +| `sqrt(a)` | `__fsqrt_rn(a)` | `ReducedPrecision \| NotNaN \| NotInf` | +| `inverse_sqrt(a)` | `__frsqrt_rn(a)` | `ReducedPrecision \| NotNaN \| NotInf` | +| `recip(a)` | `__frcp_rn(a)` | `AllowReciprocal \| ReducedPrecision \| UnsignedZero \| NotInf` | +| `normalize(a)` | n/a (`__frsqrt_rn`) | `ReducedPrecision \| NotNaN \| NotInf` | +| `magnitude(a)` | n/a (`__fsqrt_rn`) | `ReducedPrecision \| NotNaN \| NotInf` | + +#### Other Backends + +Other backends currently don't support any of these optimizations. + +## FastDivmod + +A very common operation, especially on GPUs, is applying integer division and modulo with a uniform, +but not constant, divisor (i.e. width). For example: + +```rust +#[cube(launch)] +pub fn some_2d_kernel(output: &mut Array, width: u32) { + let y = ABSOLUTE_POS / width; + let x = ABSOLUTE_POS % width; + //... +} + +// ... +some_2d_kernel::launch::( + &client, + // ..., + ScalarArg::new(matrix.width as u32), +); +``` + +However, integer division is quite slow, so this might have an impact on runtime. To mitigate the +cost you can use `FastDivmod` to pre-calculate the factors for division using 64-bit +[Barret Reduction](https://en.wikipedia.org/wiki/Barrett_reduction), and pass those instead of the +divisor. +This is faster even if you only use division or modulo, and _much_ faster if you use both. + +**Example:** + +```rust +#[cube(launch)] +pub fn some_2d_kernel(output: &mut Array, width: FastDivmod) { + let (y, x) = width.div_mod(ABSOLUTE_POS); + //... +} + +some_2d_kernel::launch::( + &client, + // ..., + FastDivmodArgs::new(&client, matrix.width as u32), +); +``` + +### Backend Support + +This is implemented using efficient extended multiplication on CUDA (`__umulhi`) and Vulkan +(`OpUMulExtended`), and using manual casts and shifts on targets that support `u64`. Targets without +either (possibly `WebGPU`) fall back to normal division. diff --git a/cubecl-book/src/algorithms/quantized_matmul.md b/cubecl-book/src/algorithms/quantized_matmul.md deleted file mode 100644 index 812b11e14..000000000 --- a/cubecl-book/src/algorithms/quantized_matmul.md +++ /dev/null @@ -1,189 +0,0 @@ -# Quantized matrix multiplication - -To make matrix multiplication faster, -we replace floating-point arithmetic using `f32` -with integer arithmetic using a mix of `u8`, `u16` and `i32`. -The benefits are twofold. -First, -we replace `Tensor` with `Tensor` to reduce memory cost by a factor of 4. -This leads to faster read and write operations into global memory. -Second, -integer operations are often faster than their floating-point counterparts. - -In this section, -we start by presenting a more mathematical overview of the algorithm, -before discussing implementation. - -## Mathematical formulation - -### Scalar quantization - -A real number \\(a\\) can be approximated by an integer \\(q\\) using the formula -\\[ - a \approx s(q - z). -\\] -In this equation \\(s\\) is a scaling factor and is also a real number, -while \\(z\\) is called the zero-offset and is an integer. -In theory, -with this approximation, -we can represent exactly all real numbers that are integral multiples of \\(s\\). -All other real numbers are rounded up to the closest representable value. -However, in practice, the range of \\(q\\) is limited by its representation (e.g. `u8`, `i32`). -Hence, the zero-offset \\(z\\) allows us to slide the interval of representable numbers toward -an interval we are interested in a particular application. -Also, by using the same type for \\(q\\) and \\(z\\), -we assure that 0 is exactly representable. - -The multiplication of two real numbers is equivalent to -\\[ - a b = s_a s_b (q_a - z_a) (q_b - z_b). -\\] -However, -we are more interested in the quantized version \\(q_c\\) of \\(c = ab \\). -Given we want to approximate \\(c\\) with scaling \\(s_c\\) and zero-offset \\(z_c\\), -we have -\\[ - q_c = - z_c + \frac{s_a s_b}{s_c} (q_a - z_a) (q_b - z_b). -\\] -Except for the factor \\( (s_a s_b) / s_c \\), the above equation involves only integer arithmetic. -However, -we can always find two integers \\(u, v\\) such that -\\[ - \frac uv \approx \frac{s_a s_b}{s_c} -\\] -is a satisfying approximation. -This leads to the final approximation for quantized multiplication -\\[ - q_c \approx z_c + \frac uv (q_a - z_a)(q_b - z_b) -\\] -requiring only integer arithmetic. - -### Matrix quantization - -The same idea holds for matrix multiplication. -To distinguish matrices from scalars, -we use capital letters for the former and lower letters for the latter. - -A real matrix \\( A \\) is approximated by an integer matrix \\( Q \\) using -\\[ - A \approx s (Q - z N). -\\] -Here \\( N \\) is a matrix of ones the same size as \\( A \\). -For two matrices \\(A \\) and \\( B \\) with respective shape \\(m \times k\\) -and \\(k \times n\\) and their product \\( C \\) of shape \\( m \times n \\), -we have, similar to the scalar case that -\\[ - Q_c \approx z_c N_c + \frac uv (Q_a - z_a N_a)(Q_b - z_b N_b). -\\] - -## Implementation - -As an example, -we describe how to implement the quantized matrix multiplication -where the elements of \\(Q_a\\), \\(Q_b\\) and \\(Q_c\\) and the zero-offsets are represented as `u8`. - -To compute \\(Q_a - z_a N_a \\), -we first convert the values to `i16` before performing the subtraction. -Then, we can compute the product \\((Q_a - z_a N_a)(Q_b - z_b N_b)\\) -by converting the values to `i32` before multiplying. -Of course, -in practice, we perform all these conversions on-the-fly to avoid wastefully allocating new matrices. - -Now, suppose that \\(x\\) is a single element in the resulting matrix and \\(y\\) -is the element with the same position in \\(Q_c\\). -We still need to compute the following -\\[ - y = z_c + \frac uv \cdot x. -\\] -The tricky part here is the product. -First, -we impose that \\( v \\) is a power of 2 so that dividing by \\( v \\) -is equivalent to right-shifting the product \\( u x \\). -Then, we need to find the best values \\( u \\) and \\( v \\) -for the scaling factor \\( \sigma = \frac{s_a s_b}{s_c} \\). -The trick is to cleverly multiply \\( \sigma \\) by 1, to get a form that allows us to work with powers of 2: -\\[ - \sigma = \frac{2^{31 - f}}{2^{31 - f}} \sigma -\\] -where \\(2^f\\) is the smallest power of 2 larger than \\(\sigma\\). -For example, if \\(\sigma = 0.3\\), then \\(f = -1\\) as \\(2^{-1} = 0.5 > 0.3 \\) -and \\(2^{-2} = 0.25 < 0.3\\). -From this, we deduce we that we can use \\(u = 2^{31 - f} \sigma\\) rounded to the -nearest `i64` value and \\(v = 2^{31 - f}\\). -This gives us a 31-bit approximation for multiplying by \\(\sigma\\), which is the best -we can achieve when the other multiplicand is an `i32`. -Indeed, we need to keep one bit for the sign. -To properly round the product, -one can add \\(\frac v 2\\) to the product before right shifting. - -A naive implementation of the above algorithm looks like the following. -```rust -fn scaling_ratio(sigma: f32) -> (i64, u32) { - let log = x.log2().ceil() as i32; - let u = (x * 2.0_f32.powi(31 - log)).round() as i64; - let v_shift = (31 - log) as u32; - (u, v_shift) -} - -fn approx_mul(x: i32, u: i64, v_shift: u32) -> i32 { - let prod = (x as i64) * u; - let rounding: i64 = 1 << (v_shift - 1); - let prod_with_rounding = prod + self.rounding; - (prod_with_rounding >> self.shift) as i32 -} - -fn clamp_to_u8(x: i32) -> u8 { - if x < 0 { - 0 - } else if x > u8::MAX as i32 { - u8::Max - } else { - x as u8 - } -} - -struct Matrix { - scaling: f32, - zero_offset: u8, - // ... other fields to store the matrix elements. -} - -impl Matrix { - fn quantized_mul(&self, other: &Self, output: &mut Self) -> Self { - // assume the shapes of the matrices match. - - let sigma = self.scaling * other.scaling / output.scaling; - let (u, v_shift) = scaling_ratio(sigma); - - for row in 0..self.row_count() { - for col in 0..other.col_count() { - let mut sum: i32 = 0; - for middle in 0..self.col_count() { - let a = self.get(row, middle) as i16 - self.zero_offset as i16; - let b = other.get(middle, col) as i16 - other.zero_offset as i16; - sum += (a as i32) * (b as i32); - } - sum = approx_mul(sum, u, v_shift); - - output.update(row, col, clamp_to_u8(sum + output.zero_offset as i32)) - } - } - } - - // return the value at (row, col) - fn get(&self, row: usize, col: usize) -> u8 { /* ... */ } - - // replace the value at (row, col) with the given value. - fn update(&mut self, row: usize, col: usize, value: u8) { /* ... */ } - - // return the number of rows of the matrix. - fn row_count(&self) -> usize { /* ... */ } - - // return the number of columns of the matrix. - fn col_count(&self) -> usize { /* ... */ } -} -``` -Of course, -in CubeCL, we stride to provide the fastest implementation for GPU devices. -As such, the example emphasizes the correct type casting to demonstrate how this is achieved in CubeCL. diff --git a/cubecl-book/src/algorithms/summary.md b/cubecl-book/src/algorithms/summary.md deleted file mode 100644 index 23e0f8ee3..000000000 --- a/cubecl-book/src/algorithms/summary.md +++ /dev/null @@ -1,10 +0,0 @@ -# Algorithm reference - -In this section, -we introduce different algorithms provided by CubeCL. -This is a best effort list and we focus first on nontrivial algorithms -deserving more explanations than what is reasonable to put in the API documentation. -This section is also a bit more technical compared to the others, as it serves two purposes. -First, it is a reference for users interested in the lower-level details of CubeCL. -Second, it is a reference for the developers who want to update the implementation as -the algorithms often get obfuscated by optimization details. diff --git a/cubecl-book/src/core-features/features.md b/cubecl-book/src/core-features/features.md index ad7806c7b..e306253da 100644 --- a/cubecl-book/src/core-features/features.md +++ b/cubecl-book/src/core-features/features.md @@ -83,5 +83,10 @@ Async tensor loading using the TMA accelerator available on Blackwell cards. Plane-level cooperative matrix multiply-add operations, with built-in block scaling. Available on Blackwell cards. +--- + [^1]: fp8/fp6/fp4 types are supported only for conversion and MMA + + + [^2]: bf16 is only supported for conversion, CMMA and, on some platforms, dot product diff --git a/cubecl-book/src/getting-started/installation.md b/cubecl-book/src/getting-started/installation.md index 32a19f244..8c2e055cd 100644 --- a/cubecl-book/src/getting-started/installation.md +++ b/cubecl-book/src/getting-started/installation.md @@ -1,22 +1,27 @@ # Installation -Installing CubeCL is straightforward. It’s available as a Rust crate, and you can add it to your project by updating your Cargo.toml: + +Installing CubeCL is straightforward. It’s available as a Rust crate, and you can add it to your +project by updating your Cargo.toml: ```toml [dependencies] cubecl = { - version = "0.7.0", # Replace with the latest version from crates.io + version = "0.9.0", # Replace with the latest version from crates.io features = ["wgpu"] # Enable desired runtime features (e.g., wgpu, cuda, hip) } ``` -The more challenging aspect is ensuring that you have the necessary drivers to run the selected runtime. +The more challenging aspect is ensuring that you have the necessary drivers to run the selected +runtime. -CubeCL supports multiple GPU runtimes, each requiring specific drivers or frameworks. Enable the appropriate feature flag in Cargo.toml and ensure the corresponding drivers are installed. +CubeCL supports multiple GPU and CPU runtimes, each requiring specific drivers or frameworks. Enable +the appropriate feature flag in Cargo.toml and ensure the corresponding drivers are installed. -| Platform | Runtime | Supported OS | Requirements | Installation/Notes | Feature Flag | -|----------|----------|-----------------------------|----------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------| -| WebGPU | wgpu | Linux, Windows, macOS, wasm | Vulkan drivers (typically pre-installed on modern OSes) | On linux install the vulkan driver. | wgpu | -| CUDA | CUDA | Linux, Windows | NVIDIA CUDA drivers and toolkit | Download and install from the NVIDIA CUDA Downloads page. Verify installation with nvidia-smi. | cuda | -| ROCm | HIP | Linux, Windows | AMD ROCm framework | Linux: Follow the ROCm Linux Quick Start. Windows: See the ROCm Windows Installation Guide. | hip | -| Metal | wgpu | macOS | Apple device with Metal support (macOS 10.13 or later) | No additional drivers needed; Metal is built into macOS. | wgpu-msl | -| Vulkan | wgpu | Linux, Windows | Vulkan drivers | On linux install via package manager, on windows it is typically included with GPU drivers (NVIDIA/AMD). | wgpu-spirv | +| Platform | Device Type | Runtime | Supported OS | Requirements | Installation/Notes | Feature Flag | +| -------- | ----------- | -------- | --------------------------- | ------------------------------------------------------- | -------------------------------------------------------------------------------------------------------- | ------------ | +| WebGPU | GPU | wgpu | Linux, Windows, macOS, wasm | Vulkan drivers (typically pre-installed on modern OSes) | On linux install the vulkan driver. | wgpu | +| CUDA | GPU | CUDA | Linux, Windows | NVIDIA CUDA drivers and toolkit | Download and install from the NVIDIA CUDA Downloads page. Verify installation with nvidia-smi. | cuda | +| ROCm | GPU | HIP | Linux, Windows | AMD ROCm framework | Linux: Follow the ROCm Linux Quick Start. Windows: See the ROCm Windows Installation Guide. | hip | +| Metal | GPU | wgpu | macOS | Apple device with Metal support (macOS 10.13 or later) | No additional drivers needed; Metal is built into macOS. | wgpu-msl | +| Vulkan | GPU | wgpu | Linux, Windows | Vulkan drivers | On linux install via package manager, on windows it is typically included with GPU drivers (NVIDIA/AMD). | wgpu-spirv | +| LLVM | CPU | Rust Std | LLVM supported targets | LLVM bundle | Automatically installed when compiling | cpu | diff --git a/cubecl-book/src/getting-started/src/bin/v2-gpu.rs b/cubecl-book/src/getting-started/src/bin/v2-gpu.rs index ef3444938..2ef9d8e1a 100644 --- a/cubecl-book/src/getting-started/src/bin/v2-gpu.rs +++ b/cubecl-book/src/getting-started/src/bin/v2-gpu.rs @@ -16,7 +16,7 @@ pub fn launch(device: &R::Device) { let client = R::client(device); let input = GpuTensor::::arange(vec![3, 3], &client); - let output = GpuTensor::::empty(vec![3, 3], &client); + let output = GpuTensor::::empty(vec![3], &client); unsafe { reduce_matrix::launch_unchecked::( diff --git a/cubecl-book/src/getting-started/src/bin/v3-gpu.rs b/cubecl-book/src/getting-started/src/bin/v3-gpu.rs index 546b73e10..3e2a96e36 100644 --- a/cubecl-book/src/getting-started/src/bin/v3-gpu.rs +++ b/cubecl-book/src/getting-started/src/bin/v3-gpu.rs @@ -6,7 +6,7 @@ use cubecl_example::gpu_tensor::GpuTensor; // Change to the path of your own mod pub struct ReductionBench { input_shape: Vec, - client: ComputeClient, + client: ComputeClient, _f: PhantomData, } diff --git a/cubecl-book/src/getting-started/src/bin/v4-gpu.rs b/cubecl-book/src/getting-started/src/bin/v4-gpu.rs index 83d99d677..d5b8741de 100644 --- a/cubecl-book/src/getting-started/src/bin/v4-gpu.rs +++ b/cubecl-book/src/getting-started/src/bin/v4-gpu.rs @@ -6,7 +6,7 @@ use cubecl_example::gpu_tensor::GpuTensor; // Change to the path of your own mod pub struct ReductionBench { input_shape: Vec, - client: ComputeClient, + client: ComputeClient, _f: PhantomData, } diff --git a/cubecl-book/src/getting-started/src/bin/v5-gpu.rs b/cubecl-book/src/getting-started/src/bin/v5-gpu.rs index 315e9d899..71f6dd7e4 100644 --- a/cubecl-book/src/getting-started/src/bin/v5-gpu.rs +++ b/cubecl-book/src/getting-started/src/bin/v5-gpu.rs @@ -6,7 +6,7 @@ use cubecl_example::gpu_tensor::GpuTensor; // Change to the path of your own mod pub struct ReductionBench { input_shape: Vec, - client: ComputeClient, + client: ComputeClient, _f: PhantomData, } diff --git a/cubecl-book/src/getting-started/src/bin/v6-gpu.rs b/cubecl-book/src/getting-started/src/bin/v6-gpu.rs index a5891e05d..6553ac6b0 100644 --- a/cubecl-book/src/getting-started/src/bin/v6-gpu.rs +++ b/cubecl-book/src/getting-started/src/bin/v6-gpu.rs @@ -6,7 +6,7 @@ use cubecl_example::gpu_tensor::GpuTensor; // Change to the path of your own mod pub struct ReductionBench { input_shape: Vec, - client: ComputeClient, + client: ComputeClient, _f: PhantomData, } diff --git a/cubecl-book/src/getting-started/src/bin/v7-gpu.rs b/cubecl-book/src/getting-started/src/bin/v7-gpu.rs index 87ce09e40..b4977c3f3 100644 --- a/cubecl-book/src/getting-started/src/bin/v7-gpu.rs +++ b/cubecl-book/src/getting-started/src/bin/v7-gpu.rs @@ -6,7 +6,7 @@ use cubecl_example::gpu_tensor::GpuTensor; // Change to the path of your own mod pub struct ReductionBench { input_shape: Vec, - client: ComputeClient, + client: ComputeClient, _f: PhantomData, } diff --git a/cubecl-book/src/getting-started/src/gpu_tensor.rs b/cubecl-book/src/getting-started/src/gpu_tensor.rs index 24f36be79..198cee6d6 100644 --- a/cubecl-book/src/getting-started/src/gpu_tensor.rs +++ b/cubecl-book/src/getting-started/src/gpu_tensor.rs @@ -26,7 +26,7 @@ impl Clone for GpuTensor { impl GpuTensor { /// Create a GpuTensor with a shape filled by number in order - pub fn arange(shape: Vec, client: &ComputeClient) -> Self { + pub fn arange(shape: Vec, client: &ComputeClient) -> Self { let size = shape.iter().product(); let data: Vec = (0..size).map(|i| F::from_int(i as i64)).collect(); let data = client.create(F::as_bytes(&data)); @@ -42,7 +42,7 @@ impl GpuTensor { } /// Create an empty GpuTensor with a shape - pub fn empty(shape: Vec, client: &ComputeClient) -> Self { + pub fn empty(shape: Vec, client: &ComputeClient) -> Self { let size = shape.iter().product::() * core::mem::size_of::(); let data = client.empty(size); @@ -62,7 +62,7 @@ impl GpuTensor { } /// Return the data from the client - pub fn read(self, client: &ComputeClient) -> Vec { + pub fn read(self, client: &ComputeClient) -> Vec { let bytes = client.read_one(self.data.binding()); F::from_bytes(&bytes).to_vec() } diff --git a/examples/device_sharing/Cargo.toml b/examples/device_sharing/Cargo.toml index ce5943822..621d8ebb1 100644 --- a/examples/device_sharing/Cargo.toml +++ b/examples/device_sharing/Cargo.toml @@ -12,7 +12,7 @@ default = [] wgpu = ["cubecl/wgpu"] [dependencies] -cubecl = { path = "../../crates/cubecl", version = "0.7.0" } +cubecl = { path = "../../crates/cubecl", version = "0.9.0" } half = { workspace = true } sum_things = { path = "../sum_things" } diff --git a/examples/fusing/Cargo.toml b/examples/fusing/Cargo.toml index 984067ad2..61a1b7674 100644 --- a/examples/fusing/Cargo.toml +++ b/examples/fusing/Cargo.toml @@ -12,5 +12,5 @@ wgpu = ["cubecl/wgpu"] cuda = ["cubecl/cuda"] [dependencies] -cubecl = { path = "../../crates/cubecl", version = "0.7.0", default-features = false } +cubecl = { path = "../../crates/cubecl", version = "0.9.0", default-features = false } half = { workspace = true } diff --git a/examples/gelu/Cargo.toml b/examples/gelu/Cargo.toml index 311af8465..faceb913b 100644 --- a/examples/gelu/Cargo.toml +++ b/examples/gelu/Cargo.toml @@ -13,5 +13,5 @@ cuda = ["cubecl/cuda"] cpu = ["cubecl/cpu"] [dependencies] -cubecl = { path = "../../crates/cubecl", version = "0.7.0" } +cubecl = { path = "../../crates/cubecl", version = "0.9.0" } half = { workspace = true } diff --git a/examples/normalization/Cargo.toml b/examples/normalization/Cargo.toml index 625a6243e..538e0fcb7 100644 --- a/examples/normalization/Cargo.toml +++ b/examples/normalization/Cargo.toml @@ -12,5 +12,5 @@ wgpu = ["cubecl/wgpu"] cuda = ["cubecl/cuda"] [dependencies] -cubecl = { path = "../../crates/cubecl", version = "0.7.0" } +cubecl = { path = "../../crates/cubecl", version = "0.9.0" } half = { workspace = true } diff --git a/examples/sum_things/Cargo.toml b/examples/sum_things/Cargo.toml index aa99b4bee..3c8de4e80 100644 --- a/examples/sum_things/Cargo.toml +++ b/examples/sum_things/Cargo.toml @@ -12,5 +12,5 @@ wgpu = ["cubecl/wgpu"] cuda = ["cubecl/cuda"] [dependencies] -cubecl = { path = "../../crates/cubecl", version = "0.7.0" } +cubecl = { path = "../../crates/cubecl", version = "0.9.0" } half = { workspace = true } diff --git a/examples/sum_things/src/lib.rs b/examples/sum_things/src/lib.rs index 6691cf863..cafa8c4e5 100644 --- a/examples/sum_things/src/lib.rs +++ b/examples/sum_things/src/lib.rs @@ -102,7 +102,7 @@ impl CreateSeries for SumThenMul { } fn launch_basic( - client: &ComputeClient, + client: &ComputeClient, input: &Handle, output: &Handle, len: usize, @@ -120,7 +120,7 @@ fn launch_basic( } fn launch_subgroup( - client: &ComputeClient, + client: &ComputeClient, input: &Handle, output: &Handle, len: usize, @@ -139,7 +139,7 @@ fn launch_subgroup( } fn launch_trait( - client: &ComputeClient, + client: &ComputeClient, input: &Handle, output: &Handle, len: usize, @@ -157,7 +157,7 @@ fn launch_trait( } fn launch_series( - client: &ComputeClient, + client: &ComputeClient, input: &Handle, output: &Handle, len: usize, diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index 555b45bb1..f5a05d6ad 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "xtask" -version = "1.4.0" +version = "2.1.11" edition.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/xtask/src/commands/profile.rs b/xtask/src/commands/profile.rs index 2de4c25a9..5813ce9a8 100644 --- a/xtask/src/commands/profile.rs +++ b/xtask/src/commands/profile.rs @@ -68,7 +68,7 @@ impl Profile { &options.bench, "--release", "--features", - "cuda", + "cuda,random", ], None, None, From b21738a2cc695e4c43185e9777af72771831e8a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 6 Nov 2025 14:59:04 +0100 Subject: [PATCH 22/23] Fix duplicate inv_sqrt test --- crates/cubecl-core/src/runtime_tests/unary.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/cubecl-core/src/runtime_tests/unary.rs b/crates/cubecl-core/src/runtime_tests/unary.rs index 3d0d55a91..f71504920 100644 --- a/crates/cubecl-core/src/runtime_tests/unary.rs +++ b/crates/cubecl-core/src/runtime_tests/unary.rs @@ -873,7 +873,6 @@ macro_rules! testgen_unary { add_test!(test_degrees); add_test!(test_radians); add_test!(test_normalize); - add_test!(test_inverse_sqrt); add_test!(test_magnitude); add_test!(test_sqrt); add_test!(test_inverse_sqrt); From 72b5330b1808f8c47861dcd89925f852729fe0e4 Mon Sep 17 00:00:00 2001 From: Jorge Perez Burgos Date: Thu, 6 Nov 2025 21:05:36 +0100 Subject: [PATCH 23/23] Correct order of dialect conversion registration. --- crates/cubecl-cpu/src/compiler/module.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/cubecl-cpu/src/compiler/module.rs b/crates/cubecl-cpu/src/compiler/module.rs index e76f60c87..571baf7aa 100644 --- a/crates/cubecl-cpu/src/compiler/module.rs +++ b/crates/cubecl-cpu/src/compiler/module.rs @@ -70,11 +70,11 @@ impl<'a> Module<'a> { pass_manager.add_pass(pass::conversion::create_index_to_llvm()); pass_manager.add_pass(pass::conversion::create_scf_to_control_flow()); pass_manager.add_pass(pass::conversion::create_control_flow_to_llvm()); + pass_manager.add_pass(pass::conversion::create_math_to_llvm()); + pass_manager.add_pass(pass::conversion::create_math_to_libm()); pass_manager.add_pass(pass::conversion::create_vector_to_llvm()); pass_manager.add_pass(pass::conversion::create_arith_to_llvm()); pass_manager.add_pass(pass::conversion::create_func_to_llvm()); - pass_manager.add_pass(pass::conversion::create_math_to_llvm()); - pass_manager.add_pass(pass::conversion::create_math_to_libm()); pass_manager.add_pass(pass::transform::create_inliner()); pass_manager.add_pass(pass::conversion::create_reconcile_unrealized_casts()); pass_manager.add_pass(pass::transform::create_sccp());