Skip to content

Commit

Permalink
Vertex Loader dynamic stride support
Browse files Browse the repository at this point in the history
  • Loading branch information
etang-cw committed Oct 4, 2023
1 parent 0287ae6 commit 7d58bbd
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 11 deletions.
4 changes: 4 additions & 0 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ struct CLIArguments
uint32_t msl_fixed_subgroup_size = 0;
CompilerMSL::Options::IndexType msl_vertex_index_type = CompilerMSL::Options::IndexType::None;
bool msl_use_pixel_type_loads = false;
bool msl_dynamic_vertex_stride = false;
bool msl_force_sample_rate_shading = false;
bool msl_manual_helper_invocation_updates = true;
bool msl_check_discarded_frag_stores = false;
Expand Down Expand Up @@ -950,6 +951,7 @@ static void print_help_msl()
"\t\tIntended for Vulkan Portability implementations where VK_EXT_subgroup_size_control is not supported or disabled.\n"
"\t\tIf 0, assume variable subgroup size as actually exposed by Metal.\n"
"\t[--msl-use-pixel-type-loads]:\n\t\tEnable use of MSL pixel-type loads (e.g. rgb9e5<float3>).\n"
"\t[--msl-dynamic-vertex-stride]:\n\t\tEnable dynamic strides in shader vertex loader.\n"
"\t[--msl-force-sample-rate-shading]:\n\t\tForce fragment shaders to run per sample.\n"
"\t\tThis adds a [[sample_id]] parameter if none is already present.\n"
"\t[--msl-no-manual-helper-invocation-updates]:\n\t\tDo not manually update the HelperInvocation builtin when a fragment is discarded.\n"
Expand Down Expand Up @@ -1235,6 +1237,7 @@ static string compile_iteration(const CLIArguments &args, std::vector<uint32_t>
msl_opts.emulate_subgroups = args.msl_emulate_subgroups;
msl_opts.fixed_subgroup_size = args.msl_fixed_subgroup_size;
msl_opts.vertex_index_type = args.msl_vertex_index_type;
msl_opts.vertex_loader_dynamic_stride = args.msl_dynamic_vertex_stride;
msl_opts.use_pixel_type_loads = args.msl_use_pixel_type_loads;
msl_opts.force_sample_rate_shading = args.msl_force_sample_rate_shading;
msl_opts.manual_helper_invocation_updates = args.msl_manual_helper_invocation_updates;
Expand Down Expand Up @@ -1980,6 +1983,7 @@ static int main_inner(int argc, char *argv[])
THROW("Bad index type");
});
cbs.add("--msl-use-pixel-type-loads", [&args](CLIParser &) { args.msl_use_pixel_type_loads = true; });
cbs.add("--msl-dynamic-vertex-stride", [&args](CLIParser &) { args.msl_dynamic_vertex_stride = true; });
cbs.add("--msl-force-sample-rate-shading", [&args](CLIParser &) { args.msl_force_sample_rate_shading = true; });
cbs.add("--msl-no-manual-helper-invocation-updates",
[&args](CLIParser &) { args.msl_manual_helper_invocation_updates = false; });
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

static half3 spvLoadVertexRG11B10Half(uint value)
{
ushort3 res = ushort3((value << 4) & 0x7ff0, (value >> 7) & 0x7ff0, (value >> 17) & 0x7fe0);
return as_type<half3>(res);
}
static float3 spvLoadVertexRGB9E5Float(uint value)
{
float exponent = exp2(float(value >> 27)) * exp2(float(-(15 + 9)));
uint3 mantissa = uint3(value & 0x1ff, extract_bits(value, 9, 9), extract_bits(value, 18, 9));
return float3(mantissa) * exponent;
}
struct main0_out
{
float4 gl_Position [[position]];
};

struct main0_in
{
float4 a0 [[attribute(0)]];
float4 a1 [[attribute(1)]];
float4 a3 [[attribute(3)]];
float4 a4 [[attribute(4)]];
float4 a5 [[attribute(5)]];
float4 a6 [[attribute(6)]];
uint a7 [[attribute(7)]];
float4 a8 [[attribute(8)]];
};

struct spvVertexData0
{
uchar4 a0;
uchar spvPad4;
packed_uchar3 a1;
};
static_assert(alignof(spvVertexData0) == 4, "Unexpected alignment");

struct spvVertexData1
{
ushort spvPad0[4];
packed_short4 a3;
};
static_assert(alignof(spvVertexData1) == 2, "Unexpected alignment");

struct spvVertexData2
{
uint a4;
uint a5;
uint a6;
};
static_assert(alignof(spvVertexData2) == 4, "Unexpected alignment");

