Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crates/cubecl-core/src/frontend/container/line/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use num_traits::{NumCast, ToPrimitive};

use crate::{
self as cubecl,
prelude::{IsInf, IsNan, Powi, SaturatingAdd, SaturatingSub},
prelude::{IsInf, IsNan, Powi, SaturatingAdd, SaturatingSub, Trunc},
};
use crate::{
frontend::{
Expand Down Expand Up @@ -255,6 +255,7 @@ impl<P: CubePrimitive + Remainder> Remainder for Line<P> {}
impl<P: CubePrimitive + Round> Round for Line<P> {}
impl<P: CubePrimitive + Floor> Floor for Line<P> {}
impl<P: CubePrimitive + Ceil> Ceil for Line<P> {}
impl<P: CubePrimitive + Trunc> Trunc for Line<P> {}
impl<P: CubePrimitive + ReverseBits> ReverseBits for Line<P> {}
impl<P: CubePrimitive + BitwiseNot> BitwiseNot for Line<P> {}
impl<P: CubePrimitive + SaturatingAdd> SaturatingAdd for Line<P> {}
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/src/frontend/element/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub trait Float:
+ Round
+ Floor
+ Ceil
+ Trunc
+ Erf
+ Recip
+ Magnitude
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/src/frontend/element/float/typemap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ impl<const POS: u8> Sqrt for ElemExpand<POS> {}
impl<const POS: u8> Round for ElemExpand<POS> {}
impl<const POS: u8> Floor for ElemExpand<POS> {}
impl<const POS: u8> Ceil for ElemExpand<POS> {}
impl<const POS: u8> Trunc for ElemExpand<POS> {}
impl<const POS: u8> IsNan for ElemExpand<POS> {}
impl<const POS: u8> IsInf for ElemExpand<POS> {}

Expand Down
12 changes: 12 additions & 0 deletions crates/cubecl-core/src/frontend/operation/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,18 @@ impl_unary_func!(
f32,
f64
);
impl_unary_func!(
Trunc,
trunc,
__expand_trunc,
Arithmetic::Trunc,
f16,
bf16,
flex32,
tf32,
f32,
f64
);
impl_unary_func!(
Erf,
erf,
Expand Down
24 changes: 24 additions & 0 deletions crates/cubecl-core/src/runtime_tests/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,29 @@ test_unary_impl!(
]
);

test_unary_impl!(
test_trunc,
F,
F::trunc,
[{
input_vectorization: 1,
out_vectorization: 1,
input: as_type![F: -1.2, -1., -0., 0.],
expected: as_type![F: -1., -1., -0., 0.]
},
{
input_vectorization: 2,
out_vectorization: 2,
input: as_type![F: f32::NAN, 1., 1.2, 1.9],
expected: as_type![F: f32::NAN, 1., 1., 1.0]
},{
input_vectorization: 4,
out_vectorization: 4,
input: as_type![F: -0.9, 0.2, f32::NAN, 1.99],
expected: as_type![F: -0., 0., f32::NAN, 1.]
}]
);

test_unary_impl_fixed!(
test_is_nan,
F,
Expand Down Expand Up @@ -479,6 +502,7 @@ macro_rules! testgen_unary {
add_test!(test_normalize);
add_test!(test_magnitude);
add_test!(test_abs);
add_test!(test_trunc);
add_test!(test_is_nan);
add_test!(test_is_inf);
}
Expand Down
3 changes: 3 additions & 0 deletions crates/cubecl-cpp/src/shared/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,9 @@ impl<D: Dialect> CppCompiler<D> {
gpu::Arithmetic::Ceil(op) => {
instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
}
gpu::Arithmetic::Trunc(op) => {
instructions.push(Instruction::Trunc(self.compile_unary(op, out)))
}
gpu::Arithmetic::Remainder(op) => {
instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
}
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-cpp/src/shared/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ pub enum Instruction<D: Dialect> {
},
Round(UnaryInstruction<D>),
Ceil(UnaryInstruction<D>),
Trunc(UnaryInstruction<D>),
Floor(UnaryInstruction<D>),
Warp(WarpInstruction<D>),
Wmma(WmmaInstruction<D>),
Expand Down Expand Up @@ -538,6 +539,7 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
Instruction::ThreadFence => f.write_str("__threadfence();\n"),
Instruction::Round(it) => Round::format(f, &it.input, &it.out),
Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
Instruction::Trunc(it) => Trunc::format(f, &it.input, &it.out),
Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
Instruction::SliceLength { input, out } => {
let out = out.fmt_left();
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-cpp/src/shared/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ function!(Sin, "sin");
function!(Sqrt, "sqrt");
function!(Exp, "exp");
function!(Ceil, "ceil");
function!(Trunc, "trunc");
function!(Floor, "floor");
function!(Round, "rint");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ impl<'a> Visitor<'a> {
));
self.insert_variable(out, result);
}
Arithmetic::Trunc(trunc) => {
let value = self.get_variable(trunc.input);
let result = self.append_operation_with_result(llvm_ods::intr_trunc(
self.context,
value,
self.location,
));
self.insert_variable(out, result);
}
Arithmetic::Clamp(clamp) => {
let value = self.get_variable(clamp.input);
let mut min = self.get_variable(clamp.min_value);
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-ir/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub enum Arithmetic {
Round(UnaryOperator),
Floor(UnaryOperator),
Ceil(UnaryOperator),
Trunc(UnaryOperator),
Erf(UnaryOperator),
Recip(UnaryOperator),
Clamp(ClampOperator),
Expand Down Expand Up @@ -73,6 +74,7 @@ impl Display for Arithmetic {
Arithmetic::Round(op) => write!(f, "{}.round()", op.input),
Arithmetic::Floor(op) => write!(f, "{}.floor()", op.input),
Arithmetic::Ceil(op) => write!(f, "{}.ceil()", op.input),
Arithmetic::Trunc(op) => write!(f, "{}.trunc()", op.input),
Arithmetic::Erf(op) => write!(f, "{}.erf()", op.input),
Arithmetic::Recip(op) => write!(f, "{}.recip()", op.input),
Arithmetic::Clamp(op) => {
Expand Down
3 changes: 3 additions & 0 deletions crates/cubecl-ir/src/processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ impl ScopeProcessing {
Arithmetic::Ceil(op) => {
sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
}
Arithmetic::Trunc(op) => {
sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
}
Arithmetic::Erf(op) => {
sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
}
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-opt/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ impl Optimizer {
| Arithmetic::Round(unary_operator)
| Arithmetic::Floor(unary_operator)
| Arithmetic::Ceil(unary_operator)
| Arithmetic::Trunc(unary_operator)
| Arithmetic::Erf(unary_operator)
| Arithmetic::Recip(unary_operator)
| Arithmetic::Neg(unary_operator)
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-opt/src/passes/constant_prop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ fn try_const_eval_arithmetic(op: &mut Arithmetic) -> Option<ConstantScalarValue>
Arithmetic::Round(op) => const_eval_float!(op.input; num::Float::round),
Arithmetic::Floor(op) => const_eval_float!(op.input; num::Float::floor),
Arithmetic::Ceil(op) => const_eval_float!(op.input; num::Float::ceil),
Arithmetic::Trunc(op) => const_eval_float!(op.input; num::Float::trunc),
Arithmetic::Recip(op) => const_eval_float!(op.input; num::Float::recip),
Arithmetic::Neg(op) => {
use ConstantScalarValue::*;
Expand Down
8 changes: 8 additions & 0 deletions crates/cubecl-spirv/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,14 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
}
})
}
Arithmetic::Trunc(op) => {
self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
T::trunc(b, ty, input, out);
if matches!(out_ty.elem(), Elem::Relaxed) {
b.decorate(out, Decoration::RelaxedPrecision, []);
}
})
}
Arithmetic::Clamp(op) => {
let input = self.compile_variable(op.input);
let min = self.compile_variable(op.min_value);
Expand Down
5 changes: 5 additions & 0 deletions crates/cubecl-spirv/src/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub trait TargetExtensions<T: SpirvTarget> {
fn s_abs(b: &mut SpirvCompiler<T>, ty: Word, input: Word, out: Word);
fn floor(b: &mut SpirvCompiler<T>, ty: Word, input: Word, out: Word);
fn ceil(b: &mut SpirvCompiler<T>, ty: Word, input: Word, out: Word);
fn trunc(b: &mut SpirvCompiler<T>, ty: Word, input: Word, out: Word);
fn sin(b: &mut SpirvCompiler<T>, ty: Word, input: Word, out: Word);
fn cos(b: &mut SpirvCompiler<T>, ty: Word, input: Word, out: Word);
fn tanh(b: &mut SpirvCompiler<T>, ty: Word, input: Word, out: Word);
Expand Down Expand Up @@ -57,6 +58,10 @@ pub mod glcompute {
b.gl_ceil_id(ty, Some(out), input).unwrap();
}

fn trunc(b: &mut SpirvCompiler<T>, ty: Word, input: Word, out: Word) {
b.gl_trunc_id(ty, Some(out), input).unwrap();
}

fn sin(b: &mut SpirvCompiler<T>, ty: Word, input: Word, out: Word) {
b.gl_sin_id(ty, Some(out), input).unwrap();
}
Expand Down
4 changes: 4 additions & 0 deletions crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,10 @@ impl WgslCompiler {
input: self.compile_variable(op.input),
out: self.compile_variable(out),
}),
cube::Arithmetic::Trunc(op) => instructions.push(wgsl::Instruction::Trunc {
input: self.compile_variable(op.input),
out: self.compile_variable(out),
}),
cube::Arithmetic::Erf(op) => {
let mut scope = scope.child();
expand_erf(&mut scope, op.input, out);
Expand Down
8 changes: 8 additions & 0 deletions crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ pub enum Instruction {
input: Variable,
out: Variable,
},
Trunc {
input: Variable,
out: Variable,
},
Remainder {
lhs: Variable,
rhs: Variable,
Expand Down Expand Up @@ -854,6 +858,10 @@ for (var {i}: {i_ty} = {start}; {i} {cmp} {end}; {increment}) {{
let out = out.fmt_left();
writeln!(f, "{out} = ceil({input});")
}
Instruction::Trunc { input, out } => {
let out = out.fmt_left();
writeln!(f, "{out} = trunc({input});")
}
Instruction::Subgroup(op) => write!(f, "{op}"),
Instruction::Bitcast { input, out } => {
let elem = out.item();
Expand Down