Skip to content

Commit

Permalink
spirv-val: Add initial SPV_EXT_mesh_shader validation (#4924)
Browse files Browse the repository at this point in the history
* Move TaskEXT check to OpEmitMeshTasksEXT

* Add MeshNV for Execution Model alias
  • Loading branch information
sjfricke authored Sep 23, 2022
1 parent 265b455 commit b53d7a8
Show file tree
Hide file tree
Showing 14 changed files with 788 additions and 11 deletions.
1 change: 1 addition & 0 deletions Android.mk
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ SPVTOOLS_SRC_FILES := \
source/val/validate_instruction.cpp \
source/val/validate_memory.cpp \
source/val/validate_memory_semantics.cpp \
source/val/validate_mesh_shading.cpp \
source/val/validate_misc.cpp \
source/val/validate_mode_setting.cpp \
source/val/validate_layout.cpp \
Expand Down
1 change: 1 addition & 0 deletions BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ static_library("spvtools_val") {
"source/val/validate_memory.cpp",
"source/val/validate_memory_semantics.cpp",
"source/val/validate_memory_semantics.h",
"source/val/validate_mesh_shading.h",
"source/val/validate_misc.cpp",
"source/val/validate_mode_setting.cpp",
"source/val/validate_non_uniform.cpp",
Expand Down
1 change: 1 addition & 0 deletions source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ set(SPIRV_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_logicals.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_memory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_memory_semantics.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_mesh_shading.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_misc.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_mode_setting.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_non_uniform.cpp
Expand Down
14 changes: 14 additions & 0 deletions source/val/validate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
return error;
}

bool has_mask_task_nv = false;
bool has_mask_task_ext = false;
std::vector<Instruction*> visited_entry_points;
for (auto& instruction : vstate->ordered_instructions()) {
{
Expand Down Expand Up @@ -247,6 +249,11 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
}
}
visited_entry_points.push_back(inst);

has_mask_task_nv |= (execution_model == SpvExecutionModelTaskNV ||
execution_model == SpvExecutionModelMeshNV);
has_mask_task_ext |= (execution_model == SpvExecutionModelTaskEXT ||
execution_model == SpvExecutionModelMeshEXT);
}
if (inst->opcode() == SpvOpFunctionCall) {
if (!vstate->in_function_body()) {
Expand Down Expand Up @@ -298,6 +305,12 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
<< "Missing required OpSamplerImageAddressingModeNV instruction.";

if (has_mask_task_ext && has_mask_task_nv)
return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
<< vstate->VkErrorID(7102)
<< "Module can't mix MeshEXT/TaskEXT with MeshNV/TaskNV Execution "
"Model.";

// Catch undefined forward references before performing further checks.
if (auto error = ValidateForwardDecls(*vstate)) return error;

Expand Down Expand Up @@ -352,6 +365,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
if (auto error = LiteralsPass(*vstate, &instruction)) return error;
if (auto error = RayQueryPass(*vstate, &instruction)) return error;
if (auto error = RayTracingPass(*vstate, &instruction)) return error;
if (auto error = MeshShadingPass(*vstate, &instruction)) return error;
}

// Validate the preconditions involving adjacent instructions. e.g. SpvOpPhi
Expand Down
3 changes: 3 additions & 0 deletions source/val/validate.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst);
/// Validates correctness of ray tracing instructions.
spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst);

/// Validates correctness of mesh shading instructions.
spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst);

/// Calculates the reachability of basic blocks.
void ReachabilityPass(ValidationState_t& _);

Expand Down
6 changes: 1 addition & 5 deletions source/val/validate_cfg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,7 @@ spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst) {
case SpvOpTerminateRayKHR:
case SpvOpEmitMeshTasksEXT:
_.current_function().RegisterBlockEnd(std::vector<uint32_t>());
// Ops with dedicated passes check for the Execution Model there
if (opcode == SpvOpKill) {
_.current_function().RegisterExecutionModelLimitation(
SpvExecutionModelFragment,
Expand All @@ -1088,11 +1089,6 @@ spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst) {
SpvExecutionModelAnyHitKHR,
"OpTerminateRayKHR requires AnyHitKHR execution model");
}
if (opcode == SpvOpEmitMeshTasksEXT) {
_.current_function().RegisterExecutionModelLimitation(
SpvExecutionModelTaskEXT,
"OpEmitMeshTasksEXT requires TaskEXT execution model");
}

break;
default:
Expand Down
13 changes: 13 additions & 0 deletions source/val/validate_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,19 @@ spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
}
}

