Skip to content

Commit 848def1

Browse files
committed
Potentially optimize dot4{I,U}8Packed on Metal
This might allow the Metal compiler to emit faster code, but that's not confirmed. See <gpuweb/gpuweb#2677 (comment)>.
1 parent f44bd46 commit 848def1

File tree

2 files changed

+106
-33
lines changed

2 files changed

+106
-33
lines changed

naga/src/back/msl/writer.rs

+94-29
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ const fn scalar_is_int(scalar: crate::Scalar) -> bool {
121121
/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
122122
const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
123123

124+
/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
125+
const REINTERPRET_PREFIX: &str = "reinterpreted_";
126+
124127
/// Wrapper for identifier names for clamped level-of-detail values
125128
///
126129
/// Values of this type implement [`core::fmt::Display`], formatting as
@@ -156,6 +159,30 @@ impl Display for ArraySizeMember {
156159
}
157160
}
158161

162+
/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
163+
///
164+
/// Implements [`core::fmt::Display`], formatting as a name derived from
165+
/// `target_type` and the variable name of `orig`.
166+
#[derive(Clone, Copy)]
167+
struct Reinterpreted<'a> {
168+
target_type: &'a str,
169+
orig: Handle<crate::Expression>,
170+
}
171+
172+
impl<'a> Reinterpreted<'a> {
173+
const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
174+
Self { target_type, orig }
175+
}
176+
}
177+
178+
impl Display for Reinterpreted<'_> {
179+
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
180+
f.write_str(REINTERPRET_PREFIX)?;
181+
f.write_str(self.target_type)?;
182+
self.orig.write_prefixed(f, "_e")
183+
}
184+
}
185+
159186
struct TypeContext<'a> {
160187
handle: Handle<crate::Type>,
161188
gctx: proc::GlobalCtx<'a>,
@@ -1470,14 +1497,14 @@ impl<W: Write> Writer<W> {
14701497

14711498
/// Emit code for the arithmetic expression of the dot product.
14721499
///
1473-
/// The argument `extractor` is a function that accepts a `Writer`, a handle to a vector,
1474-
/// and an index. writes out the expression for the component at that index.
1475-
fn put_dot_product(
1500+
/// The argument `extractor` is a function that accepts a `Writer`, a vector, and
1501+
/// an index. It writes out the expression for the vector component at that index.
1502+
fn put_dot_product<T: Copy>(
14761503
&mut self,
1477-
arg: Handle<crate::Expression>,
1478-
arg1: Handle<crate::Expression>,
1504+
arg: T,
1505+
arg1: T,
14791506
size: usize,
1480-
extractor: impl Fn(&mut Self, Handle<crate::Expression>, usize) -> BackendResult,
1507+
extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
14811508
) -> BackendResult {
14821509
// Write parentheses around the dot product expression to prevent operators
14831510
// with different precedences from applying earlier.
@@ -2206,24 +2233,22 @@ impl<W: Write> Writer<W> {
22062233
),
22072234
},
22082235
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
2209-
let conversion = match fun {
2210-
Mf::Dot4I8Packed => "int",
2211-
Mf::Dot4U8Packed => "",
2236+
// The two function arguments were already reinterpreted as packed (signed
2237+
// or unsigned) chars in `Self::put_block`.
2238+
let packed_type = match fun {
2239+
Mf::Dot4I8Packed => "packed_char4",
2240+
Mf::Dot4U8Packed => "packed_uchar4",
22122241
_ => unreachable!(),
22132242
};
22142243

22152244
return self.put_dot_product(
2216-
arg,
2217-
arg1.unwrap(),
2245+
Reinterpreted::new(packed_type, arg),
2246+
Reinterpreted::new(packed_type, arg1.unwrap()),
22182247
4,
22192248
|writer, arg, index| {
2220-
write!(writer.out, "({}(", conversion)?;
2221-
writer.put_expression(arg, context, true)?;
2222-
if index == 3 {
2223-
write!(writer.out, ") >> 24)")?;
2224-
} else {
2225-
write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
2226-
}
2249+
// MSL implicitly promotes these (signed or unsigned) chars to
2250+
// `int` or `uint` in the multiplication, so no overflow can occur.
2251+
write!(writer.out, "{arg}[{index}]")?;
22272252
Ok(())
22282253
},
22292254
);
@@ -3362,17 +3387,57 @@ impl<W: Write> Writer<W> {
33623387
match *statement {
33633388
crate::Statement::Emit(ref range) => {
33643389
for handle in range.clone() {
3365-
// `ImageLoad` expressions covered by the `Restrict` bounds check policy
3366-
// may need to cache a clamped version of their level-of-detail argument.
3367-
if let crate::Expression::ImageLoad {
3368-
image,
3369-
level: mip_level,
3370-
..
3371-
} = context.expression.function.expressions[handle]
3372-
{
3373-
self.put_cache_restricted_level(
3374-
handle, image, mip_level, level, context,
3375-
)?;
3390+
use crate::MathFunction as Mf;
3391+
3392+
match context.expression.function.expressions[handle] {
3393+
// `ImageLoad` expressions covered by the `Restrict` bounds check policy
3394+
// may need to cache a clamped version of their level-of-detail argument.
3395+
crate::Expression::ImageLoad {
3396+
image,
3397+
level: mip_level,
3398+
..
3399+
} => {
3400+
self.put_cache_restricted_level(
3401+
handle, image, mip_level, level, context,
3402+
)?;
3403+
}
3404+
3405+
// If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` then we
3406+
// introduce two intermediate variables that recast the two arguments as
3407+
// packed (signed or unsigned) chars. The actual dot product is
3408+
// implemented in `Self::put_expression`, and it uses both of these
3409+
// intermediate variables multiple times.
3410+
crate::Expression::Math {
3411+
fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
3412+
arg,
3413+
arg1,
3414+
..
3415+
} => {
3416+
let arg1 = arg1.unwrap();
3417+
let packed_type = match fun {
3418+
Mf::Dot4I8Packed => "packed_char4",
3419+
Mf::Dot4U8Packed => "packed_uchar4",
3420+
_ => unreachable!(),
3421+
};
3422+
3423+
write!(
3424+
self.out,
3425+
"{level}{packed_type} {0} = as_type<{packed_type}>(",
3426+
Reinterpreted::new(packed_type, arg)
3427+
)?;
3428+
self.put_expression(arg, &context.expression, true)?;
3429+
writeln!(self.out, ");")?;
3430+
3431+
write!(
3432+
self.out,
3433+
"{level}{packed_type} {0} = as_type<{packed_type}>(",
3434+
Reinterpreted::new(packed_type, arg1)
3435+
)?;
3436+
self.put_expression(arg1, &context.expression, true)?;
3437+
writeln!(self.out, ");")?;
3438+
}
3439+
3440+
_ => (),
33763441
}
33773442

33783443
let ptr_class = context.expression.resolve_type(handle).pointer_space();

naga/tests/out/msl/wgsl-functions.msl

+12-4
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,22 @@ int test_integer_dot_product(
2929

3030
uint test_packed_integer_dot_product(
3131
) {
32-
int c_5_ = ( + (int(1u) << 24 >> 24) * (int(2u) << 24 >> 24) + (int(1u) << 16 >> 24) * (int(2u) << 16 >> 24) + (int(1u) << 8 >> 24) * (int(2u) << 8 >> 24) + (int(1u) >> 24) * (int(2u) >> 24));
33-
uint c_6_ = ( + ((3u) << 24 >> 24) * ((4u) << 24 >> 24) + ((3u) << 16 >> 24) * ((4u) << 16 >> 24) + ((3u) << 8 >> 24) * ((4u) << 8 >> 24) + ((3u) >> 24) * ((4u) >> 24));
32+
packed_char4 reinterpreted_packed_char4_e0 = as_type<packed_char4>(1u);
33+
packed_char4 reinterpreted_packed_char4_e1 = as_type<packed_char4>(2u);
34+
int c_5_ = ( + reinterpreted_packed_char4_e0[0] * reinterpreted_packed_char4_e1[0] + reinterpreted_packed_char4_e0[1] * reinterpreted_packed_char4_e1[1] + reinterpreted_packed_char4_e0[2] * reinterpreted_packed_char4_e1[2] + reinterpreted_packed_char4_e0[3] * reinterpreted_packed_char4_e1[3]);
35+
packed_uchar4 reinterpreted_packed_uchar4_e3 = as_type<packed_uchar4>(3u);
36+
packed_uchar4 reinterpreted_packed_uchar4_e4 = as_type<packed_uchar4>(4u);
37+
uint c_6_ = ( + reinterpreted_packed_uchar4_e3[0] * reinterpreted_packed_uchar4_e4[0] + reinterpreted_packed_uchar4_e3[1] * reinterpreted_packed_uchar4_e4[1] + reinterpreted_packed_uchar4_e3[2] * reinterpreted_packed_uchar4_e4[2] + reinterpreted_packed_uchar4_e3[3] * reinterpreted_packed_uchar4_e4[3]);
3438
uint _e7 = 5u + c_6_;
3539
uint _e9 = 6u + c_6_;
36-
int c_7_ = ( + (int(_e7) << 24 >> 24) * (int(_e9) << 24 >> 24) + (int(_e7) << 16 >> 24) * (int(_e9) << 16 >> 24) + (int(_e7) << 8 >> 24) * (int(_e9) << 8 >> 24) + (int(_e7) >> 24) * (int(_e9) >> 24));
40+
packed_char4 reinterpreted_packed_char4_e7 = as_type<packed_char4>(_e7);
41+
packed_char4 reinterpreted_packed_char4_e9 = as_type<packed_char4>(_e9);
42+
int c_7_ = ( + reinterpreted_packed_char4_e7[0] * reinterpreted_packed_char4_e9[0] + reinterpreted_packed_char4_e7[1] * reinterpreted_packed_char4_e9[1] + reinterpreted_packed_char4_e7[2] * reinterpreted_packed_char4_e9[2] + reinterpreted_packed_char4_e7[3] * reinterpreted_packed_char4_e9[3]);
3743
uint _e12 = 7u + c_6_;
3844
uint _e14 = 8u + c_6_;
39-
uint c_8_ = ( + ((_e12) << 24 >> 24) * ((_e14) << 24 >> 24) + ((_e12) << 16 >> 24) * ((_e14) << 16 >> 24) + ((_e12) << 8 >> 24) * ((_e14) << 8 >> 24) + ((_e12) >> 24) * ((_e14) >> 24));
45+
packed_uchar4 reinterpreted_packed_uchar4_e12 = as_type<packed_uchar4>(_e12);
46+
packed_uchar4 reinterpreted_packed_uchar4_e14 = as_type<packed_uchar4>(_e14);
47+
uint c_8_ = ( + reinterpreted_packed_uchar4_e12[0] * reinterpreted_packed_uchar4_e14[0] + reinterpreted_packed_uchar4_e12[1] * reinterpreted_packed_uchar4_e14[1] + reinterpreted_packed_uchar4_e12[2] * reinterpreted_packed_uchar4_e14[2] + reinterpreted_packed_uchar4_e12[3] * reinterpreted_packed_uchar4_e14[3]);
4048
return c_8_;
4149
}
4250

0 commit comments

Comments
 (0)