Skip to content

Type-erased Specializers #20192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 13 additions & 109 deletions crates/bevy_render/macros/src/specializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ const SPECIALIZE_ALL_IDENT: &str = "all";
const KEY_ATTR_IDENT: &str = "key";
const KEY_DEFAULT_IDENT: &str = "default";

const BASE_DESCRIPTOR_ATTR_IDENT: &str = "base_descriptor";

enum SpecializeImplTargets {
All,
Specific(Vec<Path>),
Expand Down Expand Up @@ -87,7 +85,6 @@ struct FieldInfo {
ty: Type,
member: Member,
key: Key,
use_base_descriptor: bool,
}

impl FieldInfo {
Expand Down Expand Up @@ -117,15 +114,6 @@ impl FieldInfo {
parse_quote!(#ty: #specialize_path::Specializer<#target_path>)
}
}

fn get_base_descriptor_predicate(
&self,
specialize_path: &Path,
target_path: &Path,
) -> WherePredicate {
let ty = &self.ty;
parse_quote!(#ty: #specialize_path::GetBaseDescriptor<#target_path>)
}
}

fn get_field_info(
Expand All @@ -151,12 +139,8 @@ fn get_field_info(

let mut use_key_field = true;
let mut key = Key::Index(key_index);
let mut use_base_descriptor = false;
for attr in &field.attrs {
match &attr.meta {
Meta::Path(path) if path.is_ident(&BASE_DESCRIPTOR_ATTR_IDENT) => {
use_base_descriptor = true;
}
Meta::List(MetaList { path, tokens, .. }) if path.is_ident(&KEY_ATTR_IDENT) => {
let owned_tokens = tokens.clone().into();
let Ok(parsed_key) = syn::parse::<Key>(owned_tokens) else {
Expand Down Expand Up @@ -190,7 +174,6 @@ fn get_field_info(
ty: field_ty,
member: field_member,
key,
use_base_descriptor,
});
}

Expand Down Expand Up @@ -261,41 +244,18 @@ pub fn impl_specializer(input: TokenStream) -> TokenStream {
})
.collect();

let base_descriptor_fields = field_info
.iter()
.filter(|field| field.use_base_descriptor)
.collect::<Vec<_>>();

if base_descriptor_fields.len() > 1 {
return syn::Error::new(
Span::call_site(),
"Too many #[base_descriptor] attributes found. It must be present on exactly one field",
)
.into_compile_error()
.into();
}

let base_descriptor_field = base_descriptor_fields.first().copied();

match targets {
SpecializeImplTargets::All => {
let specialize_impl = impl_specialize_all(
&specialize_path,
&ecs_path,
&ast,
&field_info,
&key_patterns,
&key_tuple_idents,
);
let get_base_descriptor_impl = base_descriptor_field
.map(|field_info| impl_get_base_descriptor_all(&specialize_path, &ast, field_info))
.unwrap_or_default();
[specialize_impl, get_base_descriptor_impl]
.into_iter()
.collect()
}
SpecializeImplTargets::Specific(targets) => {
let specialize_impls = targets.iter().map(|target| {
SpecializeImplTargets::All => impl_specialize_all(
&specialize_path,
&ecs_path,
&ast,
&field_info,
&key_patterns,
&key_tuple_idents,
),
SpecializeImplTargets::Specific(targets) => targets
.iter()
.map(|target| {
impl_specialize_specific(
&specialize_path,
&ecs_path,
Expand All @@ -305,14 +265,8 @@ pub fn impl_specializer(input: TokenStream) -> TokenStream {
&key_patterns,
&key_tuple_idents,
)
});
let get_base_descriptor_impls = targets.iter().filter_map(|target| {
base_descriptor_field.map(|field_info| {
impl_get_base_descriptor_specific(&specialize_path, &ast, field_info, target)
})
});
specialize_impls.chain(get_base_descriptor_impls).collect()
}
})
.collect(),
}
}

Expand Down Expand Up @@ -406,56 +360,6 @@ fn impl_specialize_specific(
})
}

fn impl_get_base_descriptor_specific(
specialize_path: &Path,
ast: &DeriveInput,
base_descriptor_field_info: &FieldInfo,
target_path: &Path,
) -> TokenStream {
let struct_name = &ast.ident;
let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl();
let field_ty = &base_descriptor_field_info.ty;
let field_member = &base_descriptor_field_info.member;
TokenStream::from(quote!(
impl #impl_generics #specialize_path::GetBaseDescriptor<#target_path> for #struct_name #type_generics #where_clause {
fn get_base_descriptor(&self) -> <#target_path as #specialize_path::Specializable>::Descriptor {
<#field_ty as #specialize_path::GetBaseDescriptor<#target_path>>::get_base_descriptor(&self.#field_member)
}
}
))
}

fn impl_get_base_descriptor_all(
specialize_path: &Path,
ast: &DeriveInput,
base_descriptor_field_info: &FieldInfo,
) -> TokenStream {
let target_path = Path::from(format_ident!("T"));
let struct_name = &ast.ident;
let mut generics = ast.generics.clone();
generics.params.insert(
0,
parse_quote!(#target_path: #specialize_path::Specializable),
);

let where_clause = generics.make_where_clause();
where_clause.predicates.push(
base_descriptor_field_info.get_base_descriptor_predicate(specialize_path, &target_path),
);

let (_, type_generics, _) = ast.generics.split_for_impl();
let (impl_generics, _, where_clause) = &generics.split_for_impl();
let field_ty = &base_descriptor_field_info.ty;
let field_member = &base_descriptor_field_info.member;
TokenStream::from(quote! {
impl #impl_generics #specialize_path::GetBaseDescriptor<#target_path> for #struct_name #type_generics #where_clause {
fn get_base_descriptor(&self) -> <#target_path as #specialize_path::Specializable>::Descriptor {
<#field_ty as #specialize_path::GetBaseDescriptor<#target_path>>::get_base_descriptor(&self.#field_member)
}
}
})
}

