From ab03bc64a5dceda5e094235ceabf582f62b67330 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Fri, 17 May 2024 17:36:17 +0200 Subject: [PATCH] ensure safety of indirect dispatch by injecting a compute shader that validates the content of the indirect buffer also adds missing indirect buffer offset validation --- deno_webgpu/binding.rs | 2 + deno_webgpu/shader.rs | 1 + tests/tests/dispatch_workgroups_indirect.rs | 197 +++++++++ tests/tests/root.rs | 1 + wgpu-core/Cargo.toml | 4 + wgpu-core/src/binding_model.rs | 7 + wgpu-core/src/command/bind.rs | 14 +- wgpu-core/src/command/compute.rs | 31 ++ wgpu-core/src/command/compute_command.rs | 16 + wgpu-core/src/command/mod.rs | 13 +- wgpu-core/src/command/render.rs | 16 + wgpu-core/src/device/global.rs | 47 ++- wgpu-core/src/device/resource.rs | 51 ++- wgpu-core/src/indirect_validation.rs | 423 ++++++++++++++++++++ wgpu-core/src/instance.rs | 34 +- wgpu-core/src/lib.rs | 2 + wgpu-core/src/pipeline.rs | 8 +- wgpu-core/src/resource.rs | 5 + wgpu/Cargo.toml | 6 + wgpu/src/backend/wgpu_core.rs | 5 + 20 files changed, 862 insertions(+), 21 deletions(-) create mode 100644 tests/tests/dispatch_workgroups_indirect.rs create mode 100644 wgpu-core/src/indirect_validation.rs diff --git a/deno_webgpu/binding.rs b/deno_webgpu/binding.rs index 0efeb6716a9..e5f6d6c613a 100644 --- a/deno_webgpu/binding.rs +++ b/deno_webgpu/binding.rs @@ -224,6 +224,7 @@ pub fn op_webgpu_create_pipeline_layout( label: Some(label), bind_group_layouts: Cow::from(bind_group_layouts), push_constant_ranges: Default::default(), + ignore_push_constant_check: false, }; gfx_put!(device => instance.device_create_pipeline_layout( @@ -288,6 +289,7 @@ pub fn op_webgpu_create_bind_group( buffer_id: buffer_resource.1, offset: entry.offset.unwrap_or(0), size: std::num::NonZeroU64::new(entry.size.unwrap_or(0)), + allow_indirect_as_storage: false, }, ) } diff --git a/deno_webgpu/shader.rs b/deno_webgpu/shader.rs index 17cde43936d..2a5abfa9b62 100644 --- a/deno_webgpu/shader.rs +++ b/deno_webgpu/shader.rs @@ -43,6 +43,7 @@ pub fn op_webgpu_create_shader_module( let descriptor = wgpu_core::pipeline::ShaderModuleDescriptor { label: Some(label), shader_bound_checks: wgpu_types::ShaderBoundChecks::default(), + ignore_push_constant_check: false, }; gfx_put!(device => instance.device_create_shader_module( diff --git a/tests/tests/dispatch_workgroups_indirect.rs b/tests/tests/dispatch_workgroups_indirect.rs new file mode 100644 index 00000000000..f08f39b15aa --- /dev/null +++ b/tests/tests/dispatch_workgroups_indirect.rs @@ -0,0 +1,197 @@ +use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext}; + +/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12). +#[gpu_test] +static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .downlevel_flags( + wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION, + ) + .limits(wgpu::Limits::downlevel_defaults()) + .expect_fail(FailureCase::backend(wgt::Backends::DX12)), + ) + .run_async(|ctx| async move { + let num_workgroups = [1, 2, 3]; + let res = run_test(&ctx, &num_workgroups, false).await; + assert_eq!(res, num_workgroups); + }); + +/// Make sure that we discard (don't run) the dispatch if its size exceeds the device limit. +#[gpu_test] +static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .downlevel_flags( + wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION, + ) + .limits(wgpu::Limits { + max_compute_workgroups_per_dimension: 10, + ..wgpu::Limits::downlevel_defaults() + }), + ) + .run_async(|ctx| async move { + let max = ctx.device.limits().max_compute_workgroups_per_dimension; + + let res = run_test(&ctx, &[max, max, max], false).await; + assert_eq!(res, [max; 3]); + + let res = run_test(&ctx, &[max + 1, 1, 1], false).await; + assert_eq!(res, [0; 3]); + + let res = run_test(&ctx, &[1, max + 1, 1], false).await; + assert_eq!(res, [0; 3]); + + let res = run_test(&ctx, &[1, 1, max + 1], false).await; + assert_eq!(res, [0; 3]); + }); + +/// Make sure that unsetting the bind group set by the validation code works properly. +#[gpu_test] +static UNSET_INTERNAL_BIND_GROUP: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .downlevel_flags( + wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION, + ) + .limits(wgpu::Limits::downlevel_defaults()), + ) + .run_async(|ctx| async move { + ctx.device.push_error_scope(wgpu::ErrorFilter::Validation); + + let _ = run_test(&ctx, &[0, 0, 0], true).await; + + let error = pollster::block_on(ctx.device.pop_error_scope()); + assert!(error.map_or(false, |error| format!("{error}") + .contains("Expected bind group is missing"))); + }); + +async fn run_test( + ctx: &TestingContext, + num_workgroups: &[u32; 3], + forget_to_set_bind_group: bool, +) -> [u32; 3] { + const SHADER_SRC: &str = " + @group(0) @binding(0) + var out: array; + + @compute @workgroup_size(1) + fn main(@builtin(num_workgroups) num_workgroups: vec3, @builtin(workgroup_id) workgroup_id: vec3) { + if (all(workgroup_id == vec3())) { + out[0] = num_workgroups.x; + out[1] = num_workgroups.y; + out[2] = num_workgroups.z; + } + } + "; + + let module = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()), + }); + + let pipeline = ctx + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: "main", + compilation_options: Default::default(), + cache: None, + }); + + let out_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: 12, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: 12, + usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ, + mapped_at_creation: false, + }); + + let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &pipeline.get_bind_group_layout(0), + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: out_buffer.as_entire_binding(), + }], + }); + + let mut res = None; + + for (indirect_offset, indirect_buffer_size) in [ + // internal src buffer binding size will be buffer.size + (0, 12), + (4, 4 + 12), + (4, 8 + 12), + (256 * 2 - 4 - 12, 256 * 2 - 4), + // internal src buffer binding size will be 256 * 2 + x + (0, 256 * 2 * 2 + 4), + (256, 256 * 2 * 2 + 8), + (256 + 4, 256 * 2 * 2 + 12), + (256 * 2 + 16, 256 * 2 * 2 + 16), + (256 * 2 * 2, 256 * 2 * 2 + 32), + (256 + 12, 256 * 2 * 2 + 64), + ] { + let indirect_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: indirect_buffer_size, + usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::INDIRECT, + mapped_at_creation: false, + }); + + ctx.queue.write_buffer( + &indirect_buffer, + indirect_offset, + bytemuck::bytes_of(num_workgroups), + ); + + let mut encoder = ctx + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor::default()); + { + let mut compute_pass = + encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default()); + compute_pass.set_pipeline(&pipeline); + if !forget_to_set_bind_group { + compute_pass.set_bind_group(0, &bind_group, &[]); + } + compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset); + } + + encoder.copy_buffer_to_buffer(&out_buffer, 0, &readback_buffer, 0, 12); + + ctx.queue.submit(Some(encoder.finish())); + + readback_buffer + .slice(..) + .map_async(wgpu::MapMode::Read, |_| {}); + + ctx.async_poll(wgpu::Maintain::wait()) + .await + .panic_on_timeout(); + + let view = readback_buffer.slice(..).get_mapped_range(); + + let current_res = *bytemuck::from_bytes(&view); + drop(view); + readback_buffer.unmap(); + + if let Some(past_res) = res { + assert_eq!(past_res, current_res); + } else { + res = Some(current_res); + } + } + + res.unwrap() +} diff --git a/tests/tests/root.rs b/tests/tests/root.rs index 1cb5b56c7c0..d5f55132699 100644 --- a/tests/tests/root.rs +++ b/tests/tests/root.rs @@ -14,6 +14,7 @@ mod clear_texture; mod compute_pass_ownership; mod create_surface_error; mod device; +mod dispatch_workgroups_indirect; mod encoder; mod external_texture; mod float32_filterable; diff --git a/wgpu-core/Cargo.toml b/wgpu-core/Cargo.toml index f8c28b8793f..24da632a8dd 100644 --- a/wgpu-core/Cargo.toml +++ b/wgpu-core/Cargo.toml @@ -46,6 +46,10 @@ renderdoc = ["hal/renderdoc"] ## to the validation carried out at public APIs in all builds. strict_asserts = ["wgt/strict_asserts"] +## Validates indirect draw/dispatch calls. This will also enable naga's +## WGSL frontend since we use a WGSL compute shader to do the validation. +indirect-validation = ["naga/wgsl-in"] + ## Enables serialization via `serde` on common wgpu types. serde = ["dep:serde", "wgt/serde", "arrayvec/serde"] diff --git a/wgpu-core/src/binding_model.rs b/wgpu-core/src/binding_model.rs index 732c152dcfd..36ecfd5bf5f 100644 --- a/wgpu-core/src/binding_model.rs +++ b/wgpu-core/src/binding_model.rs @@ -617,6 +617,10 @@ pub struct PipelineLayoutDescriptor<'a> { /// [`Features::PUSH_CONSTANTS`](wgt::Features::PUSH_CONSTANTS) feature must /// be enabled. pub push_constant_ranges: Cow<'a, [wgt::PushConstantRange]>, + /// This is an internal flag used by indirect validation. + /// It allows usage of push constants without having the + /// [`Features::PUSH_CONSTANTS`](wgt::Features::PUSH_CONSTANTS) feature enabled. + pub ignore_push_constant_check: bool, } #[derive(Debug)] @@ -758,6 +762,9 @@ pub struct BufferBinding { pub buffer_id: BufferId, pub offset: wgt::BufferAddress, pub size: Option, + /// This is an internal flag used by indirect validation. + /// It allows indirect buffers to be bound as storage buffers. + pub allow_indirect_as_storage: bool, } // Note: Duplicated in `wgpu-rs` as `BindingResource` diff --git a/wgpu-core/src/command/bind.rs b/wgpu-core/src/command/bind.rs index c643611a967..75d566c5963 100644 --- a/wgpu-core/src/command/bind.rs +++ b/wgpu-core/src/command/bind.rs @@ -131,7 +131,7 @@ mod compat { diff.push(format!("Expected {expected_bgl_type} bind group layout, got {assigned_bgl_type}")) } } else { - diff.push("Assigned bind group layout not found (internal error)".to_owned()); + diff.push("Expected bind group is missing".to_owned()); } } else { diff.push("Expected bind group layout not found (internal error)".to_owned()); @@ -191,6 +191,10 @@ mod compat { self.make_range(index) } + pub fn unassign(&mut self, index: usize) { + self.entries[index].assigned = None; + } + pub fn list_active(&self) -> impl Iterator + '_ { self.entries .iter() @@ -358,6 +362,14 @@ impl Binder { &self.payloads[bind_range] } + pub(super) fn unassign_group(&mut self, index: usize) { + log::trace!("\tBinding [{}] = null", index); + + self.payloads[index].reset(); + + self.manager.unassign(index); + } + pub(super) fn list_active<'a>(&'a self) -> impl Iterator>> + '_ { let payloads = &self.payloads; self.manager diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index acbff0a0304..24cd716a007 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -182,6 +182,8 @@ pub enum ComputePassErrorInner { InvalidQuerySet(id::QuerySetId), #[error("Indirect buffer {0:?} is invalid or destroyed")] InvalidIndirectBuffer(id::BufferId), + #[error("Indirect buffer offset {0:?} is not a multiple of 4")] + UnalignedIndirectBufferOffset(BufferAddress), #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")] IndirectBufferOverrun { offset: u64, @@ -473,6 +475,16 @@ impl Global { .map_pass_err(pass_scope); } + #[cfg(feature = "indirect-validation")] + let mut base = base; + #[cfg(feature = "indirect-validation")] + device + .indirect_validation + .get() + .unwrap() + .inject_dispatch_indirect_validation(device, &mut base) + .map_pass_err(pass_scope)?; + let mut cmd_buf_data = cmd_buf.data.lock(); let cmd_buf_data = cmd_buf_data.as_mut().unwrap(); @@ -654,6 +666,20 @@ impl Global { } } } + ArcComputeCommand::UnsetBindGroup { index } => { + let scope = PassErrorScope::UnsetBindGroup(index); + + let max_bind_groups = cmd_buf.limits.max_bind_groups; + if index >= max_bind_groups { + return Err(ComputePassErrorInner::BindGroupIndexOutOfRange { + index, + max: max_bind_groups, + }) + .map_pass_err(scope); + } + + state.binder.unassign_group(index as usize); + } ArcComputeCommand::SetPipeline(pipeline) => { let pipeline_id = pipeline.as_info().id(); let scope = PassErrorScope::SetPipelineCompute(pipeline_id); @@ -811,6 +837,11 @@ impl Global { check_buffer_usage(buffer_id, buffer.usage, wgt::BufferUsages::INDIRECT) .map_pass_err(scope)?; + if offset % 4 != 0 { + return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset)) + .map_pass_err(scope); + } + let end_offset = offset + mem::size_of::() as u64; if end_offset > buffer.size { return Err(ComputePassErrorInner::IndirectBufferOverrun { diff --git a/wgpu-core/src/command/compute_command.rs b/wgpu-core/src/command/compute_command.rs index 49fdbbec24d..fd3b531628e 100644 --- a/wgpu-core/src/command/compute_command.rs +++ b/wgpu-core/src/command/compute_command.rs @@ -19,6 +19,10 @@ pub enum ComputeCommand { bind_group_id: id::BindGroupId, }, + UnsetBindGroup { + index: u32, + }, + SetPipeline(id::ComputePipelineId), /// Set a range of push constants to values stored in `push_constant_data`. @@ -103,6 +107,10 @@ impl ComputeCommand { })?, }, + ComputeCommand::UnsetBindGroup { index } => { + ArcComputeCommand::UnsetBindGroup { index } + } + ComputeCommand::SetPipeline(pipeline_id) => ArcComputeCommand::SetPipeline( pipelines_guard .get_owned(pipeline_id) @@ -194,6 +202,10 @@ pub enum ArcComputeCommand { bind_group: Arc>, }, + UnsetBindGroup { + index: u32, + }, + SetPipeline(Arc>), /// Set a range of push constants to values stored in `push_constant_data`. @@ -261,6 +273,10 @@ impl From<&ArcComputeCommand> for ComputeCommand { bind_group_id: bind_group.as_info().id(), }, + ArcComputeCommand::UnsetBindGroup { index } => { + ComputeCommand::UnsetBindGroup { index: *index } + } + ArcComputeCommand::SetPipeline(pipeline) => { ComputeCommand::SetPipeline(pipeline.as_info().id()) } diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index 874e207a278..9c2e5d35d7f 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -15,8 +15,15 @@ use std::sync::Arc; pub(crate) use self::clear::clear_texture; pub use self::{ - bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*, - dyn_compute_pass::DynComputePass, query::*, render::*, transfer::*, + bundle::*, + clear::ClearError, + compute::*, + compute_command::{ArcComputeCommand, ComputeCommand}, + draw::*, + dyn_compute_pass::DynComputePass, + query::*, + render::*, + transfer::*, }; pub(crate) use allocator::CommandAllocator; @@ -892,6 +899,8 @@ pub enum PassErrorScope { Pass(Option), #[error("In a set_bind_group command")] SetBindGroup(id::BindGroupId), + #[error("In a unset_bind_group command, slot: {0}")] + UnsetBindGroup(u32), #[error("In a set_pipeline command")] SetPipelineRender(id::RenderPipelineId), #[error("In a set_pipeline command")] diff --git a/wgpu-core/src/command/render.rs b/wgpu-core/src/command/render.rs index defd6a608ba..71b1beeb0e8 100644 --- a/wgpu-core/src/command/render.rs +++ b/wgpu-core/src/command/render.rs @@ -616,6 +616,8 @@ pub enum RenderPassErrorInner { MissingFeatures(#[from] MissingFeatures), #[error(transparent)] MissingDownlevelFlags(#[from] MissingDownlevelFlags), + #[error("Indirect buffer offset {0:?} is not a multiple of 4")] + UnalignedIndirectBufferOffset(BufferAddress), #[error("Indirect draw uses bytes {offset}..{end_offset} {} which overruns indirect buffer of size {buffer_size}", count.map_or_else(String::new, |v| format!("(using count {v})")))] IndirectBufferOverrun { @@ -2050,6 +2052,13 @@ impl Global { let actual_count = count.map_or(1, |c| c.get()); + if offset % 4 != 0 { + return Err(RenderPassErrorInner::UnalignedIndirectBufferOffset( + offset, + )) + .map_pass_err(scope); + } + let end_offset = offset + stride as u64 * actual_count as u64; if end_offset > indirect_buffer.size { return Err(RenderPassErrorInner::IndirectBufferOverrun { @@ -2141,6 +2150,13 @@ impl Global { .ok_or(RenderCommandError::DestroyedBuffer(count_buffer_id)) .map_pass_err(scope)?; + if offset % 4 != 0 { + return Err(RenderPassErrorInner::UnalignedIndirectBufferOffset( + offset, + )) + .map_pass_err(scope); + } + let end_offset = offset + stride * max_count as u64; if end_offset > indirect_buffer.size { return Err(RenderPassErrorInner::IndirectBufferOverrun { diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index a5c51b269f7..5c460e3f817 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -262,6 +262,51 @@ impl Global { let (id, resource) = fid.assign(Arc::new(buffer)); api_log!("Device::create_buffer({desc:?}) -> {id:?}"); + #[cfg(feature = "indirect-validation")] + if desc.usage.contains(wgt::BufferUsages::INDIRECT) { + // We create an indirect buffer in IndirectValidation's constructor, + // device.indirect_validation won't yet be set. + if let Some(indirect_validation) = device.indirect_validation.get() { + let binding_size = + crate::indirect_validation::IndirectValidation::calculate_src_buffer_binding_size(&device, &resource); + let (bg_id, error) = self.device_create_bind_group::( + device_id, + &crate::binding_model::BindGroupDescriptor { + label: None, + layout: indirect_validation.src_bind_group_layout, + entries: std::borrow::Cow::Borrowed(&[ + crate::binding_model::BindGroupEntry { + binding: 0, + resource: crate::binding_model::BindingResource::Buffer( + crate::binding_model::BufferBinding { + buffer_id: id, + offset: 0, + size: Some( + std::num::NonZeroU64::new(binding_size).unwrap(), + ), + allow_indirect_as_storage: true, + }, + ), + }, + ]), + }, + None, + ); + if let Some(error) = error { + let buffer = hub.buffers.write().replace_with_error(id).unwrap(); + device + .lock_life() + .schedule_resource_destruction(queue::TempResource::Buffer(buffer), !0); + return ( + id, + Some(CreateBufferError::IndirectValidationBindGroup(error)), + ); + } + let bg = hub.bind_groups.write().remove(bg_id).unwrap(); + resource.indirect_validation_bind_group.set(bg).unwrap(); + } + } + device .trackers .lock() @@ -1215,7 +1260,7 @@ impl Global { #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { let data = match source { - #[cfg(feature = "wgsl")] + #[cfg(any(feature = "wgsl", feature = "indirect-validation"))] pipeline::ShaderModuleSource::Wgsl(ref code) => { trace.make_binary("wgsl", code.as_bytes()) } diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index f9242848c87..2ca2cc495c3 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -96,6 +96,8 @@ pub struct Device { pub(crate) queue: OnceCell>>, queue_to_drop: OnceCell, pub(crate) zero_buffer: Option, + #[cfg(feature = "indirect-validation")] + pub(crate) indirect_validation: OnceCell>, pub(crate) info: ResourceInfo>, pub(crate) command_allocator: command::CommandAllocator, @@ -270,6 +272,8 @@ impl Device { queue: OnceCell::new(), queue_to_drop: OnceCell::new(), zero_buffer: Some(zero_buffer), + #[cfg(feature = "indirect-validation")] + indirect_validation: OnceCell::new(), info: ResourceInfo::new("", None), command_allocator, active_submission_index: AtomicU64::new(0), @@ -588,6 +592,12 @@ impl Device { return Err(resource::CreateBufferError::InvalidUsage(desc.usage)); } + if desc.usage.contains(wgt::BufferUsages::INDIRECT) { + // We are going to be reading from it, internally; + // when validating the content of the buffer + usage |= hal::BufferUses::STORAGE_READ | hal::BufferUses::STORAGE_READ_WRITE; + } + if !self .features .contains(wgt::Features::MAPPABLE_PRIMARY_BUFFERS) @@ -659,6 +669,8 @@ impl Device { Some(self.tracker_indices.buffers.clone()), ), bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, Vec::new()), + #[cfg(feature = "indirect-validation")] + indirect_validation_bind_group: OnceCell::new(), }) } @@ -719,6 +731,8 @@ impl Device { Some(self.tracker_indices.buffers.clone()), ), bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, Vec::new()), + #[cfg(feature = "indirect-validation")] + indirect_validation_bind_group: OnceCell::new(), } } @@ -1416,7 +1430,7 @@ impl Device { source: pipeline::ShaderModuleSource<'a>, ) -> Result, pipeline::CreateShaderModuleError> { let (module, source) = match source { - #[cfg(feature = "wgsl")] + #[cfg(any(feature = "wgsl", feature = "indirect-validation"))] pipeline::ShaderModuleSource::Wgsl(code) => { profiling::scope!("naga::front::wgsl::parse_str"); let module = naga::front::wgsl::parse_str(&code).map_err(|inner| { @@ -1486,7 +1500,12 @@ impl Device { }; let info = create_validator( - self.features, + self.features + | if desc.ignore_push_constant_check { + wgt::Features::PUSH_CONSTANTS + } else { + wgt::Features::empty() + }, self.downlevel.flags, naga::valid::ValidationFlags::all(), ) @@ -1866,7 +1885,7 @@ impl Device { } }; - let (pub_usage, internal_use, range_limit) = match binding_ty { + let (mut pub_usage, internal_use, range_limit) = match binding_ty { wgt::BufferBindingType::Uniform => ( wgt::BufferUsages::UNIFORM, hal::BufferUses::UNIFORM, @@ -1902,7 +1921,14 @@ impl Device { return Err(DeviceError::WrongDevice.into()); } + // Allow indirect buffers to be bound as storage buffers so that + // we can validate their content. Note that we already pass this + // usage to hal at buffer creation. + if buffer.usage.contains(wgt::BufferUsages::INDIRECT) && bb.allow_indirect_as_storage { + pub_usage = pub_usage.difference(wgt::BufferUsages::STORAGE); + } check_buffer_usage(bb.buffer_id, buffer.usage, pub_usage)?; + let raw_buffer = buffer .raw .get(snatch_guard) @@ -2485,7 +2511,7 @@ impl Device { }); } - if !desc.push_constant_ranges.is_empty() { + if !desc.ignore_push_constant_check && !desc.push_constant_ranges.is_empty() { self.require_features(wgt::Features::PUSH_CONSTANTS)?; } @@ -2500,13 +2526,15 @@ impl Device { } used_stages |= pc.stages; - let device_max_pc_size = self.limits.max_push_constant_size; - if device_max_pc_size < pc.range.end { - return Err(Error::PushConstantRangeTooLarge { - index, - range: pc.range.clone(), - max: device_max_pc_size, - }); + if !desc.ignore_push_constant_check { + let device_max_pc_size = self.limits.max_push_constant_size; + if device_max_pc_size < pc.range.end { + return Err(Error::PushConstantRangeTooLarge { + index, + range: pc.range.clone(), + max: device_max_pc_size, + }); + } } if pc.range.start % wgt::PUSH_CONSTANT_ALIGNMENT != 0 { @@ -2617,6 +2645,7 @@ impl Device { label: None, bind_group_layouts: Cow::Borrowed(&ids.group_ids[..group_count]), push_constant_ranges: Cow::Borrowed(&[]), //TODO? + ignore_push_constant_check: false, }; let layout = self.create_pipeline_layout(&layout_desc, bgl_registry)?; pipeline_layout_registry.force_replace(ids.root_id, layout); diff --git a/wgpu-core/src/indirect_validation.rs b/wgpu-core/src/indirect_validation.rs new file mode 100644 index 00000000000..7d97d111bea --- /dev/null +++ b/wgpu-core/src/indirect_validation.rs @@ -0,0 +1,423 @@ +use std::{mem, sync::Arc}; + +use thiserror::Error; +use wgt::{PushConstantRange, ShaderStages}; + +use crate::{ + binding_model::{ + BindGroup, CreateBindGroupError, CreateBindGroupLayoutError, CreatePipelineLayoutError, + }, + command::{ArcComputeCommand, BasePass, ComputePassErrorInner}, + device::{Device, DeviceError}, + global::Global, + hal_api::HalApi, + id, + pipeline::{ComputePipeline, CreateComputePipelineError, CreateShaderModuleError}, + resource::{Buffer, CreateBufferError, Resource}, + validation::check_buffer_usage, +}; + +#[derive(Clone, Debug, Error)] +#[non_exhaustive] +pub enum CreateDispatchIndirectValidationPipelineError { + #[error(transparent)] + Device(#[from] DeviceError), + #[error(transparent)] + ShaderModule(#[from] CreateShaderModuleError), + #[error(transparent)] + BindGroupLayout(#[from] CreateBindGroupLayoutError), + #[error(transparent)] + PipelineLayout(#[from] CreatePipelineLayoutError), + #[error(transparent)] + ComputePipeline(#[from] CreateComputePipelineError), + #[error(transparent)] + Buffer(#[from] CreateBufferError), + #[error(transparent)] + BindGroup(#[from] CreateBindGroupError), + #[error("invalid id")] + InvalidId, +} + +#[derive(Debug)] +pub struct IndirectValidation { + pub pipeline: Arc>, + pub dst_buffer: Arc>, + pub dst_bind_group: Arc>, + pub src_bind_group_layout: id::BindGroupLayoutId, +} + +impl IndirectValidation { + pub fn new( + global: &Global, + device: Arc>, + ) -> Result { + let max_compute_workgroups_per_dimension = + device.limits.max_compute_workgroups_per_dimension; + + let src = format!(" + @group(0) @binding(0) + var dst: array; + @group(1) @binding(0) + var src: array; + struct OffsetPc {{ + inner: u32, + }} + var offset: OffsetPc; + + @compute @workgroup_size(1) + fn main() {{ + let src = vec3(src[offset.inner], src[offset.inner + 1], src[offset.inner + 2]); + let res = select(src, vec3(), src > vec3({max_compute_workgroups_per_dimension}u)); + dst[0] = res.x; + dst[1] = res.y; + dst[2] = res.z; + }} + "); + + let device_id = device.info.id(); + + let (module, error) = global.device_create_shader_module::( + device_id, + &crate::pipeline::ShaderModuleDescriptor { + label: None, + shader_bound_checks: wgt::ShaderBoundChecks::default(), + ignore_push_constant_check: true, + }, + crate::pipeline::ShaderModuleSource::Wgsl(std::borrow::Cow::Owned(src)), + None, + ); + if let Some(error) = error { + return Err(error.into()); + } + + let (dst_bind_group_layout, error) = global.device_create_bind_group_layout::( + device_id, + &crate::binding_model::BindGroupLayoutDescriptor { + label: None, + entries: std::borrow::Cow::Borrowed(&[wgt::BindGroupLayoutEntry { + binding: 0, + visibility: wgt::ShaderStages::COMPUTE, + ty: wgt::BindingType::Buffer { + ty: wgt::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), + }, + count: None, + }]), + }, + None, + ); + if let Some(error) = error { + return Err(error.into()); + } + + let (src_bind_group_layout, error) = global.device_create_bind_group_layout::( + device_id, + &crate::binding_model::BindGroupLayoutDescriptor { + label: None, + entries: std::borrow::Cow::Borrowed(&[wgt::BindGroupLayoutEntry { + binding: 0, + visibility: wgt::ShaderStages::COMPUTE, + ty: wgt::BindingType::Buffer { + ty: wgt::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: true, + min_binding_size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), + }, + count: None, + }]), + }, + None, + ); + if let Some(error) = error { + return Err(error.into()); + } + + let (layout, error) = global.device_create_pipeline_layout::( + device_id, + &crate::binding_model::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: std::borrow::Cow::Borrowed(&[ + dst_bind_group_layout, + src_bind_group_layout, + ]), + push_constant_ranges: std::borrow::Cow::Borrowed(&[PushConstantRange { + stages: ShaderStages::COMPUTE, + range: 0..4, + }]), + ignore_push_constant_check: true, + }, + None, + ); + if let Some(error) = error { + return Err(error.into()); + } + + let (pipeline, error) = global.device_create_compute_pipeline::( + device_id, + &crate::pipeline::ComputePipelineDescriptor { + label: None, + layout: Some(layout), + stage: crate::pipeline::ProgrammableStageDescriptor { + module, + entry_point: Some(std::borrow::Cow::Borrowed("main")), + constants: Default::default(), + zero_initialize_workgroup_memory: true, + vertex_pulling_transform: true, + }, + cache: None, + }, + None, + None, + ); + if let Some(error) = error { + return Err(error.into()); + } + + let (dst_buffer_id, error) = global.device_create_buffer::( + device_id, + &crate::resource::BufferDescriptor { + label: None, + size: 4 * 3, + usage: wgt::BufferUsages::INDIRECT | wgt::BufferUsages::STORAGE, + mapped_at_creation: false, + }, + None, + ); + if let Some(error) = error { + return Err(error.into()); + } + + let (dst_bind_group_id, error) = global.device_create_bind_group::( + device_id, + &crate::binding_model::BindGroupDescriptor { + label: None, + layout: dst_bind_group_layout, + entries: std::borrow::Cow::Borrowed(&[crate::binding_model::BindGroupEntry { + binding: 0, + resource: crate::binding_model::BindingResource::Buffer( + crate::binding_model::BufferBinding { + buffer_id: dst_buffer_id, + offset: 0, + size: Some(std::num::NonZeroU64::new(4 * 3).unwrap()), + allow_indirect_as_storage: false, + }, + ), + }]), + }, + None, + ); + if let Some(error) = error { + return Err(error.into()); + } + + let hub = A::hub(global); + + let pipeline = hub + .compute_pipelines + .write() + .remove(pipeline) + .ok_or(CreateDispatchIndirectValidationPipelineError::InvalidId)?; + + let dst_buffer = hub + .buffers + .write() + .remove(dst_buffer_id) + .ok_or(CreateDispatchIndirectValidationPipelineError::InvalidId)?; + + let dst_bind_group = hub + .bind_groups + .write() + .remove(dst_bind_group_id) + .ok_or(CreateDispatchIndirectValidationPipelineError::InvalidId)?; + + Ok(Self { + pipeline, + dst_buffer, + dst_bind_group, + src_bind_group_layout, + }) + } + + pub(crate) fn calculate_src_buffer_binding_size(device: &Device, buffer: &Buffer) -> u64 { + let alignment = device.limits.min_storage_buffer_offset_alignment as u64; + + // We need to choose a binding size that can address all possible sets of 12 contiguous bytes in the buffer taking + // into account that the dynamic offset needs to be a multiple of `min_storage_buffer_offset_alignment`. + + // Given the know variables: `offset`, `buffer_size`, `alignment` and the rule `offset + 12 <= buffer_size`. + + // Let `chunks = floor(buffer_size / alignment)`. + // Let `chunk` be the interval `[0, chunks]`. + // Let `offset = alignment * chunk + r` where `r` is the interval [0, alignment - 4]. + // Let `binding` be the interval `[offset, offset + 12]`. + // Let `aligned_offset = alignment * chunk`. + // Let `aligned_binding` be the interval `[aligned_offset, aligned_offset + r + 12]`. + // Let `aligned_binding_size = r + 12 = [12, alignment + 8]`. + // Let `min_aligned_binding_size = alignment + 8`. + + // `min_aligned_binding_size` is the minimum binding size required to address all 12 contiguous bytes in the buffer + // but the last aligned_offset + min_aligned_binding_size might overflow the buffer. In order to avoid this we must + // pick a larger `binding_size` that satisfies: `last_aligned_offset + binding_size = buffer_size` and + // `binding_size >= min_aligned_binding_size`. + + // Let `buffer_size = alignment * chunks + sr` where `sr` is the interval [0, alignment - 4]. + // Let `last_aligned_offset = alignment * (chunks - u)` where `u` is the interval [0, chunks]. + // => `binding_size = buffer_size - last_aligned_offset` + // => `binding_size = alignment * chunks + sr - alignment * (chunks - u)` + // => `binding_size = alignment * chunks + sr - alignment * chunks + alignment * u` + // => `binding_size = sr + alignment * u` + // => `min_aligned_binding_size <= sr + alignment * u` + // => `alignment + 8 <= sr + alignment * u` + // => `u` must be at least 2 + // => `binding_size = sr + alignment * 2` + + let binding_size = 2 * alignment + (buffer.size % alignment); + binding_size.min(buffer.size) + } + + pub fn inject_dispatch_indirect_validation( + &self, + device: &Device, + base: &mut BasePass>, + ) -> Result<(), ComputePassErrorInner> { + if !base + .commands + .iter() + .any(|cmd| matches!(cmd, ArcComputeCommand::DispatchIndirect { .. })) + { + return Ok(()); + } + + profiling::scope!("CommandEncoder::inject_dispatch_indirect_validation"); + + let mut new_commands = Vec::with_capacity(base.commands.len()); + let mut current_pipeline = None; + let mut current_first_2_bind_groups = [None, None]; + + for command in base.commands.drain(..) { + match command { + ArcComputeCommand::SetBindGroup { + index, + num_dynamic_offsets, + ref bind_group, + } => { + if index == 0 || index == 1 { + current_first_2_bind_groups[index as usize] = + Some((num_dynamic_offsets, bind_group.clone())); + } + new_commands.push(command); + } + ArcComputeCommand::SetPipeline(ref pipeline) => { + current_pipeline = Some(pipeline.clone()); + new_commands.push(command); + } + ArcComputeCommand::DispatchIndirect { ref buffer, offset } => { + // if there is no pipeline set, don't inject the validation commands as we will error anyway + if let Some(original_pipeline) = current_pipeline.clone() { + // validate some buffer properties that won't be validated later + check_buffer_usage( + buffer.as_info().id(), + buffer.usage, + wgt::BufferUsages::INDIRECT, + )?; + + if offset % 4 != 0 { + return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset( + offset, + )); + } + + let end_offset = + offset + mem::size_of::() as u64; + if end_offset > buffer.size { + return Err(ComputePassErrorInner::IndirectBufferOverrun { + offset, + end_offset, + buffer_size: buffer.size, + }); + } + + // The offset we receive is only required to be aligned to 4 bytes. + // + // Binding offsets and dynamic offsets are required to be aligned to + // min_storage_buffer_offset_alignment (256 bytes by default). + // + // So, we work around this limitation by calculating an aligned offset + // and pass the remainder through a push constant. + // + // We could bind the whole buffer and only have to pass the offset + // through a push constant but we might run into the + // max_storage_buffer_binding_size limit. + // + // See the inner docs of `calculate_src_buffer_binding_size` to + // see how we get the appropriate `binding_size`. + let alignment = device.limits.min_storage_buffer_offset_alignment as u64; + let binding_size = Self::calculate_src_buffer_binding_size(device, buffer); + let aligned_offset = offset - offset % alignment; + // This works because `binding_size` is either `buffer.size` or `alignment * 2 + buffer.size % alignment`. + let max_aligned_offset = buffer.size - binding_size; + let aligned_offset = aligned_offset.min(max_aligned_offset); + let offset_remainder = offset - aligned_offset; + + new_commands.push(ArcComputeCommand::SetPipeline(self.pipeline.clone())); + + base.dynamic_offsets.push(aligned_offset as u32); + + let values_offset = base.push_constant_data.len() as u32; + base.push_constant_data.push(offset_remainder as u32 / 4); + + new_commands.push(ArcComputeCommand::SetPushConstant { + offset: 0, + size_bytes: 4, + values_offset, + }); + new_commands.push(ArcComputeCommand::SetBindGroup { + index: 0, + num_dynamic_offsets: 0, + bind_group: self.dst_bind_group.clone(), + }); + new_commands.push(ArcComputeCommand::SetBindGroup { + index: 1, + num_dynamic_offsets: 1, + bind_group: buffer + .indirect_validation_bind_group + .get() + .unwrap() + .clone(), + }); + new_commands.push(ArcComputeCommand::Dispatch([1, 1, 1])); + + new_commands.push(ArcComputeCommand::SetPipeline(original_pipeline)); + for (index, current_bind_group) in + current_first_2_bind_groups.iter().enumerate() + { + if let Some((num_dynamic_offsets, bind_group)) = + current_bind_group.clone() + { + new_commands.push(ArcComputeCommand::SetBindGroup { + index: index as u32, + num_dynamic_offsets, + bind_group, + }); + } else { + new_commands.push(ArcComputeCommand::UnsetBindGroup { + index: index as u32, + }); + } + } + new_commands.push(ArcComputeCommand::DispatchIndirect { + buffer: self.dst_buffer.clone(), + offset: 0, + }); + } else { + new_commands.push(command) + } + } + command => new_commands.push(command), + } + } + base.commands = new_commands; + + Ok(()) + } +} diff --git a/wgpu-core/src/instance.rs b/wgpu-core/src/instance.rs index 5d21ed0398f..86965e52211 100644 --- a/wgpu-core/src/instance.rs +++ b/wgpu-core/src/instance.rs @@ -1109,10 +1109,9 @@ impl Global { Ok((device, queue)) => (device, queue), Err(e) => break e, }; - let (device_id, _) = device_fid.assign(Arc::new(device)); + let (device_id, device) = device_fid.assign(Arc::new(device)); resource_log!("Created Device {:?}", device_id); - let device = hub.devices.get(device_id).unwrap(); queue.device = Some(device.clone()); let (queue_id, queue) = queue_fid.assign(Arc::new(queue)); @@ -1120,6 +1119,20 @@ impl Global { device.set_queue(queue); + #[cfg(feature = "indirect-validation")] + match crate::indirect_validation::IndirectValidation::new(self, device.clone()) { + Ok(indirect_validation) => { + device.indirect_validation.set(indirect_validation).unwrap(); + } + Err(_) => { + hub.devices + .force_replace_with_error(device_id, desc.label.borrow_or_default()); + hub.queues + .force_replace_with_error(queue_id, desc.label.borrow_or_default()); + return (device_id, queue_id, Some(RequestDeviceError::Internal)); + } + }; + return (device_id, queue_id, None); }; @@ -1161,10 +1174,9 @@ impl Global { Ok(device) => device, Err(e) => break e, }; - let (device_id, _) = devices_fid.assign(Arc::new(device)); + let (device_id, device) = devices_fid.assign(Arc::new(device)); resource_log!("Created Device {:?}", device_id); - let device = hub.devices.get(device_id).unwrap(); queue.device = Some(device.clone()); let (queue_id, queue) = queues_fid.assign(Arc::new(queue)); @@ -1172,6 +1184,20 @@ impl Global { device.set_queue(queue); + #[cfg(feature = "indirect-validation")] + match crate::indirect_validation::IndirectValidation::new(self, device.clone()) { + Ok(indirect_validation) => { + device.indirect_validation.set(indirect_validation).unwrap(); + } + Err(_) => { + hub.devices + .force_replace_with_error(device_id, desc.label.borrow_or_default()); + hub.queues + .force_replace_with_error(queue_id, desc.label.borrow_or_default()); + return (device_id, queue_id, Some(RequestDeviceError::Internal)); + } + }; + return (device_id, queue_id, None); }; diff --git a/wgpu-core/src/lib.rs b/wgpu-core/src/lib.rs index ebf80091c3f..ad85b5d0cdd 100644 --- a/wgpu-core/src/lib.rs +++ b/wgpu-core/src/lib.rs @@ -61,6 +61,8 @@ mod hash_utils; pub mod hub; pub mod id; pub mod identity; +#[cfg(feature = "indirect-validation")] +mod indirect_validation; mod init_tracker; pub mod instance; mod lock; diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index f3e7dbacb27..ab3f5045610 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -26,7 +26,7 @@ pub(crate) struct LateSizedBufferGroup { #[allow(clippy::large_enum_variant)] pub enum ShaderModuleSource<'a> { - #[cfg(feature = "wgsl")] + #[cfg(any(feature = "wgsl", feature = "indirect-validation"))] Wgsl(Cow<'a, str>), #[cfg(feature = "glsl")] Glsl(Cow<'a, str>, naga::front::glsl::Options), @@ -45,6 +45,10 @@ pub struct ShaderModuleDescriptor<'a> { pub label: Label<'a>, #[cfg_attr(feature = "serde", serde(default))] pub shader_bound_checks: wgt::ShaderBoundChecks, + /// This is an internal flag used by indirect validation. + /// It allows usage of push constants without having the + /// [`Features::PUSH_CONSTANTS`](wgt::Features::PUSH_CONSTANTS) feature enabled. + pub ignore_push_constant_check: bool, } #[derive(Debug)] @@ -113,7 +117,7 @@ impl ShaderModule { #[derive(Clone, Debug, Error)] #[non_exhaustive] pub enum CreateShaderModuleError { - #[cfg(feature = "wgsl")] + #[cfg(any(feature = "wgsl", feature = "indirect-validation"))] #[error(transparent)] Parsing(#[from] ShaderError), #[cfg(feature = "glsl")] diff --git a/wgpu-core/src/resource.rs b/wgpu-core/src/resource.rs index 9ae275615ad..d60485c8c4e 100644 --- a/wgpu-core/src/resource.rs +++ b/wgpu-core/src/resource.rs @@ -22,6 +22,7 @@ use crate::{ }; use hal::CommandEncoder; +use once_cell::sync::OnceCell; use smallvec::SmallVec; use thiserror::Error; use wgt::WasmNotSendSync; @@ -400,6 +401,8 @@ pub struct Buffer { pub(crate) info: ResourceInfo>, pub(crate) map_state: Mutex>, pub(crate) bind_groups: Mutex>>>, + #[cfg(feature = "indirect-validation")] + pub(crate) indirect_validation_bind_group: OnceCell>>, } impl Drop for Buffer { @@ -611,6 +614,8 @@ pub enum CreateBufferError { MaxBufferSize { requested: u64, maximum: u64 }, #[error(transparent)] MissingDownlevelFlags(#[from] MissingDownlevelFlags), + #[error("Failed to create bind group for indirect buffer validation: {0}")] + IndirectValidationBindGroup(#[from] crate::binding_model::CreateBindGroupError), } impl Resource for Buffer { diff --git a/wgpu/Cargo.toml b/wgpu/Cargo.toml index 81927f0a632..05c287dcd60 100644 --- a/wgpu/Cargo.toml +++ b/wgpu/Cargo.toml @@ -124,6 +124,12 @@ features = ["raw-window-handle"] workspace = true features = ["raw-window-handle"] +# If we are not targeting WebGL, enable indirect-validation. +# WebGL doesn't support indirect execution so this is not needed. +[target.'cfg(not(target_arch = "wasm32"))'.dependencies.wgc] +workspace = true +features = ["indirect-validation"] + # Enable `wgc` by default on macOS and iOS to allow the `metal` crate feature to # enable the Metal backend while being no-op on other targets. [target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies.wgc] diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index e00bd4a3848..a230d06937e 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -892,6 +892,7 @@ impl crate::Context for ContextWgpuCore { let descriptor = wgc::pipeline::ShaderModuleDescriptor { label: desc.label.map(Borrowed), shader_bound_checks, + ignore_push_constant_check: false, }; let source = match desc.source { #[cfg(feature = "spirv")] @@ -950,6 +951,7 @@ impl crate::Context for ContextWgpuCore { // Doesn't matter the value since spirv shaders aren't mutated to include // runtime checks shader_bound_checks: unsafe { wgt::ShaderBoundChecks::unchecked() }, + ignore_push_constant_check: false, }; let (id, error) = wgc::gfx_select!( device => self.0.device_create_shader_module_spirv(*device, &descriptor, Borrowed(&desc.source), None) @@ -1033,6 +1035,7 @@ impl crate::Context for ContextWgpuCore { buffer_id: binding.buffer.id.into(), offset: binding.offset, size: binding.size, + allow_indirect_as_storage: false, })); } } @@ -1053,6 +1056,7 @@ impl crate::Context for ContextWgpuCore { buffer_id: buffer.id.into(), offset, size, + allow_indirect_as_storage: false, }), BindingResource::BufferArray(array) => { let slice = &remaining_arrayed_buffer_bindings[..array.len()]; @@ -1132,6 +1136,7 @@ impl crate::Context for ContextWgpuCore { label: desc.label.map(Borrowed), bind_group_layouts: Borrowed(&temp_layouts), push_constant_ranges: Borrowed(desc.push_constant_ranges), + ignore_push_constant_check: false, }; let (id, error) = wgc::gfx_select!(device => self.0.device_create_pipeline_layout(