diff --git a/reference/opt/shaders-msl/comp/shared-zero-init-simple.comp b/reference/opt/shaders-msl/comp/shared-zero-init-simple.comp new file mode 100644 index 000000000..b193a2036 --- /dev/null +++ b/reference/opt/shaders-msl/comp/shared-zero-init-simple.comp @@ -0,0 +1,31 @@ +#pragma clang diagnostic ignored "-Wsometimes-uninitialized" +#include +#include + +using namespace metal; + +struct SSBO +{ + float in_data[1]; +}; + +struct SSBO2 +{ + float out_data[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(4u, 1u, 1u); + +kernel void main0(const device SSBO& _22 [[buffer(0)]], device SSBO2& _32 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]]) +{ + threadgroup float sShared; + { + if (gl_LocalInvocationIndex == 0) + { + sShared = 0.0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + _32.out_data[gl_GlobalInvocationID.x] = sShared + _22.in_data[gl_GlobalInvocationID.x]; +} + diff --git a/reference/opt/shaders-msl/comp/shared-zero-init.comp b/reference/opt/shaders-msl/comp/shared-zero-init.comp new file mode 100644 index 000000000..41b41d749 --- /dev/null +++ b/reference/opt/shaders-msl/comp/shared-zero-init.comp @@ -0,0 +1,90 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wmissing-braces" + +#include +#include + +using namespace metal; + +template +struct spvUnsafeArray +{ + T elements[Num ? Num : 1]; + + thread T& operator [] (size_t pos) thread + { + return elements[pos]; + } + constexpr const thread T& operator [] (size_t pos) const thread + { + return elements[pos]; + } + + device T& operator [] (size_t pos) device + { + return elements[pos]; + } + constexpr const device T& operator [] (size_t pos) const device + { + return elements[pos]; + } + + constexpr const constant T& operator [] (size_t pos) const constant + { + return elements[pos]; + } + + threadgroup T& operator [] (size_t pos) threadgroup + { + return elements[pos]; + } + constexpr const threadgroup T& operator [] (size_t pos) const threadgroup + { + return elements[pos]; + } +}; + +struct SSBO +{ + float in_data[1]; +}; + +struct SSBO2 +{ + float out_data[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(4u, 1u, 1u); + +constant spvUnsafeArray _31 = spvUnsafeArray({ 0.0, 0.0, 0.0, 0.0 }); + +kernel void main0(const device SSBO& _22 [[buffer(0)]], device SSBO2& _48 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]]) +{ + threadgroup spvUnsafeArray sShared; + { + threadgroup uint *sShared_ptr = (threadgroup uint *)&sShared; + uint sShared_sz = sizeof(sShared); + uint sShared_pos = gl_LocalInvocationIndex; + uint sShared_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z; + while (sizeof(uint) * sShared_pos < sShared_sz) + { + sShared_ptr[sShared_pos] = 0u; + sShared_pos += sShared_stride; + } + if (gl_LocalInvocationIndex == 0) + { + sShared_pos = (sShared_sz / sizeof(uint)) * sizeof(uint); + threadgroup uchar *sShared_ptr2 = (threadgroup uchar *)&sShared; + while (sShared_pos < sShared_sz) + { + sShared_ptr2[sShared_pos] = '\0'; + sShared_pos++; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + sShared[gl_LocalInvocationIndex] += _22.in_data[gl_GlobalInvocationID.x]; + threadgroup_barrier(mem_flags::mem_threadgroup); + _48.out_data[gl_GlobalInvocationID.x] = sShared[3u - gl_LocalInvocationIndex]; +} + diff --git a/reference/opt/shaders/comp/spec-const-arraydim-init.comp b/reference/opt/shaders/comp/spec-const-arraydim-init.comp new file mode 100644 index 000000000..3a74bf9ee --- /dev/null +++ b/reference/opt/shaders/comp/spec-const-arraydim-init.comp @@ -0,0 +1,27 @@ +#version 450 +#extension GL_EXT_null_initializer : require +layout(local_size_x = 2, local_size_y = 1, local_size_z = 1) in; + +struct Data +{ + float a; + float b; +}; + +#ifndef SPIRV_CROSS_CONSTANT_ID_0 +#define SPIRV_CROSS_CONSTANT_ID_0 2 +#endif +const int arraySize = SPIRV_CROSS_CONSTANT_ID_0; +const Data _25[arraySize] = { }; + +layout(binding = 0, std430) buffer SSBO +{ + Data outdata[]; +} _11; + +void main() +{ + _11.outdata[gl_WorkGroupID.x].a = _25[gl_WorkGroupID.x].a; + _11.outdata[gl_WorkGroupID.x].b = _25[gl_WorkGroupID.x].b; +} + diff --git a/reference/shaders-msl/comp/shared-zero-init-simple.comp b/reference/shaders-msl/comp/shared-zero-init-simple.comp new file mode 100644 index 000000000..f8e21f350 --- /dev/null +++ b/reference/shaders-msl/comp/shared-zero-init-simple.comp @@ -0,0 +1,33 @@ +#pragma clang diagnostic ignored "-Wsometimes-uninitialized" +#include +#include + +using namespace metal; + +struct SSBO +{ + float in_data[1]; +}; + +struct SSBO2 +{ + float out_data[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(4u, 1u, 1u); + +kernel void main0(const device SSBO& _22 [[buffer(0)]], device SSBO2& _32 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]]) +{ + threadgroup float sShared; + { + if (gl_LocalInvocationIndex == 0) + { + sShared = 0.0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + uint ident = gl_GlobalInvocationID.x; + float idata = _22.in_data[ident]; + _32.out_data[ident] = sShared + idata; +} + diff --git a/reference/shaders-msl/comp/shared-zero-init.comp b/reference/shaders-msl/comp/shared-zero-init.comp new file mode 100644 index 000000000..2d0e55b96 --- /dev/null +++ b/reference/shaders-msl/comp/shared-zero-init.comp @@ -0,0 +1,92 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wmissing-braces" + +#include +#include + +using namespace metal; + +template +struct spvUnsafeArray +{ + T elements[Num ? Num : 1]; + + thread T& operator [] (size_t pos) thread + { + return elements[pos]; + } + constexpr const thread T& operator [] (size_t pos) const thread + { + return elements[pos]; + } + + device T& operator [] (size_t pos) device + { + return elements[pos]; + } + constexpr const device T& operator [] (size_t pos) const device + { + return elements[pos]; + } + + constexpr const constant T& operator [] (size_t pos) const constant + { + return elements[pos]; + } + + threadgroup T& operator [] (size_t pos) threadgroup + { + return elements[pos]; + } + constexpr const threadgroup T& operator [] (size_t pos) const threadgroup + { + return elements[pos]; + } +}; + +struct SSBO +{ + float in_data[1]; +}; + +struct SSBO2 +{ + float out_data[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(4u, 1u, 1u); + +constant spvUnsafeArray _31 = spvUnsafeArray({ 0.0, 0.0, 0.0, 0.0 }); + +kernel void main0(const device SSBO& _22 [[buffer(0)]], device SSBO2& _48 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]]) +{ + threadgroup spvUnsafeArray sShared; + { + threadgroup uint *sShared_ptr = (threadgroup uint *)&sShared; + uint sShared_sz = sizeof(sShared); + uint sShared_pos = gl_LocalInvocationIndex; + uint sShared_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z; + while (sizeof(uint) * sShared_pos < sShared_sz) + { + sShared_ptr[sShared_pos] = 0u; + sShared_pos += sShared_stride; + } + if (gl_LocalInvocationIndex == 0) + { + sShared_pos = (sShared_sz / sizeof(uint)) * sizeof(uint); + threadgroup uchar *sShared_ptr2 = (threadgroup uchar *)&sShared; + while (sShared_pos < sShared_sz) + { + sShared_ptr2[sShared_pos] = '\0'; + sShared_pos++; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + uint ident = gl_GlobalInvocationID.x; + float idata = _22.in_data[ident]; + sShared[gl_LocalInvocationIndex] += idata; + threadgroup_barrier(mem_flags::mem_threadgroup); + _48.out_data[ident] = sShared[(4u - gl_LocalInvocationIndex) - 1u]; +} + diff --git a/reference/shaders/comp/spec-const-arraydim-init.comp b/reference/shaders/comp/spec-const-arraydim-init.comp new file mode 100644 index 000000000..3a74bf9ee --- /dev/null +++ b/reference/shaders/comp/spec-const-arraydim-init.comp @@ -0,0 +1,27 @@ +#version 450 +#extension GL_EXT_null_initializer : require +layout(local_size_x = 2, local_size_y = 1, local_size_z = 1) in; + +struct Data +{ + float a; + float b; +}; + +#ifndef SPIRV_CROSS_CONSTANT_ID_0 +#define SPIRV_CROSS_CONSTANT_ID_0 2 +#endif +const int arraySize = SPIRV_CROSS_CONSTANT_ID_0; +const Data _25[arraySize] = { }; + +layout(binding = 0, std430) buffer SSBO +{ + Data outdata[]; +} _11; + +void main() +{ + _11.outdata[gl_WorkGroupID.x].a = _25[gl_WorkGroupID.x].a; + _11.outdata[gl_WorkGroupID.x].b = _25[gl_WorkGroupID.x].b; +} + diff --git a/shaders-msl/comp/shared-zero-init-simple.comp b/shaders-msl/comp/shared-zero-init-simple.comp new file mode 100644 index 000000000..fe9bac5ad --- /dev/null +++ b/shaders-msl/comp/shared-zero-init-simple.comp @@ -0,0 +1,24 @@ +#version 450 +#extension GL_EXT_null_initializer : enable +layout(local_size_x = 4) in; + +shared float sShared = {}; + +layout(std430, binding = 0) readonly buffer SSBO +{ + float in_data[]; +}; + +layout(std430, binding = 1) writeonly buffer SSBO2 +{ + float out_data[]; +}; + +void main() +{ + uint ident = gl_GlobalInvocationID.x; + float idata = in_data[ident]; + + out_data[ident] = sShared + idata; +} + diff --git a/shaders-msl/comp/shared-zero-init.comp b/shaders-msl/comp/shared-zero-init.comp new file mode 100644 index 000000000..f30522c77 --- /dev/null +++ b/shaders-msl/comp/shared-zero-init.comp @@ -0,0 +1,28 @@ +#version 450 +#extension GL_EXT_null_initializer : enable +layout(local_size_x = 4) in; + +shared float sShared[gl_WorkGroupSize.x] = {}; + +layout(std430, binding = 0) readonly buffer SSBO +{ + float in_data[]; +}; + +layout(std430, binding = 1) writeonly buffer SSBO2 +{ + float out_data[]; +}; + +void main() +{ + uint ident = gl_GlobalInvocationID.x; + float idata = in_data[ident]; + + sShared[gl_LocalInvocationIndex] += idata; + memoryBarrierShared(); + barrier(); + + out_data[ident] = sShared[gl_WorkGroupSize.x - gl_LocalInvocationIndex - 1u]; +} + diff --git a/shaders/comp/spec-const-arraydim-init.comp b/shaders/comp/spec-const-arraydim-init.comp new file mode 100644 index 000000000..0999b12e0 --- /dev/null +++ b/shaders/comp/spec-const-arraydim-init.comp @@ -0,0 +1,22 @@ +#version 450 +#extension GL_EXT_null_initializer : require + +layout(constant_id = 0) const int arraySize = 2; +layout(local_size_x = 2) in; + +struct Data +{ + float a; + float b; +}; + +layout(std430, binding = 0) buffer SSBO +{ + Data outdata[]; +}; + +void main() +{ + Data d[arraySize] = {}; + outdata[gl_WorkGroupID.x] = d[gl_WorkGroupID.x]; +} diff --git a/spirv_common.hpp b/spirv_common.hpp index b70536d9e..a4783269a 100644 --- a/spirv_common.hpp +++ b/spirv_common.hpp @@ -1410,6 +1410,10 @@ struct SPIRConstant : IVariant // If true, this is a LUT, and should always be declared in the outer scope. bool is_used_as_lut = false; + // If this is a null constant of array type with specialized length. + // May require special handling in initializer + bool is_null_array_specialized_length = false; + // For composites which are constant arrays, etc. SmallVector subconstants; diff --git a/spirv_cross.cpp b/spirv_cross.cpp index 952006991..0f0b12db5 100644 --- a/spirv_cross.cpp +++ b/spirv_cross.cpp @@ -4911,13 +4911,16 @@ void Compiler::make_constant_null(uint32_t id, uint32_t type) uint32_t parent_id = ir.increase_bound_by(1); make_constant_null(parent_id, constant_type.parent_type); - if (!constant_type.array_size_literal.back()) - SPIRV_CROSS_THROW("Array size of OpConstantNull must be a literal."); - - SmallVector elements(constant_type.array.back()); - for (uint32_t i = 0; i < constant_type.array.back(); i++) + // The array size of OpConstantNull can be either literal or specialization constant. + // In the latter case, we cannot take the value as-is, as it can be changed to anything. + // Rather, we assume it to be *one* for the sake of initializer. + bool is_literal_array_size = constant_type.array_size_literal.back(); + uint32_t count = is_literal_array_size ? constant_type.array.back() : 1; + SmallVector elements(count); + for (uint32_t i = 0; i < count; i++) elements[i] = parent_id; - set(id, type, elements.data(), uint32_t(elements.size()), false); + auto &constant = set(id, type, elements.data(), uint32_t(elements.size()), false); + constant.is_null_array_specialized_length = !is_literal_array_size; } else if (!constant_type.member_types.empty()) { diff --git a/spirv_cross_parsed_ir.cpp b/spirv_cross_parsed_ir.cpp index b05afeb3f..397e40f4d 100644 --- a/spirv_cross_parsed_ir.cpp +++ b/spirv_cross_parsed_ir.cpp @@ -1050,16 +1050,21 @@ void ParsedIR::make_constant_null(uint32_t id, uint32_t type, bool add_to_typed_ uint32_t parent_id = increase_bound_by(1); make_constant_null(parent_id, constant_type.parent_type, add_to_typed_id_set); - if (!constant_type.array_size_literal.back()) - SPIRV_CROSS_THROW("Array size of OpConstantNull must be a literal."); - - SmallVector elements(constant_type.array.back()); - for (uint32_t i = 0; i < constant_type.array.back(); i++) + // The array size of OpConstantNull can be either literal or specialization constant. + // In the latter case, we cannot take the value as-is, as it can be changed to anything. + // Rather, we assume it to be *one* for the sake of initializer. + bool is_literal_array_size = constant_type.array_size_literal.back(); + uint32_t count = is_literal_array_size ? constant_type.array.back() : 1; + + SmallVector elements(count); + for (uint32_t i = 0; i < count; i++) elements[i] = parent_id; if (add_to_typed_id_set) add_typed_id(TypeConstant, id); - variant_set(ids[id], type, elements.data(), uint32_t(elements.size()), false).self = id; + auto& constant = variant_set(ids[id], type, elements.data(), uint32_t(elements.size()), false); + constant.self = id; + constant.is_null_array_specialized_length = !is_literal_array_size; } else if (!constant_type.member_types.empty()) { diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index 18441d992..d358e1020 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -681,6 +681,8 @@ string CompilerGLSL::compile() backend.requires_relaxed_precision_analysis = options.es || options.vulkan_semantics; backend.support_precise_qualifier = (!options.es && options.version >= 400) || (options.es && options.version >= 320); + backend.constant_null_initializer = "{ }"; + backend.requires_matching_array_initializer = true; if (is_legacy_es()) backend.support_case_fallthrough = false; @@ -5902,6 +5904,11 @@ string CompilerGLSL::constant_expression(const SPIRConstant &c, { return backend.null_pointer_literal; } + else if (c.is_null_array_specialized_length && backend.requires_matching_array_initializer) + { + require_extension_internal("GL_EXT_null_initializer"); + return backend.constant_null_initializer; + } else if (!c.subconstants.empty()) { // Handles Arrays and structures. @@ -15805,13 +15812,24 @@ string CompilerGLSL::variable_decl(const SPIRVariable &variable) else if (options.force_zero_initialized_variables && type_can_zero_initialize(type)) res += join(" = ", to_zero_initialized_expression(get_variable_data_type_id(variable))); } - else if (variable.initializer && !variable_decl_is_remapped_storage(variable, StorageClassWorkgroup)) + else if (variable.initializer) { - uint32_t expr = variable.initializer; - if (ir.ids[expr].get_type() != TypeUndef) - res += join(" = ", to_initializer_expression(variable)); - else if (options.force_zero_initialized_variables && type_can_zero_initialize(type)) - res += join(" = ", to_zero_initialized_expression(get_variable_data_type_id(variable))); + if (!variable_decl_is_remapped_storage(variable, StorageClassWorkgroup)) + { + uint32_t expr = variable.initializer; + if (ir.ids[expr].get_type() != TypeUndef) + res += join(" = ", to_initializer_expression(variable)); + else if (options.force_zero_initialized_variables && type_can_zero_initialize(type)) + res += join(" = ", to_zero_initialized_expression(get_variable_data_type_id(variable))); + } + else + { + // Workgroup memory requires special handling. First, it can only be Null-Initialized. + // GLSL will handle this with null initializer, while others require more work after the decl + require_extension_internal("GL_EXT_null_initializer"); + if (!backend.constant_null_initializer.empty()) + res += join(" = ", backend.constant_null_initializer); + } } return res; @@ -16572,6 +16590,12 @@ void CompilerGLSL::emit_function(SPIRFunction &func, const Bitset &return_flags) // Comes from MSL which can push global variables as local variables in main function. add_local_variable_name(var.self); statement(variable_decl(var), ";"); + + // "Real" workgroup variables in compute shaders needs extra caretaking. + // They need to be initialized with an extra routine as they come in arbitrary form. + if (var.storage == StorageClassWorkgroup && var.initializer) + emit_workgroup_initialization(var); + var.deferred_declaration = false; } else if (var.storage == StorageClassPrivate) @@ -16678,6 +16702,10 @@ void CompilerGLSL::emit_fixup() } } +void CompilerGLSL::emit_workgroup_initialization(const SPIRVariable &) +{ +} + void CompilerGLSL::flush_phi(BlockID from, BlockID to) { auto &child = get(to); diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp index b8c920c00..5ae266eb7 100644 --- a/spirv_glsl.hpp +++ b/spirv_glsl.hpp @@ -453,6 +453,7 @@ class CompilerGLSL : public Compiler virtual std::string variable_decl(const SPIRType &type, const std::string &name, uint32_t id = 0); virtual bool variable_decl_is_remapped_storage(const SPIRVariable &var, spv::StorageClass storage) const; virtual std::string to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id); + virtual void emit_workgroup_initialization(const SPIRVariable &var); struct TextureFunctionBaseArguments { @@ -625,6 +626,7 @@ class CompilerGLSL : public Compiler const char *uint16_t_literal_suffix = "us"; const char *nonuniform_qualifier = "nonuniformEXT"; const char *boolean_mix_function = "mix"; + std::string constant_null_initializer = ""; SPIRType::BaseType boolean_in_struct_remapped_type = SPIRType::Boolean; bool swizzle_is_function = false; bool shared_is_implied = false; @@ -632,6 +634,7 @@ class CompilerGLSL : public Compiler bool explicit_struct_type = false; bool use_initializer_list = false; bool use_typed_initializer_list = false; + bool requires_matching_array_initializer = false; bool can_declare_struct_inline = true; bool can_declare_arrays_inline = true; bool native_row_major_matrix = true; diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 21cd3d173..7d07e2d45 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -272,7 +272,9 @@ void CompilerMSL::build_implicit_builtins() (active_input_builtins.get(BuiltInVertexId) || active_input_builtins.get(BuiltInVertexIndex) || active_input_builtins.get(BuiltInBaseVertex) || active_input_builtins.get(BuiltInInstanceId) || active_input_builtins.get(BuiltInInstanceIndex) || active_input_builtins.get(BuiltInBaseInstance)); - bool need_local_invocation_index = (msl_options.emulate_subgroups && active_input_builtins.get(BuiltInSubgroupId)) || is_mesh_shader(); + bool need_local_invocation_index = + (msl_options.emulate_subgroups && active_input_builtins.get(BuiltInSubgroupId)) || is_mesh_shader() || + needs_workgroup_zero_init; bool need_workgroup_size = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInNumSubgroups); bool force_frag_depth_passthrough = get_execution_model() == ExecutionModelFragment && !uses_explicit_early_fragment_test() && need_subpass_input && @@ -1649,6 +1651,7 @@ string CompilerMSL::compile() analyze_image_and_sampler_usage(); analyze_sampled_image_usage(); analyze_interlocked_resource_usage(); + analyze_workgroup_variables(); preprocess_op_codes(); build_implicit_builtins(); @@ -5550,6 +5553,10 @@ void CompilerMSL::emit_header() if (suppress_incompatible_pointer_types_discard_qualifiers) statement("#pragma clang diagnostic ignored \"-Wincompatible-pointer-types-discards-qualifiers\""); + // Disable warning about "sometimes unitialized" when zero-initializing simple threadgroup variables + if (suppress_sometimes_unitialized) + statement("#pragma clang diagnostic ignored \"-Wsometimes-uninitialized\""); + // Disable warning about missing braces for array template to make arrays a value type if (spv_function_implementations.count(SPVFuncImplUnsafeArray) != 0) statement("#pragma clang diagnostic ignored \"-Wmissing-braces\""); @@ -17620,6 +17627,23 @@ void CompilerMSL::analyze_sampled_image_usage() } } +void CompilerMSL::analyze_workgroup_variables() +{ + ir.for_each_typed_id([&](uint32_t, SPIRVariable &var) { + // If workgroup variables have initializer, it can only be ConstantNull (zero init) + if (var.storage == StorageClassWorkgroup && var.initializer) + { + needs_workgroup_zero_init = true; + + // MSL compiler does not like the routine to initialize simple threadgroup variables, + // falsely claiming it is "sometimes uninitialized". Suppress it. + auto &type = get_variable_data_type(var); + if (type.array.empty() && type.member_types.empty()) + suppress_sometimes_unitialized = true; + } + }); +} + bool CompilerMSL::SampledImageScanner::handle(spv::Op opcode, const uint32_t *args, uint32_t length) { switch (opcode) @@ -19406,6 +19430,70 @@ void CompilerMSL::emit_mesh_tasks(SPIRBlock &block) statement("return;"); } +void CompilerMSL::emit_workgroup_initialization(const SPIRVariable &var) +{ + auto &type = get_variable_data_type(var); + + begin_scope(); + + if (type.array.empty() && type.member_types.empty()) + { + // For simple shared variables, we just initialize it in thread 0 of the block + // We use short to represent bool for threadgroup variable to workaround compiler bug, + // so we do a temporary fixup here. Alas. (see the type_to_glsl method) + bool is_boolean = type.basetype == SPIRType::Boolean; + if (is_boolean) + type.basetype = SPIRType::Short; + + statement("if (gl_LocalInvocationIndex == 0)"); + begin_scope(); + statement(to_name(var.self), " = ", to_initializer_expression(var), ";"); + end_scope(); + + if (is_boolean) + type.basetype = SPIRType::Boolean; + } + else + { + // Otherwise, we use a loop to cooperatively initialize the memory within the group + + // First, we define a few variable names; + string var_name = to_name(var.self); + string var_ptr_name = join(var_name, "_ptr"); + string var_size_name = join(var_name, "_sz"); + string var_pos_name = join(var_name, "_pos"); + string var_stride_name = join(var_name, "_stride"); + string var_ptr2_name = join(var_name, "_ptr2"); + + statement("threadgroup uint *", var_ptr_name, " = (threadgroup uint *)&", var_name, ";"); + statement("uint ", var_size_name, " = ", "sizeof(", var_name, ");"); + statement("uint ", var_pos_name, " = gl_LocalInvocationIndex;"); + statement("uint ", var_stride_name, " = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;"); + + statement("while (sizeof(uint) * ", var_pos_name, " < ", var_size_name, ")"); + begin_scope(); + statement(var_ptr_name, "[", var_pos_name, "] = 0u;"); + statement(var_pos_name, " += ", var_stride_name, ";"); + end_scope(); + + statement("if (gl_LocalInvocationIndex == 0)"); + begin_scope(); + statement(var_pos_name, " = (", var_size_name, " / sizeof(uint)) * sizeof(uint);"); + statement("threadgroup uchar *", var_ptr2_name, " = (threadgroup uchar *)&", var_name, ";"); + + statement("while (", var_pos_name, " < ", var_size_name, ")"); + begin_scope(); + statement(var_ptr2_name, "[", var_pos_name, "] = '\\0';"); + statement(var_pos_name, "++;"); + end_scope(); + end_scope(); + } + + statement("threadgroup_barrier(mem_flags::mem_threadgroup);"); + + end_scope(); +} + string CompilerMSL::additional_fixed_sample_mask_str() const { char print_buffer[32]; diff --git a/spirv_msl.hpp b/spirv_msl.hpp index d4f565e68..c1a581f39 100644 --- a/spirv_msl.hpp +++ b/spirv_msl.hpp @@ -881,6 +881,7 @@ class CompilerMSL : public CompilerGLSL void emit_mesh_entry_point(); void emit_mesh_outputs(); void emit_mesh_tasks(SPIRBlock &block) override; + void emit_workgroup_initialization(const SPIRVariable &var) override; // Allow Metal to use the array template to make arrays a value type std::string type_to_array_glsl(const SPIRType &type, uint32_t variable_id) override; @@ -1142,6 +1143,7 @@ class CompilerMSL : public CompilerGLSL void emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression) override; void analyze_sampled_image_usage(); + void analyze_workgroup_variables(); bool access_chain_needs_stage_io_builtin_translation(uint32_t base) override; bool prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type, spv::StorageClass storage, @@ -1224,6 +1226,7 @@ class CompilerMSL : public CompilerGLSL bool needs_subgroup_size = false; bool needs_sample_id = false; bool needs_helper_invocation = false; + bool needs_workgroup_zero_init = false; bool writes_to_depth = false; std::string qual_pos_var_name; std::string stage_in_var_name = "in"; @@ -1286,6 +1289,7 @@ class CompilerMSL : public CompilerGLSL bool suppress_missing_prototypes = false; bool suppress_incompatible_pointer_types_discard_qualifiers = false; + bool suppress_sometimes_unitialized = false; void add_spv_func_and_recompile(SPVFuncImpl spv_func);