Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update build configurations to webgpu EP #22047

Open
wants to merge 35 commits into
base: fs-eire/webgpu-ep
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
75322e9
Use Dawn libs directly to minimize binary size.
skottmckay Sep 10, 2024
e8ed35f
Fix Windows build
skottmckay Sep 10, 2024
7ec73ba
Merge remote-tracking branch 'origin/fs-eire/webgpu-ep' into skottmck…
fs-eire Sep 11, 2024
f4cbc76
Update patch with iOS build fixes.
skottmckay Sep 11, 2024
7d9ad9a
Merge branch 'skottmckay/MiscWebGPUUpdates' of https://github.com/mic…
skottmckay Sep 11, 2024
e1e75b8
WGSL writer is only needed when Vulkan is being used
skottmckay Sep 11, 2024
9333307
Merge remote-tracking branch 'origin/fs-eire/webgpu-ep' into skottmck…
skottmckay Sep 13, 2024
afd202a
Fix build errors
skottmckay Sep 13, 2024
bd25d1c
Fix transpose.cc build error.
skottmckay Sep 13, 2024
3f9be82
Go back to WGSL writer being required on all builds
skottmckay Sep 13, 2024
788e129
Refine external libraries to add dependencies
skottmckay Sep 15, 2024
edf5056
Try again
skottmckay Sep 16, 2024
0ad21bf
De-alias existing deps
skottmckay Sep 16, 2024
ee1d958
Fix C++20 errors
skottmckay Sep 16, 2024
3cb0c6a
Fix c++20 errors
skottmckay Sep 16, 2024
7b85dda
Update some apple infra
skottmckay Sep 16, 2024
1da2bce
Enable webgpu in some configs to test via CI
skottmckay Sep 16, 2024
6d14e28
Merge branch 'skottmckay/MiscWebGPUUpdates' of https://github.com/mic…
skottmckay Sep 16, 2024
259aa5d
Update one more CI
skottmckay Sep 16, 2024
64ccd2d
Fix some build and test issues
skottmckay Sep 17, 2024
ce23c21
Expand check on whether std::chrono::operator<< can be used to cover …
skottmckay Sep 18, 2024
51db660
Fix condition. Leave in some pragmas for debugging build failures sho…
skottmckay Sep 18, 2024
b712ebc
Add dummy header
skottmckay Sep 18, 2024
45f3bbf
Update apple uitest apps to run webgpu tests
skottmckay Sep 18, 2024
9b888af
Disable WebGPU in mac-catalyst build. APIs used by Dawn are not avail…
skottmckay Sep 18, 2024
210a760
Fix AppendExecutionProvider call
skottmckay Sep 18, 2024
d366d44
Merge
skottmckay Sep 23, 2024
909ac3d
Merge remote-tracking branch 'origin/fs-eire/webgpu-ep' into skottmck…
skottmckay Sep 23, 2024
edb9980
reduce some diffs
skottmckay Sep 23, 2024
9ea28ac
Merge remote-tracking branch 'origin/fs-eire/webgpu-ep' into skottmck…
fs-eire Sep 24, 2024
698e6ae
Enable in Android build for automated testing.
skottmckay Sep 24, 2024
7d4946d
Merge branch 'skottmckay/MiscWebGPUUpdates' of https://github.com/mic…
skottmckay Sep 24, 2024
b9e98a7
Fix typo in #define
skottmckay Sep 24, 2024
4b55f23
Fix some macos warnings.
skottmckay Sep 24, 2024
af7c39e
Fix ATanH on Metal
skottmckay Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -590,12 +590,47 @@ if (onnxruntime_USE_WEBGPU)
dawn
URL ${DEP_URL_dawn}
URL_HASH SHA1=${DEP_SHA1_dawn}
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch
)
set(DAWN_FETCH_DEPENDENCIES ON)
set(DAWN_ENABLE_INSTALL ON)
set(TINT_BUILD_TESTS OFF)
set(DAWN_USE_BUILT_DXC ON)

