Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure safety of indirect dispatch #5714

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).
- Call `flush_mapped_ranges` when unmapping write-mapped buffers. By @teoxoy in [#6089](https://github.com/gfx-rs/wgpu/pull/6089).
- When mapping buffers for reading, mark buffers as initialized only when they have `MAP_WRITE` usage. By @teoxoy in [#6178](https://github.com/gfx-rs/wgpu/pull/6178).
- Add a separate pipeline constants error. By @teoxoy in [#6094](https://github.com/gfx-rs/wgpu/pull/6094).
- Ensure safety of indirect dispatch by injecting a compute shader that validates the content of the indirect buffer. By @teoxoy in [#5714](https://github.com/gfx-rs/wgpu/pull/5714)

#### GLES / OpenGL

Expand Down
241 changes: 241 additions & 0 deletions tests/tests/dispatch_workgroups_indirect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext};

/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12).
#[gpu_test]
static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(wgpu::Features::PUSH_CONSTANTS)
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
})
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
)
.run_async(|ctx| async move {
let num_workgroups = [1, 2, 3];
let res = run_test(&ctx, &num_workgroups, false).await;
assert_eq!(res, num_workgroups);
});

/// Make sure that we discard (don't run) the dispatch if its size exceeds the device limit.
#[gpu_test]
static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(wgpu::Features::PUSH_CONSTANTS)
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_compute_workgroups_per_dimension: 10,
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
})
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
)
.run_async(|ctx| async move {
let max = ctx.device.limits().max_compute_workgroups_per_dimension;

let res = run_test(&ctx, &[max, max, max], false).await;
assert_eq!(res, [max; 3]);

let res = run_test(&ctx, &[max + 1, 1, 1], false).await;
assert_eq!(res, [0; 3]);
jimblandy marked this conversation as resolved.
Show resolved Hide resolved

let res = run_test(&ctx, &[1, max + 1, 1], false).await;
assert_eq!(res, [0; 3]);

let res = run_test(&ctx, &[1, 1, max + 1], false).await;
assert_eq!(res, [0; 3]);
});

/// Make sure that resetting the bind groups set by the validation code works properly.
#[gpu_test]
static RESET_BIND_GROUPS: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(wgpu::Features::PUSH_CONSTANTS)
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
}),
)
.run_async(|ctx| async move {
ctx.device.push_error_scope(wgpu::ErrorFilter::Validation);

let _ = run_test(&ctx, &[0, 0, 0], true).await;

let error = pollster::block_on(ctx.device.pop_error_scope());
ErichDonGubler marked this conversation as resolved.
Show resolved Hide resolved
assert!(error.map_or(false, |error| {
format!("{error}").contains("The current set ComputePipeline with '' label expects a BindGroup to be set at index 0")
}));
});

async fn run_test(
ctx: &TestingContext,
num_workgroups: &[u32; 3],
forget_to_set_bind_group: bool,
) -> [u32; 3] {
const SHADER_SRC: &str = "
struct TestOffsetPc {
inner: u32,
}
// `test_offset.inner` should always be 0; we test that resetting the push constant set by the validation code works properly.
var<push_constant> test_offset: TestOffsetPc;
@group(0) @binding(0)
var<storage, read_write> out: array<u32, 3>;
@compute @workgroup_size(1)
fn main(@builtin(num_workgroups) num_workgroups: vec3u, @builtin(workgroup_id) workgroup_id: vec3u) {
if (all(workgroup_id == vec3u())) {
ErichDonGubler marked this conversation as resolved.
Show resolved Hide resolved
out[0] = num_workgroups.x + test_offset.inner;
out[1] = num_workgroups.y + test_offset.inner;
out[2] = num_workgroups.z + test_offset.inner;
}
}
";

let module = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
});

let bgl = ctx
.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgt::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}],
});

let layout = ctx
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bgl],
push_constant_ranges: &[wgt::PushConstantRange {
stages: wgt::ShaderStages::COMPUTE,
range: 0..4,
}],
});

let pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&layout),
module: &module,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});

let out_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});

let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &pipeline.get_bind_group_layout(0),
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: out_buffer.as_entire_binding(),
}],
});

let mut res = None;

for (indirect_offset, indirect_buffer_size) in [
// internal src buffer binding size will be buffer.size
(0, 12),
(4, 4 + 12),
(4, 8 + 12),
(256 * 2 - 4 - 12, 256 * 2 - 4),
// internal src buffer binding size will be 256 * 2 + x
(0, 256 * 2 * 2 + 4),
(256, 256 * 2 * 2 + 8),
(256 + 4, 256 * 2 * 2 + 12),
(256 * 2 + 16, 256 * 2 * 2 + 16),
(256 * 2 * 2, 256 * 2 * 2 + 32),
(256 + 12, 256 * 2 * 2 + 64),
] {
let indirect_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: indirect_buffer_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::INDIRECT,
mapped_at_creation: false,
});

ctx.queue.write_buffer(
&indirect_buffer,
indirect_offset,
bytemuck::bytes_of(num_workgroups),
);

let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut compute_pass = encoder.begin_compute_pass(&Default::default());
compute_pass.set_pipeline(&pipeline);
compute_pass.set_push_constants(0, &[0, 0, 0, 0]);
if !forget_to_set_bind_group {
compute_pass.set_bind_group(0, Some(&bind_group), &[]);
}
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
}

encoder.copy_buffer_to_buffer(&out_buffer, 0, &readback_buffer, 0, 12);

ctx.queue.submit(Some(encoder.finish()));

readback_buffer
.slice(..)
.map_async(wgpu::MapMode::Read, |_| {});

ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();

let view = readback_buffer.slice(..).get_mapped_range();

let current_res = *bytemuck::from_bytes(&view);
drop(view);
readback_buffer.unmap();

if let Some(past_res) = res {
assert_eq!(past_res, current_res);
} else {
res = Some(current_res);
}
}

res.unwrap()
}
1 change: 1 addition & 0 deletions tests/tests/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod clear_texture;
mod compute_pass_ownership;
mod create_surface_error;
mod device;
mod dispatch_workgroups_indirect;
mod encoder;
mod external_texture;
mod float32_filterable;
Expand Down
4 changes: 4 additions & 0 deletions wgpu-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ renderdoc = ["hal/renderdoc"]
## to the validation carried out at public APIs in all builds.
strict_asserts = ["wgt/strict_asserts"]

## Validates indirect draw/dispatch calls. This will also enable naga's
## WGSL frontend since we use a WGSL compute shader to do the validation.
indirect-validation = ["naga/wgsl-in"]

## Enables serialization via `serde` on common wgpu types.
serde = ["dep:serde", "wgt/serde", "arrayvec/serde"]

Expand Down
20 changes: 16 additions & 4 deletions wgpu-core/src/command/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,17 @@ mod compat {
entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
}
}
fn make_range(&self, start_index: usize) -> Range<usize> {

pub fn num_valid_entries(&self) -> usize {
// find first incompatible entry
let end = self
.entries
self.entries
.iter()
.position(|e| e.is_incompatible())
.unwrap_or(self.entries.len());
.unwrap_or(self.entries.len())
}

fn make_range(&self, start_index: usize) -> Range<usize> {
let end = self.num_valid_entries();
start_index..end.max(start_index)
}

Expand Down Expand Up @@ -406,6 +410,14 @@ impl Binder {
.map(move |index| payloads[index].group.as_ref().unwrap())
}

#[cfg(feature = "indirect-validation")]
pub(super) fn list_valid<'a>(&'a self) -> impl Iterator<Item = (usize, &'a EntryPayload)> + '_ {
self.payloads
.iter()
.take(self.manager.num_valid_entries())
.enumerate()
}

pub(super) fn check_compatibility<T: Labeled>(
&self,
pipeline: &T,
Expand Down
Loading