Skip to content

Commit

Permalink
[wip] MSL: Fix dynamically indexed pull interpolants
Browse files Browse the repository at this point in the history
Related to KhronosGroup#1796.
  • Loading branch information
ncesario-lunarg committed Jul 25, 2024
1 parent 68d4011 commit 94c618c
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 38 deletions.
5 changes: 4 additions & 1 deletion spirv_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1673,7 +1673,10 @@ enum ExtendedDecorations

SPIRVCrossDecorationOverlappingBinding,

SPIRVCrossDecorationCount
// Tracks the variable dynamically indexing into an array used with pull interpolants
SPIRVCrossDecorationInterpolantIndexVariable,

SPIRVCrossDecorationCount,
};

struct Meta
Expand Down
171 changes: 135 additions & 36 deletions spirv_msl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2924,6 +2924,31 @@ void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage
statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, ";");
}
});
if (pull_model_inputs.count(var.self)) {
const auto& lerpOps = pull_model_inputs.at(var.self);
for (const auto op : lerpOps) {
entry_func.fixup_hooks_in.push_back([=, &var]() {
string lerp_call, lerp_name;
switch (op) {
case GLSLstd450InterpolateAtCentroid:
lerp_name = "centroid";
lerp_call = ".interpolate_at_centroid()";
break;
case GLSLstd450InterpolateAtSample:
lerp_name = "sample";
lerp_call = join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
break;
case GLSLstd450InterpolateAtOffset:
lerp_name = "center";
lerp_call = ".interpolate_at_center()";
break;
default:
SPIRV_CROSS_THROW("Bad interpolation operator");
}
statement(to_name(var.self) + "_" + lerp_name, "[", i, "] = ", ib_var_ref, ".", mbr_name, lerp_call, ";");
});
}
}
break;

case StorageClassOutput:
Expand Down Expand Up @@ -8675,20 +8700,29 @@ void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t
}

auto *c = maybe_get<SPIRConstant>(ops[i]);
if (!c || c->specialization)
SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable using pull-model "
"interpolation. This is currently unsupported.");

if (type->parent_type)
type = &get<SPIRType>(type->parent_type);
else if (type->basetype == SPIRType::Struct)
else if (type->basetype == SPIRType::Struct) {
if (!c || c->specialization) {
SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable using pull-model "
"interpolation. This is currently unsupported.");
}
type = &get<SPIRType>(type->member_types[c->scalar()]);
}

if (!has_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex) &&
i - 3 == var_type.array.size())
continue;

interface_index += c->scalar();
if (c) {
interface_index += c->scalar();
} else {
const auto indexVar = maybe_get_backing_variable(ops[i]);
if (indexVar) {
set_extended_decoration(ops[1], SPIRVCrossDecorationInterpolantIndexVariable, indexVar->self);
}
}
}
// Save this to the access chain itself so we can recover it later when calling an interpolation function.
set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, interface_index);
Expand Down Expand Up @@ -10715,25 +10749,37 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
break;

case GLSLstd450InterpolateAtCentroid:
{
// We can't just emit the expression normally, because the qualified name contains a call to the default
// interpolate method, or refers to a local variable. We saved the interface index we need; use it to construct
// the base for the method call.
uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
string component;
if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
{
uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
auto *c = maybe_get<SPIRConstant>(index_expr);
if (!c || c->specialization)
component = join("[", to_expression(index_expr), "]");
else
component = join(".", index_to_swizzle(c->scalar()));
}
emit_op(result_type, id,
join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
".interpolate_at_centroid()", component),
should_forward(args[0]));
{
// We can't just emit the expression normally, because the qualified name contains a call to the default
// interpolate method, or refers to a local variable. We saved the interface index we need; use it to construct
// the base for the method call.
uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
string component;
if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
{
uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
auto *c = maybe_get<SPIRConstant>(index_expr);
if (!c || c->specialization)
component = join("[", to_expression(index_expr), "]");
else
component = join(".", index_to_swizzle(c->scalar()));
}
const auto var = maybe_get_backing_variable(args[0]);
if (var) {
auto array_index = std::to_string(interface_index);
if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantIndexVariable)) {
const auto indexVarId = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantIndexVariable);
array_index = to_name(indexVarId);
}
emit_op(result_type, id,
join(to_name(var->self, true) + "_centroid", "[", array_index , "]", component),
should_forward(args[0]));
} else {
emit_op(result_type, id,
join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
".interpolate_at_centroid()", component),
should_forward(args[0]));
}
break;
}

Expand All @@ -10750,10 +10796,22 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
else
component = join(".", index_to_swizzle(c->scalar()));
}
emit_op(result_type, id,
join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
".interpolate_at_sample(", to_expression(args[1]), ")", component),
should_forward(args[0]) && should_forward(args[1]));
const auto var = maybe_get_backing_variable(args[0]);
if (var) {
auto array_index = std::to_string(interface_index);
if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantIndexVariable)) {
const auto indexVarId = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantIndexVariable);
array_index = to_name(indexVarId);
}
emit_op(result_type, id,
join(to_name(var->self, true) + "_sample", "[", array_index , "]", component),
should_forward(args[0]));
} else {
emit_op(result_type, id,
join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
".interpolate_at_sample(", to_expression(args[1]), ")", component),
should_forward(args[0]) && should_forward(args[1]));
}
break;
}

