diff --git a/source/fuzz/fuzzer_pass_donate_modules.cpp b/source/fuzz/fuzzer_pass_donate_modules.cpp index 33813d2f25..0587a5069b 100644 --- a/source/fuzz/fuzzer_pass_donate_modules.cpp +++ b/source/fuzz/fuzzer_pass_donate_modules.cpp @@ -299,43 +299,41 @@ void FuzzerPassDonateModules::HandleTypesAndValues( } break; case SpvOpTypeFunction: { // It is not OK to have multiple function types that use identical ids - // for their return and parameter types. We thus first look for a - // matching function type in the recipient module and use the id of this - // type if a match is found. Otherwise we add a remapped version of the - // function type. - - // Build a sequence of types used as parameters for the function type. - std::vector parameter_types; - // We start iterating at 1 because 0 is the function's return type. - for (uint32_t index = 1; index < type_or_value.NumInOperands(); - index++) { - parameter_types.push_back(GetIRContext()->get_type_mgr()->GetType( - original_id_to_donated_id->at( - type_or_value.GetSingleWordInOperand(index)))); + // for their return and parameter types. We thus go through all + // existing function types to look for a match. We do not use the + // type manager here because we want to regard two function types that + // are structurally identical but that differ with respect to the + // actual ids used for pointer types as different. + // + // Example: + // + // %1 = OpTypeVoid + // %2 = OpTypeInt 32 0 + // %3 = OpTypePointer Function %2 + // %4 = OpTypePointer Function %2 + // %5 = OpTypeFunction %1 %3 + // %6 = OpTypeFunction %1 %4 + // + // We regard %5 and %6 as distinct function types here, even though + // they both have the form "uint32* -> void" + + std::vector return_and_parameter_types; + for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) { + return_and_parameter_types.push_back(original_id_to_donated_id->at( + type_or_value.GetSingleWordInOperand(i))); } - // Make a type object corresponding to the function type. - opt::analysis::Function function_type( - GetIRContext()->get_type_mgr()->GetType( - original_id_to_donated_id->at( - type_or_value.GetSingleWordInOperand(0))), - parameter_types); - - // Check whether a function type corresponding to this this type object - // is already declared by the module. - auto function_type_id = - GetIRContext()->get_type_mgr()->GetId(&function_type); - if (function_type_id) { - // A suitable existing function was found - use its id. - new_result_id = function_type_id; + uint32_t existing_function_id = fuzzerutil::FindFunctionType( + GetIRContext(), return_and_parameter_types); + if (existing_function_id) { + new_result_id = existing_function_id; } else { // No match was found, so add a remapped version of the function type // to the module, with a fresh id. new_result_id = GetFuzzerContext()->GetFreshId(); std::vector argument_type_ids; - for (uint32_t index = 1; index < type_or_value.NumInOperands(); - index++) { + for (uint32_t i = 1; i < type_or_value.NumInOperands(); i++) { argument_type_ids.push_back(original_id_to_donated_id->at( - type_or_value.GetSingleWordInOperand(index))); + type_or_value.GetSingleWordInOperand(i))); } ApplyTransformation(TransformationAddTypeFunction( new_result_id, diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp index 82d761cc0d..085246e7c9 100644 --- a/source/fuzz/fuzzer_util.cpp +++ b/source/fuzz/fuzzer_util.cpp @@ -258,33 +258,36 @@ uint32_t WalkCompositeTypeIndices( auto should_be_composite_type = context->get_def_use_mgr()->GetDef(sub_object_type_id); assert(should_be_composite_type && "The type should exist."); - if (SpvOpTypeArray == should_be_composite_type->opcode()) { - auto array_length = GetArraySize(*should_be_composite_type, context); - if (array_length == 0 || index >= array_length) { - return 0; + switch (should_be_composite_type->opcode()) { + case SpvOpTypeArray: { + auto array_length = GetArraySize(*should_be_composite_type, context); + if (array_length == 0 || index >= array_length) { + return 0; + } + sub_object_type_id = + should_be_composite_type->GetSingleWordInOperand(0); + break; } - sub_object_type_id = should_be_composite_type->GetSingleWordInOperand(0); - } else if (SpvOpTypeMatrix == should_be_composite_type->opcode()) { - auto matrix_column_count = - should_be_composite_type->GetSingleWordInOperand(1); - if (index >= matrix_column_count) { - return 0; + case SpvOpTypeMatrix: + case SpvOpTypeVector: { + auto count = should_be_composite_type->GetSingleWordInOperand(1); + if (index >= count) { + return 0; + } + sub_object_type_id = + should_be_composite_type->GetSingleWordInOperand(0); + break; } - sub_object_type_id = should_be_composite_type->GetSingleWordInOperand(0); - } else if (SpvOpTypeStruct == should_be_composite_type->opcode()) { - if (index >= GetNumberOfStructMembers(*should_be_composite_type)) { - return 0; + case SpvOpTypeStruct: { + if (index >= GetNumberOfStructMembers(*should_be_composite_type)) { + return 0; + } + sub_object_type_id = + should_be_composite_type->GetSingleWordInOperand(index); + break; } - sub_object_type_id = - should_be_composite_type->GetSingleWordInOperand(index); - } else if (SpvOpTypeVector == should_be_composite_type->opcode()) { - auto vector_length = should_be_composite_type->GetSingleWordInOperand(1); - if (index >= vector_length) { + default: return 0; - } - sub_object_type_id = should_be_composite_type->GetSingleWordInOperand(0); - } else { - return 0; } } return sub_object_type_id; @@ -347,6 +350,35 @@ bool IsMergeOrContinue(opt::IRContext* ir_context, uint32_t block_id) { return result; } +uint32_t FindFunctionType(opt::IRContext* ir_context, + const std::vector& type_ids) { + // Look through the existing types for a match. + for (auto& type_or_value : ir_context->types_values()) { + if (type_or_value.opcode() != SpvOpTypeFunction) { + // We are only interested in function types. + continue; + } + if (type_or_value.NumInOperands() != type_ids.size()) { + // Not a match: different numbers of arguments. + continue; + } + // Check whether the return type and argument types match. + bool input_operands_match = true; + for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) { + if (type_ids[i] != type_or_value.GetSingleWordInOperand(i)) { + input_operands_match = false; + break; + } + } + if (input_operands_match) { + // Everything matches. + return type_or_value.result_id(); + } + } + // No match was found. + return 0; +} + } // namespace fuzzerutil } // namespace fuzz diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h index cbd81cd987..f0a2953fd2 100644 --- a/source/fuzz/fuzzer_util.h +++ b/source/fuzz/fuzzer_util.h @@ -131,6 +131,12 @@ bool IsNonFunctionTypeId(opt::IRContext* ir_context, uint32_t id); // Returns true if and only if |block_id| is a merge block or continue target bool IsMergeOrContinue(opt::IRContext* ir_context, uint32_t block_id); +// Returns the result id of an instruction of the form: +// %id = OpTypeFunction |type_ids| +// or 0 if no such instruction exists. +uint32_t FindFunctionType(opt::IRContext* ir_context, + const std::vector& type_ids); + } // namespace fuzzerutil } // namespace fuzz diff --git a/source/fuzz/transformation.cpp b/source/fuzz/transformation.cpp index c7aae58777..8037af15e4 100644 --- a/source/fuzz/transformation.cpp +++ b/source/fuzz/transformation.cpp @@ -16,6 +16,7 @@ #include +#include "source/fuzz/fuzzer_util.h" #include "source/fuzz/transformation_add_constant_boolean.h" #include "source/fuzz/transformation_add_constant_composite.h" #include "source/fuzz/transformation_add_constant_scalar.h" @@ -159,5 +160,18 @@ std::unique_ptr Transformation::FromMessage( return nullptr; } +bool Transformation::CheckIdIsFreshAndNotUsedByThisTransformation( + uint32_t id, opt::IRContext* context, + std::set* ids_used_by_this_transformation) { + if (!fuzzerutil::IsFreshId(context, id)) { + return false; + } + if (ids_used_by_this_transformation->count(id) != 0) { + return false; + } + ids_used_by_this_transformation->insert(id); + return true; +} + } // namespace fuzz } // namespace spvtools diff --git a/source/fuzz/transformation.h b/source/fuzz/transformation.h index c6b852fd78..dbe803f35c 100644 --- a/source/fuzz/transformation.h +++ b/source/fuzz/transformation.h @@ -83,6 +83,15 @@ class Transformation { // representation of a transformation given by |message|. static std::unique_ptr FromMessage( const protobufs::Transformation& message); + + // Helper that returns true if and only if (a) |id| is a fresh id for the + // module, and (b) |id| is not in |ids_used_by_this_transformation|, a set of + // ids already known to be in use by a transformation. This is useful when + // checking id freshness for a transformation that uses many ids, all of which + // must be distinct. + static bool CheckIdIsFreshAndNotUsedByThisTransformation( + uint32_t id, opt::IRContext* context, + std::set* ids_used_by_this_transformation); }; } // namespace fuzz diff --git a/source/fuzz/transformation_outline_function.cpp b/source/fuzz/transformation_outline_function.cpp index b50b9c5fa8..1b308c4daf 100644 --- a/source/fuzz/transformation_outline_function.cpp +++ b/source/fuzz/transformation_outline_function.cpp @@ -368,20 +368,6 @@ protobufs::Transformation TransformationOutlineFunction::ToMessage() const { return result; } -bool TransformationOutlineFunction:: - CheckIdIsFreshAndNotUsedByThisTransformation( - uint32_t id, opt::IRContext* context, - std::set* ids_used_by_this_transformation) const { - if (!fuzzerutil::IsFreshId(context, id)) { - return false; - } - if (ids_used_by_this_transformation->count(id) != 0) { - return false; - } - ids_used_by_this_transformation->insert(id); - return true; -} - std::vector TransformationOutlineFunction::GetRegionInputIds( opt::IRContext* context, const std::set& region_set, opt::BasicBlock* region_exit_block) { @@ -540,15 +526,16 @@ TransformationOutlineFunction::PrepareFunctionPrototype( // not exist there cannot already be a function type with this struct as its // return type. if (region_output_ids.empty()) { + std::vector return_and_parameter_types; opt::analysis::Void void_type; return_type_id = context->get_type_mgr()->GetId(&void_type); - std::vector argument_types; + return_and_parameter_types.push_back(return_type_id); for (auto id : region_input_ids) { - argument_types.push_back(context->get_type_mgr()->GetType( - context->get_def_use_mgr()->GetDef(id)->type_id())); + return_and_parameter_types.push_back( + context->get_def_use_mgr()->GetDef(id)->type_id()); } - opt::analysis::Function function_type(&void_type, argument_types); - function_type_id = context->get_type_mgr()->GetId(&function_type); + function_type_id = + fuzzerutil::FindFunctionType(context, return_and_parameter_types); } // If no existing function type was found, we need to create one. diff --git a/source/fuzz/transformation_outline_function.h b/source/fuzz/transformation_outline_function.h index b59e66092b..43bdf3b85c 100644 --- a/source/fuzz/transformation_outline_function.h +++ b/source/fuzz/transformation_outline_function.h @@ -128,15 +128,6 @@ class TransformationOutlineFunction : public Transformation { opt::BasicBlock* region_exit_block); private: - // A helper method for the applicability check. Returns true if and only if - // |id| is (a) a fresh id for the module, and (b) an id that has not - // previously been subject to this check. We use this to check whether the - // ids given for the transformation are not only fresh but also different from - // one another. - bool CheckIdIsFreshAndNotUsedByThisTransformation( - uint32_t id, opt::IRContext* context, - std::set* ids_used_by_this_transformation) const; - // Ensures that the module's id bound is at least the maximum of any fresh id // associated with the transformation. void UpdateModuleIdBoundForFreshIds( diff --git a/test/fuzz/fuzzer_pass_donate_modules_test.cpp b/test/fuzz/fuzzer_pass_donate_modules_test.cpp index 988b675ba2..7342dd426a 100644 --- a/test/fuzz/fuzzer_pass_donate_modules_test.cpp +++ b/test/fuzz/fuzzer_pass_donate_modules_test.cpp @@ -438,6 +438,58 @@ TEST(FuzzerPassDonateModulesTest, DonationWithInputAndOutputVariables) { ASSERT_TRUE(IsEqual(env, after_transformation, recipient_context.get())); } +TEST(FuzzerPassDonateModulesTest, DonateFunctionTypeWithDifferentPointers) { + std::string recipient_and_donor_shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 0 + %7 = OpTypePointer Function %6 + %8 = OpTypeFunction %2 %7 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %9 = OpVariable %7 Function + %10 = OpFunctionCall %2 %11 %9 + OpReturn + OpFunctionEnd + %11 = OpFunction %2 None %8 + %12 = OpFunctionParameter %7 + %13 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto recipient_context = BuildModule( + env, consumer, recipient_and_donor_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, recipient_context.get())); + + const auto donor_context = BuildModule( + env, consumer, recipient_and_donor_shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, donor_context.get())); + + FactManager fact_manager; + + FuzzerContext fuzzer_context(MakeUnique(0).get(), 100); + protobufs::TransformationSequence transformation_sequence; + + FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), &fact_manager, + &fuzzer_context, &transformation_sequence, + {}); + + fuzzer_pass.DonateSingleModule(donor_context.get()); + + // We just check that the result is valid. Checking to what it should be + // exactly equal to would be very fragile. + ASSERT_TRUE(IsValid(env, recipient_context.get())); +} + TEST(FuzzerPassDonateModulesTest, Miscellaneous1) { std::string recipient_shader = R"( OpCapability Shader diff --git a/test/fuzz/transformation_outline_function_test.cpp b/test/fuzz/transformation_outline_function_test.cpp index 4f828b6da4..5cd1437675 100644 --- a/test/fuzz/transformation_outline_function_test.cpp +++ b/test/fuzz/transformation_outline_function_test.cpp @@ -2040,6 +2040,91 @@ TEST(TransformationOutlineFunctionTest, Miscellaneous3) { ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); } +TEST(TransformationOutlineFunctionTest, Miscellaneous4) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %6 "main" + OpExecutionMode %6 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %21 = OpTypeBool + %100 = OpTypeInt 32 0 + %101 = OpTypePointer Function %100 + %102 = OpTypePointer Function %100 + %103 = OpTypeFunction %2 %101 + %6 = OpFunction %2 None %3 + %7 = OpLabel + %104 = OpVariable %102 Function + OpBranch %80 + %80 = OpLabel + %105 = OpLoad %100 %104 + OpBranch %106 + %106 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation( + /*entry_block*/ 80, + /*exit_block*/ 106, + /*new_function_struct_return_type_id*/ 300, + /*new_function_type_id*/ 301, + /*new_function_id*/ 302, + /*new_function_region_entry_block*/ 304, + /*new_caller_result_id*/ 305, + /*new_callee_result_id*/ 306, + /*input_id_to_fresh_id*/ {{104, 307}}, + /*output_id_to_fresh_id*/ {}); + + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + + std::string after_transformation = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %6 "main" + OpExecutionMode %6 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %21 = OpTypeBool + %100 = OpTypeInt 32 0 + %101 = OpTypePointer Function %100 + %102 = OpTypePointer Function %100 + %103 = OpTypeFunction %2 %101 + %301 = OpTypeFunction %2 %102 + %6 = OpFunction %2 None %3 + %7 = OpLabel + %104 = OpVariable %102 Function + OpBranch %80 + %80 = OpLabel + %305 = OpFunctionCall %2 %302 %104 + OpReturn + OpFunctionEnd + %302 = OpFunction %2 None %301 + %307 = OpFunctionParameter %102 + %304 = OpLabel + %105 = OpLoad %100 %307 + OpBranch %106 + %106 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); +} + } // namespace } // namespace fuzz } // namespace spvtools