From 1484b28ab85ca0fa207449fcecab5d348650c261 Mon Sep 17 00:00:00 2001 From: Erich Gubler Date: Mon, 29 Jan 2024 21:41:55 -0500 Subject: [PATCH] WIP: feat(shader)!: make `ProgrammableStage::entry_point` optional TODO: break out shader stage conversion (`shader_stage_from_stage_bit`) --- deno_webgpu/pipeline.rs | 12 +-- examples/src/boids/mod.rs | 6 +- examples/src/bunnymark/mod.rs | 4 +- examples/src/conservative_raster/mod.rs | 16 ++-- examples/src/cube/mod.rs | 8 +- examples/src/hello_compute/mod.rs | 2 +- examples/src/hello_synchronization/mod.rs | 4 +- examples/src/hello_triangle/mod.rs | 4 +- examples/src/hello_workgroups/mod.rs | 2 +- examples/src/mipmap/mod.rs | 8 +- examples/src/msaa_line/mod.rs | 4 +- examples/src/render_to_texture/mod.rs | 4 +- examples/src/repeated_compute/mod.rs | 2 +- examples/src/shadow/mod.rs | 8 +- examples/src/skybox/mod.rs | 8 +- examples/src/srgb_blend/mod.rs | 4 +- examples/src/stencil_triangles/mod.rs | 8 +- examples/src/storage_texture/mod.rs | 2 +- examples/src/texture_arrays/mod.rs | 4 +- examples/src/timestamp_queries/mod.rs | 6 +- examples/src/uniform_values/mod.rs | 4 +- examples/src/water/mod.rs | 8 +- player/tests/data/bind-group.ron | 2 +- .../tests/data/pipeline-statistics-query.ron | 2 +- player/tests/data/quad.ron | 4 +- player/tests/data/zero-init-buffer.ron | 2 +- .../tests/data/zero-init-texture-binding.ron | 2 +- tests/src/image.rs | 2 +- tests/tests/bgra8unorm_storage.rs | 2 +- tests/tests/bind_group_layout_dedup.rs | 10 +-- tests/tests/device.rs | 4 +- tests/tests/mem_leaks.rs | 4 +- tests/tests/nv12_texture/mod.rs | 4 +- tests/tests/occlusion_query/mod.rs | 2 +- tests/tests/partially_bounded_arrays/mod.rs | 2 +- tests/tests/pipeline.rs | 2 +- tests/tests/push_constants.rs | 2 +- tests/tests/regression/issue_3349.rs | 4 +- tests/tests/regression/issue_3457.rs | 8 +- tests/tests/scissor_tests/mod.rs | 4 +- tests/tests/shader/mod.rs | 2 +- tests/tests/shader/zero_init_workgroup_mem.rs | 4 +- tests/tests/shader_primitive_index/mod.rs | 4 +- tests/tests/shader_view_format/mod.rs | 4 +- tests/tests/vertex_indices/mod.rs | 6 +- wgpu-core/src/device/resource.rs | 52 +++++++++--- wgpu-core/src/pipeline.rs | 16 +++- wgpu-core/src/validation.rs | 83 ++++++++++++++++--- wgpu/src/backend/webgpu.rs | 16 ++-- wgpu/src/backend/wgpu_core.rs | 6 +- wgpu/src/lib.rs | 27 ++++-- 51 files changed, 263 insertions(+), 147 deletions(-) diff --git a/deno_webgpu/pipeline.rs b/deno_webgpu/pipeline.rs index 9175fe2075d..9306a15bb38 100644 --- a/deno_webgpu/pipeline.rs +++ b/deno_webgpu/pipeline.rs @@ -74,7 +74,7 @@ pub enum GPUPipelineLayoutOrGPUAutoLayoutMode { #[serde(rename_all = "camelCase")] pub struct GpuProgrammableStage { module: ResourceId, - entry_point: String, + entry_point: Option, // constants: HashMap } @@ -110,7 +110,7 @@ pub fn op_webgpu_create_compute_pipeline( layout: pipeline_layout, stage: wgpu_core::pipeline::ProgrammableStageDescriptor { module: compute_shader_module_resource.1, - entry_point: Cow::from(compute.entry_point), + entry_point: compute.entry_point.map(Cow::from), // TODO(lucacasonato): support args.compute.constants }, }; @@ -278,7 +278,7 @@ impl<'a> From for wgpu_core::pipeline::VertexBufferLayout #[serde(rename_all = "camelCase")] struct GpuVertexState { module: ResourceId, - entry_point: String, + entry_point: Option, buffers: Vec>, } @@ -305,7 +305,7 @@ impl From for wgpu_types::MultisampleState { struct GpuFragmentState { targets: Vec>, module: u32, - entry_point: String, + entry_point: Option, // TODO(lucacasonato): constants } @@ -355,7 +355,7 @@ pub fn op_webgpu_create_render_pipeline( Some(wgpu_core::pipeline::FragmentState { stage: wgpu_core::pipeline::ProgrammableStageDescriptor { module: fragment_shader_module_resource.1, - entry_point: Cow::from(fragment.entry_point), + entry_point: fragment.entry_point.map(Cow::from), }, targets: Cow::from(fragment.targets), }) @@ -377,7 +377,7 @@ pub fn op_webgpu_create_render_pipeline( vertex: wgpu_core::pipeline::VertexState { stage: wgpu_core::pipeline::ProgrammableStageDescriptor { module: vertex_shader_module_resource.1, - entry_point: Cow::Owned(args.vertex.entry_point), + entry_point: args.vertex.entry_point.map(Cow::from), }, buffers: Cow::Owned(vertex_buffers), }, diff --git a/examples/src/boids/mod.rs b/examples/src/boids/mod.rs index b608394134a..1c524c3b2cb 100644 --- a/examples/src/boids/mod.rs +++ b/examples/src/boids/mod.rs @@ -131,7 +131,7 @@ impl crate::framework::Example for Example { layout: Some(&render_pipeline_layout), vertex: wgpu::VertexState { module: &draw_shader, - entry_point: "main_vs", + entry_point: Some("main_vs"), buffers: &[ wgpu::VertexBufferLayout { array_stride: 4 * 4, @@ -147,7 +147,7 @@ impl crate::framework::Example for Example { }, fragment: Some(wgpu::FragmentState { module: &draw_shader, - entry_point: "main_fs", + entry_point: Some("main_fs"), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState::default(), @@ -162,7 +162,7 @@ impl crate::framework::Example for Example { label: Some("Compute pipeline"), layout: Some(&compute_pipeline_layout), module: &compute_shader, - entry_point: "main", + entry_point: Some("main"), }); // buffer for the three 2d triangle vertices of each instance diff --git a/examples/src/bunnymark/mod.rs b/examples/src/bunnymark/mod.rs index c29da351ee3..272d6ccca9c 100644 --- a/examples/src/bunnymark/mod.rs +++ b/examples/src/bunnymark/mod.rs @@ -202,12 +202,12 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(wgpu::ColorTargetState { format: config.view_formats[0], blend: Some(wgpu::BlendState::ALPHA_BLENDING), diff --git a/examples/src/conservative_raster/mod.rs b/examples/src/conservative_raster/mod.rs index ce2054caa0b..cef2ab22a79 100644 --- a/examples/src/conservative_raster/mod.rs +++ b/examples/src/conservative_raster/mod.rs @@ -96,12 +96,12 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout_empty), vertex: wgpu::VertexState { module: &shader_triangle_and_lines, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader_triangle_and_lines, - entry_point: "fs_main_red", + entry_point: Some("fs_main_red"), targets: &[Some(RENDER_TARGET_FORMAT.into())], }), primitive: wgpu::PrimitiveState { @@ -119,12 +119,12 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout_empty), vertex: wgpu::VertexState { module: &shader_triangle_and_lines, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader_triangle_and_lines, - entry_point: "fs_main_blue", + entry_point: Some("fs_main_blue"), targets: &[Some(RENDER_TARGET_FORMAT.into())], }), primitive: wgpu::PrimitiveState::default(), @@ -143,12 +143,12 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout_empty), vertex: wgpu::VertexState { module: &shader_triangle_and_lines, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader_triangle_and_lines, - entry_point: "fs_main_white", + entry_point: Some("fs_main_white"), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { @@ -204,12 +204,12 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState::default(), diff --git a/examples/src/cube/mod.rs b/examples/src/cube/mod.rs index d21aafe5de8..f2789d30d50 100644 --- a/examples/src/cube/mod.rs +++ b/examples/src/cube/mod.rs @@ -243,12 +243,12 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &vertex_buffers, }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { @@ -269,12 +269,12 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &vertex_buffers, }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_wire", + entry_point: Some("fs_wire"), targets: &[Some(wgpu::ColorTargetState { format: config.view_formats[0], blend: Some(wgpu::BlendState { diff --git a/examples/src/hello_compute/mod.rs b/examples/src/hello_compute/mod.rs index ef452bf0234..d01a94880ef 100644 --- a/examples/src/hello_compute/mod.rs +++ b/examples/src/hello_compute/mod.rs @@ -108,7 +108,7 @@ async fn execute_gpu_inner( label: None, layout: None, module: &cs_module, - entry_point: "main", + entry_point: Some("main"), }); // Instantiates the bind group, once again specifying the binding of buffers. diff --git a/examples/src/hello_synchronization/mod.rs b/examples/src/hello_synchronization/mod.rs index c2a6fe8b26c..31e28639842 100644 --- a/examples/src/hello_synchronization/mod.rs +++ b/examples/src/hello_synchronization/mod.rs @@ -102,13 +102,13 @@ async fn execute( label: None, layout: Some(&pipeline_layout), module: &shaders_module, - entry_point: "patient_main", + entry_point: Some("patient_main"), }); let hasty_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { label: None, layout: Some(&pipeline_layout), module: &shaders_module, - entry_point: "hasty_main", + entry_point: Some("hasty_main"), }); //---------------------------------------------------------- diff --git a/examples/src/hello_triangle/mod.rs b/examples/src/hello_triangle/mod.rs index faa1db8f8b2..d6910234606 100644 --- a/examples/src/hello_triangle/mod.rs +++ b/examples/src/hello_triangle/mod.rs @@ -58,12 +58,12 @@ async fn run(event_loop: EventLoop<()>, window: Window) { layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(swapchain_format.into())], }), primitive: wgpu::PrimitiveState::default(), diff --git a/examples/src/hello_workgroups/mod.rs b/examples/src/hello_workgroups/mod.rs index 3e5795048fa..d05064ee7c8 100644 --- a/examples/src/hello_workgroups/mod.rs +++ b/examples/src/hello_workgroups/mod.rs @@ -109,7 +109,7 @@ async fn run() { label: None, layout: Some(&pipeline_layout), module: &shader, - entry_point: "main", + entry_point: Some("main"), }); //---------------------------------------------------------- diff --git a/examples/src/mipmap/mod.rs b/examples/src/mipmap/mod.rs index a4600753e95..c1b7cd69307 100644 --- a/examples/src/mipmap/mod.rs +++ b/examples/src/mipmap/mod.rs @@ -92,12 +92,12 @@ impl Example { layout: None, vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(TEXTURE_FORMAT.into())], }), primitive: wgpu::PrimitiveState { @@ -289,12 +289,12 @@ impl crate::framework::Example for Example { layout: None, vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { diff --git a/examples/src/msaa_line/mod.rs b/examples/src/msaa_line/mod.rs index 595bcbf17a3..9ceea43e18d 100644 --- a/examples/src/msaa_line/mod.rs +++ b/examples/src/msaa_line/mod.rs @@ -53,7 +53,7 @@ impl Example { layout: Some(pipeline_layout), vertex: wgpu::VertexState { module: shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[wgpu::VertexBufferLayout { array_stride: std::mem::size_of::() as wgpu::BufferAddress, step_mode: wgpu::VertexStepMode::Vertex, @@ -62,7 +62,7 @@ impl Example { }, fragment: Some(wgpu::FragmentState { module: shader, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { diff --git a/examples/src/render_to_texture/mod.rs b/examples/src/render_to_texture/mod.rs index 75d22dfe87f..60887e0380d 100644 --- a/examples/src/render_to_texture/mod.rs +++ b/examples/src/render_to_texture/mod.rs @@ -58,12 +58,12 @@ async fn run(_path: Option) { layout: None, vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(wgpu::TextureFormat::Rgba8UnormSrgb.into())], }), primitive: wgpu::PrimitiveState::default(), diff --git a/examples/src/repeated_compute/mod.rs b/examples/src/repeated_compute/mod.rs index faed2467bda..4ac867f2f85 100644 --- a/examples/src/repeated_compute/mod.rs +++ b/examples/src/repeated_compute/mod.rs @@ -244,7 +244,7 @@ impl WgpuContext { label: None, layout: Some(&pipeline_layout), module: &shader, - entry_point: "main", + entry_point: Some("main"), }); WgpuContext { diff --git a/examples/src/shadow/mod.rs b/examples/src/shadow/mod.rs index 485d0d78d60..ed060bf06fc 100644 --- a/examples/src/shadow/mod.rs +++ b/examples/src/shadow/mod.rs @@ -499,7 +499,7 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_bake", + entry_point: Some("vs_bake"), buffers: &[vb_desc.clone()], }, fragment: None, @@ -631,16 +631,16 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[vb_desc], }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: if supports_storage_resources { + entry_point: Some(if supports_storage_resources { "fs_main" } else { "fs_main_without_storage" - }, + }), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { diff --git a/examples/src/skybox/mod.rs b/examples/src/skybox/mod.rs index bdb5e661420..1b3fecf6ce6 100644 --- a/examples/src/skybox/mod.rs +++ b/examples/src/skybox/mod.rs @@ -198,12 +198,12 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_sky", + entry_point: Some("vs_sky"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_sky", + entry_point: Some("fs_sky"), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { @@ -225,7 +225,7 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_entity", + entry_point: Some("vs_entity"), buffers: &[wgpu::VertexBufferLayout { array_stride: std::mem::size_of::() as wgpu::BufferAddress, step_mode: wgpu::VertexStepMode::Vertex, @@ -234,7 +234,7 @@ impl crate::framework::Example for Example { }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_entity", + entry_point: Some("fs_entity"), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { diff --git a/examples/src/srgb_blend/mod.rs b/examples/src/srgb_blend/mod.rs index d4021e6c5f9..94087385b8a 100644 --- a/examples/src/srgb_blend/mod.rs +++ b/examples/src/srgb_blend/mod.rs @@ -130,12 +130,12 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &vertex_buffers, }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(wgpu::ColorTargetState { format: config.view_formats[0], blend: Some(wgpu::BlendState::ALPHA_BLENDING), diff --git a/examples/src/stencil_triangles/mod.rs b/examples/src/stencil_triangles/mod.rs index bf645d3a34c..6c5f01b0120 100644 --- a/examples/src/stencil_triangles/mod.rs +++ b/examples/src/stencil_triangles/mod.rs @@ -73,12 +73,12 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &vertex_buffers, }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(wgpu::ColorTargetState { format: config.view_formats[0], blend: None, @@ -111,12 +111,12 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &vertex_buffers, }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(config.view_formats[0].into())], }), primitive: Default::default(), diff --git a/examples/src/storage_texture/mod.rs b/examples/src/storage_texture/mod.rs index d389ce139e1..c265bfb560d 100644 --- a/examples/src/storage_texture/mod.rs +++ b/examples/src/storage_texture/mod.rs @@ -99,7 +99,7 @@ async fn run(_path: Option) { label: None, layout: Some(&pipeline_layout), module: &shader, - entry_point: "main", + entry_point: Some("main"), }); log::info!("Wgpu context set up."); diff --git a/examples/src/texture_arrays/mod.rs b/examples/src/texture_arrays/mod.rs index ccad7599931..cdf0d107af7 100644 --- a/examples/src/texture_arrays/mod.rs +++ b/examples/src/texture_arrays/mod.rs @@ -320,7 +320,7 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module: &base_shader_module, - entry_point: "vert_main", + entry_point: Some("vert_main"), buffers: &[wgpu::VertexBufferLayout { array_stride: vertex_size as wgpu::BufferAddress, step_mode: wgpu::VertexStepMode::Vertex, @@ -329,7 +329,7 @@ impl crate::framework::Example for Example { }, fragment: Some(wgpu::FragmentState { module: fragment_shader_module, - entry_point: fragment_entry_point, + entry_point: Some(fragment_entry_point), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { diff --git a/examples/src/timestamp_queries/mod.rs b/examples/src/timestamp_queries/mod.rs index beccac73b2b..137ca30692b 100644 --- a/examples/src/timestamp_queries/mod.rs +++ b/examples/src/timestamp_queries/mod.rs @@ -287,7 +287,7 @@ fn compute_pass( label: None, layout: None, module, - entry_point: "main_cs", + entry_point: Some("main_cs"), }); let bind_group_layout = compute_pipeline.get_bind_group_layout(0); let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { @@ -341,12 +341,12 @@ fn render_pass( layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(format.into())], }), primitive: wgpu::PrimitiveState::default(), diff --git a/examples/src/uniform_values/mod.rs b/examples/src/uniform_values/mod.rs index 4a31ddc0693..001bda357cd 100644 --- a/examples/src/uniform_values/mod.rs +++ b/examples/src/uniform_values/mod.rs @@ -178,12 +178,12 @@ impl WgpuContext { layout: Some(&pipeline_layout), vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(swapchain_format.into())], }), primitive: wgpu::PrimitiveState::default(), diff --git a/examples/src/water/mod.rs b/examples/src/water/mod.rs index e1a80b1e4f3..bb07b99eae2 100644 --- a/examples/src/water/mod.rs +++ b/examples/src/water/mod.rs @@ -511,7 +511,7 @@ impl crate::framework::Example for Example { // Vertex shader and input buffers vertex: wgpu::VertexState { module: &water_module, - entry_point: "vs_main", + entry_point: Some("vs_main"), // Layout of our vertices. This should match the structs // which are uploaded to the GPU. This should also be // ensured by tagging on either a `#[repr(C)]` onto a @@ -526,7 +526,7 @@ impl crate::framework::Example for Example { // Fragment shader and output targets fragment: Some(wgpu::FragmentState { module: &water_module, - entry_point: "fs_main", + entry_point: Some("fs_main"), // Describes how the colour will be interpolated // and assigned to the output attachment. targets: &[Some(wgpu::ColorTargetState { @@ -580,7 +580,7 @@ impl crate::framework::Example for Example { layout: Some(&terrain_pipeline_layout), vertex: wgpu::VertexState { module: &terrain_module, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[wgpu::VertexBufferLayout { array_stride: terrain_vertex_size as wgpu::BufferAddress, step_mode: wgpu::VertexStepMode::Vertex, @@ -589,7 +589,7 @@ impl crate::framework::Example for Example { }, fragment: Some(wgpu::FragmentState { module: &terrain_module, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { diff --git a/player/tests/data/bind-group.ron b/player/tests/data/bind-group.ron index eacc36eb662..7a45b6dbc64 100644 --- a/player/tests/data/bind-group.ron +++ b/player/tests/data/bind-group.ron @@ -56,7 +56,7 @@ layout: Some(Id(0, 1, Empty)), stage: ( module: Id(0, 1, Empty), - entry_point: "main", + entry_point: Some("main"), ), ), ), diff --git a/player/tests/data/pipeline-statistics-query.ron b/player/tests/data/pipeline-statistics-query.ron index 2565ee73767..9206a51cb90 100644 --- a/player/tests/data/pipeline-statistics-query.ron +++ b/player/tests/data/pipeline-statistics-query.ron @@ -29,7 +29,7 @@ layout: Some(Id(0, 1, Empty)), stage: ( module: Id(0, 1, Empty), - entry_point: "main", + entry_point: Some("main"), ), ), ), diff --git a/player/tests/data/quad.ron b/player/tests/data/quad.ron index 68bf8ee97ef..db8e6b747ad 100644 --- a/player/tests/data/quad.ron +++ b/player/tests/data/quad.ron @@ -57,14 +57,14 @@ vertex: ( stage: ( module: Id(0, 1, Empty), - entry_point: "vs_main", + entry_point: Some("vs_main"), ), buffers: [], ), fragment: Some(( stage: ( module: Id(0, 1, Empty), - entry_point: "fs_main", + entry_point: Some("fs_main"), ), targets: [ Some(( diff --git a/player/tests/data/zero-init-buffer.ron b/player/tests/data/zero-init-buffer.ron index 73692d10ee7..ab02ec21d4e 100644 --- a/player/tests/data/zero-init-buffer.ron +++ b/player/tests/data/zero-init-buffer.ron @@ -133,7 +133,7 @@ layout: Some(Id(0, 1, Empty)), stage: ( module: Id(0, 1, Empty), - entry_point: "main", + entry_point: Some("main"), ), ), ), diff --git a/player/tests/data/zero-init-texture-binding.ron b/player/tests/data/zero-init-texture-binding.ron index 7dcfa4e2e63..986b3b079e8 100644 --- a/player/tests/data/zero-init-texture-binding.ron +++ b/player/tests/data/zero-init-texture-binding.ron @@ -134,7 +134,7 @@ layout: Some(Id(0, 1, Empty)), stage: ( module: Id(0, 1, Empty), - entry_point: "main", + entry_point: Some("main"), ), ), ), diff --git a/tests/src/image.rs b/tests/src/image.rs index e1b9b072011..83b11b45297 100644 --- a/tests/src/image.rs +++ b/tests/src/image.rs @@ -368,7 +368,7 @@ fn copy_via_compute( label: Some("pipeline read"), layout: Some(&pll), module: &sm, - entry_point: "copy_texture_to_buffer", + entry_point: Some("copy_texture_to_buffer"), }); { diff --git a/tests/tests/bgra8unorm_storage.rs b/tests/tests/bgra8unorm_storage.rs index b1ca3b83624..54fd489598c 100644 --- a/tests/tests/bgra8unorm_storage.rs +++ b/tests/tests/bgra8unorm_storage.rs @@ -95,7 +95,7 @@ static BGRA8_UNORM_STORAGE: GpuTestConfiguration = GpuTestConfiguration::new() let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { label: None, layout: Some(&pl), - entry_point: "main", + entry_point: Some("main"), module: &module, }); diff --git a/tests/tests/bind_group_layout_dedup.rs b/tests/tests/bind_group_layout_dedup.rs index 8da284b41ba..a3f8f43b398 100644 --- a/tests/tests/bind_group_layout_dedup.rs +++ b/tests/tests/bind_group_layout_dedup.rs @@ -92,7 +92,7 @@ async fn bgl_dedupe(ctx: TestingContext) { label: None, layout: Some(&pipeline_layout), module: &module, - entry_point: "no_resources", + entry_point: Some("no_resources"), }; let pipeline = ctx.device.create_compute_pipeline(&desc); @@ -217,7 +217,7 @@ fn bgl_dedupe_with_dropped_user_handle(ctx: TestingContext) { label: None, layout: Some(&pipeline_layout), module: &module, - entry_point: "no_resources", + entry_point: Some("no_resources"), }); let mut encoder = ctx.device.create_command_encoder(&Default::default()); @@ -262,7 +262,7 @@ fn bgl_dedupe_derived(ctx: TestingContext) { label: None, layout: None, module: &module, - entry_point: "resources", + entry_point: Some("resources"), }); // We create two bind groups, pulling the bind_group_layout from the pipeline each time. @@ -340,7 +340,7 @@ fn separate_programs_have_incompatible_derived_bgls(ctx: TestingContext) { label: None, layout: None, module: &module, - entry_point: "resources", + entry_point: Some("resources"), }; // Create two pipelines, creating a BG from the second. let pipeline1 = ctx.device.create_compute_pipeline(&desc); @@ -407,7 +407,7 @@ fn derived_bgls_incompatible_with_regular_bgls(ctx: TestingContext) { label: None, layout: None, module: &module, - entry_point: "resources", + entry_point: Some("resources"), }); // Create a matching BGL diff --git a/tests/tests/device.rs b/tests/tests/device.rs index 5d3a1022342..cd99af29575 100644 --- a/tests/tests/device.rs +++ b/tests/tests/device.rs @@ -447,7 +447,7 @@ static DEVICE_DESTROY_THEN_MORE: GpuTestConfiguration = GpuTestConfiguration::ne layout: None, vertex: wgpu::VertexState { module: &shader_module, - entry_point: "", + entry_point: Some(""), buffers: &[], }, primitive: wgpu::PrimitiveState::default(), @@ -465,7 +465,7 @@ static DEVICE_DESTROY_THEN_MORE: GpuTestConfiguration = GpuTestConfiguration::ne label: None, layout: None, module: &shader_module, - entry_point: "", + entry_point: Some(""), }); }); diff --git a/tests/tests/mem_leaks.rs b/tests/tests/mem_leaks.rs index 91a329d20cc..8e0ad8efdfd 100644 --- a/tests/tests/mem_leaks.rs +++ b/tests/tests/mem_leaks.rs @@ -95,14 +95,14 @@ async fn draw_test_with_reports( layout: Some(&ppl), vertex: wgpu::VertexState { buffers: &[], - entry_point: "vs_main_builtin", + entry_point: Some("vs_main_builtin"), module: &shader, }, primitive: wgpu::PrimitiveState::default(), depth_stencil: None, multisample: wgpu::MultisampleState::default(), fragment: Some(wgpu::FragmentState { - entry_point: "fs_main", + entry_point: Some("fs_main"), module: &shader, targets: &[Some(wgpu::ColorTargetState { format: wgpu::TextureFormat::Rgba8Unorm, diff --git a/tests/tests/nv12_texture/mod.rs b/tests/tests/nv12_texture/mod.rs index aba12e02b65..69a73029049 100644 --- a/tests/tests/nv12_texture/mod.rs +++ b/tests/tests/nv12_texture/mod.rs @@ -23,12 +23,12 @@ static NV12_TEXTURE_CREATION_SAMPLING: GpuTestConfiguration = GpuTestConfigurati layout: None, vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(target_format.into())], }), primitive: wgpu::PrimitiveState { diff --git a/tests/tests/occlusion_query/mod.rs b/tests/tests/occlusion_query/mod.rs index 0c3f3072a54..e0c644aa258 100644 --- a/tests/tests/occlusion_query/mod.rs +++ b/tests/tests/occlusion_query/mod.rs @@ -36,7 +36,7 @@ static OCCLUSION_QUERY: GpuTestConfiguration = GpuTestConfiguration::new() layout: None, vertex: wgpu::VertexState { module: &shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: None, diff --git a/tests/tests/partially_bounded_arrays/mod.rs b/tests/tests/partially_bounded_arrays/mod.rs index a1350718f56..9032ad27a29 100644 --- a/tests/tests/partially_bounded_arrays/mod.rs +++ b/tests/tests/partially_bounded_arrays/mod.rs @@ -68,7 +68,7 @@ static PARTIALLY_BOUNDED_ARRAY: GpuTestConfiguration = GpuTestConfiguration::new label: None, layout: Some(&pipeline_layout), module: &cs_module, - entry_point: "main", + entry_point: Some("main"), }); let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { diff --git a/tests/tests/pipeline.rs b/tests/tests/pipeline.rs index 2350f106639..0f13ed854a9 100644 --- a/tests/tests/pipeline.rs +++ b/tests/tests/pipeline.rs @@ -27,7 +27,7 @@ static PIPELINE_DEFAULT_LAYOUT_BAD_MODULE: GpuTestConfiguration = GpuTestConfigu label: Some("mandelbrot compute pipeline"), layout: None, module: &module, - entry_point: "doesn't exist", + entry_point: Some("doesn't exist"), }); pipeline.get_bind_group_layout(0); diff --git a/tests/tests/push_constants.rs b/tests/tests/push_constants.rs index 0e16b3df650..7fd8faab90d 100644 --- a/tests/tests/push_constants.rs +++ b/tests/tests/push_constants.rs @@ -102,7 +102,7 @@ async fn partial_update_test(ctx: TestingContext) { label: Some("pipeline"), layout: Some(&pipeline_layout), module: &sm, - entry_point: "main", + entry_point: Some("main"), }); let mut encoder = ctx diff --git a/tests/tests/regression/issue_3349.rs b/tests/tests/regression/issue_3349.rs index 2d94d569208..1e8bcdd7b64 100644 --- a/tests/tests/regression/issue_3349.rs +++ b/tests/tests/regression/issue_3349.rs @@ -101,12 +101,12 @@ async fn multi_stage_data_binding_test(ctx: TestingContext) { layout: Some(&pll), vertex: wgpu::VertexState { module: &vs_sm, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &fs_sm, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(wgpu::ColorTargetState { format: wgpu::TextureFormat::Rgba8Unorm, blend: None, diff --git a/tests/tests/regression/issue_3457.rs b/tests/tests/regression/issue_3457.rs index 12ace62e88f..aae4972e2be 100644 --- a/tests/tests/regression/issue_3457.rs +++ b/tests/tests/regression/issue_3457.rs @@ -51,7 +51,7 @@ static PASS_RESET_VERTEX_BUFFER: GpuTestConfiguration = layout: Some(&pipeline_layout), vertex: VertexState { module: &module, - entry_point: "double_buffer_vert", + entry_point: Some("double_buffer_vert"), buffers: &[ VertexBufferLayout { array_stride: 16, @@ -70,7 +70,7 @@ static PASS_RESET_VERTEX_BUFFER: GpuTestConfiguration = multisample: MultisampleState::default(), fragment: Some(FragmentState { module: &module, - entry_point: "double_buffer_frag", + entry_point: Some("double_buffer_frag"), targets: &[Some(ColorTargetState { format: TextureFormat::Rgba8Unorm, blend: None, @@ -87,7 +87,7 @@ static PASS_RESET_VERTEX_BUFFER: GpuTestConfiguration = layout: Some(&pipeline_layout), vertex: VertexState { module: &module, - entry_point: "single_buffer_vert", + entry_point: Some("single_buffer_vert"), buffers: &[VertexBufferLayout { array_stride: 16, step_mode: VertexStepMode::Vertex, @@ -99,7 +99,7 @@ static PASS_RESET_VERTEX_BUFFER: GpuTestConfiguration = multisample: MultisampleState::default(), fragment: Some(FragmentState { module: &module, - entry_point: "single_buffer_frag", + entry_point: Some("single_buffer_frag"), targets: &[Some(ColorTargetState { format: TextureFormat::Rgba8Unorm, blend: None, diff --git a/tests/tests/scissor_tests/mod.rs b/tests/tests/scissor_tests/mod.rs index 11b72ba7a48..7b102961317 100644 --- a/tests/tests/scissor_tests/mod.rs +++ b/tests/tests/scissor_tests/mod.rs @@ -42,7 +42,7 @@ async fn scissor_test_impl( label: Some("Pipeline"), layout: None, vertex: wgpu::VertexState { - entry_point: "vs_main", + entry_point: Some("vs_main"), module: &shader, buffers: &[], }, @@ -50,7 +50,7 @@ async fn scissor_test_impl( depth_stencil: None, multisample: wgpu::MultisampleState::default(), fragment: Some(wgpu::FragmentState { - entry_point: "fs_main", + entry_point: Some("fs_main"), module: &shader, targets: &[Some(wgpu::ColorTargetState { format: wgpu::TextureFormat::Rgba8Unorm, diff --git a/tests/tests/shader/mod.rs b/tests/tests/shader/mod.rs index 1a981971f75..55a9832b0e1 100644 --- a/tests/tests/shader/mod.rs +++ b/tests/tests/shader/mod.rs @@ -306,7 +306,7 @@ async fn shader_input_output_test( label: Some(&format!("pipeline {test_name}")), layout: Some(&pll), module: &sm, - entry_point: "cs_main", + entry_point: Some("cs_main"), }); // -- Initializing data -- diff --git a/tests/tests/shader/zero_init_workgroup_mem.rs b/tests/tests/shader/zero_init_workgroup_mem.rs index 6774f1aac1f..2eab2b4a39a 100644 --- a/tests/tests/shader/zero_init_workgroup_mem.rs +++ b/tests/tests/shader/zero_init_workgroup_mem.rs @@ -86,7 +86,7 @@ static ZERO_INIT_WORKGROUP_MEMORY: GpuTestConfiguration = GpuTestConfiguration:: label: Some("pipeline read"), layout: Some(&pll), module: &sm, - entry_point: "read", + entry_point: Some("read"), }); let pipeline_write = ctx @@ -95,7 +95,7 @@ static ZERO_INIT_WORKGROUP_MEMORY: GpuTestConfiguration = GpuTestConfiguration:: label: Some("pipeline write"), layout: None, module: &sm, - entry_point: "write", + entry_point: Some("write"), }); // -- Initializing data -- diff --git a/tests/tests/shader_primitive_index/mod.rs b/tests/tests/shader_primitive_index/mod.rs index 096df9c0f71..cab8bb52d91 100644 --- a/tests/tests/shader_primitive_index/mod.rs +++ b/tests/tests/shader_primitive_index/mod.rs @@ -129,14 +129,14 @@ async fn pulling_common( shader_location: 0, }], }], - entry_point: "vs_main", + entry_point: Some("vs_main"), module: &shader, }, primitive: wgpu::PrimitiveState::default(), depth_stencil: None, multisample: wgpu::MultisampleState::default(), fragment: Some(wgpu::FragmentState { - entry_point: "fs_main", + entry_point: Some("fs_main"), module: &shader, targets: &[Some(wgpu::ColorTargetState { format: wgpu::TextureFormat::Rgba8Unorm, diff --git a/tests/tests/shader_view_format/mod.rs b/tests/tests/shader_view_format/mod.rs index 842388763ba..6861b2800e1 100644 --- a/tests/tests/shader_view_format/mod.rs +++ b/tests/tests/shader_view_format/mod.rs @@ -92,12 +92,12 @@ async fn reinterpret( layout: None, vertex: wgpu::VertexState { module: shader, - entry_point: "vs_main", + entry_point: Some("vs_main"), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: shader, - entry_point: "fs_main", + entry_point: Some("fs_main"), targets: &[Some(src_format.into())], }), primitive: wgpu::PrimitiveState { diff --git a/tests/tests/vertex_indices/mod.rs b/tests/tests/vertex_indices/mod.rs index 745eeff8c39..85a5f562d31 100644 --- a/tests/tests/vertex_indices/mod.rs +++ b/tests/tests/vertex_indices/mod.rs @@ -265,14 +265,14 @@ async fn vertex_index_common(ctx: TestingContext) { layout: Some(&ppl), vertex: wgpu::VertexState { buffers: &[], - entry_point: "vs_main_builtin", + entry_point: Some("vs_main_builtin"), module: &shader, }, primitive: wgpu::PrimitiveState::default(), depth_stencil: None, multisample: wgpu::MultisampleState::default(), fragment: Some(wgpu::FragmentState { - entry_point: "fs_main", + entry_point: Some("fs_main"), module: &shader, targets: &[Some(wgpu::ColorTargetState { format: wgpu::TextureFormat::Rgba8Unorm, @@ -283,7 +283,7 @@ async fn vertex_index_common(ctx: TestingContext) { multiview: None, }; let builtin_pipeline = ctx.device.create_render_pipeline(&pipeline_desc); - pipeline_desc.vertex.entry_point = "vs_main_buffers"; + pipeline_desc.vertex.entry_point = Some("vs_main_buffers"); pipeline_desc.vertex.buffers = &[ wgpu::VertexBufferLayout { array_stride: 4, diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index b2c85a056a5..bd525a7db4a 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -2656,14 +2656,21 @@ impl Device { let mut shader_binding_sizes = FastHashMap::default(); let io = validation::StageIo::default(); + let final_entry_point_name; + { let stage = wgt::ShaderStages::COMPUTE; + final_entry_point_name = shader_module.finalize_entry_point_name( + stage, + desc.stage.entry_point.as_ref().map(|ep| ep.as_ref()), + )?; + if let Some(ref interface) = shader_module.interface { let _ = interface.check_stage( &mut binding_layout_source, &mut shader_binding_sizes, - &desc.stage.entry_point, + &final_entry_point_name, stage, io, None, @@ -2691,7 +2698,7 @@ impl Device { label: desc.label.to_hal(self.instance_flags), layout: pipeline_layout.raw(), stage: hal::ProgrammableStage { - entry_point: desc.stage.entry_point.as_ref(), + entry_point: final_entry_point_name.as_ref(), module: shader_module.raw(), }, }; @@ -3051,6 +3058,7 @@ impl Device { }; let vertex_shader_module; + let vertex_entry_point_name; let vertex_stage = { let stage_desc = &desc.vertex.stage; let stage = wgt::ShaderStages::VERTEX; @@ -3065,27 +3073,37 @@ impl Device { return Err(DeviceError::WrongDevice.into()); } + let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error }; + + vertex_entry_point_name = vertex_shader_module + .finalize_entry_point_name( + stage, + stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()), + ) + .map_err(stage_err)?; + if let Some(ref interface) = vertex_shader_module.interface { io = interface .check_stage( &mut binding_layout_source, &mut shader_binding_sizes, - &stage_desc.entry_point, + &vertex_entry_point_name, stage, io, desc.depth_stencil.as_ref().map(|d| d.depth_compare), ) - .map_err(|error| pipeline::CreateRenderPipelineError::Stage { stage, error })?; + .map_err(stage_err)?; validated_stages |= stage; } hal::ProgrammableStage { module: vertex_shader_module.raw(), - entry_point: stage_desc.entry_point.as_ref(), + entry_point: &vertex_entry_point_name, } }; let mut fragment_shader_module = None; + let fragment_entry_point_name; let fragment_stage = match desc.fragment { Some(ref fragment_state) => { let stage = wgt::ShaderStages::FRAGMENT; @@ -3099,28 +3117,38 @@ impl Device { })?, ); + let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error }; + + fragment_entry_point_name = shader_module + .finalize_entry_point_name( + stage, + fragment_state + .stage + .entry_point + .as_ref() + .map(|ep| ep.as_ref()), + ) + .map_err(stage_err)?; + if validated_stages == wgt::ShaderStages::VERTEX { if let Some(ref interface) = shader_module.interface { io = interface .check_stage( &mut binding_layout_source, &mut shader_binding_sizes, - &fragment_state.stage.entry_point, + &fragment_entry_point_name, stage, io, desc.depth_stencil.as_ref().map(|d| d.depth_compare), ) - .map_err(|error| pipeline::CreateRenderPipelineError::Stage { - stage, - error, - })?; + .map_err(stage_err)?; validated_stages |= stage; } } if let Some(ref interface) = shader_module.interface { shader_expects_dual_source_blending = interface - .fragment_uses_dual_source_blending(&fragment_state.stage.entry_point) + .fragment_uses_dual_source_blending(&fragment_entry_point_name) .map_err(|error| pipeline::CreateRenderPipelineError::Stage { stage, error, @@ -3129,7 +3157,7 @@ impl Device { Some(hal::ProgrammableStage { module: shader_module.raw(), - entry_point: fragment_state.stage.entry_point.as_ref(), + entry_point: &fragment_entry_point_name, }) } None => None, diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index acc1b24b0c1..1a4b2d8685d 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -92,6 +92,20 @@ impl ShaderModule { pub(crate) fn raw(&self) -> &A::ShaderModule { self.raw.as_ref().unwrap() } + + pub(crate) fn finalize_entry_point_name( + &self, + stage_bit: wgt::ShaderStages, + entry_point: Option<&str>, + ) -> Result { + match &self.interface { + Some(interface) => interface.finalize_entry_point_name(stage_bit, entry_point), + None => entry_point + .map(|ep| ep.to_string()) + .map(validation::FinalizedEntryPointName::new) + .ok_or(validation::StageError::NoEntryPointFound), + } + } } #[derive(Clone, Debug)] @@ -215,7 +229,7 @@ pub struct ProgrammableStageDescriptor<'a> { pub module: ShaderModuleId, /// The name of the entry point in the compiled shader. There must be a function with this name /// in the shader. - pub entry_point: Cow<'a, str>, + pub entry_point: Option>, } /// Number of implicit bind groups derived at pipeline creation. diff --git a/wgpu-core/src/validation.rs b/wgpu-core/src/validation.rs index a0947ae83ff..c80553f6b9f 100644 --- a/wgpu-core/src/validation.rs +++ b/wgpu-core/src/validation.rs @@ -1,6 +1,6 @@ use crate::{device::bgl, FastHashMap, FastHashSet}; use arrayvec::ArrayVec; -use std::{collections::hash_map::Entry, fmt}; +use std::{collections::hash_map::Entry, fmt, ops::Deref}; use thiserror::Error; use wgt::{BindGroupLayoutEntry, BindingType}; @@ -271,6 +271,16 @@ pub enum StageError { }, #[error("Location[{location}] is provided by the previous stage output but is not consumed as input by this stage.")] InputNotConsumed { location: wgt::ShaderLocation }, + #[error( + "Unable to select an entry point: no entry point was found in the provided shader module" + )] + NoEntryPointFound, + #[error( + "Unable to select an entry point: \ + multiple entry points were found in the provided shader module, \ + but no entry point was specified" + )] + MultipleEntryPointsFound, } fn map_storage_format_to_naga(format: wgt::TextureFormat) -> Option { @@ -953,28 +963,57 @@ impl Interface { } } + pub fn finalize_entry_point_name( + &self, + stage_bit: wgt::ShaderStages, + entry_point_name: Option<&str>, + ) -> Result { + let stage = Self::shader_stage_from_stage_bit(stage_bit); + Ok(FinalizedEntryPointName::new( + entry_point_name + .map(|ep| ep.to_string()) + .map(Ok) + .unwrap_or_else(|| { + let mut entry_points = self + .entry_points + .keys() + .filter_map(|(ep_stage, name)| (ep_stage == &stage).then_some(name)); + let first = entry_points.next().ok_or(StageError::NoEntryPointFound)?; + entry_points + .next() + .ok_or(StageError::MultipleEntryPointsFound)?; + Ok(first.clone()) + })?, + )) + } + + pub(crate) fn shader_stage_from_stage_bit(stage_bit: wgt::ShaderStages) -> naga::ShaderStage { + match stage_bit { + wgt::ShaderStages::VERTEX => naga::ShaderStage::Vertex, + wgt::ShaderStages::FRAGMENT => naga::ShaderStage::Fragment, + wgt::ShaderStages::COMPUTE => naga::ShaderStage::Compute, + _ => unreachable!(), + } + } + pub fn check_stage( &self, layouts: &mut BindingLayoutSource<'_>, shader_binding_sizes: &mut FastHashMap, - entry_point_name: &str, + entry_point_name: &FinalizedEntryPointName, stage_bit: wgt::ShaderStages, inputs: StageIo, compare_function: Option, ) -> Result { // Since a shader module can have multiple entry points with the same name, // we need to look for one with the right execution model. - let shader_stage = match stage_bit { - wgt::ShaderStages::VERTEX => naga::ShaderStage::Vertex, - wgt::ShaderStages::FRAGMENT => naga::ShaderStage::Fragment, - wgt::ShaderStages::COMPUTE => naga::ShaderStage::Compute, - _ => unreachable!(), - }; + let shader_stage = Self::shader_stage_from_stage_bit(stage_bit); let pair = (shader_stage, entry_point_name.to_string()); - let entry_point = self - .entry_points - .get(&pair) - .ok_or(StageError::MissingEntryPoint(pair.1))?; + let entry_point = match self.entry_points.get(&pair) { + Some(some) => some, + None => return Err(StageError::MissingEntryPoint(pair.1)), + }; + let (_stage, entry_point_name) = pair; // check resources visibility for &handle in entry_point.resources.iter() { @@ -1246,3 +1285,23 @@ impl Interface { .map(|ep| ep.dual_source_blending) } } + +/// An entry point name finalized with [`Interface::finalize_entry_point_name`]. +#[derive(Debug)] +pub struct FinalizedEntryPointName { + inner: String, +} + +impl Deref for FinalizedEntryPointName { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.inner.as_ref() + } +} + +impl FinalizedEntryPointName { + pub(crate) fn new(inner: String) -> Self { + Self { inner } + } +} diff --git a/wgpu/src/backend/webgpu.rs b/wgpu/src/backend/webgpu.rs index 5eb3638c382..91d97830c42 100644 --- a/wgpu/src/backend/webgpu.rs +++ b/wgpu/src/backend/webgpu.rs @@ -1688,7 +1688,9 @@ impl crate::context::Context for ContextWebGpu { let module: &::ShaderModuleData = downcast_ref(desc.vertex.module.data.as_ref()); let mut mapped_vertex_state = web_sys::GpuVertexState::new(&module.0); - mapped_vertex_state.entry_point(desc.vertex.entry_point); + if let Some(ep) = desc.vertex.entry_point { + mapped_vertex_state.entry_point(ep); + } let buffers = desc .vertex @@ -1762,8 +1764,10 @@ impl crate::context::Context for ContextWebGpu { .collect::(); let module: &::ShaderModuleData = downcast_ref(frag.module.data.as_ref()); - let mapped_fragment_desc = web_sys::GpuFragmentState::new(&module.0, &targets); - mapped_fragment_desc.entry_point(frag.entry_point); + let mut mapped_fragment_desc = web_sys::GpuFragmentState::new(&module.0, &targets); + if let Some(ep) = frag.entry_point { + mapped_fragment_desc.entry_point(ep); + } mapped_desc.fragment(&mapped_fragment_desc); } @@ -1787,8 +1791,10 @@ impl crate::context::Context for ContextWebGpu { ) -> (Self::ComputePipelineId, Self::ComputePipelineData) { let shader_module: &::ShaderModuleData = downcast_ref(desc.module.data.as_ref()); - let mapped_compute_stage = web_sys::GpuProgrammableStage::new(&shader_module.0); - mapped_compute_stage.entry_point(desc.entry_point); + let mut mapped_compute_stage = web_sys::GpuProgrammableStage::new(&shader_module.0); + if let Some(ep) = desc.entry_point { + mapped_compute_stage.entry_point(ep); + } let auto_layout = wasm_bindgen::JsValue::from(web_sys::GpuAutoLayoutMode::Auto); let mut mapped_desc = web_sys::GpuComputePipelineDescriptor::new( &match desc.layout { diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index b5d230f3e3d..b2ff94db808 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -1102,7 +1102,7 @@ impl crate::Context for ContextWgpuCore { vertex: pipe::VertexState { stage: pipe::ProgrammableStageDescriptor { module: desc.vertex.module.id.into(), - entry_point: Borrowed(desc.vertex.entry_point), + entry_point: desc.vertex.entry_point.map(Borrowed), }, buffers: Borrowed(&vertex_buffers), }, @@ -1112,7 +1112,7 @@ impl crate::Context for ContextWgpuCore { fragment: desc.fragment.as_ref().map(|frag| pipe::FragmentState { stage: pipe::ProgrammableStageDescriptor { module: frag.module.id.into(), - entry_point: Borrowed(frag.entry_point), + entry_point: frag.entry_point.map(Borrowed), }, targets: Borrowed(frag.targets), }), @@ -1160,7 +1160,7 @@ impl crate::Context for ContextWgpuCore { layout: desc.layout.map(|l| l.id.into()), stage: pipe::ProgrammableStageDescriptor { module: desc.module.id.into(), - entry_point: Borrowed(desc.entry_point), + entry_point: desc.entry_point.map(Borrowed), }, }; diff --git a/wgpu/src/lib.rs b/wgpu/src/lib.rs index 03c65f7a829..75950ecd28e 100644 --- a/wgpu/src/lib.rs +++ b/wgpu/src/lib.rs @@ -1512,9 +1512,12 @@ static_assertions::assert_impl_all!(VertexBufferLayout<'_>: Send, Sync); pub struct VertexState<'a> { /// The compiled shader module for this stage. pub module: &'a ShaderModule, - /// The name of the entry point in the compiled shader. There must be a function with this name - /// in the shader. - pub entry_point: &'a str, + /// The name of the entry point in the compiled shader to use. + /// + /// If [`Some`], there must be a function with this name and no return value in the shader. + /// Otherwise, there must be exactly one `@compute` entry point in the provided `module`, which + /// will be used for the pipeline. + pub entry_point: Option<&'a str>, /// The format of any vertex buffers used with this pipeline. pub buffers: &'a [VertexBufferLayout<'a>], } @@ -1531,9 +1534,12 @@ static_assertions::assert_impl_all!(VertexState<'_>: Send, Sync); pub struct FragmentState<'a> { /// The compiled shader module for this stage. pub module: &'a ShaderModule, - /// The name of the entry point in the compiled shader. There must be a function with this name - /// in the shader. - pub entry_point: &'a str, + /// The name of the entry point in the compiled shader to use. + /// + /// If [`Some`], there must be a function with this name and no return value in the shader. + /// Otherwise, there must be exactly one `@compute` entry point in the provided `module`, which + /// will be used for the pipeline. + pub entry_point: Option<&'a str>, /// The color state of the render targets. pub targets: &'a [Option], } @@ -1620,9 +1626,12 @@ pub struct ComputePipelineDescriptor<'a> { pub layout: Option<&'a PipelineLayout>, /// The compiled shader module for this stage. pub module: &'a ShaderModule, - /// The name of the entry point in the compiled shader. There must be a function with this name - /// and no return value in the shader. - pub entry_point: &'a str, + /// The name of the entry point in the compiled shader to use. + /// + /// If [`Some`], there must be a function with this name and no return value in the shader. + /// Otherwise, there must be exactly one `@compute` entry point in the provided `module`, which + /// will be used for the pipeline. + pub entry_point: Option<&'a str>, } #[cfg(send_sync)] static_assertions::assert_impl_all!(ComputePipelineDescriptor<'_>: Send, Sync);