# use dawn::native_objects and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size
set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE)
set(DAWN_BUILD_SAMPLES OFF CACHE BOOL "" FORCE)
set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE)
set(DAWN_ENABLE_NULL OFF CACHE BOOL "" FORCE)
set(DAWN_FETCH_DEPENDENCIES ON CACHE BOOL "" FORCE)

# disable things we don't use
set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF)
set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE)
set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE)
set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE)
set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE)
set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE)
set(DAWN_USE_X11 OFF CACHE BOOL "" FORCE)

set(TINT_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_CMD_TOOLS OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_IR_BINARY OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_SPV_READER OFF CACHE BOOL "" FORCE) # don't need. disabling is a large binary size saving
set(TINT_BUILD_WGSL_WRITER ON CACHE BOOL "" FORCE) # needed to create cache key

# SPIR-V validation shouldn't be required given we're using Tint to create the SPIR-V.
if (NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
set(DAWN_ENABLE_SPIRV_VALIDATION OFF CACHE BOOL "" FORCE)
endif()

if (WIN32)
# building this requires the HLSL writer to be enabled in Tint. TBD if that we need either of these to be ON.
set(DAWN_USE_BUILT_DXC ON CACHE BOOL "" FORCE)
set(TINT_BUILD_HLSL_WRITER ON CACHE BOOL "" FORCE)

# Vulkan may optionally be included in a Windows build. Exclude until we have an explicit use case that requires it.
set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE)
endif()

onnxruntime_fetchcontent_makeavailable(dawn)
endif()

Expand Down
12 changes: 3 additions & 9 deletions cmake/onnxruntime_providers_webgpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,8 @@

source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_webgpu_cc_srcs})
onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs})
onnxruntime_add_include_to_target(onnxruntime_providers_webgpu onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface)
target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn)

# Copy webgpu_dawn.dll to the output directory
add_custom_command(
TARGET onnxruntime_providers_webgpu
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different "$<TARGET_FILE:dawn::webgpu_dawn>" "$<TARGET_FILE_DIR:onnxruntime_providers_webgpu>"
VERBATIM )
onnxruntime_add_include_to_target(onnxruntime_providers_webgpu
onnxruntime_common dawn::dawncpp_headers dawn::dawn_headers onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface)
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native dawn::dawn_proc)

set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime")
12 changes: 12 additions & 0 deletions cmake/patches/dawn/dawn.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
diff --git a/src/tint/api/BUILD.cmake b/src/tint/api/BUILD.cmake
index 0037d83276..6372c4ee77 100644
--- a/src/tint/api/BUILD.cmake
+++ b/src/tint/api/BUILD.cmake
@@ -57,6 +57,7 @@ tint_target_add_dependencies(tint_api lib
tint_lang_wgsl_ast_transform
tint_lang_wgsl_common
tint_lang_wgsl_features
+ tint_lang_wgsl_inspector
tint_lang_wgsl_program
tint_lang_wgsl_sem
tint_lang_wgsl_writer_ir_to_program
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/webgpu/buffer_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,10 @@ std::ostream& operator<<(std::ostream& os, BufferCacheMode mode) {

BufferManager::BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode)
: context_{context},
storage_cache_{std::move(CreateBufferCacheManager(storage_buffer_cache_mode))},
uniform_cache_{std::move(CreateBufferCacheManager(uniform_buffer_cache_mode))},
query_resolve_cache_{std::move(CreateBufferCacheManager(query_resolve_buffer_cache_mode))},
default_cache_{std::move(CreateBufferCacheManager(BufferCacheMode::Disabled))} {
storage_cache_{CreateBufferCacheManager(storage_buffer_cache_mode)},
uniform_cache_{CreateBufferCacheManager(uniform_buffer_cache_mode)},
query_resolve_cache_{CreateBufferCacheManager(query_resolve_buffer_cache_mode)},
default_cache_{CreateBufferCacheManager(BufferCacheMode::Disabled)} {
}

void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) {
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/webgpu/program_cache_key.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@

namespace {
// append the info of an input or output to the cachekey
void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, bool& first) {
void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency,

Check warning on line 13 in onnxruntime/core/providers/webgpu/program_cache_key.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/webgpu/program_cache_key.cc:13: Lines should be <= 120 characters long [whitespace/line_length] [2]
bool& first) {
if (first) {
first = false;
} else {
ss << '|';
}

if ((dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) {
#ifndef NDEBUG // if debug build
ss << var_type;
Expand All @@ -24,6 +26,7 @@
#endif
ss << ';';
}

if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) {
ss D("Dims=") << tensor.Shape().ToString();
} else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/program_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ProgramArtifact {
const std::vector<int> shape_uniform_ranks;

ProgramArtifact(ProgramArtifact&&) = default;
ProgramArtifact& operator=(ProgramArtifact&&) = default;
ProgramArtifact& operator=(ProgramArtifact&&) = delete; // can't change const members.
fs-eire marked this conversation as resolved.
Show resolved Hide resolved

private:
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProgramArtifact);
Expand Down
46 changes: 23 additions & 23 deletions onnxruntime/core/providers/webgpu/shader_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,29 @@ Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, Sh
}
} // namespace

Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const {
ORT_RETURN_IF_ERROR(ValidateVariableDataType(input.tensor->GetElementType(), var.type_));
ORT_RETURN_IF_ERROR(ValidateVariableShape(input.tensor->Shape(),
input.use_override_shape,
input.use_override_shape ? input.override_shape : input.tensor->Shape(),
var.num_components_));
ORT_RETURN_IF_ERROR(ValidateVariableDependency(input.dependency, var.usage_, true));

return Status::OK();
}
Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const {
ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_));
ORT_RETURN_IF_ERROR(ValidateVariableShape(output.tensor->Shape(),
output.use_override_shape,
output.use_override_shape ? output.override_shape : output.tensor->Shape(),
var.num_components_));
ORT_RETURN_IF_ERROR(ValidateVariableDependency(output.dependency, var.usage_, false));

