Skip to content

Commit

Permalink
[wip] MSL: Fix dynamically indexed pull interpolants
Browse files Browse the repository at this point in the history
Related to KhronosGroup#1796.
  • Loading branch information
ncesario-lunarg committed Jul 25, 2024
1 parent 68d4011 commit c566a32
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 25 deletions.
111 changes: 87 additions & 24 deletions spirv_msl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2924,6 +2924,31 @@ void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage
statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, ";");
}
});
if (pull_model_inputs.count(var.self)) {
const auto& lerpOps = pull_model_inputs.at(var.self);
for (const auto op : lerpOps) {
entry_func.fixup_hooks_in.push_back([=, &var]() {
string lerp_call, lerp_name;
switch (op) {
case GLSLstd450InterpolateAtCentroid:
lerp_name = "centroid";
lerp_call = ".interpolate_at_centroid()";
break;
case GLSLstd450InterpolateAtSample:
lerp_name = "sample";
lerp_call = join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
break;
case GLSLstd450InterpolateAtOffset:
lerp_name = "center";
lerp_call = ".interpolate_at_center()";
break;
default:
SPIRV_CROSS_THROW("Bad interpolation operator");
}
statement(to_name(var.self) + "_" + lerp_name, "[", i, "] = ", ib_var_ref, ".", mbr_name, lerp_call, ";");
});
}
}
break;

case StorageClassOutput:
Expand Down Expand Up @@ -8643,7 +8668,8 @@ bool CompilerMSL::access_chain_needs_stage_io_builtin_translation(uint32_t base)
void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length)
{
auto *var = maybe_get_backing_variable(ops[2]);
if (!var || !pull_model_inputs.count(var->self))
const auto vt = get_pointee_type(ops[0]);
if (!var || !pull_model_inputs.count(var->self) || !is_scalar(vt))
return;
// Get the base index.
uint32_t interface_index;
Expand Down Expand Up @@ -8675,9 +8701,9 @@ void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t
}

auto *c = maybe_get<SPIRConstant>(ops[i]);
if (!c || c->specialization)
SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable using pull-model "
"interpolation. This is currently unsupported.");
if (!c || c->specialization) {
auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
}

if (type->parent_type)
type = &get<SPIRType>(type->parent_type);
Expand Down Expand Up @@ -10715,25 +10741,33 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
break;

case GLSLstd450InterpolateAtCentroid:
{
// We can't just emit the expression normally, because the qualified name contains a call to the default
// interpolate method, or refers to a local variable. We saved the interface index we need; use it to construct
// the base for the method call.
uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
string component;
if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
{
uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
auto *c = maybe_get<SPIRConstant>(index_expr);
if (!c || c->specialization)
component = join("[", to_expression(index_expr), "]");
else
component = join(".", index_to_swizzle(c->scalar()));
}
emit_op(result_type, id,
join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
".interpolate_at_centroid()", component),
should_forward(args[0]));
{
// We can't just emit the expression normally, because the qualified name contains a call to the default
// interpolate method, or refers to a local variable. We saved the interface index we need; use it to construct
// the base for the method call.
uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
string component;
if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
{
uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
auto *c = maybe_get<SPIRConstant>(index_expr);
if (!c || c->specialization)
component = join("[", to_expression(index_expr), "]");
else
component = join(".", index_to_swizzle(c->scalar()));
}
const auto var = maybe_get_backing_variable(args[0]);
if (var) {
emit_op(result_type, id,
//join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
join(to_name(var->self, true) + "_centroid", "[", interface_index, "]", component),
should_forward(args[0]));
} else {
emit_op(result_type, id,
join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
".interpolate_at_centroid()", component),
should_forward(args[0]));
}
break;
}

Expand Down Expand Up @@ -15803,6 +15837,34 @@ std::string CompilerMSL::variable_decl(const SPIRType &type, const std::string &
return CompilerGLSL::variable_decl(type, name, id);
}

std::string CompilerMSL::variable_decl(const SPIRVariable &variable) {
auto res = CompilerGLSL::variable_decl(variable);

if (pull_model_inputs.count(variable.self)) {
const auto &type = get_variable_data_type(variable);
auto& interpOps = pull_model_inputs.at(variable.self);
for (const auto op : interpOps) {
res += ";\n";
string lerp_name;
switch (op) {
case GLSLstd450InterpolateAtCentroid:
lerp_name = "centroid";
break;
case GLSLstd450InterpolateAtSample:
lerp_name = "sample";
break;
case GLSLstd450InterpolateAtOffset:
lerp_name = "offset";
break;
default:
SPIRV_CROSS_THROW("Bad interpolation operator.");
}
res += join(to_qualifiers_glsl(variable.self), variable_decl(type, to_name(variable.self), variable.self) + "_" + lerp_name, " = {}");
}
}
return res;
}

std::string CompilerMSL::sampler_type(const SPIRType &type, uint32_t id, bool member)
{
auto *var = maybe_get<SPIRVariable>(id);
Expand Down Expand Up @@ -17451,7 +17513,8 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
auto *var = compiler.maybe_get_backing_variable(args[4]);
if (var)
{
compiler.pull_model_inputs.insert(var->self);
auto& interpOps = compiler.pull_model_inputs[var->self]; // get or create the op set
interpOps.emplace(op_450);
auto &var_type = compiler.get_variable_element_type(*var);
// In addition, if this variable has a 'Sample' decoration, we need the sample ID
// in order to do default interpolation.
Expand Down
3 changes: 2 additions & 1 deletion spirv_msl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,7 @@ class CompilerMSL : public CompilerGLSL

// GCC workaround of lambdas calling protected functions (for older GCC versions)
std::string variable_decl(const SPIRType &type, const std::string &name, uint32_t id = 0) override;
std::string variable_decl(const SPIRVariable &variable) override;

std::string image_type_glsl(const SPIRType &type, uint32_t id, bool member) override;
std::string sampler_type(const SPIRType &type, uint32_t id, bool member);
Expand Down Expand Up @@ -1228,7 +1229,7 @@ class CompilerMSL : public CompilerGLSL
SmallVector<std::pair<uint32_t, uint32_t>> buffer_aliases_argument;
SmallVector<uint32_t> buffer_aliases_discrete;
std::unordered_set<uint32_t> atomic_image_vars_emulated; // Emulate texture2D atomic operations
std::unordered_set<uint32_t> pull_model_inputs;
std::unordered_map<uint32_t, std::unordered_set<GLSLstd450>> pull_model_inputs;
std::unordered_set<uint32_t> recursive_inputs;

SmallVector<SPIRVariable *> entry_point_bindings;
Expand Down

0 comments on commit c566a32

Please sign in to comment.