Expand All @@ -10770,13 +10828,25 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
else
component = join(".", index_to_swizzle(c->scalar()));
}
// Like Direct3D, Metal puts the (0, 0) at the upper-left corner, not the center as SPIR-V and GLSL do.
// Offset the offset by (1/2 - 1/16), or 0.4375, to compensate for this.
// It has to be (1/2 - 1/16) and not 1/2, or several CTS tests subtly break on Intel.
emit_op(result_type, id,
join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
".interpolate_at_offset(", to_expression(args[1]), " + 0.4375)", component),
should_forward(args[0]) && should_forward(args[1]));
const auto var = maybe_get_backing_variable(args[0]);
if (var) {
auto array_index = std::to_string(interface_index);
if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantIndexVariable)) {
const auto indexVarId = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantIndexVariable);
array_index = to_name(indexVarId);
}
emit_op(result_type, id,
join(to_name(var->self, true) + "_offset", "[", array_index , "]", component),
should_forward(args[0]));
} else {
// Like Direct3D, Metal puts the (0, 0) at the upper-left corner, not the center as SPIR-V and GLSL do.
// Offset the offset by (1/2 - 1/16), or 0.4375, to compensate for this.
// It has to be (1/2 - 1/16) and not 1/2, or several CTS tests subtly break on Intel.
emit_op(result_type, id,
join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
".interpolate_at_offset(", to_expression(args[1]), " + 0.4375)", component),
should_forward(args[0]) && should_forward(args[1]));
}
break;
}

Expand Down Expand Up @@ -15803,6 +15873,34 @@ std::string CompilerMSL::variable_decl(const SPIRType &type, const std::string &
return CompilerGLSL::variable_decl(type, name, id);
}

std::string CompilerMSL::variable_decl(const SPIRVariable &variable) {
auto res = CompilerGLSL::variable_decl(variable);

if (pull_model_inputs.count(variable.self)) {
const auto &type = get_variable_data_type(variable);
auto& interpOps = pull_model_inputs.at(variable.self);
for (const auto op : interpOps) {
res += ";\n";
string lerp_name;
switch (op) {
case GLSLstd450InterpolateAtCentroid:
lerp_name = "centroid";
break;
case GLSLstd450InterpolateAtSample:
lerp_name = "sample";
break;
case GLSLstd450InterpolateAtOffset:
lerp_name = "offset";
break;
default:
SPIRV_CROSS_THROW("Bad interpolation operator.");
}
res += join(to_qualifiers_glsl(variable.self), variable_decl(type, to_name(variable.self), variable.self) + "_" + lerp_name, " = {}");
}
}
return res;
}

std::string CompilerMSL::sampler_type(const SPIRType &type, uint32_t id, bool member)
{
auto *var = maybe_get<SPIRVariable>(id);
Expand Down Expand Up @@ -17451,7 +17549,8 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
auto *var = compiler.maybe_get_backing_variable(args[4]);
if (var)
{
compiler.pull_model_inputs.insert(var->self);
auto& interpOps = compiler.pull_model_inputs[var->self]; // get or create the op set
interpOps.emplace(op_450);
auto &var_type = compiler.get_variable_element_type(*var);
// In addition, if this variable has a 'Sample' decoration, we need the sample ID
// in order to do default interpolation.
Expand Down
3 changes: 2 additions & 1 deletion spirv_msl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,7 @@ class CompilerMSL : public CompilerGLSL

// GCC workaround of lambdas calling protected functions (for older GCC versions)
std::string variable_decl(const SPIRType &type, const std::string &name, uint32_t id = 0) override;
std::string variable_decl(const SPIRVariable &variable) override;

std::string image_type_glsl(const SPIRType &type, uint32_t id, bool member) override;
std::string sampler_type(const SPIRType &type, uint32_t id, bool member);
Expand Down Expand Up @@ -1228,7 +1229,7 @@ class CompilerMSL : public CompilerGLSL
SmallVector<std::pair<uint32_t, uint32_t>> buffer_aliases_argument;
SmallVector<uint32_t> buffer_aliases_discrete;
std::unordered_set<uint32_t> atomic_image_vars_emulated; // Emulate texture2D atomic operations
std::unordered_set<uint32_t> pull_model_inputs;
std::unordered_map<uint32_t, std::unordered_set<GLSLstd450>> pull_model_inputs;
std::unordered_set<uint32_t> recursive_inputs;

SmallVector<SPIRVariable *> entry_point_bindings;
Expand Down

0 comments on commit 94c618c

Please sign in to comment.