Skip to content

Commit

Permalink
Merge pull request #132 from SludgePhD/update-wgpu-0-14
Browse files Browse the repository at this point in the history
Update to wgpu 0.14
  • Loading branch information
haixuanTao authored Nov 15, 2022
2 parents aa92e55 + fc3a606 commit e3106ed
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 83 deletions.
40 changes: 14 additions & 26 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions wonnx-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ protobuf = { version = "2.27.1", features = ["with-bytes"] }
structopt = { version = "0.3.26", features = [ "paw" ] }
thiserror = "1.0.31"
tract-onnx = { version = "0.16.7", optional = true }
wgpu = "0.13.1"
wgpu = "0.14.0"
wonnx = { version = "0.3.0" }
wonnx-preprocessing = { version = "0.3.0" }
human_bytes = "0.3.1"

[dev-dependencies]
assert_cmd = "2.0.4"
assert_cmd = "2.0.4"
4 changes: 2 additions & 2 deletions wonnx-preprocessing/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ protobuf = { version = "2.27.1", features = ["with-bytes"] }
thiserror = "1.0.31"
tokenizers = "0.11.3"
tract-onnx = { version = "^0.17.0", optional = true }
wgpu = "0.13.1"
wgpu = "0.14.0"
wonnx = { version = "0.3.0" }
serde_json = "^1.0"

[dev-dependencies]
env_logger = "0.9.0"
env_logger = "0.9.0"
2 changes: 1 addition & 1 deletion wonnx/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ exclude = [
]

[dependencies]
wgpu = "0.13.1"
wgpu = "0.14.0"
bytemuck = "1.9.1"
protobuf = { version = "2.27.1", features = ["with-bytes"] }
log = "0.4.17"
Expand Down
67 changes: 15 additions & 52 deletions wonnx/src/gpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
borrow::Cow,
collections::{HashMap, HashSet},
convert::TryInto,
sync::Arc,
sync::{mpsc, Arc},
};

use bytemuck::NoUninit;
Expand Down Expand Up @@ -407,16 +407,7 @@ fn buffer_with_bytes(
raw_data: &[u8],
) -> Result<Buffer, GpuError> {
let buffer_usage = match readable {
true => {
// On wgpu we can MAP_READ a buffer that is also used as STORAGE, but WebGPU (on at least Chrome)
// disallows this. Therefore we need to do an additional copy into a MAP_READ buffer when reading back a
// STORAGE buffer when on WebGPU.
if cfg!(target_arch = "wasm32") {
BufferUsages::STORAGE | BufferUsages::COPY_SRC
} else {
BufferUsages::STORAGE | BufferUsages::MAP_READ
}
}
true => BufferUsages::STORAGE | BufferUsages::COPY_SRC,
false => BufferUsages::STORAGE,
};

Expand Down Expand Up @@ -478,14 +469,7 @@ impl<'model> OperatorDefinition<'model> {
);

let buffer_usage = if outputs_readable {
// On wgpu we can MAP_READ a buffer that is also used as STORAGE, but WebGPU (on at least Chrome)
// disallows this. Therefore we need to do an additional copy into a MAP_READ buffer when reading back a
// STORAGE buffer when on WebGPU.
if cfg!(target_arch = "wasm32") {
BufferUsages::STORAGE | BufferUsages::COPY_SRC
} else {
BufferUsages::STORAGE | BufferUsages::MAP_READ
}
BufferUsages::STORAGE | BufferUsages::COPY_SRC
} else {
BufferUsages::STORAGE
};
Expand Down Expand Up @@ -664,40 +648,19 @@ impl GpuTensor {
let buffer_slice = self.buffer.slice(..);
let shape = self.shape.clone();

// On wgpu we can MAP_READ a buffer that is also used as STORAGE, but WebGPU (on at least Chrome)
// disallows this. Therefore we need to do an additional copy into a MAP_READ buffer when reading back a
// STORAGE buffer when on WebGPU.
#[cfg(target_arch = "wasm32")]
{
let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();

wgpu::util::DownloadBuffer::read_buffer(device, queue, &buffer_slice, move |buffer| {
// Called on download completed
tx.send(match buffer {
Ok(bytes) => Ok(Self::read_bytes_to_vec(&bytes, shape)),
Err(error) => Err(GpuError::BufferAsyncError(error)),
})
.unwrap();
});
device.poll(wgpu::Maintain::Wait);
// The callback will have been called by now due to poll(Wait)
rx.receive().await.unwrap()
}
let (tx, rx) = mpsc::sync_channel(1);

#[cfg(not(target_arch = "wasm32"))]
{
let output_data = {
let _ = queue; // Need this because otherwise compiler complains we are not using the queue parameter
buffer_slice.map_async(wgpu::MapMode::Read, |_| {});
device.poll(wgpu::Maintain::Wait);
buffer_slice.get_mapped_range()
};

let result = Self::read_bytes_to_vec(&output_data[..], shape);
drop(output_data);
self.buffer.unmap();
Ok(result)
}
wgpu::util::DownloadBuffer::read_buffer(device, queue, &buffer_slice, move |buffer| {
// Called on download completed
tx.send(match buffer {
Ok(bytes) => Ok(Self::read_bytes_to_vec(&bytes, shape)),
Err(error) => Err(GpuError::BufferAsyncError(error)),
})
.unwrap();
});
device.poll(wgpu::Maintain::Wait);
// The callback will have been called by now due to poll(Wait)
rx.recv().unwrap()
}

fn read_bytes_to_vec<A>(output_data: &[A], shape: Shape) -> OutputTensor
Expand Down

0 comments on commit e3106ed

Please sign in to comment.