diff --git a/crates/cubecl-core/src/frontend/container/line/ops.rs b/crates/cubecl-core/src/frontend/container/line/ops.rs index cac015871..982fbd05f 100644 --- a/crates/cubecl-core/src/frontend/container/line/ops.rs +++ b/crates/cubecl-core/src/frontend/container/line/ops.rs @@ -4,7 +4,7 @@ use num_traits::{NumCast, ToPrimitive}; use crate::{ self as cubecl, - prelude::{IsInf, IsNan, Powi, SaturatingAdd, SaturatingSub}, + prelude::{IsInf, IsNan, Powi, SaturatingAdd, SaturatingSub, Trunc}, }; use crate::{ frontend::{ @@ -255,6 +255,7 @@ impl Remainder for Line

{} impl Round for Line

{} impl Floor for Line

{} impl Ceil for Line

{} +impl Trunc for Line

{} impl ReverseBits for Line

{} impl BitwiseNot for Line

{} impl SaturatingAdd for Line

{} diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs index fe322960e..8e77c8b4d 100644 --- a/crates/cubecl-core/src/frontend/element/float.rs +++ b/crates/cubecl-core/src/frontend/element/float.rs @@ -32,6 +32,7 @@ pub trait Float: + Round + Floor + Ceil + + Trunc + Erf + Recip + Magnitude diff --git a/crates/cubecl-core/src/frontend/element/float/typemap.rs b/crates/cubecl-core/src/frontend/element/float/typemap.rs index 89564a468..dea516774 100644 --- a/crates/cubecl-core/src/frontend/element/float/typemap.rs +++ b/crates/cubecl-core/src/frontend/element/float/typemap.rs @@ -254,6 +254,7 @@ impl Sqrt for ElemExpand {} impl Round for ElemExpand {} impl Floor for ElemExpand {} impl Ceil for ElemExpand {} +impl Trunc for ElemExpand {} impl IsNan for ElemExpand {} impl IsInf for ElemExpand {} diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index 382e64698..affb5f8eb 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -235,6 +235,18 @@ impl_unary_func!( f32, f64 ); +impl_unary_func!( + Trunc, + trunc, + __expand_trunc, + Arithmetic::Trunc, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); impl_unary_func!( Erf, erf, diff --git a/crates/cubecl-core/src/runtime_tests/unary.rs b/crates/cubecl-core/src/runtime_tests/unary.rs index 714ee5ddd..89c0faee8 100644 --- a/crates/cubecl-core/src/runtime_tests/unary.rs +++ b/crates/cubecl-core/src/runtime_tests/unary.rs @@ -305,6 +305,29 @@ test_unary_impl!( ] ); +test_unary_impl!( + test_trunc, + F, + F::trunc, + [{ + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: -1.2, -1., -0., 0.], + expected: as_type![F: -1., -1., -0., 0.] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: f32::NAN, 1., 1.2, 1.9], + expected: as_type![F: f32::NAN, 1., 1., 1.0] + },{ + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: -0.9, 0.2, f32::NAN, 1.99], + expected: as_type![F: -0., 0., f32::NAN, 1.] + }] +); + test_unary_impl_fixed!( test_is_nan, F, @@ -479,6 +502,7 @@ macro_rules! testgen_unary { add_test!(test_normalize); add_test!(test_magnitude); add_test!(test_abs); + add_test!(test_trunc); add_test!(test_is_nan); add_test!(test_is_inf); } diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index cbd7a759f..e9989cac8 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -1012,6 +1012,9 @@ impl CppCompiler { gpu::Arithmetic::Ceil(op) => { instructions.push(Instruction::Ceil(self.compile_unary(op, out))) } + gpu::Arithmetic::Trunc(op) => { + instructions.push(Instruction::Trunc(self.compile_unary(op, out))) + } gpu::Arithmetic::Remainder(op) => { instructions.push(Instruction::Remainder(self.compile_binary(op, out))) } diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index 70d450757..baddedd9f 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -200,6 +200,7 @@ pub enum Instruction { }, Round(UnaryInstruction), Ceil(UnaryInstruction), + Trunc(UnaryInstruction), Floor(UnaryInstruction), Warp(WarpInstruction), Wmma(WmmaInstruction), @@ -538,6 +539,7 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ Instruction::ThreadFence => f.write_str("__threadfence();\n"), Instruction::Round(it) => Round::format(f, &it.input, &it.out), Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out), + Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out), Instruction::Floor(it) => Floor::format(f, &it.input, &it.out), Instruction::SliceLength { input, out } => { let out = out.fmt_left(); diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index b23fb6490..7711a2d2c 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -154,6 +154,7 @@ function!(Sin, "sin"); function!(Sqrt, "sqrt"); function!(Exp, "exp"); function!(Ceil, "ceil"); +function!(Trunc, "trunc"); function!(Floor, "floor"); function!(Round, "rint"); diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index 647bed3ca..596ba0a56 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -40,6 +40,15 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result); } + Arithmetic::Trunc(trunc) => { + let value = self.get_variable(trunc.input); + let result = self.append_operation_with_result(llvm_ods::intr_trunc( + self.context, + value, + self.location, + )); + self.insert_variable(out, result); + } Arithmetic::Clamp(clamp) => { let value = self.get_variable(clamp.input); let mut min = self.get_variable(clamp.min_value); diff --git a/crates/cubecl-ir/src/arithmetic.rs b/crates/cubecl-ir/src/arithmetic.rs index 67f0abfc9..e6197bfeb 100644 --- a/crates/cubecl-ir/src/arithmetic.rs +++ b/crates/cubecl-ir/src/arithmetic.rs @@ -32,6 +32,7 @@ pub enum Arithmetic { Round(UnaryOperator), Floor(UnaryOperator), Ceil(UnaryOperator), + Trunc(UnaryOperator), Erf(UnaryOperator), Recip(UnaryOperator), Clamp(ClampOperator), @@ -73,6 +74,7 @@ impl Display for Arithmetic { Arithmetic::Round(op) => write!(f, "{}.round()", op.input), Arithmetic::Floor(op) => write!(f, "{}.floor()", op.input), Arithmetic::Ceil(op) => write!(f, "{}.ceil()", op.input), + Arithmetic::Trunc(op) => write!(f, "{}.trunc()", op.input), Arithmetic::Erf(op) => write!(f, "{}.erf()", op.input), Arithmetic::Recip(op) => write!(f, "{}.recip()", op.input), Arithmetic::Clamp(op) => { diff --git a/crates/cubecl-ir/src/processing.rs b/crates/cubecl-ir/src/processing.rs index 8f2968c29..15ef8d207 100644 --- a/crates/cubecl-ir/src/processing.rs +++ b/crates/cubecl-ir/src/processing.rs @@ -131,6 +131,9 @@ impl ScopeProcessing { Arithmetic::Ceil(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } + Arithmetic::Trunc(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } Arithmetic::Erf(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index 38dfe129d..ec7cbf566 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -97,6 +97,7 @@ impl Optimizer { | Arithmetic::Round(unary_operator) | Arithmetic::Floor(unary_operator) | Arithmetic::Ceil(unary_operator) + | Arithmetic::Trunc(unary_operator) | Arithmetic::Erf(unary_operator) | Arithmetic::Recip(unary_operator) | Arithmetic::Neg(unary_operator) diff --git a/crates/cubecl-opt/src/passes/constant_prop.rs b/crates/cubecl-opt/src/passes/constant_prop.rs index 9c8fba0a6..c3867f8ec 100644 --- a/crates/cubecl-opt/src/passes/constant_prop.rs +++ b/crates/cubecl-opt/src/passes/constant_prop.rs @@ -421,6 +421,7 @@ fn try_const_eval_arithmetic(op: &mut Arithmetic) -> Option Arithmetic::Round(op) => const_eval_float!(op.input; num::Float::round), Arithmetic::Floor(op) => const_eval_float!(op.input; num::Float::floor), Arithmetic::Ceil(op) => const_eval_float!(op.input; num::Float::ceil), + Arithmetic::Trunc(op) => const_eval_float!(op.input; num::Float::trunc), Arithmetic::Recip(op) => const_eval_float!(op.input; num::Float::recip), Arithmetic::Neg(op) => { use ConstantScalarValue::*; diff --git a/crates/cubecl-spirv/src/arithmetic.rs b/crates/cubecl-spirv/src/arithmetic.rs index 642351e6d..533cb7d19 100644 --- a/crates/cubecl-spirv/src/arithmetic.rs +++ b/crates/cubecl-spirv/src/arithmetic.rs @@ -384,6 +384,14 @@ impl SpirvCompiler { } }) } + Arithmetic::Trunc(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::trunc(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } Arithmetic::Clamp(op) => { let input = self.compile_variable(op.input); let min = self.compile_variable(op.min_value); diff --git a/crates/cubecl-spirv/src/extensions.rs b/crates/cubecl-spirv/src/extensions.rs index 303d69119..2330ecdb4 100644 --- a/crates/cubecl-spirv/src/extensions.rs +++ b/crates/cubecl-spirv/src/extensions.rs @@ -10,6 +10,7 @@ pub trait TargetExtensions { fn s_abs(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn floor(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn ceil(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn trunc(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn sin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn cos(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn tanh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); @@ -57,6 +58,10 @@ pub mod glcompute { b.gl_ceil_id(ty, Some(out), input).unwrap(); } + fn trunc(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.gl_trunc_id(ty, Some(out), input).unwrap(); + } + fn sin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { b.gl_sin_id(ty, Some(out), input).unwrap(); } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 19c57f9c1..c6ee0061f 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -787,6 +787,10 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(out), }), + cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), cube::Arithmetic::Erf(op) => { let mut scope = scope.child(); expand_erf(&mut scope, op.input, out); diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index 2d252d167..da050752e 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -265,6 +265,10 @@ pub enum Instruction { input: Variable, out: Variable, }, + Trunc { + input: Variable, + out: Variable, + }, Remainder { lhs: Variable, rhs: Variable, @@ -854,6 +858,10 @@ for (var {i}: {i_ty} = {start}; {i} {cmp} {end}; {increment}) {{ let out = out.fmt_left(); writeln!(f, "{out} = ceil({input});") } + Instruction::Trunc { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = trunc({input});") + } Instruction::Subgroup(op) => write!(f, "{op}"), Instruction::Bitcast { input, out } => { let elem = out.item();