return Status::OK();
}

#endif // NDEBUG

const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope,
const std::string& name,
ShaderVariable::Usage usage,
Expand Down Expand Up @@ -224,27 +247,6 @@ const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope,
return *var;
}

Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const {
ORT_RETURN_IF_ERROR(ValidateVariableDataType(input.tensor->GetElementType(), var.type_));
ORT_RETURN_IF_ERROR(ValidateVariableShape(input.tensor->Shape(),
input.use_override_shape,
input.use_override_shape ? input.override_shape : input.tensor->Shape(),
var.num_components_));
ORT_RETURN_IF_ERROR(ValidateVariableDependency(input.dependency, var.usage_, true));

return Status::OK();
}
Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const {
ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_));
ORT_RETURN_IF_ERROR(ValidateVariableShape(output.tensor->Shape(),
output.use_override_shape,
output.use_override_shape ? output.override_shape : output.tensor->Shape(),
var.num_components_));
ORT_RETURN_IF_ERROR(ValidateVariableDependency(output.dependency, var.usage_, false));

return Status::OK();
}

Status ShaderHelper::ValidateShapeForInputsAndOutputs() const {
const auto& input_vars = vars_[static_cast<int>(ProgramVariableScope::Input)];
const auto& output_vars = vars_[static_cast<int>(ProgramVariableScope::Output)];
Expand Down Expand Up @@ -304,8 +306,6 @@ Status ShaderHelper::ValidateShapeForInputsAndOutputs() const {
return Status::OK();
}

#endif

Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector<int>& shape_uniform_ranks) const {
std::ostringstream ss;
ss.imbue(std::locale::classic());
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include <memory>
#include <cmath>

#include "dawn/dawn_proc.h"
#include "dawn/native/DawnNative.h"

#include "core/common/common.h"

#include "core/providers/webgpu/compute_context.h"
Expand Down Expand Up @@ -89,6 +92,8 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info
std::call_once(init_flag_, [this, &webgpu_ep_info]() {
snnn marked this conversation as resolved.
Show resolved Hide resolved
// Initialization.Step.1 - Create wgpu::Instance
if (instance_ == nullptr) {
dawnProcSetProcs(&dawn::native::GetProcs());

wgpu::InstanceDescriptor instance_desc{};
instance_desc.features.timedWaitAnyEnable = true;
instance_ = wgpu::CreateInstance(&instance_desc);
Expand Down
Loading