struct spvVertexData3
{
uchar spvPad0;
uchar a8;
ushort spvPad2[15];
uint a7;
};
static_assert(alignof(spvVertexData3) == 4, "Unexpected alignment");

main0_in spvLoadVertex(const device spvVertexData0& data0, const device spvVertexData1& data1, const device spvVertexData2& data2, const device spvVertexData3& data3)
{
main0_in out;
out.a0 = unpack_unorm4x8_to_float(as_type<uint>(data0.a0));
out.a1 = float4(float3(uchar3(data0.a1)).bgr, 1);
out.a3 = max(float4(short4(data1.a3)) * (1.f / 32767), -1.f);
out.a4 = unpack_unorm10a2_to_float(data2.a4);
out.a5 = float4(spvLoadVertexRGB9E5Float(data2.a5), 1);
out.a6 = float4(float3(spvLoadVertexRG11B10Half(data2.a6)), 1);
out.a7 = data3.a7;
out.a8 = float4(float(data3.a8) * (1.f / 255), 0, 0, 1);
return out;
}

vertex main0_out main0(device const uchar* spvVertexBuffer0 [[buffer(0)]], device const uchar* spvVertexBuffer1 [[buffer(1)]], device const uchar* spvVertexBuffer2 [[buffer(2)]], device const uchar* spvVertexBuffer3 [[buffer(3)]], uint gl_VertexIndex [[vertex_id]], uint gl_BaseVertex [[base_vertex]], uint gl_InstanceIndex [[instance_id]], uint gl_BaseInstance [[base_instance]], const device uint* spvVertexStrides [[buffer(19)]])
{
main0_out out = {};
main0_in in = spvLoadVertex(*reinterpret_cast<device const spvVertexData0*>(spvVertexBuffer0 + spvVertexStrides[0] * gl_InstanceIndex), *reinterpret_cast<device const spvVertexData1*>(spvVertexBuffer1 + spvVertexStrides[1] * gl_VertexIndex), *reinterpret_cast<device const spvVertexData2*>(spvVertexBuffer2 + spvVertexStrides[2] * gl_BaseInstance), *reinterpret_cast<device const spvVertexData3*>(spvVertexBuffer3 + spvVertexStrides[3] * (gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4)));
out.gl_Position = ((((((in.a0 + in.a1) + in.a3) + in.a4) + in.a5) + in.a6) + float4(float(in.a7))) + in.a8;
return out;
}

87 changes: 87 additions & 0 deletions reference/shaders-msl/vert/attrs.vertex-loader.dynamic-stride.vert
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

static half3 spvLoadVertexRG11B10Half(uint value)
{
ushort3 res = ushort3((value << 4) & 0x7ff0, (value >> 7) & 0x7ff0, (value >> 17) & 0x7fe0);
return as_type<half3>(res);
}
static float3 spvLoadVertexRGB9E5Float(uint value)
{
float exponent = exp2(float(value >> 27)) * exp2(float(-(15 + 9)));
uint3 mantissa = uint3(value & 0x1ff, extract_bits(value, 9, 9), extract_bits(value, 18, 9));
return float3(mantissa) * exponent;
}
struct main0_out
{
float4 gl_Position [[position]];
};

struct main0_in
{
float4 a0 [[attribute(0)]];
float4 a1 [[attribute(1)]];
float4 a3 [[attribute(3)]];
float4 a4 [[attribute(4)]];
float4 a5 [[attribute(5)]];
float4 a6 [[attribute(6)]];
uint a7 [[attribute(7)]];
float4 a8 [[attribute(8)]];
};

struct spvVertexData0
{
uchar4 a0;
uchar spvPad4;
packed_uchar3 a1;
};
static_assert(alignof(spvVertexData0) == 4, "Unexpected alignment");

struct spvVertexData1
{
ushort spvPad0[4];
packed_short4 a3;
};
static_assert(alignof(spvVertexData1) == 2, "Unexpected alignment");

struct spvVertexData2
{
uint a4;
uint a5;
uint a6;
};
static_assert(alignof(spvVertexData2) == 4, "Unexpected alignment");

struct spvVertexData3
{
uchar spvPad0;
uchar a8;
ushort spvPad2[15];
uint a7;
};
static_assert(alignof(spvVertexData3) == 4, "Unexpected alignment");

main0_in spvLoadVertex(const device spvVertexData0& data0, const device spvVertexData1& data1, const device spvVertexData2& data2, const device spvVertexData3& data3)
{
main0_in out;
out.a0 = unpack_unorm4x8_to_float(as_type<uint>(data0.a0));
out.a1 = float4(float3(uchar3(data0.a1)).bgr, 1);
out.a3 = max(float4(short4(data1.a3)) * (1.f / 32767), -1.f);
out.a4 = unpack_unorm10a2_to_float(data2.a4);
out.a5 = float4(spvLoadVertexRGB9E5Float(data2.a5), 1);
out.a6 = float4(float3(spvLoadVertexRG11B10Half(data2.a6)), 1);
out.a7 = data3.a7;
out.a8 = float4(float(data3.a8) * (1.f / 255), 0, 0, 1);
return out;
}

