From 8afb3ceb89a7b97bf5a46737abbe98306da14c8f Mon Sep 17 00:00:00 2001 From: HugoPeters1024 Date: Sun, 21 Jan 2024 19:47:13 +0100 Subject: [PATCH] add `storage_texture` option to as_bind_group macro (#9943) # Objective - Add the ability to describe storage texture bindings when deriving `AsBindGroup`. - This is especially valuable for the compute story of bevy which deserves some extra love imo. ## Solution - This add the ability to annotate struct fields with a `#[storage_texture(0)]` annotation. - Instead of adding specific option parsing for all the image formats and access modes, I simply accept a token stream and defer checking to see if the option is valid to the compiler. This still results in useful and friendly errors and is free to maintain and always compatible with wgpu changes. --- ## Changelog - The `#[storage_texture(..)]` annotation is now accepted for fields of `Handle` in structs that derive `AsBindGroup`. - The game_of_life compute shader example has been updated to use `AsBindGroup` together with `[storage_texture(..)]` to obtain the `BindGroupLayout`. ## Migration Guide --- .../bevy_render/macros/src/as_bind_group.rs | 119 ++++++++++++++++++ crates/bevy_render/macros/src/lib.rs | 2 +- .../src/render_resource/bind_group.rs | 16 +++ .../shader/compute_shader_game_of_life.rs | 22 ++-- 4 files changed, 146 insertions(+), 13 deletions(-) diff --git a/crates/bevy_render/macros/src/as_bind_group.rs b/crates/bevy_render/macros/src/as_bind_group.rs index 9e750fd57ed77..b1c7a12e0b0bc 100644 --- a/crates/bevy_render/macros/src/as_bind_group.rs +++ b/crates/bevy_render/macros/src/as_bind_group.rs @@ -11,6 +11,7 @@ use syn::{ const UNIFORM_ATTRIBUTE_NAME: Symbol = Symbol("uniform"); const TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("texture"); +const STORAGE_TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("storage_texture"); const SAMPLER_ATTRIBUTE_NAME: Symbol = Symbol("sampler"); const STORAGE_ATTRIBUTE_NAME: Symbol = Symbol("storage"); const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data"); @@ -19,6 +20,7 @@ const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data"); enum BindingType { Uniform, Texture, + StorageTexture, Sampler, Storage, } @@ -133,6 +135,8 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result { BindingType::Uniform } else if attr_ident == TEXTURE_ATTRIBUTE_NAME { BindingType::Texture + } else if attr_ident == STORAGE_TEXTURE_ATTRIBUTE_NAME { + BindingType::StorageTexture } else if attr_ident == SAMPLER_ATTRIBUTE_NAME { BindingType::Sampler } else if attr_ident == STORAGE_ATTRIBUTE_NAME { @@ -255,6 +259,45 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result { } }); } + BindingType::StorageTexture => { + let StorageTextureAttrs { + dimension, + image_format, + access, + visibility, + } = get_storage_texture_binding_attr(nested_meta_items)?; + + let visibility = + visibility.hygienic_quote("e! { #render_path::render_resource }); + + let fallback_image = get_fallback_image(&render_path, dimension); + + binding_impls.push(quote! { + ( #binding_index, + #render_path::render_resource::OwnedBindingResource::TextureView({ + let handle: Option<&#asset_path::Handle<#render_path::texture::Image>> = (&self.#field_name).into(); + if let Some(handle) = handle { + images.get(handle).ok_or_else(|| #render_path::render_resource::AsBindGroupError::RetryNextUpdate)?.texture_view.clone() + } else { + #fallback_image.texture_view.clone() + } + }) + ) + }); + + binding_layouts.push(quote! { + #render_path::render_resource::BindGroupLayoutEntry { + binding: #binding_index, + visibility: #visibility, + ty: #render_path::render_resource::BindingType::StorageTexture { + access: #render_path::render_resource::StorageTextureAccess::#access, + format: #render_path::render_resource::TextureFormat::#image_format, + view_dimension: #render_path::render_resource::#dimension, + }, + count: None, + } + }); + } BindingType::Texture => { let TextureAttrs { dimension, @@ -585,6 +628,10 @@ impl ShaderStageVisibility { fn vertex_fragment() -> Self { Self::Flags(VisibilityFlags::vertex_fragment()) } + + fn compute() -> Self { + Self::Flags(VisibilityFlags::compute()) + } } impl VisibilityFlags { @@ -595,6 +642,13 @@ impl VisibilityFlags { ..Default::default() } } + + fn compute() -> Self { + Self { + compute: true, + ..Default::default() + } + } } impl ShaderStageVisibility { @@ -741,7 +795,72 @@ impl Default for TextureAttrs { } } +struct StorageTextureAttrs { + dimension: BindingTextureDimension, + // Parsing of the image_format parameter is deferred to the type checker, + // which will error if the format is not member of the TextureFormat enum. + image_format: proc_macro2::TokenStream, + // Parsing of the access parameter is deferred to the type checker, + // which will error if the access is not member of the StorageTextureAccess enum. + access: proc_macro2::TokenStream, + visibility: ShaderStageVisibility, +} + +impl Default for StorageTextureAttrs { + fn default() -> Self { + Self { + dimension: Default::default(), + image_format: quote! { Rgba8Unorm }, + access: quote! { ReadWrite }, + visibility: ShaderStageVisibility::compute(), + } + } +} + +fn get_storage_texture_binding_attr(metas: Vec) -> Result { + let mut storage_texture_attrs = StorageTextureAttrs::default(); + + for meta in metas { + use syn::Meta::{List, NameValue}; + match meta { + // Parse #[storage_texture(0, dimension = "...")]. + NameValue(m) if m.path == DIMENSION => { + let value = get_lit_str(DIMENSION, &m.value)?; + storage_texture_attrs.dimension = get_texture_dimension_value(value)?; + } + // Parse #[storage_texture(0, format = ...))]. + NameValue(m) if m.path == IMAGE_FORMAT => { + storage_texture_attrs.image_format = m.value.into_token_stream(); + } + // Parse #[storage_texture(0, access = ...))]. + NameValue(m) if m.path == ACCESS => { + storage_texture_attrs.access = m.value.into_token_stream(); + } + // Parse #[storage_texture(0, visibility(...))]. + List(m) if m.path == VISIBILITY => { + storage_texture_attrs.visibility = get_visibility_flag_value(&m)?; + } + NameValue(m) => { + return Err(Error::new_spanned( + m.path, + "Not a valid name. Available attributes: `dimension`, `image_format`, `access`.", + )); + } + _ => { + return Err(Error::new_spanned( + meta, + "Not a name value pair: `foo = \"...\"`", + )); + } + } + } + + Ok(storage_texture_attrs) +} + const DIMENSION: Symbol = Symbol("dimension"); +const IMAGE_FORMAT: Symbol = Symbol("image_format"); +const ACCESS: Symbol = Symbol("access"); const SAMPLE_TYPE: Symbol = Symbol("sample_type"); const FILTERABLE: Symbol = Symbol("filterable"); const MULTISAMPLED: Symbol = Symbol("multisampled"); diff --git a/crates/bevy_render/macros/src/lib.rs b/crates/bevy_render/macros/src/lib.rs index 89eec6b220c9a..97126ba830bf4 100644 --- a/crates/bevy_render/macros/src/lib.rs +++ b/crates/bevy_render/macros/src/lib.rs @@ -51,7 +51,7 @@ pub fn derive_extract_component(input: TokenStream) -> TokenStream { #[proc_macro_derive( AsBindGroup, - attributes(uniform, texture, sampler, bind_group_data, storage) + attributes(uniform, storage_texture, texture, sampler, bind_group_data, storage) )] pub fn derive_as_bind_group(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); diff --git a/crates/bevy_render/src/render_resource/bind_group.rs b/crates/bevy_render/src/render_resource/bind_group.rs index fc4f5c0d7d608..03e98abbd8818 100644 --- a/crates/bevy_render/src/render_resource/bind_group.rs +++ b/crates/bevy_render/src/render_resource/bind_group.rs @@ -87,6 +87,8 @@ impl Deref for BindGroup { /// values: Vec, /// #[storage(4, read_only, buffer)] /// buffer: Buffer, +/// #[storage_texture(5)] +/// storage_texture: Handle, /// } /// ``` /// @@ -97,6 +99,7 @@ impl Deref for BindGroup { /// @group(2) @binding(1) var color_texture: texture_2d; /// @group(2) @binding(2) var color_sampler: sampler; /// @group(2) @binding(3) var values: array; +/// @group(2) @binding(5) var storage_texture: texture_storage_2d; /// ``` /// Note that the "group" index is determined by the usage context. It is not defined in [`AsBindGroup`]. For example, in Bevy material bind groups /// are generally bound to group 2. @@ -123,6 +126,19 @@ impl Deref for BindGroup { /// | `multisampled` = ... | `true`, `false` | `false` | /// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `vertex`, `fragment` | /// +/// * `storage_texture(BINDING_INDEX, arguments)` +/// * This field's [`Handle`](bevy_asset::Handle) will be used to look up the matching [`Texture`](crate::render_resource::Texture) +/// GPU resource, which will be bound as a storage texture in shaders. The field will be assumed to implement [`Into>>`]. In practice, +/// most fields should be a [`Handle`](bevy_asset::Handle) or [`Option>`]. If the value of an [`Option>`] is +/// [`None`], the [`FallbackImage`] resource will be used instead. +/// +/// | Arguments | Values | Default | +/// |------------------------|--------------------------------------------------------------------------------------------|---------------| +/// | `dimension` = "..." | `"1d"`, `"2d"`, `"2d_array"`, `"3d"`, `"cube"`, `"cube_array"` | `"2d"` | +/// | `image_format` = ... | any member of [`TextureFormat`](crate::render_resource::TextureFormat) | `Rgba8Unorm` | +/// | `access` = ... | any member of [`StorageTextureAccess`](crate::render_resource::StorageTextureAccess) | `ReadWrite` | +/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `compute` | +/// /// * `sampler(BINDING_INDEX, arguments)` /// * This field's [`Handle`](bevy_asset::Handle) will be used to look up the matching [`Sampler`] GPU /// resource, which will be bound as a sampler in shaders. The field will be assumed to implement [`Into>>`]. In practice, diff --git a/examples/shader/compute_shader_game_of_life.rs b/examples/shader/compute_shader_game_of_life.rs index df16d67b178a4..8248e4e9b14f3 100644 --- a/examples/shader/compute_shader_game_of_life.rs +++ b/examples/shader/compute_shader_game_of_life.rs @@ -10,7 +10,7 @@ use bevy::{ render_asset::RenderAssetPersistencePolicy, render_asset::RenderAssets, render_graph::{self, RenderGraph}, - render_resource::{binding_types::texture_storage_2d, *}, + render_resource::*, renderer::{RenderContext, RenderDevice}, Render, RenderApp, RenderSet, }, @@ -65,7 +65,7 @@ fn setup(mut commands: Commands, mut images: ResMut>) { }); commands.spawn(Camera2dBundle::default()); - commands.insert_resource(GameOfLifeImage(image)); + commands.insert_resource(GameOfLifeImage { texture: image }); } pub struct GameOfLifeComputePlugin; @@ -95,8 +95,11 @@ impl Plugin for GameOfLifeComputePlugin { } } -#[derive(Resource, Clone, Deref, ExtractResource)] -struct GameOfLifeImage(Handle); +#[derive(Resource, Clone, Deref, ExtractResource, AsBindGroup)] +struct GameOfLifeImage { + #[storage_texture(0, image_format = Rgba8Unorm, access = ReadWrite)] + texture: Handle, +} #[derive(Resource)] struct GameOfLifeImageBindGroup(BindGroup); @@ -108,7 +111,7 @@ fn prepare_bind_group( game_of_life_image: Res, render_device: Res, ) { - let view = gpu_images.get(&game_of_life_image.0).unwrap(); + let view = gpu_images.get(&game_of_life_image.texture).unwrap(); let bind_group = render_device.create_bind_group( None, &pipeline.texture_bind_group_layout, @@ -126,13 +129,8 @@ pub struct GameOfLifePipeline { impl FromWorld for GameOfLifePipeline { fn from_world(world: &mut World) -> Self { - let texture_bind_group_layout = world.resource::().create_bind_group_layout( - None, - &BindGroupLayoutEntries::single( - ShaderStages::COMPUTE, - texture_storage_2d(TextureFormat::Rgba8Unorm, StorageTextureAccess::ReadWrite), - ), - ); + let render_device = world.resource::(); + let texture_bind_group_layout = GameOfLifeImage::bind_group_layout(render_device); let shader = world .resource::() .load("shaders/game_of_life.wgsl");