Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement submission indexes #2700

Merged
merged 5 commits into from
Jun 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deno_webgpu/src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ pub async fn op_webgpu_buffer_get_map_async(
{
let state = state.borrow();
let instance = state.borrow::<super::Instance>();
gfx_select!(device => instance.device_poll(device, false)).unwrap();
gfx_select!(device => instance.device_poll(device, wgpu_types::Maintain::Wait)).unwrap();
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
Expand Down
4 changes: 2 additions & 2 deletions player/src/bin/play.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ fn main() {
}

gfx_select!(device => global.device_stop_capture(device));
gfx_select!(device => global.device_poll(device, true)).unwrap();
gfx_select!(device => global.device_poll(device, wgt::Maintain::Wait)).unwrap();
}
#[cfg(feature = "winit")]
{
Expand Down Expand Up @@ -181,7 +181,7 @@ fn main() {
},
Event::LoopDestroyed => {
log::info!("Closing");
gfx_select!(device => global.device_poll(device, true)).unwrap();
gfx_select!(device => global.device_poll(device, wgt::Maintain::Wait)).unwrap();
}
_ => {}
}
Expand Down
2 changes: 1 addition & 1 deletion player/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl Test<'_> {
}

println!("\t\t\tWaiting...");
wgc::gfx_select!(device => global.device_poll(device, true)).unwrap();
wgc::gfx_select!(device => global.device_poll(device, wgt::Maintain::Wait)).unwrap();