pub fn impl_specializer_key(input: TokenStream) -> TokenStream {
let bevy_render_path: Path = crate::bevy_render_path();
let specialize_path = {
Expand Down
3 changes: 3 additions & 0 deletions crates/bevy_render/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ use render_asset::{
extract_render_asset_bytes_per_frame, reset_render_asset_bytes_per_frame,
RenderAssetBytesPerFrame, RenderAssetBytesPerFrameLimiter,
};
use render_resource::init_empty_bind_group_layout;
use renderer::{RenderAdapter, RenderDevice, RenderQueue};
use settings::RenderResources;
use sync_world::{
Expand Down Expand Up @@ -465,6 +466,8 @@ impl Plugin for RenderPlugin {
Render,
reset_render_asset_bytes_per_frame.in_set(RenderSystems::Cleanup),
);

render_app.add_systems(RenderStartup, init_empty_bind_group_layout);
}

app.register_type::<alpha::AlphaMode>()
Expand Down
20 changes: 19 additions & 1 deletion crates/bevy_render/src/render_resource/bind_group_layout.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::define_atomic_id;
use crate::{define_atomic_id, renderer::RenderDevice};
use bevy_ecs::system::Res;
use bevy_platform::sync::OnceLock;
use bevy_utils::WgpuWrapper;
use core::ops::Deref;

Expand Down Expand Up @@ -62,3 +64,19 @@ impl Deref for BindGroupLayout {
&self.value
}
}

static EMPTY_BIND_GROUP_LAYOUT: OnceLock<BindGroupLayout> = OnceLock::new();

pub(crate) fn init_empty_bind_group_layout(render_device: Res<RenderDevice>) {
let layout = render_device.create_bind_group_layout(Some("empty_bind_group_layout"), &[]);
EMPTY_BIND_GROUP_LAYOUT
.set(layout)
.expect("init_empty_bind_group_layout was called more than once");
}

pub fn empty_bind_group_layout() -> BindGroupLayout {
EMPTY_BIND_GROUP_LAYOUT
.get()
.expect("init_empty_bind_group_layout was not called")
.clone()
}
32 changes: 31 additions & 1 deletion crates/bevy_render/src/render_resource/pipeline.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::ShaderDefVal;
use super::{empty_bind_group_layout, ShaderDefVal};
use crate::mesh::VertexBufferLayout;
use crate::{
define_atomic_id,
Expand All @@ -7,7 +7,9 @@ use crate::{
use alloc::borrow::Cow;
use bevy_asset::Handle;
use bevy_utils::WgpuWrapper;
use core::iter;
use core::ops::Deref;
use thiserror::Error;
use wgpu::{
ColorTargetState, DepthStencilState, MultisampleState, PrimitiveState, PushConstantRange,
};
Expand Down Expand Up @@ -112,6 +114,20 @@ pub struct RenderPipelineDescriptor {
pub zero_initialize_workgroup_memory: bool,
}

#[derive(Copy, Clone, Debug, Error)]
#[error("RenderPipelineDescriptor has no FragmentState configured")]
pub struct NoFragmentStateError;

impl RenderPipelineDescriptor {
pub fn fragment_mut(&mut self) -> Result<&mut FragmentState, NoFragmentStateError> {
self.fragment.as_mut().ok_or(NoFragmentStateError)
}

pub fn set_layout(&mut self, index: usize, layout: BindGroupLayout) {
filling_set_at(&mut self.layout, index, empty_bind_group_layout(), layout);
}
}

#[derive(Clone, Debug, Eq, PartialEq, Default)]
pub struct VertexState {
/// The compiled shader module for this stage.
Expand All @@ -137,6 +153,12 @@ pub struct FragmentState {
pub targets: Vec<Option<ColorTargetState>>,
}

impl FragmentState {
pub fn set_target(&mut self, index: usize, target: ColorTargetState) {
filling_set_at(&mut self.targets, index, None, Some(target));
}
}

/// Describes a compute pipeline.
#[derive(Clone, Debug, PartialEq, Eq, Default)]
pub struct ComputePipelineDescriptor {
Expand All @@ -153,3 +175,11 @@ pub struct ComputePipelineDescriptor {
/// If this is false, reading from workgroup variables before writing to them will result in garbage values.
pub zero_initialize_workgroup_memory: bool,
}

// utility function to set a value at the specified index, extending with
// a filler value if the index is out of bounds.
fn filling_set_at<T: Clone>(vec: &mut Vec<T>, index: usize, filler: T, value: T) {
let num_to_fill = (index + 1).saturating_sub(vec.len());
vec.extend(iter::repeat_n(filler, num_to_fill));
vec[index] = value;
}
Loading
Loading