vertex main0_out main0(device const uchar* spvVertexBuffer0 [[buffer(0)]], device const uchar* spvVertexBuffer1 [[buffer(1)]], device const uchar* spvVertexBuffer2 [[buffer(2)]], device const uchar* spvVertexBuffer3 [[buffer(3)]], uint gl_VertexIndex [[vertex_id]], uint gl_BaseVertex [[base_vertex]], uint gl_InstanceIndex [[instance_id]], uint gl_BaseInstance [[base_instance]], const device uint* spvVertexStrides [[buffer(19)]])
{
main0_out out = {};
main0_in in = spvLoadVertex(*reinterpret_cast<device const spvVertexData0*>(spvVertexBuffer0 + spvVertexStrides[0] * gl_InstanceIndex), *reinterpret_cast<device const spvVertexData1*>(spvVertexBuffer1 + spvVertexStrides[1] * gl_VertexIndex), *reinterpret_cast<device const spvVertexData2*>(spvVertexBuffer2 + spvVertexStrides[2] * gl_BaseInstance), *reinterpret_cast<device const spvVertexData3*>(spvVertexBuffer3 + spvVertexStrides[3] * (gl_BaseInstance + (gl_InstanceIndex - gl_BaseInstance) / 4)));
out.gl_Position = ((((((in.a0 + in.a1) + in.a3) + in.a4) + in.a5) + in.a6) + float4(float(in.a7))) + in.a8;
return out;
}

15 changes: 15 additions & 0 deletions shaders-msl/vert/attrs.vertex-loader.dynamic-stride.vert
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#version 310 es

layout (location = 0) in vec4 a0;
layout (location = 1) in vec4 a1;
layout (location = 3) in vec4 a3;
layout (location = 4) in vec4 a4;
layout (location = 5) in vec4 a5;
layout (location = 6) in vec4 a6;
layout (location = 7) in uint a7;
layout (location = 8) in vec4 a8;

void main()
{
gl_Position = a0 + a1 + a3 + a4 + a5 + a6 + float(a7) + a8;
}
50 changes: 43 additions & 7 deletions spirv_msl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7517,7 +7517,7 @@ void CompilerMSL::prepare_shader_vertex_loader()
const Meta &type_meta = ir.meta[var.basetype];
assert(type_meta.members.size() == type.member_types.size());

vertex_loader_writer.init(vertex_attributes, vertex_bindings);
vertex_loader_writer.init(vertex_attributes, vertex_bindings, msl_options.vertex_loader_dynamic_stride);
for (size_t i = 0; i < type_meta.members.size(); i++)
{
const Meta::Decoration &meta = type_meta.members[i];
Expand All @@ -7536,9 +7536,16 @@ void CompilerMSL::prepare_shader_vertex_loader()
continue;
if (!load.empty())
load.append(", ");
std::string istr = std::to_string(i);
if (msl_options.vertex_loader_dynamic_stride)
{
load.append("*reinterpret_cast<device const spvVertexData");
load.append(istr);
load.append("*>(");
}
load.append("spvVertexBuffer");
load.append(std::to_string(i));
if (binding.stride == 0)
load.append(istr);
if (binding.stride == 0 && !msl_options.vertex_loader_dynamic_stride)
{
load.append("[0]");
}
Expand All @@ -7559,22 +7566,35 @@ void CompilerMSL::prepare_shader_vertex_loader()
default:
SPIRV_CROSS_THROW("Unrecognized vertex binding rate");
}
load.push_back('[');
if (msl_options.vertex_loader_dynamic_stride)
{
load.append(" + spvVertexStrides[");
load.append(istr);
load.append("] * ");
}
else
{
load.push_back('[');
}
if (binding.divisor <= 1)
{
load.append(binding.divisor == 0 ? base : index);
}
else
{
if (msl_options.vertex_loader_dynamic_stride)
load.push_back('(');
load.append(base);
load.append(" + (");
load.append(index);
load.append(" - ");
load.append(base);
load.append(") / ");
load.append(std::to_string(binding.divisor));
if (msl_options.vertex_loader_dynamic_stride)
load.push_back(')');
}
load.push_back(']');
load.push_back(msl_options.vertex_loader_dynamic_stride ? ')' : ']');
}
}