for expect in self.expectations {
println!("\t\t\tChecking {}", expect.name);
Expand Down
2 changes: 2 additions & 0 deletions wgpu-core/src/device/life.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ struct ActiveSubmission<A: hal::Api> {
pub enum WaitIdleError {
#[error(transparent)]
Device(#[from] DeviceError),
#[error("Tried to wait using a submission index from the wrong device. Submission index is from device {0:?}. Called poll on device {1:?}.")]
WrongSubmissionIndex(id::QueueId, id::DeviceId),
#[error("GPU got stuck :(")]
StuckGpu,
}
Expand Down
40 changes: 32 additions & 8 deletions wgpu-core/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,9 @@ impl<A: HalApi> Device<A> {

/// Check this device for completed commands.
///
/// The `maintain` argument tells how the maintence function should behave, either
/// blocking or just polling the current state of the gpu.
///
/// Return a pair `(closures, queue_empty)`, where:
///
/// - `closures` is a list of actions to take: mapping buffers, notifying the user
Expand All @@ -439,7 +442,7 @@ impl<A: HalApi> Device<A> {
fn maintain<'this, 'token: 'this, G: GlobalIdentityHandlerFactory>(
&'this self,
hub: &Hub<A, G>,
force_wait: bool,
maintain: wgt::Maintain<queue::WrappedSubmissionIndex>,
token: &mut Token<'token, Self>,
) -> Result<(UserClosures, bool), WaitIdleError> {
profiling::scope!("maintain", "Device");
Expand All @@ -463,14 +466,21 @@ impl<A: HalApi> Device<A> {
);
life_tracker.triage_mapped(hub, token);

let last_done_index = if force_wait {
let current_index = self.active_submission_index;
let last_done_index = if maintain.is_wait() {
let index_to_wait_for = match maintain {
wgt::Maintain::WaitForSubmissionIndex(submission_index) => {
// We don't need to check to see if the queue id matches
// as we already checked this from inside the poll call.
submission_index.index
}
_ => self.active_submission_index,
};
unsafe {
self.raw
.wait(&self.fence, current_index, CLEANUP_WAIT_MS)
.wait(&self.fence, index_to_wait_for, CLEANUP_WAIT_MS)
.map_err(DeviceError::from)?
};
current_index
index_to_wait_for
} else {
unsafe {
self.raw
Expand Down Expand Up @@ -4957,16 +4967,25 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
pub fn device_poll<A: HalApi>(
&self,
device_id: id::DeviceId,
force_wait: bool,
maintain: wgt::Maintain<queue::WrappedSubmissionIndex>,
) -> Result<bool, WaitIdleError> {
let (closures, queue_empty) = {
if let wgt::Maintain::WaitForSubmissionIndex(submission_index) = maintain {
if submission_index.queue_id != device_id {
return Err(WaitIdleError::WrongSubmissionIndex(
submission_index.queue_id,
device_id,
));
}
}

let hub = A::hub(self);
let mut token = Token::root();
let (device_guard, mut token) = hub.devices.read(&mut token);
device_guard
.get(device_id)
.map_err(|_| DeviceError::Invalid)?
.maintain(hub, force_wait, &mut token)?
.maintain(hub, maintain, &mut token)?
};

closures.fire();
Expand Down Expand Up @@ -4994,7 +5013,12 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let (device_guard, mut token) = hub.devices.read(&mut token);

for (id, device) in device_guard.iter(A::VARIANT) {
let (cbs, queue_empty) = device.maintain(hub, force_wait, &mut token)?;
let maintain = if force_wait {
wgt::Maintain::Wait
} else {
wgt::Maintain::Poll
};
let (cbs, queue_empty) = device.maintain(hub, maintain, &mut token)?;
all_queue_empty = all_queue_empty && queue_empty;

// If the device's own `RefCount` clone is the only one left, and
Expand Down
23 changes: 17 additions & 6 deletions wgpu-core/src/device/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
id,
init_tracker::{has_copy_partial_init_tracker_coverage, TextureInitRange},
resource::{BufferAccessError, BufferMapState, TextureInner},
track, FastHashSet,
track, FastHashSet, SubmissionIndex,
};

use hal::{CommandEncoder as _, Device as _, Queue as _};
Expand Down Expand Up @@ -79,6 +79,13 @@ impl SubmittedWorkDoneClosure {
}
}

#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct WrappedSubmissionIndex {
pub queue_id: id::QueueId,
pub index: SubmissionIndex,
}

struct StagingData<A: hal::Api> {
buffer: A::Buffer,
}
Expand Down Expand Up @@ -620,10 +627,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
&self,
queue_id: id::QueueId,
command_buffer_ids: &[id::CommandBufferId],
) -> Result<(), QueueSubmitError> {
) -> Result<WrappedSubmissionIndex, QueueSubmitError> {
profiling::scope!("submit", "Queue");

let callbacks = {
let (submit_index, callbacks) = {
let hub = A::hub(self);
let mut token = Token::root();

Expand Down Expand Up @@ -958,23 +965,27 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {

// This will schedule destruction of all resources that are no longer needed
// by the user but used in the command stream, among other things.
let (closures, _) = match device.maintain(hub, false, &mut token) {
let (closures, _) = match device.maintain(hub, wgt::Maintain::Wait, &mut token) {
Ok(closures) => closures,
Err(WaitIdleError::Device(err)) => return Err(QueueSubmitError::Queue(err)),
Err(WaitIdleError::StuckGpu) => return Err(QueueSubmitError::StuckGpu),
Err(WaitIdleError::WrongSubmissionIndex(..)) => unreachable!(),
};

device.pending_writes.temp_resources = pending_write_resources;
device.temp_suspected.clear();
device.lock_life(&mut token).post_submit();

closures
(submit_index, closures)
};

// the closures should execute with nothing locked!
callbacks.fire();

Ok(())
Ok(WrappedSubmissionIndex {
queue_id,
index: submit_index,
})
}

pub fn queue_get_timestamp_period<A: HalApi>(
Expand Down
1 change: 1 addition & 0 deletions wgpu-hal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ pub trait Device<A: Api>: Send + Sync {
unsafe fn create_fence(&self) -> Result<A::Fence, DeviceError>;
unsafe fn destroy_fence(&self, fence: A::Fence);
unsafe fn get_fence_value(&self, fence: &A::Fence) -> Result<FenceValue, DeviceError>;
/// Calling wait with a lower value than the current fence value will immediately return.
unsafe fn wait(
&self,
fence: &A::Fence,
Expand Down
37 changes: 37 additions & 0 deletions wgpu-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2287,6 +2287,43 @@ impl Default for ColorWrites {
}
}

/// Passed to `Device::poll` to control how and if it should block.
#[derive(Clone)]
pub enum Maintain<T> {
/// On native backends, block until the given submission has
/// completed execution, and any callbacks have been invoked.
///
/// On the web, this has no effect. Callbacks are invoked from the
/// window event loop.
WaitForSubmissionIndex(T),
/// Same as WaitForSubmissionIndex but waits for the most recent submission.
Wait,
/// Check the device for a single time without blocking.
Poll,
}

impl<T> Maintain<T> {
/// This maintain represents a wait of some kind.
pub fn is_wait(&self) -> bool {
match *self {
Self::WaitForSubmissionIndex(..) | Self::Wait => true,
Self::Poll => false,
}
}

/// Map on the wait index type.
pub fn map_index<U, F>(self, func: F) -> Maintain<U>
where
F: FnOnce(T) -> U,
{
match self {
Self::WaitForSubmissionIndex(i) => Maintain::WaitForSubmissionIndex(func(i)),
Self::Wait => Maintain::Wait,
Self::Poll => Maintain::Poll,
}
}
}

/// State of the stencil operation (fixed-pipeline stage).
///
/// For use in [`DepthStencilState`].
Expand Down
25 changes: 18 additions & 7 deletions wgpu/examples/capture/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::env;
use std::fs::File;
use std::io::Write;
use std::mem::size_of;
use wgpu::{Buffer, Device};
use wgpu::{Buffer, Device, SubmissionIndex};

async fn run(png_output_path: &str) {
let args: Vec<_> = env::args().collect();
Expand All @@ -20,14 +20,22 @@ async fn run(png_output_path: &str) {
return;
}
};
let (device, buffer, buffer_dimensions) = create_red_image_with_dimensions(width, height).await;
create_png(png_output_path, device, buffer, &buffer_dimensions).await;
let (device, buffer, buffer_dimensions, submission_index) =
create_red_image_with_dimensions(width, height).await;
create_png(
png_output_path,
device,
buffer,
&buffer_dimensions,
submission_index,
)
.await;
}

async fn create_red_image_with_dimensions(
width: usize,
height: usize,
) -> (Device, Buffer, BufferDimensions) {
) -> (Device, Buffer, BufferDimensions, SubmissionIndex) {
let adapter = wgpu::Instance::new(
wgpu::util::backend_bits_from_env().unwrap_or_else(wgpu::Backends::all),
)
Expand Down Expand Up @@ -114,15 +122,16 @@ async fn create_red_image_with_dimensions(
encoder.finish()
};

queue.submit(Some(command_buffer));
(device, output_buffer, buffer_dimensions)
let index = queue.submit(Some(command_buffer));
(device, output_buffer, buffer_dimensions, index)
}

async fn create_png(
png_output_path: &str,
device: Device,
output_buffer: Buffer,
buffer_dimensions: &BufferDimensions,
submission_index: SubmissionIndex,
) {
// Note that we're not calling `.await` here.
let buffer_slice = output_buffer.slice(..);
Expand All @@ -133,7 +142,9 @@ async fn create_png(
// Poll the device in a blocking manner so that our future resolves.
// In an actual application, `device.poll(...)` should
// be called in an event loop or on another thread.
device.poll(wgpu::Maintain::Wait);
//
// We pass our submission index so we don't need to wait for any other possible submissions.
device.poll(wgpu::Maintain::WaitForSubmissionIndex(submission_index));
// If a file system is available, write the buffer as a PNG
let has_file_system_available = cfg!(not(target_arch = "wasm32"));
if !has_file_system_available {
Expand Down
14 changes: 7 additions & 7 deletions wgpu/src/backend/direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,7 @@ impl crate::Context for Context {
type SurfaceId = Surface;

type SurfaceOutputDetail = SurfaceOutputDetail;
type SubmissionIndex = wgc::device::queue::WrappedSubmissionIndex;

type RequestAdapterFuture = Ready<Option<Self::AdapterId>>;
#[allow(clippy::type_complexity)]
Expand Down Expand Up @@ -1569,7 +1570,8 @@ impl crate::Context for Context {

#[cfg(any(not(target_arch = "wasm32"), feature = "emscripten"))]
{
match wgc::gfx_select!(device.id => global.device_poll(device.id, true)) {
match wgc::gfx_select!(device.id => global.device_poll(device.id, wgt::Maintain::Wait))
{
Ok(_) => (),
Err(err) => self.handle_error_fatal(err, "Device::drop"),
}
Expand All @@ -1580,12 +1582,10 @@ impl crate::Context for Context {

fn device_poll(&self, device: &Self::DeviceId, maintain: crate::Maintain) -> bool {
let global = &self.0;
let maintain_inner = maintain.map_index(|i| i.0);
match wgc::gfx_select!(device.id => global.device_poll(
device.id,
match maintain {
crate::Maintain::Poll => false,
crate::Maintain::Wait => true,
}
maintain_inner
)) {
Ok(queue_empty) => queue_empty,
Err(err) => self.handle_error_fatal(err, "Device::poll"),
Expand Down Expand Up @@ -2179,12 +2179,12 @@ impl crate::Context for Context {
&self,
queue: &Self::QueueId,
command_buffers: I,
) {
) -> Self::SubmissionIndex {
let temp_command_buffers = command_buffers.collect::<SmallVec<[_; 4]>>();

let global = &self.0;
match wgc::gfx_select!(*queue => global.queue_submit(*queue, &temp_command_buffers)) {
Ok(()) => (),
Ok(index) => index,
Err(err) => self.handle_error_fatal(err, "Queue::submit"),
}
}
Expand Down
5 changes: 4 additions & 1 deletion wgpu/src/backend/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,7 @@ impl crate::Context for Context {
type SurfaceId = Sendable<web_sys::GpuCanvasContext>;

type SurfaceOutputDetail = SurfaceOutputDetail;
type SubmissionIndex = ();

type RequestAdapterFuture = MakeSendFuture<
wasm_bindgen_futures::JsFuture,
Expand Down Expand Up @@ -2213,10 +2214,12 @@ impl crate::Context for Context {
&self,
queue: &Self::QueueId,
command_buffers: I,
) {
) -> Self::SubmissionIndex {
let temp_command_buffers = command_buffers.map(|i| i.0).collect::<js_sys::Array>();

queue.0.submit(&temp_command_buffers);

// SubmissionIndex is (), so just let this function end
}

fn queue_get_timestamp_period(&self, _queue: &Self::QueueId) -> f32 {
Expand Down
Loading