Skip to content

Commit 991856b

Browse files
WIP: fix(naga): properly impl. auto. type conv. for select
1 parent 78d05f5 commit 991856b

File tree

6 files changed

+117
-45
lines changed

6 files changed

+117
-45
lines changed

naga/src/front/wgsl/lower/mod.rs

+50-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use alloc::{
44
string::{String, ToString},
55
vec::Vec,
66
};
7+
use arrayvec::ArrayVec;
78
use core::num::NonZeroU32;
89

910
use crate::common::wgsl::{TryToWgsl, TypeContext};
@@ -2478,13 +2479,58 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
24782479
} else {
24792480
match function.name {
24802481
"select" => {
2481-
let mut args = ctx.prepare_args(arguments, 3, span);
2482+
const NUM_ARGS: usize = 3;
2483+
2484+
// TODO: dedupe with `math_function_helper`
24822485

2483-
let reject = self.expression(args.next()?, ctx)?;
2484-
let accept = self.expression(args.next()?, ctx)?;
2485-
let condition = self.expression(args.next()?, ctx)?;
2486+
let mut lowered_arguments = ArrayVec::<_, NUM_ARGS>::new();
2487+
let mut args = ctx.prepare_args(arguments, NUM_ARGS as u32, span);
2488+
2489+
for _ in 0..lowered_arguments.capacity() {
2490+
let lowered = self.expression_for_abstract(args.next()?, ctx)?;
2491+
ctx.grow_types(lowered)?;
2492+
lowered_arguments.push(lowered);
2493+
}
24862494

24872495
args.finish()?;
2496+
let mut lowered_arguments = lowered_arguments.into_inner().unwrap();
2497+
2498+
let fun_overloads = crate::proc::select();
2499+
2500+
#[derive(Debug, Clone, Copy)]
2501+
struct Select;
2502+
2503+
impl TryToWgsl for Select {
2504+
fn try_to_wgsl(self) -> Option<&'static str> {
2505+
Some("select")
2506+
}
2507+
2508+
const DESCRIPTION: &'static str = "`select` built-in";
2509+
}
2510+
let rule = self.resolve_overloads(
2511+
span,
2512+
Select,
2513+
fun_overloads,
2514+
&lowered_arguments,
2515+
ctx,
2516+
)?;
2517+
2518+
self.apply_automatic_conversions_for_call(
2519+
&rule,
2520+
&mut lowered_arguments,
2521+
ctx,
2522+
)?;
2523+
2524+
// If this function returns a predeclared type, register it
2525+
// in `Module::special_types`. The typifier will expect to
2526+
// be able to find it there.
2527+
if let crate::proc::Conclusion::Predeclared(predeclared) =
2528+
rule.conclusion
2529+
{
2530+
ctx.module.generate_predeclared_type(predeclared);
2531+
}
2532+
2533+
let [reject, accept, condition] = lowered_arguments;
24882534

24892535
crate::Expression::Select {
24902536
reject,

naga/src/proc/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pub use emitter::Emitter;
1919
pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
2020
pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
2121
pub use namer::{EntryPointIndex, NameKey, Namer};
22-
pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
22+
pub use overloads::{select, Conclusion, MissingSpecialType, OverloadSet, Rule};
2323
pub use terminator::ensure_block_returns;
2424
use thiserror::Error;
2525
pub use type_methods::min_max_float_representable_by;

naga/src/proc/overloads/mathfunction.rs

+18-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::proc::overloads::list::List;
55
use crate::proc::overloads::regular::regular;
66
use crate::proc::overloads::utils::{
77
concrete_int_scalars, float_scalars, float_scalars_unimplemented_abstract, list, pairs, rule,
8-
scalar_or_vecn, triples, vector_sizes,
8+
scalar_or_vecn, scalars, triples, vector_sizes,
99
};
1010
use crate::proc::overloads::OverloadSet;
1111

@@ -187,6 +187,23 @@ fn transpose() -> List {
187187
)
188188
}
189189

190+
pub fn select() -> impl OverloadSet {
191+
let bool_arg = |input| match input {
192+
ir::TypeInner::Scalar(_) => ir::TypeInner::Scalar(ir::Scalar::BOOL),
193+
ir::TypeInner::Vector { size, scalar: _ } => ir::TypeInner::Vector {
194+
size,
195+
scalar: ir::Scalar::BOOL,
196+
},
197+
_ => unreachable!(),
198+
};
199+
list(scalars().flat_map(|scalar| {
200+
scalar_or_vecn(scalar).map(|input| {
201+
let bool_arg = bool_arg(input.clone());
202+
rule([input.clone(), input.clone(), bool_arg], input)
203+
})
204+
}))
205+
}
206+
190207
fn extract_bits() -> List {
191208
list(concrete_int_scalars().flat_map(|scalar| {
192209
scalar_or_vecn(scalar).map(|input| {

naga/src/proc/overloads/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,5 @@ pub trait OverloadSet: Clone {
235235
/// Return an object that can be formatted with [`core::fmt::Debug`].
236236
fn for_debug(&self, types: &crate::UniqueArena<ir::Type>) -> impl fmt::Debug;
237237
}
238+
239+
pub use mathfunction::select;

naga/src/proc/overloads/utils.rs

+19
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,25 @@ pub fn float_scalars() -> impl Iterator<Item = ir::Scalar> + Clone {
3434
.into_iter()
3535
}
3636

37+
/// Produce all [`ir::Scalar`]s.
38+
///
39+
/// Note that `*32` and `F16` must appear before other sizes; this is how we
40+
/// represent conversion rank.
41+
pub fn scalars() -> impl Iterator<Item = ir::Scalar> + Clone {
42+
[
43+
ir::Scalar::ABSTRACT_INT,
44+
ir::Scalar::ABSTRACT_FLOAT,
45+
ir::Scalar::I32,
46+
ir::Scalar::U32,
47+
ir::Scalar::F32,
48+
ir::Scalar::F16,
49+
ir::Scalar::I64,
50+
ir::Scalar::U64,
51+
ir::Scalar::F64,
52+
]
53+
.into_iter()
54+
}
55+
3756
/// Produce all the floating-point [`ir::Scalar`]s, but omit
3857
/// abstract types, for #7405.
3958
pub fn float_scalars_unimplemented_abstract() -> impl Iterator<Item = ir::Scalar> + Clone {

naga/src/valid/expression.rs

+27-39
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ pub enum ExpressionError {
129129
WrongArgumentCount(crate::MathFunction),
130130
#[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
131131
InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
132+
// TODO: dedupe with above
133+
#[error("Argument [{0}] to `select` as expression {1:?} has an invalid type.")]
134+
InvalidArgumentTypeSelect(u32, Handle<crate::Expression>),
132135
#[error(
133136
"workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
134137
)]
@@ -926,47 +929,32 @@ impl super::Validator {
926929
accept,
927930
reject,
928931
} => {
929-
let accept_inner = &resolver[accept];
930-
let reject_inner = &resolver[reject];
931-
let condition_ty = &resolver[condition];
932-
let condition_good = match *condition_ty {
933-
Ti::Scalar(Sc {
934-
kind: Sk::Bool,
935-
width: _,
936-
}) => {
937-
// When `condition` is a single boolean, `accept` and
938-
// `reject` can be vectors or scalars.
939-
match *accept_inner {
940-
Ti::Scalar { .. } | Ti::Vector { .. } => true,
941-
_ => false,
942-
}
932+
// TODO: dedupe with math functions
933+
934+
let mut overloads = crate::proc::select();
935+
log::debug!(
936+
"initial overloads for `select`: {:#?}",
937+
overloads.for_debug(&module.types)
938+
);
939+
940+
for (i, (expr, ty)) in [reject, accept, condition]
941+
.iter()
942+
.copied()
943+
.map(|arg| (arg, &resolver[arg]))
944+
.enumerate()
945+
{
946+
overloads = overloads.arg(i, ty, &module.types);
947+
log::debug!(
948+
"overloads after arg {i}: {:#?}",
949+
overloads.for_debug(&module.types)
950+
);
951+
952+
if overloads.is_empty() {
953+
log::debug!("all overloads eliminated");
954+
return Err(ExpressionError::InvalidArgumentTypeSelect(i as u32, expr));
943955
}
944-
Ti::Vector {
945-
size,
946-
scalar:
947-
Sc {
948-
kind: Sk::Bool,
949-
width: _,
950-
},
951-
} => match *accept_inner {
952-
Ti::Vector {
953-
size: other_size, ..
954-
} => size == other_size,
955-
_ => false,
956-
},
957-
_ => false,
958-
};
959-
if accept_inner != reject_inner {
960-
return Err(ExpressionError::SelectValuesTypeMismatch {
961-
accept: accept_inner.clone(),
962-
reject: reject_inner.clone(),
963-
});
964-
}
965-
if !condition_good {
966-
return Err(ExpressionError::SelectConditionNotABool {
967-
actual: condition_ty.clone(),
968-
});
969956
}
957+
970958
ShaderStages::all()
971959
}
972960
E::Derivative { expr, .. } => {

0 commit comments

Comments
 (0)