Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] MSL: Fix dynamically indexed pull interpolants #2364

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion spirv_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1673,7 +1673,10 @@ enum ExtendedDecorations

SPIRVCrossDecorationOverlappingBinding,

SPIRVCrossDecorationCount
// Tracks the variable dynamically indexing into an array used with pull interpolants
SPIRVCrossDecorationInterpolantIndexVariable,

SPIRVCrossDecorationCount,
};

struct Meta
Expand Down
281 changes: 257 additions & 24 deletions spirv_msl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1862,7 +1862,25 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
{
uint32_t base_id = ops[2];
if (global_var_ids.find(base_id) != global_var_ids.end())
added_arg_ids.insert(base_id);
{
const auto arg_id = [&]() -> uint32_t
{
// Check if we need to pass in a copied of array of interpolants
const auto interpOpItr = pull_model_inputs.find(base_id);
if (interpOpItr != pull_model_inputs.cend())
{
for (const auto &itr : interpOpItr->second.ops)
{
if (itr.second.count(ops[1]))
{
return itr.first.new_var_id;
}
}
}
return base_id;
}();
added_arg_ids.insert(arg_id);
}

// Use Metal's native frame-buffer fetch API for subpass inputs.
auto &type = get<SPIRType>(ops[0]);
Expand Down Expand Up @@ -2924,6 +2942,85 @@ void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage
statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, ";");
}
});
if (pull_model_inputs.count(var.self))
{
const auto &lerpOps = pull_model_inputs.at(var.self).ops;
for (const auto &op : lerpOps)
{
entry_func.fixup_hooks_in.push_back(
[=]()
{
string lerp_call, lerp_name;
switch (op.first.op)
{
case GLSLstd450InterpolateAtCentroid:
lerp_name = "centroid";
lerp_call = ".interpolate_at_centroid()";
break;
case GLSLstd450InterpolateAtSample:
{
lerp_name = "sample";
const auto sid_var = maybe_get_backing_variable(op.first.offset_arg);
if (sid_var)
{
lerp_call = join(".interpolate_at_sample(", to_expression(sid_var->self), ")");
}
else
{
const auto sid_const = maybe_get<SPIRConstant>(op.first.offset_arg);
if (sid_const)
{
lerp_call =
join(".interpolate_at_sample(", to_expression(sid_const->self), ")");
}
else
{
lerp_call = join(".interpolate_at_sample(",
to_expression(builtin_sample_id_id), ")");
}
}
}
break;
case GLSLstd450InterpolateAtOffset:
{
lerp_name = "offset";
// Like Direct3D, Metal puts the (0, 0) at the upper-left corner, not the center as SPIR-V and GLSL do.
// Offset the offset by (1/2 - 1/16), or 0.4375, to compensate for this.
// It has to be (1/2 - 1/16) and not 1/2, or several CTS tests subtly break on Intel.
const auto offset_var = maybe_get_backing_variable(op.first.offset_arg);
if (offset_var)
{
lerp_call = join(".interpolate_at_offset(", to_expression(offset_var->self),
" + 0.4375)");
}
else
{
// TODO Without load/store elimination, this will incorrectly return null. Is there already a way to "walk" the
// load/store chain in spirv-cross?
// e.g., dEQP-VK.pipeline.monolithic.multisample_interpolation.nonuniform_interpolant_indexing.offset will fail,
// though it will pass if 'spirv-opt -O' is run on the spirv first.
const auto offset_const = maybe_get<SPIRConstant>(op.first.offset_arg);
if (offset_const)
{
lerp_call = join(".interpolate_at_offset(",
to_expression(offset_const->self), " + 0.4375)");
}
else
{
lerp_call = join(".interpolate_at_offset(",
to_expression(op.first.offset_arg), " + 0.4375)");
}
}
}
break;
default:
SPIRV_CROSS_THROW("Bad interpolation operator");
}
statement(to_name(op.first.new_var_id), "[", i, "] = ", ib_var_ref, ".", mbr_name,
lerp_call, ";");
});
}
}
break;