if (inst->operands().size() > 3) {
if (storage_class == SpvStorageClassTaskPayloadWorkgroupEXT) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpVariable, <id> '" << _.getIdName(inst->id())
<< "', initializer are not allowed for TaskPayloadWorkgroupEXT";
}
if (storage_class == SpvStorageClassInput) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpVariable, <id> '" << _.getIdName(inst->id())
<< "', initializer are not allowed for Input";
}
}

if (storage_class == SpvStorageClassPhysicalStorageBuffer) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "PhysicalStorageBuffer must not be used with OpVariable.";
Expand Down
123 changes: 123 additions & 0 deletions source/val/validate_mesh_shading.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) 2022 The Khronos Group Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Validates ray query instructions from SPV_KHR_ray_query

#include "source/opcode.h"
#include "source/val/instruction.h"
#include "source/val/validate.h"
#include "source/val/validation_state.h"

namespace spvtools {
namespace val {

spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst) {
const SpvOp opcode = inst->opcode();
switch (opcode) {
case SpvOpEmitMeshTasksEXT: {
_.function(inst->function()->id())
->RegisterExecutionModelLimitation(
[](SpvExecutionModel model, std::string* message) {
if (model != SpvExecutionModelTaskEXT) {
if (message) {
*message =
"OpEmitMeshTasksEXT requires TaskEXT execution model";
}
return false;
}
return true;
});

const uint32_t group_count_x = _.GetOperandTypeId(inst, 0);
if (!_.IsUnsignedIntScalarType(group_count_x) ||
_.GetBitWidth(group_count_x) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Group Count X must be a 32-bit unsigned int scalar";
}

const uint32_t group_count_y = _.GetOperandTypeId(inst, 1);
if (!_.IsUnsignedIntScalarType(group_count_y) ||
_.GetBitWidth(group_count_y) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Group Count Y must be a 32-bit unsigned int scalar";
}

const uint32_t group_count_z = _.GetOperandTypeId(inst, 2);
if (!_.IsUnsignedIntScalarType(group_count_z) ||
_.GetBitWidth(group_count_z) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Group Count Z must be a 32-bit unsigned int scalar";
}

if (inst->operands().size() == 4) {
const auto payload = _.FindDef(inst->GetOperandAs<uint32_t>(3));
if (payload->opcode() != SpvOpVariable) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Payload must be the result of a OpVariable";
}
if (SpvStorageClass(payload->GetOperandAs<uint32_t>(2)) !=
SpvStorageClassTaskPayloadWorkgroupEXT) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Payload OpVariable must have a storage class of "
"TaskPayloadWorkgroupEXT";
}
}
break;
}

case SpvOpSetMeshOutputsEXT: {
_.function(inst->function()->id())
->RegisterExecutionModelLimitation(
[](SpvExecutionModel model, std::string* message) {
if (model != SpvExecutionModelMeshEXT) {
if (message) {
*message =
"OpSetMeshOutputsEXT requires MeshEXT execution model";
}
return false;
}
return true;
});

const uint32_t vertex_count = _.GetOperandTypeId(inst, 0);
if (!_.IsUnsignedIntScalarType(vertex_count) ||
_.GetBitWidth(vertex_count) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Vertex Count must be a 32-bit unsigned int scalar";
}

const uint32_t primitive_count = _.GetOperandTypeId(inst, 1);
if (!_.IsUnsignedIntScalarType(primitive_count) ||
_.GetBitWidth(primitive_count) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Primitive Count must be a 32-bit unsigned int scalar";
}

break;
}

case SpvOpWritePackedPrimitiveIndices4x8NV: {
// No validation rules (for the moment).
break;
}

default:
break;
}

