Skip to content

Commit

Permalink
add storage_texture option to as_bind_group macro (#9943)
Browse files Browse the repository at this point in the history
# Objective

- Add the ability to describe storage texture bindings when deriving
`AsBindGroup`.
- This is especially valuable for the compute story of bevy which
deserves some extra love imo.

## Solution

- This add the ability to annotate struct fields with a
`#[storage_texture(0)]` annotation.
- Instead of adding specific option parsing for all the image formats
and access modes, I simply accept a token stream and defer checking to
see if the option is valid to the compiler. This still results in useful
and friendly errors and is free to maintain and always compatible with
wgpu changes.

---

## Changelog

- The `#[storage_texture(..)]` annotation is now accepted for fields of
`Handle<Image>` in structs that derive `AsBindGroup`.
- The game_of_life compute shader example has been updated to use
`AsBindGroup` together with `[storage_texture(..)]` to obtain the
`BindGroupLayout`.

## Migration Guide
  • Loading branch information
HugoPeters1024 authored Jan 21, 2024
1 parent 0fa14c8 commit 8afb3ce
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 13 deletions.
119 changes: 119 additions & 0 deletions crates/bevy_render/macros/src/as_bind_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use syn::{

const UNIFORM_ATTRIBUTE_NAME: Symbol = Symbol("uniform");
const TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("texture");
const STORAGE_TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("storage_texture");
const SAMPLER_ATTRIBUTE_NAME: Symbol = Symbol("sampler");
const STORAGE_ATTRIBUTE_NAME: Symbol = Symbol("storage");
const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
Expand All @@ -19,6 +20,7 @@ const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
enum BindingType {
Uniform,
Texture,
StorageTexture,
Sampler,
Storage,
}
Expand Down Expand Up @@ -133,6 +135,8 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
BindingType::Uniform
} else if attr_ident == TEXTURE_ATTRIBUTE_NAME {
BindingType::Texture
} else if attr_ident == STORAGE_TEXTURE_ATTRIBUTE_NAME {
BindingType::StorageTexture
} else if attr_ident == SAMPLER_ATTRIBUTE_NAME {
BindingType::Sampler
} else if attr_ident == STORAGE_ATTRIBUTE_NAME {
Expand Down Expand Up @@ -255,6 +259,45 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
}
});
}
BindingType::StorageTexture => {
let StorageTextureAttrs {
dimension,
image_format,
access,
visibility,
} = get_storage_texture_binding_attr(nested_meta_items)?;

let visibility =
visibility.hygienic_quote(&quote! { #render_path::render_resource });

let fallback_image = get_fallback_image(&render_path, dimension);

binding_impls.push(quote! {
( #binding_index,
#render_path::render_resource::OwnedBindingResource::TextureView({
let handle: Option<&#asset_path::Handle<#render_path::texture::Image>> = (&self.#field_name).into();
if let Some(handle) = handle {
images.get(handle).ok_or_else(|| #render_path::render_resource::AsBindGroupError::RetryNextUpdate)?.texture_view.clone()
} else {
#fallback_image.texture_view.clone()
}
})
)
});

binding_layouts.push(quote! {
#render_path::render_resource::BindGroupLayoutEntry {
binding: #binding_index,
visibility: #visibility,
ty: #render_path::render_resource::BindingType::StorageTexture {
access: #render_path::render_resource::StorageTextureAccess::#access,
format: #render_path::render_resource::TextureFormat::#image_format,
view_dimension: #render_path::render_resource::#dimension,
},
count: None,
}
});
}
BindingType::Texture => {
let TextureAttrs {
dimension,
Expand Down Expand Up @@ -585,6 +628,10 @@ impl ShaderStageVisibility {
fn vertex_fragment() -> Self {
Self::Flags(VisibilityFlags::vertex_fragment())
}

fn compute() -> Self {
Self::Flags(VisibilityFlags::compute())
}
}

impl VisibilityFlags {
Expand All @@ -595,6 +642,13 @@ impl VisibilityFlags {
..Default::default()
}
}

fn compute() -> Self {
Self {
compute: true,
..Default::default()
}
}
}

impl ShaderStageVisibility {
Expand Down Expand Up @@ -741,7 +795,72 @@ impl Default for TextureAttrs {
}
}

struct StorageTextureAttrs {
dimension: BindingTextureDimension,
// Parsing of the image_format parameter is deferred to the type checker,
// which will error if the format is not member of the TextureFormat enum.
image_format: proc_macro2::TokenStream,
// Parsing of the access parameter is deferred to the type checker,
// which will error if the access is not member of the StorageTextureAccess enum.
access: proc_macro2::TokenStream,
visibility: ShaderStageVisibility,
}

impl Default for StorageTextureAttrs {
fn default() -> Self {
Self {
dimension: Default::default(),
image_format: quote! { Rgba8Unorm },
access: quote! { ReadWrite },
visibility: ShaderStageVisibility::compute(),
}
}
}

fn get_storage_texture_binding_attr(metas: Vec<Meta>) -> Result<StorageTextureAttrs> {
let mut storage_texture_attrs = StorageTextureAttrs::default();

for meta in metas {
use syn::Meta::{List, NameValue};
match meta {
// Parse #[storage_texture(0, dimension = "...")].
NameValue(m) if m.path == DIMENSION => {
let value = get_lit_str(DIMENSION, &m.value)?;
storage_texture_attrs.dimension = get_texture_dimension_value(value)?;
}
// Parse #[storage_texture(0, format = ...))].
NameValue(m) if m.path == IMAGE_FORMAT => {
storage_texture_attrs.image_format = m.value.into_token_stream();
}
// Parse #[storage_texture(0, access = ...))].
NameValue(m) if m.path == ACCESS => {
storage_texture_attrs.access = m.value.into_token_stream();
}
// Parse #[storage_texture(0, visibility(...))].
List(m) if m.path == VISIBILITY => {
storage_texture_attrs.visibility = get_visibility_flag_value(&m)?;
}
NameValue(m) => {
return Err(Error::new_spanned(
m.path,
"Not a valid name. Available attributes: `dimension`, `image_format`, `access`.",
));
}
_ => {
return Err(Error::new_spanned(
meta,
"Not a name value pair: `foo = \"...\"`",
));
}
}
}

Ok(storage_texture_attrs)
}

const DIMENSION: Symbol = Symbol("dimension");
const IMAGE_FORMAT: Symbol = Symbol("image_format");
const ACCESS: Symbol = Symbol("access");
const SAMPLE_TYPE: Symbol = Symbol("sample_type");
const FILTERABLE: Symbol = Symbol("filterable");
const MULTISAMPLED: Symbol = Symbol("multisampled");
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_render/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub fn derive_extract_component(input: TokenStream) -> TokenStream {

#[proc_macro_derive(
AsBindGroup,
attributes(uniform, texture, sampler, bind_group_data, storage)
attributes(uniform, storage_texture, texture, sampler, bind_group_data, storage)
)]
pub fn derive_as_bind_group(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
Expand Down
16 changes: 16 additions & 0 deletions crates/bevy_render/src/render_resource/bind_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ impl Deref for BindGroup {
/// values: Vec<f32>,
/// #[storage(4, read_only, buffer)]
/// buffer: Buffer,
/// #[storage_texture(5)]
/// storage_texture: Handle<Image>,
/// }
/// ```
///
Expand All @@ -97,6 +99,7 @@ impl Deref for BindGroup {
/// @group(2) @binding(1) var color_texture: texture_2d<f32>;
/// @group(2) @binding(2) var color_sampler: sampler;
/// @group(2) @binding(3) var<storage> values: array<f32>;
/// @group(2) @binding(5) var storage_texture: texture_storage_2d<rgba8unorm, read_write>;
/// ```
/// Note that the "group" index is determined by the usage context. It is not defined in [`AsBindGroup`]. For example, in Bevy material bind groups
/// are generally bound to group 2.
Expand All @@ -123,6 +126,19 @@ impl Deref for BindGroup {
/// | `multisampled` = ... | `true`, `false` | `false` |
/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `vertex`, `fragment` |
///
/// * `storage_texture(BINDING_INDEX, arguments)`
/// * This field's [`Handle<Image>`](bevy_asset::Handle) will be used to look up the matching [`Texture`](crate::render_resource::Texture)
/// GPU resource, which will be bound as a storage texture in shaders. The field will be assumed to implement [`Into<Option<Handle<Image>>>`]. In practice,
/// most fields should be a [`Handle<Image>`](bevy_asset::Handle) or [`Option<Handle<Image>>`]. If the value of an [`Option<Handle<Image>>`] is
/// [`None`], the [`FallbackImage`] resource will be used instead.
///
/// | Arguments | Values | Default |
/// |------------------------|--------------------------------------------------------------------------------------------|---------------|
/// | `dimension` = "..." | `"1d"`, `"2d"`, `"2d_array"`, `"3d"`, `"cube"`, `"cube_array"` | `"2d"` |
/// | `image_format` = ... | any member of [`TextureFormat`](crate::render_resource::TextureFormat) | `Rgba8Unorm` |
/// | `access` = ... | any member of [`StorageTextureAccess`](crate::render_resource::StorageTextureAccess) | `ReadWrite` |
/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `compute` |
///
/// * `sampler(BINDING_INDEX, arguments)`
/// * This field's [`Handle<Image>`](bevy_asset::Handle) will be used to look up the matching [`Sampler`] GPU
/// resource, which will be bound as a sampler in shaders. The field will be assumed to implement [`Into<Option<Handle<Image>>>`]. In practice,
Expand Down
22 changes: 10 additions & 12 deletions examples/shader/compute_shader_game_of_life.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use bevy::{
render_asset::RenderAssetPersistencePolicy,
render_asset::RenderAssets,
render_graph::{self, RenderGraph},
render_resource::{binding_types::texture_storage_2d, *},
render_resource::*,
renderer::{RenderContext, RenderDevice},
Render, RenderApp, RenderSet,
},
Expand Down Expand Up @@ -65,7 +65,7 @@ fn setup(mut commands: Commands, mut images: ResMut<Assets<Image>>) {
});
commands.spawn(Camera2dBundle::default());

commands.insert_resource(GameOfLifeImage(image));
commands.insert_resource(GameOfLifeImage { texture: image });
}

pub struct GameOfLifeComputePlugin;
Expand Down Expand Up @@ -95,8 +95,11 @@ impl Plugin for GameOfLifeComputePlugin {
}
}

#[derive(Resource, Clone, Deref, ExtractResource)]
struct GameOfLifeImage(Handle<Image>);
#[derive(Resource, Clone, Deref, ExtractResource, AsBindGroup)]
struct GameOfLifeImage {
#[storage_texture(0, image_format = Rgba8Unorm, access = ReadWrite)]
texture: Handle<Image>,
}

#[derive(Resource)]
struct GameOfLifeImageBindGroup(BindGroup);
Expand All @@ -108,7 +111,7 @@ fn prepare_bind_group(
game_of_life_image: Res<GameOfLifeImage>,
render_device: Res<RenderDevice>,
) {
let view = gpu_images.get(&game_of_life_image.0).unwrap();
let view = gpu_images.get(&game_of_life_image.texture).unwrap();
let bind_group = render_device.create_bind_group(
None,
&pipeline.texture_bind_group_layout,
Expand All @@ -126,13 +129,8 @@ pub struct GameOfLifePipeline {

impl FromWorld for GameOfLifePipeline {
fn from_world(world: &mut World) -> Self {
let texture_bind_group_layout = world.resource::<RenderDevice>().create_bind_group_layout(
None,
&BindGroupLayoutEntries::single(
ShaderStages::COMPUTE,
texture_storage_2d(TextureFormat::Rgba8Unorm, StorageTextureAccess::ReadWrite),
),
);
let render_device = world.resource::<RenderDevice>();
let texture_bind_group_layout = GameOfLifeImage::bind_group_layout(render_device);
let shader = world
.resource::<AssetServer>()
.load("shaders/game_of_life.wgsl");
Expand Down

0 comments on commit 8afb3ce

Please sign in to comment.