case StorageClassOutput:
Expand Down Expand Up @@ -8675,20 +8772,35 @@ void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t
}

auto *c = maybe_get<SPIRConstant>(ops[i]);
if (!c || c->specialization)
SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable using pull-model "
"interpolation. This is currently unsupported.");

if (type->parent_type)
type = &get<SPIRType>(type->parent_type);
else if (type->basetype == SPIRType::Struct)
{
if (!c || c->specialization)
{
SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable using pull-model "
"interpolation. This is currently unsupported.");
}
type = &get<SPIRType>(type->member_types[c->scalar()]);
}

if (!has_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex) &&
i - 3 == var_type.array.size())
continue;

interface_index += c->scalar();
if (c)
{
interface_index += c->scalar();
}
else
{
const auto indexVar = maybe_get_backing_variable(ops[i]);
if (indexVar)
{
set_extended_decoration(ops[1], SPIRVCrossDecorationInterpolantIndexVariable, indexVar->self);
}
}
}
// Save this to the access chain itself so we can recover it later when calling an interpolation function.
set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, interface_index);
Expand Down Expand Up @@ -10730,10 +10842,27 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
else
component = join(".", index_to_swizzle(c->scalar()));
}
emit_op(result_type, id,
join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
".interpolate_at_centroid()", component),
should_forward(args[0]));
const auto var = maybe_get_backing_variable(args[0]);
if (var)
{
auto array_index = std::to_string(interface_index);
if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantIndexVariable))
{
const auto indexVarId = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantIndexVariable);
array_index = to_name(indexVarId);
}
// There should only be one copy for centroid interpolation
const auto interpOpCopyId = pull_model_inputs.at(var->self).ops.cbegin()->first.new_var_id;
emit_op(result_type, id, join(to_name(interpOpCopyId, true), "[", array_index, "]", component),
should_forward(args[0]));
}
else
{
emit_op(result_type, id,
join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
".interpolate_at_centroid()", component),
should_forward(args[0]));
}
break;
}

Expand All @@ -10750,10 +10879,38 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
else
component = join(".", index_to_swizzle(c->scalar()));
}
emit_op(result_type, id,
join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
".interpolate_at_sample(", to_expression(args[1]), ")", component),
should_forward(args[0]) && should_forward(args[1]));
const auto var = maybe_get_backing_variable(args[0]);
if (var)
{
const auto &interpOps = pull_model_inputs.at(var->self).ops;
const auto interpOpCopyId = [&interpOps, &args]() -> int
{
for (const auto &op : interpOps)
{
if (op.first.offset_arg == args[1])
{
return op.first.new_var_id;
}
}
// TODO Shouldn't get here
return 0;
}();
auto array_index = std::to_string(interface_index);
if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantIndexVariable))
{
const auto indexVarId = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantIndexVariable);
array_index = to_name(indexVarId);
}
emit_op(result_type, id, join(to_name(interpOpCopyId, true), "[", array_index, "]", component),
should_forward(args[0]));
}
else
{
emit_op(result_type, id,
join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
".interpolate_at_sample(", to_expression(args[1]), ")", component),
should_forward(args[0]) && should_forward(args[1]));
}
break;
}

Expand All @@ -10770,13 +10927,41 @@ void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop,
else
component = join(".", index_to_swizzle(c->scalar()));
}
// Like Direct3D, Metal puts the (0, 0) at the upper-left corner, not the center as SPIR-V and GLSL do.
// Offset the offset by (1/2 - 1/16), or 0.4375, to compensate for this.
// It has to be (1/2 - 1/16) and not 1/2, or several CTS tests subtly break on Intel.
emit_op(result_type, id,
join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
".interpolate_at_offset(", to_expression(args[1]), " + 0.4375)", component),
should_forward(args[0]) && should_forward(args[1]));
const auto var = maybe_get_backing_variable(args[0]);
if (var)
{
auto array_index = std::to_string(interface_index);
if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantIndexVariable))
{
const auto indexVarId = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantIndexVariable);
array_index = to_name(indexVarId);
}
const auto &interpOps = pull_model_inputs.at(var->self).ops;
const auto interpOpCopyId = [&interpOps, &args]() -> int
{
for (const auto &op : interpOps)
{
if (op.first.offset_arg == args[1])
{
return op.first.new_var_id;
}
}
// TODO Shouldn't get here
return 0;
}();
emit_op(result_type, id, join(to_name(interpOpCopyId, true), "[", array_index, "]", component),
should_forward(args[0]));
}
else
{
// Like Direct3D, Metal puts the (0, 0) at the upper-left corner, not the center as SPIR-V and GLSL do.
// Offset the offset by (1/2 - 1/16), or 0.4375, to compensate for this.
// It has to be (1/2 - 1/16) and not 1/2, or several CTS tests subtly break on Intel.
emit_op(result_type, id,
join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
".interpolate_at_offset(", to_expression(args[1]), " + 0.4375)", component),
should_forward(args[0]) && should_forward(args[1]));
}
break;
}