return SPV_SUCCESS;
}

} // namespace val
} // namespace spvtools
47 changes: 47 additions & 0 deletions source/val/validate_mode_setting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,39 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
"OutputTriangleStrip execution modes.";
}
break;
case SpvExecutionModelMeshEXT:
if (!execution_modes ||
1 != std::count_if(execution_modes->begin(), execution_modes->end(),
[](const SpvExecutionMode& mode) {
switch (mode) {
case SpvExecutionModeOutputPoints:
case SpvExecutionModeOutputLinesEXT:
case SpvExecutionModeOutputTrianglesEXT:
return true;
default:
return false;
}
})) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "MeshEXT execution model entry points must specify exactly "
"one of OutputPoints, OutputLinesEXT, or "
"OutputTrianglesEXT Execution Modes.";
} else if (2 != std::count_if(
execution_modes->begin(), execution_modes->end(),
[](const SpvExecutionMode& mode) {
switch (mode) {
case SpvExecutionModeOutputPrimitivesEXT:
case SpvExecutionModeOutputVertices:
return true;
default:
return false;
}
})) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "MeshEXT execution model entry points must specify both "
"OutputPrimitivesEXT and OutputVertices Execution Modes.";
}
break;
default:
break;
}
Expand Down Expand Up @@ -443,6 +476,20 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
}
}
break;
case SpvExecutionModeOutputLinesEXT:
case SpvExecutionModeOutputTrianglesEXT:
case SpvExecutionModeOutputPrimitivesEXT:
if (!std::all_of(models->begin(), models->end(),
[](const SpvExecutionModel& model) {
return (model == SpvExecutionModelMeshEXT ||
model == SpvExecutionModelMeshNV);
})) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Execution mode can only be used with the MeshEXT or MeshNV "
"execution "
"model.";
}
break;
case SpvExecutionModePixelCenterInteger:
case SpvExecutionModeOriginUpperLeft:
case SpvExecutionModeOriginLowerLeft:
Expand Down
11 changes: 7 additions & 4 deletions source/val/validate_ray_tracing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
if (payload->opcode() != SpvOpVariable) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Payload must be the result of a OpVariable";
} else if (payload->word(3) != SpvStorageClassRayPayloadKHR &&
payload->word(3) != SpvStorageClassIncomingRayPayloadKHR) {
} else if (payload->GetOperandAs<uint32_t>(2) !=
SpvStorageClassRayPayloadKHR &&
payload->GetOperandAs<uint32_t>(2) !=
SpvStorageClassIncomingRayPayloadKHR) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Payload must have storage class RayPayloadKHR or "
"IncomingRayPayloadKHR";
Expand Down Expand Up @@ -185,8 +187,9 @@ spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
if (callable_data->opcode() != SpvOpVariable) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Callable Data must be the result of a OpVariable";
} else if (callable_data->word(3) != SpvStorageClassCallableDataKHR &&
callable_data->word(3) !=
} else if (callable_data->GetOperandAs<uint32_t>(2) !=
SpvStorageClassCallableDataKHR &&
callable_data->GetOperandAs<uint32_t>(2) !=
SpvStorageClassIncomingCallableDataKHR) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Callable Data must have storage class CallableDataKHR or "
Expand Down
17 changes: 17 additions & 0 deletions source/val/validation_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,21 @@ void ValidationState_t::RegisterStorageClassConsumer(
}
return true;
});
} else if (storage_class == SpvStorageClassTaskPayloadWorkgroupEXT) {
function(consumer->function()->id())
->RegisterExecutionModelLimitation(
[](SpvExecutionModel model, std::string* message) {
if (model != SpvExecutionModelTaskEXT &&
model != SpvExecutionModelMeshEXT) {
if (message) {
*message =
"TaskPayloadWorkgroupEXT Storage Class is limited to "
"TaskEXT and MeshKHR execution model";
}
return false;
}
return true;
});
}
}

Expand Down Expand Up @@ -2110,6 +2125,8 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
return VUID_WRAP(VUID-StandaloneSpirv-Uniform-06925);
case 6997:
return VUID_WRAP(VUID-StandaloneSpirv-SubgroupVoteKHR-06997);
case 7102:
return VUID_WRAP(VUID-StandaloneSpirv-MeshEXT-07102);
case 7320:
return VUID_WRAP(VUID-StandaloneSpirv-ExecutionModel-07320);
case 7290:
Expand Down
4 changes: 2 additions & 2 deletions test/val/val_id_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2091,9 +2091,9 @@ TEST_F(ValidateIdWithMessage, OpVariableGood) {
TEST_F(ValidateIdWithMessage, OpVariableInitializerConstantGood) {
std::string spirv = kGLSL450MemoryModel + R"(
%1 = OpTypeInt 32 0
%2 = OpTypePointer Input %1
%2 = OpTypePointer Output %1
%3 = OpConstant %1 42
%4 = OpVariable %2 Input %3)";
%4 = OpVariable %2 Output %3)";
CompileSuccessfully(spirv.c_str());
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
Expand Down
Loading

0 comments on commit b53d7a8

Please sign in to comment.