Expand Down Expand Up @@ -12804,8 +12824,15 @@ string CompilerMSL::entry_point_arg_stage_in()
if (!decl.empty())
decl.append(", ");
std::string istr = std::to_string(i);
decl.append("device const spvVertexData");
decl.append(istr);
if (msl_options.vertex_loader_dynamic_stride)
{
decl.append("device const uchar");
}
else
{
decl.append("device const spvVertexData");
decl.append(istr);
}
decl.append("* spvVertexBuffer");
decl.append(istr);
decl.append(" [[buffer(");
Expand Down Expand Up @@ -12991,6 +13018,15 @@ void CompilerMSL::entry_point_args_builtin(string &ep_args)
if (needs_base_instance_arg == TriState::Yes)
ep_args += built_in_func_arg(BuiltInBaseInstance, !ep_args.empty());

if (get_using_shader_vertex_loader() && msl_options.vertex_loader_dynamic_stride)
{
if (!ep_args.empty())
ep_args.append(", ");
ep_args.append("const device uint* spvVertexStrides [[buffer(");
ep_args.append(std::to_string(msl_options.vertex_loader_dynamic_stride_buffer_index));
ep_args.append(")]]");
}

if (capture_output_to_buffer)
{
// Add parameters to hold the indirect draw parameters and the shader output. This has to be handled
Expand Down
6 changes: 5 additions & 1 deletion spirv_msl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ class CompilerMSL : public CompilerGLSL
uint32_t shader_input_buffer_index = 22;
uint32_t shader_index_buffer_index = 21;
uint32_t shader_patch_input_buffer_index = 20;
uint32_t vertex_loader_dynamic_stride_buffer_index = 19;
uint32_t shader_input_wg_index = 0;
uint32_t device_index = 0;
uint32_t enable_frag_output_mask = 0xffffffff;
Expand Down Expand Up @@ -689,6 +690,9 @@ class CompilerMSL : public CompilerGLSL
// different shaders for these three scenarios.
IndexType vertex_index_type = IndexType::None;

// Enable use of dynamic stride for shader vertex loader
bool vertex_loader_dynamic_stride = false;

// Allows shaders to load from types like `rgb9e5<float3>` or `rgba16unorm<float4>`
// (Supported on Apple GPUs, I'm guessing the cutoff is Apple4 but Apple doesn't document it...)
// Compiles to dedicated load-and-expand instructions, which are more efficient than expanding with the ALU.
Expand Down Expand Up @@ -1095,7 +1099,7 @@ class CompilerMSL : public CompilerGLSL

public:
/// Initialize the vertex loader writer
void init(const VectorView<MSLVertexAttribute> &in_attributes, const VectorView<MSLVertexBinding> &in_bindings);
void init(const VectorView<MSLVertexAttribute> &in_attributes, const VectorView<MSLVertexBinding> &in_bindings, bool dynamic_stride);
/// Initializes a used attribute
void load(const Meta::Decoration &meta, const SPIRType &type);
SPVFuncImpl get_function_for_loading_vertex(uint32_t attribute, bool has_pixel_type_loads);
Expand Down
8 changes: 5 additions & 3 deletions spirv_msl_vertex_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ static uint8_t get_align_log2(uint32_t value, uint8_t max)
}

void CompilerMSL::MSLVertexLoaderWriter::init(const VectorView<MSLVertexAttribute> &in_attributes,
const VectorView<MSLVertexBinding> &in_bindings)
const VectorView<MSLVertexBinding> &in_bindings,
bool dynamic_stride)
{
memset(attributes, 0, sizeof(attributes));
memset(bindings, 0, sizeof(bindings));
Expand All @@ -362,8 +363,9 @@ void CompilerMSL::MSLVertexLoaderWriter::init(const VectorView<MSLVertexAttribut
{
if (binding.binding >= MaxBindings)
continue;
bindings[binding.binding].stride = binding.stride;
bindings[binding.binding].struct_size = binding.stride;
uint32_t stride = dynamic_stride ? 0 : binding.stride;
bindings[binding.binding].stride = stride;
bindings[binding.binding].struct_size = stride;
bindings[binding.binding].rate = binding.rate;
bindings[binding.binding].divisor = binding.divisor;
bindings[binding.binding].valid = true;
Expand Down
2 changes: 2 additions & 0 deletions test_shaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ def cross_compile_msl(shader, spirv, opt, iterations, paths):
msl_args.extend(['--msl-vertex-index-type', 'uint16'])
if '.pixel-loads.' in shader:
msl_args.append('--msl-use-pixel-type-loads')
if '.dynamic-stride.' in shader:
msl_args.append('--msl-dynamic-vertex-stride')
if '.vertex-loader.' in shader:
# Some vertex bindings for testing
msl_args.extend(['--msl-vertex-binding', '0', '8', 'instance', '1'])
Expand Down

0 comments on commit 7d58bbd

Please sign in to comment.