Expand Down Expand Up @@ -15803,6 +15988,22 @@ std::string CompilerMSL::variable_decl(const SPIRType &type, const std::string &
return CompilerGLSL::variable_decl(type, name, id);
}

std::string CompilerMSL::variable_decl(const SPIRVariable &variable)
{
auto res = CompilerGLSL::variable_decl(variable);

if (pull_model_inputs.count(variable.self))
{
auto &interpOps = pull_model_inputs.at(variable.self).ops;
for (const auto &op : interpOps)
{
const auto interp_var = get<SPIRVariable>(op.first.new_var_id);
res += ";\n " + CompilerGLSL::variable_decl(interp_var);
}
}
return res;
}

std::string CompilerMSL::sampler_type(const SPIRType &type, uint32_t id, bool member)
{
auto *var = maybe_get<SPIRVariable>(id);
Expand Down Expand Up @@ -17438,11 +17639,22 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
{
auto op_450 = static_cast<GLSLstd450>(args[3]);
std::string op_name;
switch (op_450)
{
case GLSLstd450InterpolateAtCentroid:
op_name = "centroid";
break;
case GLSLstd450InterpolateAtSample:
op_name = "sample";
break;
case GLSLstd450InterpolateAtOffset:
op_name = "offset";
break;
default:
break;
}
if (!op_name.empty())
{
if (!compiler.msl_options.supports_msl_version(2, 3))
SPIRV_CROSS_THROW("Pull-model interpolation requires MSL 2.3.");
Expand All @@ -17451,7 +17663,31 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
auto *var = compiler.maybe_get_backing_variable(args[4]);
if (var)
{
compiler.pull_model_inputs.insert(var->self);
auto &interpOps = compiler.pull_model_inputs[var->self]; // get or create the op set
const auto next_id = compiler.ir.increase_bound_by(1);
auto &interpOpsRefs = [&]() -> std::unordered_set<uint32_t> &
{
if (op_450 == GLSLstd450InterpolateAtOffset || op_450 == GLSLstd450InterpolateAtSample)
{
return interpOps.ops[PullModelOp{ op_450, args[5], next_id }];
}
else
{
return interpOps.ops[PullModelOp{ op_450, 0, next_id }];
}
}();
interpOpsRefs.insert(args[4]);

compiler.set<SPIRVariable>(next_id, var->basetype, StorageClassFunction, 0, var->self);
auto &meta = compiler.ir.meta[next_id];
meta = compiler.ir.meta[var->self];
meta.decoration.alias += "_" + op_name + "_" + std::to_string(next_id);

//auto &entry_func = compiler.get<SPIRFunction>(compiler.ir.default_entry_point);
//// TODO The declaration of this local variable needs to happen when var gets declared.
//// For now, this is hacked in CompilerMSL::variable_decl.
//entry_func.local_variables.insert(entry_func.local_variables.begin(), next_id);

auto &var_type = compiler.get_variable_element_type(*var);
// In addition, if this variable has a 'Sample' decoration, we need the sample ID
// in order to do default interpolation.
Expand All @@ -17474,9 +17710,6 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui
}
break;
}
default:
break;
}
}
break;
}
Expand Down
Loading
Loading