diff --git a/source/fuzz/fact_manager.cpp b/source/fuzz/fact_manager.cpp index 5c0814aab4..9672653a10 100644 --- a/source/fuzz/fact_manager.cpp +++ b/source/fuzz/fact_manager.cpp @@ -801,7 +801,7 @@ bool FactManager::DataSynonymFacts::IsSynonymous( //============================== //============================== -// Dead id facts +// Dead block facts // The purpose of this class is to group the fields and data used to represent // facts about data blocks. @@ -829,10 +829,41 @@ bool FactManager::DeadBlockFacts::BlockIsDead(uint32_t block_id) const { // End of dead block facts //============================== +//============================== +// Livesafe function facts + +// The purpose of this class is to group the fields and data used to represent +// facts about livesafe functions. +class FactManager::LivesafeFunctionFacts { + public: + // See method in FactManager which delegates to this method. + void AddFact(const protobufs::FactFunctionIsLivesafe& fact); + + // See method in FactManager which delegates to this method. + bool FunctionIsLivesafe(uint32_t function_id) const; + + private: + std::set livesafe_function_ids_; +}; + +void FactManager::LivesafeFunctionFacts::AddFact( + const protobufs::FactFunctionIsLivesafe& fact) { + livesafe_function_ids_.insert(fact.function_id()); +} + +bool FactManager::LivesafeFunctionFacts::FunctionIsLivesafe( + uint32_t function_id) const { + return livesafe_function_ids_.count(function_id) != 0; +} + +// End of livesafe function facts +//============================== + FactManager::FactManager() : uniform_constant_facts_(MakeUnique()), data_synonym_facts_(MakeUnique()), - dead_block_facts_(MakeUnique()) {} + dead_block_facts_(MakeUnique()), + livesafe_function_facts_(MakeUnique()) {} FactManager::~FactManager() = default; @@ -860,6 +891,9 @@ bool FactManager::AddFact(const fuzz::protobufs::Fact& fact, case protobufs::Fact::kBlockIsDeadFact: dead_block_facts_->AddFact(fact.block_is_dead_fact()); return true; + case protobufs::Fact::kFunctionIsLivesafeFact: + livesafe_function_facts_->AddFact(fact.function_is_livesafe_fact()); + return true; default: assert(false && "Unknown fact type."); return false; @@ -941,5 +975,15 @@ void FactManager::AddFactBlockIsDead(uint32_t block_id) { dead_block_facts_->AddFact(fact); } +bool FactManager::FunctionIsLivesafe(uint32_t function_id) const { + return livesafe_function_facts_->FunctionIsLivesafe(function_id); +} + +void FactManager::AddFactFunctionIsLivesafe(uint32_t function_id) { + protobufs::FactFunctionIsLivesafe fact; + fact.set_function_id(function_id); + livesafe_function_facts_->AddFact(fact); +} + } // namespace fuzz } // namespace spvtools diff --git a/source/fuzz/fact_manager.h b/source/fuzz/fact_manager.h index f035fcc052..20f270154f 100644 --- a/source/fuzz/fact_manager.h +++ b/source/fuzz/fact_manager.h @@ -61,6 +61,9 @@ class FactManager { // Records the fact that |block_id| is dead. void AddFactBlockIsDead(uint32_t block_id); + // Records the fact that |function_id| is livesafe. + void AddFactFunctionIsLivesafe(uint32_t function_id); + // The fact manager is responsible for managing a few distinct categories of // facts. In principle there could be different fact managers for each kind // of fact, but in practice providing one 'go to' place for facts is @@ -143,6 +146,16 @@ class FactManager { // End of dead block facts //============================== + //============================== + // Querying facts about livesafe function + + // Returns true if and ony if |function_id| is the id of a function known + // to be livesafe. + bool FunctionIsLivesafe(uint32_t function_id) const; + + // End of dead block facts + //============================== + private: // For each distinct kind of fact to be managed, we use a separate opaque // class type. @@ -159,6 +172,11 @@ class FactManager { class DeadBlockFacts; // Opaque class for management of dead block facts. std::unique_ptr dead_block_facts_; // Unique pointer to internal data. + + class LivesafeFunctionFacts; // Opaque class for management of livesafe + // function facts. + std::unique_ptr + livesafe_function_facts_; // Unique pointer to internal data. }; } // namespace fuzz diff --git a/source/fuzz/fuzzer_context.cpp b/source/fuzz/fuzzer_context.cpp index 559aecb356..afffcf54c8 100644 --- a/source/fuzz/fuzzer_context.cpp +++ b/source/fuzz/fuzzer_context.cpp @@ -38,6 +38,7 @@ const std::pair kChanceOfAdjustingSelectionControl = {20, const std::pair kChanceOfConstructingComposite = {20, 50}; const std::pair kChanceOfCopyingObject = {20, 50}; const std::pair kChanceOfDonatingAdditionalModule = {5, 50}; +const std::pair kChanceOfMakingDonorLivesafe = {40, 60}; const std::pair kChanceOfMergingBlocks = {20, 95}; const std::pair kChanceOfMovingBlockDown = {20, 50}; const std::pair kChanceOfObfuscatingConstant = {10, 90}; @@ -49,6 +50,7 @@ const std::pair kChanceOfSplittingBlock = {40, 95}; // Keep them in alphabetical order. const uint32_t kDefaultMaxLoopControlPartialCount = 100; const uint32_t kDefaultMaxLoopControlPeelCount = 100; +const uint32_t kDefaultMaxLoopLimit = 20; // Default functions for controlling how deep to go during recursive // generation/transformation. Keep them in alphabetical order. @@ -89,6 +91,8 @@ FuzzerContext::FuzzerContext(RandomGenerator* random_generator, chance_of_copying_object_ = ChooseBetweenMinAndMax(kChanceOfCopyingObject); chance_of_donating_additional_module_ = ChooseBetweenMinAndMax(kChanceOfDonatingAdditionalModule); + chance_of_making_donor_livesafe_ = + ChooseBetweenMinAndMax(kChanceOfMakingDonorLivesafe); chance_of_merging_blocks_ = ChooseBetweenMinAndMax(kChanceOfMergingBlocks); chance_of_moving_block_down_ = ChooseBetweenMinAndMax(kChanceOfMovingBlockDown); @@ -101,6 +105,7 @@ FuzzerContext::FuzzerContext(RandomGenerator* random_generator, chance_of_splitting_block_ = ChooseBetweenMinAndMax(kChanceOfSplittingBlock); max_loop_control_partial_count_ = kDefaultMaxLoopControlPartialCount; max_loop_control_peel_count_ = kDefaultMaxLoopControlPeelCount; + max_loop_limit_ = kDefaultMaxLoopLimit; } FuzzerContext::~FuzzerContext() = default; diff --git a/source/fuzz/fuzzer_context.h b/source/fuzz/fuzzer_context.h index 7cfe15b731..cc4337093d 100644 --- a/source/fuzz/fuzzer_context.h +++ b/source/fuzz/fuzzer_context.h @@ -85,6 +85,9 @@ class FuzzerContext { uint32_t GetChanceOfDonatingAdditionalModule() { return chance_of_donating_additional_module_; } + uint32_t ChanceOfMakingDonorLivesafe() { + return chance_of_making_donor_livesafe_; + } uint32_t GetChanceOfMergingBlocks() { return chance_of_merging_blocks_; } uint32_t GetChanceOfMovingBlockDown() { return chance_of_moving_block_down_; } uint32_t GetChanceOfObfuscatingConstant() { @@ -103,6 +106,9 @@ class FuzzerContext { uint32_t GetRandomLoopControlPartialCount() { return random_generator_->RandomUint32(max_loop_control_partial_count_); } + uint32_t GetRandomLoopLimit() { + return random_generator_->RandomUint32(max_loop_limit_); + } // Functions to control how deeply to recurse. // Keep them in alphabetical order. @@ -129,6 +135,7 @@ class FuzzerContext { uint32_t chance_of_constructing_composite_; uint32_t chance_of_copying_object_; uint32_t chance_of_donating_additional_module_; + uint32_t chance_of_making_donor_livesafe_; uint32_t chance_of_merging_blocks_; uint32_t chance_of_moving_block_down_; uint32_t chance_of_obfuscating_constant_; @@ -141,6 +148,7 @@ class FuzzerContext { // Keep them in alphabetical order. uint32_t max_loop_control_partial_count_; uint32_t max_loop_control_peel_count_; + uint32_t max_loop_limit_; // Functions to determine with what probability to go deeper when generating // or mutating constructs recursively. diff --git a/source/fuzz/fuzzer_pass.cpp b/source/fuzz/fuzzer_pass.cpp index 1da53f4262..9d891a59d7 100644 --- a/source/fuzz/fuzzer_pass.cpp +++ b/source/fuzz/fuzzer_pass.cpp @@ -15,6 +15,11 @@ #include "source/fuzz/fuzzer_pass.h" #include "source/fuzz/instruction_descriptor.h" +#include "source/fuzz/transformation_add_constant_scalar.h" +#include "source/fuzz/transformation_add_global_undef.h" +#include "source/fuzz/transformation_add_type_boolean.h" +#include "source/fuzz/transformation_add_type_int.h" +#include "source/fuzz/transformation_add_type_pointer.h" namespace spvtools { namespace fuzz { @@ -128,5 +133,73 @@ void FuzzerPass::MaybeAddTransformationBeforeEachInstruction( } } +uint32_t FuzzerPass::FindOrCreateBoolType() { + opt::analysis::Bool bool_type; + auto existing_id = GetIRContext()->get_type_mgr()->GetId(&bool_type); + if (existing_id) { + return existing_id; + } + auto result = GetFuzzerContext()->GetFreshId(); + ApplyTransformation(TransformationAddTypeBoolean(result)); + return result; +} + +uint32_t FuzzerPass::FindOrCreate32BitIntegerType(bool is_signed) { + opt::analysis::Integer int_type(32, is_signed); + auto existing_id = GetIRContext()->get_type_mgr()->GetId(&int_type); + if (existing_id) { + return existing_id; + } + auto result = GetFuzzerContext()->GetFreshId(); + ApplyTransformation(TransformationAddTypeInt(result, 32, is_signed)); + return result; +} + +uint32_t FuzzerPass::FindOrCreatePointerTo32BitIntegerType( + bool is_signed, SpvStorageClass storage_class) { + auto uint32_type_id = FindOrCreate32BitIntegerType(is_signed); + opt::analysis::Pointer pointer_type( + GetIRContext()->get_type_mgr()->GetType(uint32_type_id), storage_class); + auto existing_id = GetIRContext()->get_type_mgr()->GetId(&pointer_type); + if (existing_id) { + return existing_id; + } + auto result = GetFuzzerContext()->GetFreshId(); + ApplyTransformation( + TransformationAddTypePointer(result, storage_class, uint32_type_id)); + return result; +} + +uint32_t FuzzerPass::FindOrCreate32BitIntegerConstant(uint32_t word, + bool is_signed) { + auto uint32_type_id = FindOrCreate32BitIntegerType(is_signed); + opt::analysis::IntConstant int_constant( + GetIRContext()->get_type_mgr()->GetType(uint32_type_id)->AsInteger(), + {word}); + auto existing_constant = + GetIRContext()->get_constant_mgr()->FindConstant(&int_constant); + if (existing_constant) { + return GetIRContext() + ->get_constant_mgr() + ->GetDefiningInstruction(existing_constant) + ->result_id(); + } + auto result = GetFuzzerContext()->GetFreshId(); + ApplyTransformation( + TransformationAddConstantScalar(result, uint32_type_id, {word})); + return result; +} + +uint32_t FuzzerPass::FindOrCreateGlobalUndef(uint32_t type_id) { + for (auto& inst : GetIRContext()->types_values()) { + if (inst.opcode() == SpvOpUndef && inst.type_id() == type_id) { + return inst.result_id(); + } + } + auto result = GetFuzzerContext()->GetFreshId(); + ApplyTransformation(TransformationAddGlobalUndef(result, type_id)); + return result; +} + } // namespace fuzz } // namespace spvtools diff --git a/source/fuzz/fuzzer_pass.h b/source/fuzz/fuzzer_pass.h index e3bf8e8d79..09f831f4b9 100644 --- a/source/fuzz/fuzzer_pass.h +++ b/source/fuzz/fuzzer_pass.h @@ -89,7 +89,7 @@ class FuzzerPass { const protobufs::InstructionDescriptor& instruction_descriptor)> maybe_apply_transformation); - // A generic helper for applying a transforamtion that should be appplicable + // A generic helper for applying a transformation that should be applicable // by construction, and adding it to the sequence of applied transformations. template void ApplyTransformation(const TransformationType& transformation) { @@ -99,6 +99,33 @@ class FuzzerPass { *GetTransformations()->add_transformation() = transformation.ToMessage(); } + // Returns the id of an OpTypeBool instruction. If such an instruction does + // not exist, a transformation is applied to add it. + uint32_t FindOrCreateBoolType(); + + // Returns the id of an OpTypeInt instruction, with width 32 and signedness + // specified by |is_signed|. If such an instruction does not exist, a + // transformation is applied to add it. + uint32_t FindOrCreate32BitIntegerType(bool is_signed); + + // Returns the id of an OpTypePointer instruction, with a 32-bit integer base + // type of signedness specified by |is_signed|. If the pointer type or + // required integer base type do not exist, transformations are applied to add + // them. + uint32_t FindOrCreatePointerTo32BitIntegerType(bool is_signed, + SpvStorageClass storage_class); + + // Returns the id of an OpConstant instruction, with 32-bit integer type of + // signedness specified by |is_signed|, with |word| as its value. If either + // the required integer type or the constant do not exist, transformations are + // applied to add them. + uint32_t FindOrCreate32BitIntegerConstant(uint32_t word, bool is_signed); + + // Returns the result id of an instruction of the form: + // %id = OpUndef %|type_id| + // If no such instruction exists, a transformation is applied to add it. + uint32_t FindOrCreateGlobalUndef(uint32_t type_id); + private: opt::IRContext* ir_context_; FactManager* fact_manager_; diff --git a/source/fuzz/fuzzer_pass_donate_modules.cpp b/source/fuzz/fuzzer_pass_donate_modules.cpp index 0587a5069b..b0b9d27ff8 100644 --- a/source/fuzz/fuzzer_pass_donate_modules.cpp +++ b/source/fuzz/fuzzer_pass_donate_modules.cpp @@ -62,13 +62,20 @@ void FuzzerPassDonateModules::Apply() { GetFuzzerContext()->RandomIndex(donor_suppliers_))(); assert(donor_ir_context != nullptr && "Supplying of donor failed"); // Donate the supplied module. - DonateSingleModule(donor_ir_context.get()); + // + // Randomly decide whether to make the module livesafe (see + // FactFunctionIsLivesafe); doing so allows it to be used for live code + // injection but restricts its behaviour to allow this, and means that its + // functions cannot be transformed as if they were arbitrary dead code. + bool make_livesafe = GetFuzzerContext()->ChoosePercentage( + GetFuzzerContext()->ChanceOfMakingDonorLivesafe()); + DonateSingleModule(donor_ir_context.get(), make_livesafe); } while (GetFuzzerContext()->ChoosePercentage( GetFuzzerContext()->GetChanceOfDonatingAdditionalModule())); } void FuzzerPassDonateModules::DonateSingleModule( - opt::IRContext* donor_ir_context) { + opt::IRContext* donor_ir_context, bool make_livesafe) { // The ids used by the donor module may very well clash with ids defined in // the recipient module. Furthermore, some instructions defined in the donor // module will be equivalent to instructions defined in the recipient module, @@ -91,7 +98,7 @@ void FuzzerPassDonateModules::DonateSingleModule( HandleExternalInstructionImports(donor_ir_context, &original_id_to_donated_id); HandleTypesAndValues(donor_ir_context, &original_id_to_donated_id); - HandleFunctions(donor_ir_context, &original_id_to_donated_id); + HandleFunctions(donor_ir_context, &original_id_to_donated_id, make_livesafe); // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3115) Handle some // kinds of decoration. @@ -420,7 +427,8 @@ void FuzzerPassDonateModules::HandleTypesAndValues( void FuzzerPassDonateModules::HandleFunctions( opt::IRContext* donor_ir_context, - std::map* original_id_to_donated_id) { + std::map* original_id_to_donated_id, + bool make_livesafe) { // Get the ids of functions in the donor module, topologically sorted // according to the donor's call graph. auto topological_order = @@ -506,7 +514,146 @@ void FuzzerPassDonateModules::HandleFunctions( : 0, input_operands)); }); - ApplyTransformation(TransformationAddFunction(donated_instructions)); + + if (make_livesafe) { + // Various types and constants must be in place for a function to be made + // live-safe. Add them if not already present. + FindOrCreateBoolType(); // Needed for comparisons + FindOrCreatePointerTo32BitIntegerType( + false, SpvStorageClassFunction); // Needed for adding loop limiters + FindOrCreate32BitIntegerConstant( + 0, false); // Needed for initializing loop limiters + FindOrCreate32BitIntegerConstant( + 1, false); // Needed for incrementing loop limiters + + // Get a fresh id for the variable that will be used as a loop limiter. + const uint32_t loop_limiter_variable_id = + GetFuzzerContext()->GetFreshId(); + // Choose a random loop limit, and add the required constant to the + // module if not already there. + const uint32_t loop_limit = FindOrCreate32BitIntegerConstant( + GetFuzzerContext()->GetRandomLoopLimit(), false); + + // Consider every loop header in the function to donate, and create a + // structure capturing the ids to be used for manipulating the loop + // limiter each time the loop is iterated. + std::vector loop_limiters; + for (auto& block : *function_to_donate) { + if (block.IsLoopHeader()) { + protobufs::LoopLimiterInfo loop_limiter; + // Grab the loop header's id, mapped to its donated value. + loop_limiter.set_loop_header_id( + original_id_to_donated_id->at(block.id())); + // Get fresh ids that will be used to load the loop limiter, increment + // it, compare it with the loop limit, and an id for a new block that + // will contain the loop's original terminator. + loop_limiter.set_load_id(GetFuzzerContext()->GetFreshId()); + loop_limiter.set_increment_id(GetFuzzerContext()->GetFreshId()); + loop_limiter.set_compare_id(GetFuzzerContext()->GetFreshId()); + loop_limiter.set_logical_op_id(GetFuzzerContext()->GetFreshId()); + loop_limiters.emplace_back(loop_limiter); + } + } + + // Consider every access chain in the function to donate, and create a + // structure containing the ids necessary to clamp the access chain + // indices to be in-bounds. + std::vector + access_chain_clamping_info; + for (auto& block : *function_to_donate) { + for (auto& inst : block) { + switch (inst.opcode()) { + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: { + protobufs::AccessChainClampingInfo clamping_info; + clamping_info.set_access_chain_id( + original_id_to_donated_id->at(inst.result_id())); + + auto base_object = donor_ir_context->get_def_use_mgr()->GetDef( + inst.GetSingleWordInOperand(0)); + assert(base_object && "The base object must exist."); + auto pointer_type = donor_ir_context->get_def_use_mgr()->GetDef( + base_object->type_id()); + assert(pointer_type && + pointer_type->opcode() == SpvOpTypePointer && + "The base object must have pointer type."); + + auto should_be_composite_type = + donor_ir_context->get_def_use_mgr()->GetDef( + pointer_type->GetSingleWordInOperand(1)); + + // Walk the access chain, creating fresh ids to facilitate + // clamping each index. For simplicity we do this for every + // index, even though constant indices will not end up being + // clamped. + for (uint32_t index = 1; index < inst.NumInOperands(); index++) { + auto compare_and_select_ids = + clamping_info.add_compare_and_select_ids(); + compare_and_select_ids->set_first( + GetFuzzerContext()->GetFreshId()); + compare_and_select_ids->set_second( + GetFuzzerContext()->GetFreshId()); + + // Get the bound for the component being indexed into. + uint32_t bound = + TransformationAddFunction::GetBoundForCompositeIndex( + donor_ir_context, *should_be_composite_type); + const uint32_t index_id = inst.GetSingleWordInOperand(index); + auto index_inst = + donor_ir_context->get_def_use_mgr()->GetDef(index_id); + auto index_type_inst = + donor_ir_context->get_def_use_mgr()->GetDef( + index_inst->type_id()); + assert(index_type_inst->opcode() == SpvOpTypeInt); + assert(index_type_inst->GetSingleWordInOperand(0) == 32); + opt::analysis::Integer* index_int_type = + donor_ir_context->get_type_mgr() + ->GetType(index_type_inst->result_id()) + ->AsInteger(); + if (index_inst->opcode() != SpvOpConstant) { + // We will have to clamp this index, so we need a constant + // whose value is one less than the bound, to compare + // against and to use as the clamped value. + FindOrCreate32BitIntegerConstant(bound - 1, + index_int_type->IsSigned()); + } + should_be_composite_type = + TransformationAddFunction::FollowCompositeIndex( + donor_ir_context, *should_be_composite_type, index_id); + } + access_chain_clamping_info.push_back(clamping_info); + break; + } + default: + break; + } + } + } + + // If the function contains OpKill or OpUnreachable instructions, and has + // non-void return type, then we need a value %v to use in order to turn + // these into instructions of the form OpReturn %v. + uint32_t kill_unreachable_return_value_id; + auto function_return_type_inst = + donor_ir_context->get_def_use_mgr()->GetDef( + function_to_donate->type_id()); + if (function_return_type_inst->opcode() == SpvOpTypeVoid) { + // The return type is void, so we don't need a return value. + kill_unreachable_return_value_id = 0; + } else { + // We do need a return value; we use OpUndef. + kill_unreachable_return_value_id = + FindOrCreateGlobalUndef(function_return_type_inst->type_id()); + } + // Add the function in a livesafe manner. + ApplyTransformation(TransformationAddFunction( + donated_instructions, loop_limiter_variable_id, loop_limit, + loop_limiters, kill_unreachable_return_value_id, + access_chain_clamping_info)); + } else { + // Add the function in a non-livesafe manner. + ApplyTransformation(TransformationAddFunction(donated_instructions)); + } } } diff --git a/source/fuzz/fuzzer_pass_donate_modules.h b/source/fuzz/fuzzer_pass_donate_modules.h index d719e878eb..ef529db707 100644 --- a/source/fuzz/fuzzer_pass_donate_modules.h +++ b/source/fuzz/fuzzer_pass_donate_modules.h @@ -38,8 +38,10 @@ class FuzzerPassDonateModules : public FuzzerPass { void Apply() override; // Donates the global declarations and functions of |donor_ir_context| into - // the fuzzer pass's IR context. - void DonateSingleModule(opt::IRContext* donor_ir_context); + // the fuzzer pass's IR context. |make_livesafe| dictates whether the + // functions of the donated module will be made livesafe (see + // FactFunctionIsLivesafe). + void DonateSingleModule(opt::IRContext* donor_ir_context, bool make_livesafe); private: // Adapts a storage class coming from a donor module so that it will work @@ -68,9 +70,12 @@ class FuzzerPassDonateModules : public FuzzerPass { // functions in |donor_ir_context|'s call graph in a reverse-topologically- // sorted order (leaves-to-root), adding each function to the recipient // module, rewritten to use fresh ids and using |original_id_to_donated_id| to - // remap ids. + // remap ids. The |make_livesafe| argument captures whether the functions in + // the module are required to be made livesafe before being added to the + // recipient. void HandleFunctions(opt::IRContext* donor_ir_context, - std::map* original_id_to_donated_id); + std::map* original_id_to_donated_id, + bool make_livesafe); // Returns the ids of all functions in |context| in a topological order in // relation to the call graph of |context|, which is assumed to be recursion- diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp index 085246e7c9..b2ace38f17 100644 --- a/source/fuzz/fuzzer_util.cpp +++ b/source/fuzz/fuzzer_util.cpp @@ -103,10 +103,10 @@ bool PhiIdsOkForNewEdge( } phi_index++; } - // Return false if not all of the ids for extending OpPhi instructions are - // needed. This might turn out to be stricter than necessary; perhaps it would - // be OK just to not use the ids in this case. - return phi_index == static_cast(phi_ids.size()); + // We allow some of the ids provided for extending OpPhi instructions to be + // unused. Their presence does no harm, and requiring a perfect match may + // make transformations less likely to cleanly apply. + return true; } uint32_t MaybeGetBoolConstantId(opt::IRContext* context, bool value) { @@ -158,13 +158,11 @@ void AddUnreachableEdgeAndUpdateOpPhis( break; } assert(phi_index < static_cast(phi_ids.size()) && - "There should be exactly one phi id per OpPhi instruction."); + "There should be at least one phi id per OpPhi instruction."); inst.AddOperand({SPV_OPERAND_TYPE_ID, {phi_ids[phi_index]}}); inst.AddOperand({SPV_OPERAND_TYPE_ID, {bb_from->id()}}); phi_index++; } - assert(phi_index == static_cast(phi_ids.size()) && - "There should be exactly one phi id per OpPhi instruction."); } } diff --git a/source/fuzz/protobufs/spvtoolsfuzz.proto b/source/fuzz/protobufs/spvtoolsfuzz.proto index 08f39867aa..2f37a7d967 100644 --- a/source/fuzz/protobufs/spvtoolsfuzz.proto +++ b/source/fuzz/protobufs/spvtoolsfuzz.proto @@ -167,6 +167,7 @@ message Fact { FactConstantUniform constant_uniform_fact = 1; FactDataSynonym data_synonym_fact = 2; FactBlockIsDead block_is_dead_fact = 3; + FactFunctionIsLivesafe function_is_livesafe_fact = 4; } } @@ -210,6 +211,77 @@ message FactBlockIsDead { uint32 block_id = 1; } +message FactFunctionIsLivesafe { + + // Records the fact that a function is guaranteed to be "livesafe", meaning + // that it will not make out-of-bounds accesses, does not contain reachable + // OpKill or OpUnreachable instructions, does not contain loops that will + // execute for large numbers of iterations, and only invokes other livesafe + // functions. + + uint32 function_id = 1; +} + +message AccessChainClampingInfo { + + // When making a function livesafe it is necessary to clamp the indices that + // occur as operands to access chain instructions so that they are guaranteed + // to be in bounds. This message type allows an access chain instruction to + // have an associated sequence of ids that are reserved for comparing an + // access chain index with a bound (e.g. an array size), and selecting + // between the access chain index (if it is within bounds) and the bound (if + // it is not). + // + // This allows turning an instruction of the form: + // + // %result = OpAccessChain %type %object ... %index ... + // + // into: + // + // %t1 = OpULessThanEqual %bool %index %bound_minus_one + // %t2 = OpSelect %int_type %t1 %index %bound_minus_one + // %result = OpAccessChain %type %object ... %t2 ... + + // The result id of an OpAccessChain or OpInBoundsAccessChain instruction. + uint32 access_chain_id = 1; + + // A series of pairs of fresh ids, one per access chain index, for the results + // of a compare instruction and a select instruction, serving the roles of %t1 + // and %t2 in the above example. + repeated UInt32Pair compare_and_select_ids = 2; + +} + +message LoopLimiterInfo { + + // Structure capturing the information required to manipulate a loop limiter + // at a loop header. + + // The header for the loop. + uint32 loop_header_id = 1; + + // A fresh id into which the loop limiter's current value can be loaded. + uint32 load_id = 2; + + // A fresh id that can be used to increment the loaded value by 1. + uint32 increment_id = 3; + + // A fresh id that can be used to compare the loaded value with the loop + // limit. + uint32 compare_id = 4; + + // A fresh id that can be used to compute the conjunction or disjunction of + // an original loop exit condition with |compare_id|, if the loop's back edge + // block can conditionally exit the loop. + uint32 logical_op_id = 5; + + // A sequence of ids suitable for extending OpPhi instructions of the loop + // merge block if it did not previously have an incoming edge from the loop + // back edge block. + repeated uint32 phi_id = 6; + +} + message TransformationSequence { repeated Transformation transformation = 1; } @@ -366,6 +438,33 @@ message TransformationAddFunction { // The series of instructions that comprise the function. repeated Instruction instruction = 1; + // True if and only if the given function should be made livesafe (see + // FactFunctionIsLivesafe for definition). + bool is_livesafe = 2; + + // Fresh id for a new variable that will serve as a "loop limiter" for the + // function; only relevant if |is_livesafe| holds. + uint32 loop_limiter_variable_id = 3; + + // Id of an existing unsigned integer constant providing the maximum value + // that the loop limiter can reach before the loop is broken from; only + // relevant if |is_livesafe| holds. + uint32 loop_limit_constant_id = 4; + + // Fresh ids for each loop in the function that allow the loop limiter to be + // manipulated; only relevant if |is_livesafe| holds. + repeated LoopLimiterInfo loop_limiter_info = 5; + + // Id of an existing global value with the same return type as the function + // that can be used to replace OpKill and OpReachable instructions with + // ReturnValue instructions. Ignored if the function has void return type. + uint32 kill_unreachable_return_value_id = 6; + + // A mapping (represented as a sequence) from every access chain result id in + // the function to the ids required to clamp its indices to ensure they are in + // bounds. + repeated AccessChainClampingInfo access_chain_clamping_info = 7; + } message TransformationAddGlobalUndef { diff --git a/source/fuzz/transformation_add_function.cpp b/source/fuzz/transformation_add_function.cpp index 5e53961bf1..8b8b2ddd50 100644 --- a/source/fuzz/transformation_add_function.cpp +++ b/source/fuzz/transformation_add_function.cpp @@ -29,11 +29,95 @@ TransformationAddFunction::TransformationAddFunction( for (auto& instruction : instructions) { *message_.add_instruction() = instruction; } + message_.set_is_livesafe(false); +} + +TransformationAddFunction::TransformationAddFunction( + const std::vector& instructions, + uint32_t loop_limiter_variable_id, uint32_t loop_limit_constant_id, + const std::vector& loop_limiters, + uint32_t kill_unreachable_return_value_id, + const std::vector& + access_chain_clampers) { + for (auto& instruction : instructions) { + *message_.add_instruction() = instruction; + } + message_.set_is_livesafe(true); + message_.set_loop_limiter_variable_id(loop_limiter_variable_id); + message_.set_loop_limit_constant_id(loop_limit_constant_id); + for (auto& loop_limiter : loop_limiters) { + *message_.add_loop_limiter_info() = loop_limiter; + } + message_.set_kill_unreachable_return_value_id( + kill_unreachable_return_value_id); + for (auto& access_clamper : access_chain_clampers) { + *message_.add_access_chain_clamping_info() = access_clamper; + } } bool TransformationAddFunction::IsApplicable( opt::IRContext* context, - const spvtools::fuzz::FactManager& /*unused*/) const { + const spvtools::fuzz::FactManager& fact_manager) const { + // This transformation may use a lot of ids, all of which need to be fresh + // and distinct. This set tracks them. + std::set ids_used_by_this_transformation; + + // Ensure that all result ids in the new function are fresh and distinct. + for (auto& instruction : message_.instruction()) { + if (instruction.result_id()) { + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + instruction.result_id(), context, + &ids_used_by_this_transformation)) { + return false; + } + } + } + + if (message_.is_livesafe()) { + // Ensure that all ids provided for making the function livesafe are fresh + // and distinct. + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + message_.loop_limiter_variable_id(), context, + &ids_used_by_this_transformation)) { + return false; + } + for (auto& loop_limiter_info : message_.loop_limiter_info()) { + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + loop_limiter_info.load_id(), context, + &ids_used_by_this_transformation)) { + return false; + } + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + loop_limiter_info.increment_id(), context, + &ids_used_by_this_transformation)) { + return false; + } + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + loop_limiter_info.compare_id(), context, + &ids_used_by_this_transformation)) { + return false; + } + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + loop_limiter_info.logical_op_id(), context, + &ids_used_by_this_transformation)) { + return false; + } + } + for (auto& access_chain_clamping_info : + message_.access_chain_clamping_info()) { + for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) { + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + pair.first(), context, &ids_used_by_this_transformation)) { + return false; + } + if (!CheckIdIsFreshAndNotUsedByThisTransformation( + pair.second(), context, &ids_used_by_this_transformation)) { + return false; + } + } + } + } + // Because checking all the conditions for a function to be valid is a big // job that the SPIR-V validator can already do, a "try it and see" approach // is taken here. @@ -47,18 +131,49 @@ bool TransformationAddFunction::IsApplicable( if (!TryToAddFunction(cloned_module.get())) { return false; } - // Having managed to add the new function to the cloned module, we ascertain - // whether the cloned module is still valid. If it is, the transformation is - // applicable. + + if (message_.is_livesafe()) { + // We make the cloned module livesafe. + if (!TryToMakeFunctionLivesafe(cloned_module.get(), fact_manager)) { + return false; + } + } + + // Having managed to add the new function to the cloned module, and + // potentially also made it livesafe, we ascertain whether the cloned module + // is still valid. If it is, the transformation is applicable. return fuzzerutil::IsValid(cloned_module.get()); } void TransformationAddFunction::Apply( - opt::IRContext* context, spvtools::fuzz::FactManager* /*unused*/) const { - auto success = TryToAddFunction(context); + opt::IRContext* context, spvtools::fuzz::FactManager* fact_manager) const { + // Add the function to the module. As the transformation is applicable, this + // should succeed. + bool success = TryToAddFunction(context); assert(success && "The function should be successfully added."); (void)(success); // Keep release builds happy (otherwise they may complain // that |success| is not used). + + if (message_.is_livesafe()) { + // Make the function livesafe, which also should succeed. + success = TryToMakeFunctionLivesafe(context, *fact_manager); + assert(success && "It should be possible to make the function livesafe."); + (void)(success); // Keep release builds happy. + + // Inform the fact manager that the function is livesafe. + assert(message_.instruction(0).opcode() == SpvOpFunction && + "The first instruction of an 'add function' transformation must be " + "OpFunction."); + fact_manager->AddFactFunctionIsLivesafe( + message_.instruction(0).result_id()); + } else { + // Inform the fact manager that all blocks in the function are dead. + for (auto& inst : message_.instruction()) { + if (inst.opcode() == SpvOpLabel) { + fact_manager->AddFactBlockIsDead(inst.result_id()); + } + } + } context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); } @@ -149,8 +264,634 @@ bool TransformationAddFunction::TryToAddFunction( new_function->SetFunctionEnd( InstructionFromMessage(context, message_.instruction(instruction_index))); context->AddFunction(std::move(new_function)); + + context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone); + return true; } +bool TransformationAddFunction::TryToMakeFunctionLivesafe( + opt::IRContext* context, const FactManager& fact_manager) const { + assert(message_.is_livesafe() && "Precondition: is_livesafe must hold."); + + // Get a pointer to the added function. + opt::Function* added_function = nullptr; + for (auto& function : *context->module()) { + if (function.result_id() == message_.instruction(0).result_id()) { + added_function = &function; + break; + } + } + assert(added_function && "The added function should have been found."); + + if (!TryToAddLoopLimiters(context, added_function)) { + // Adding loop limiters did not work; bail out. + return false; + } + + // Consider all the instructions in the function, and: + // - attempt to replace OpKill and OpUnreachable with return instructions + // - attempt to clamp access chains to be within bounds + // - check that OpFunctionCall instructions are only to livesafe functions + for (auto& block : *added_function) { + for (auto& inst : block) { + switch (inst.opcode()) { + case SpvOpKill: + case SpvOpUnreachable: + if (!TryToTurnKillOrUnreachableIntoReturn(context, added_function, + &inst)) { + return false; + } + break; + case SpvOpAccessChain: + case SpvOpInBoundsAccessChain: + if (!TryToClampAccessChainIndices(context, &inst)) { + return false; + } + break; + case SpvOpFunctionCall: + // A livesafe function my only call other livesafe functions. + if (!fact_manager.FunctionIsLivesafe( + inst.GetSingleWordInOperand(0))) { + return false; + } + default: + break; + } + } + } + return true; +} + +bool TransformationAddFunction::TryToAddLoopLimiters( + opt::IRContext* context, opt::Function* added_function) const { + // Collect up all the loop headers so that we can subsequently add loop + // limiting logic. + std::vector loop_headers; + for (auto& block : *added_function) { + if (block.IsLoopHeader()) { + loop_headers.push_back(&block); + } + } + + if (loop_headers.empty()) { + // There are no loops, so no need to add any loop limiters. + return true; + } + + // Check that the module contains appropriate ingredients for declaring and + // manipulating a loop limiter. + + auto loop_limit_constant_id_instr = + context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id()); + if (!loop_limit_constant_id_instr || + loop_limit_constant_id_instr->opcode() != SpvOpConstant) { + // The loop limit constant id instruction must exist and have an + // appropriate opcode. + return false; + } + + auto loop_limit_type = context->get_def_use_mgr()->GetDef( + loop_limit_constant_id_instr->type_id()); + if (loop_limit_type->opcode() != SpvOpTypeInt || + loop_limit_type->GetSingleWordInOperand(0) != 32) { + // The type of the loop limit constant must be 32-bit integer. It + // doesn't actually matter whether the integer is signed or not. + return false; + } + + // Find the id of the "unsigned int" type. + opt::analysis::Integer unsigned_int_type(32, false); + uint32_t unsigned_int_type_id = + context->get_type_mgr()->GetId(&unsigned_int_type); + if (!unsigned_int_type_id) { + // Unsigned int is not available; we need this type in order to add loop + // limiters. + return false; + } + auto registered_unsigned_int_type = + context->get_type_mgr()->GetRegisteredType(&unsigned_int_type); + + // Look for 0 of type unsigned int. + opt::analysis::IntConstant zero(registered_unsigned_int_type->AsInteger(), + {0}); + auto registered_zero = context->get_constant_mgr()->FindConstant(&zero); + if (!registered_zero) { + // We need 0 in order to be able to initialize loop limiters. + return false; + } + uint32_t zero_id = context->get_constant_mgr() + ->GetDefiningInstruction(registered_zero) + ->result_id(); + + // Look for 1 of type unsigned int. + opt::analysis::IntConstant one(registered_unsigned_int_type->AsInteger(), + {1}); + auto registered_one = context->get_constant_mgr()->FindConstant(&one); + if (!registered_one) { + // We need 1 in order to be able to increment loop limiters. + return false; + } + uint32_t one_id = context->get_constant_mgr() + ->GetDefiningInstruction(registered_one) + ->result_id(); + + // Look for pointer-to-unsigned int type. + opt::analysis::Pointer pointer_to_unsigned_int_type( + registered_unsigned_int_type, SpvStorageClassFunction); + uint32_t pointer_to_unsigned_int_type_id = + context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type); + if (!pointer_to_unsigned_int_type_id) { + // We need pointer-to-unsigned int in order to declare the loop limiter + // variable. + return false; + } + + // Look for bool type. + opt::analysis::Bool bool_type; + uint32_t bool_type_id = context->get_type_mgr()->GetId(&bool_type); + if (!bool_type_id) { + // We need bool in order to compare the loop limiter's value with the loop + // limit constant. + return false; + } + + // Declare the loop limiter variable at the start of the function's entry + // block, via an instruction of the form: + // %loop_limiter_var = SpvOpVariable %ptr_to_uint Function %zero + added_function->begin()->begin()->InsertBefore(MakeUnique( + context, SpvOpVariable, pointer_to_unsigned_int_type_id, + message_.loop_limiter_variable_id(), + opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}, + {SPV_OPERAND_TYPE_ID, {zero_id}}}))); + // Update the module's id bound since we have added the loop limiter + // variable id. + fuzzerutil::UpdateModuleIdBound(context, message_.loop_limiter_variable_id()); + + // Consider each loop in turn. + for (auto loop_header : loop_headers) { + // Look for the loop's back-edge block. This is a predecessor of the loop + // header that is dominated by the loop header. + uint32_t back_edge_block_id = 0; + for (auto pred : context->cfg()->preds(loop_header->id())) { + if (context->GetDominatorAnalysis(added_function) + ->Dominates(loop_header->id(), pred)) { + back_edge_block_id = pred; + break; + } + } + if (!back_edge_block_id) { + // The loop's back-edge block must be unreachable. This means that the + // loop cannot iterate, so there is no need to make it lifesafe; we can + // move on from this loop. + continue; + } + auto back_edge_block = context->cfg()->block(back_edge_block_id); + + // Go through the sequence of loop limiter infos and find the one + // corresponding to this loop. + bool found = false; + protobufs::LoopLimiterInfo loop_limiter_info; + for (auto& info : message_.loop_limiter_info()) { + if (info.loop_header_id() == loop_header->id()) { + loop_limiter_info = info; + found = true; + break; + } + } + if (!found) { + // We don't have loop limiter info for this loop header. + return false; + } + + // The back-edge block either has the form: + // + // (1) + // + // %l = OpLabel + // ... instructions ... + // OpBranch %loop_header + // + // (2) + // + // %l = OpLabel + // ... instructions ... + // OpBranchConditional %c %loop_header %loop_merge + // + // (3) + // + // %l = OpLabel + // ... instructions ... + // OpBranchConditional %c %loop_merge %loop_header + // + // We turn these into the following: + // + // (1) + // + // %l = OpLabel + // ... instructions ... + // %t1 = OpLoad %uint32 %loop_limiter + // %t2 = OpIAdd %uint32 %t1 %one + // OpStore %loop_limiter %t2 + // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit + // OpBranchConditional %t3 %loop_merge %loop_header + // + // (2) + // + // %l = OpLabel + // ... instructions ... + // %t1 = OpLoad %uint32 %loop_limiter + // %t2 = OpIAdd %uint32 %t1 %one + // OpStore %loop_limiter %t2 + // %t3 = OpULessThan %bool %t1 %loop_limit + // %t4 = OpLogicalAnd %bool %c %t3 + // OpBranchConditional %t4 %loop_header %loop_merge + // + // (3) + // + // %l = OpLabel + // ... instructions ... + // %t1 = OpLoad %uint32 %loop_limiter + // %t2 = OpIAdd %uint32 %t1 %one + // OpStore %loop_limiter %t2 + // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit + // %t4 = OpLogicalOr %bool %c %t3 + // OpBranchConditional %t4 %loop_merge %loop_header + + auto back_edge_block_terminator = back_edge_block->terminator(); + bool compare_using_greater_than_equal; + if (back_edge_block_terminator->opcode() == SpvOpBranch) { + compare_using_greater_than_equal = true; + } else { + assert(back_edge_block_terminator->opcode() == SpvOpBranchConditional); + assert(((back_edge_block_terminator->GetSingleWordInOperand(1) == + loop_header->id() && + back_edge_block_terminator->GetSingleWordInOperand(2) == + loop_header->MergeBlockId()) || + (back_edge_block_terminator->GetSingleWordInOperand(2) == + loop_header->id() && + back_edge_block_terminator->GetSingleWordInOperand(1) == + loop_header->MergeBlockId())) && + "A back edge edge block must branch to" + " either the loop header or merge"); + compare_using_greater_than_equal = + back_edge_block_terminator->GetSingleWordInOperand(1) == + loop_header->MergeBlockId(); + } + + std::vector> new_instructions; + + // Add a load from the loop limiter variable, of the form: + // %t1 = OpLoad %uint32 %loop_limiter + new_instructions.push_back(MakeUnique( + context, SpvOpLoad, unsigned_int_type_id, loop_limiter_info.load_id(), + opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}}))); + + // Increment the loaded value: + // %t2 = OpIAdd %uint32 %t1 %one + new_instructions.push_back(MakeUnique( + context, SpvOpIAdd, unsigned_int_type_id, + loop_limiter_info.increment_id(), + opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}}, + {SPV_OPERAND_TYPE_ID, {one_id}}}))); + + // Store the incremented value back to the loop limiter variable: + // OpStore %loop_limiter %t2 + new_instructions.push_back(MakeUnique( + context, SpvOpStore, 0, 0, + opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}, + {SPV_OPERAND_TYPE_ID, {loop_limiter_info.increment_id()}}}))); + + // Compare the loaded value with the loop limit; either: + // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit + // or + // %t3 = OpULessThan %bool %t1 %loop_limit + new_instructions.push_back(MakeUnique( + context, + compare_using_greater_than_equal ? SpvOpUGreaterThanEqual + : SpvOpULessThan, + bool_type_id, loop_limiter_info.compare_id(), + opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}}, + {SPV_OPERAND_TYPE_ID, {message_.loop_limit_constant_id()}}}))); + + if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) { + new_instructions.push_back(MakeUnique( + context, + compare_using_greater_than_equal ? SpvOpLogicalOr : SpvOpLogicalAnd, + bool_type_id, loop_limiter_info.logical_op_id(), + opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_ID, + {back_edge_block_terminator->GetSingleWordInOperand(0)}}, + {SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}}}))); + } + + // Add the new instructions at the end of the back edge block, before the + // terminator and any loop merge instruction (as the back edge block can + // be the loop header). + if (back_edge_block->GetLoopMergeInst()) { + back_edge_block->GetLoopMergeInst()->InsertBefore( + std::move(new_instructions)); + } else { + back_edge_block_terminator->InsertBefore(std::move(new_instructions)); + } + + if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) { + back_edge_block_terminator->SetInOperand( + 0, {loop_limiter_info.logical_op_id()}); + } else { + assert(back_edge_block_terminator->opcode() == SpvOpBranch && + "Back-edge terminator must be OpBranch or OpBranchConditional"); + + // Check that, if the merge block starts with OpPhi instructions, suitable + // ids have been provided to give these instructions a value corresponding + // to the new incoming edge from the back edge block. + auto merge_block = context->cfg()->block(loop_header->MergeBlockId()); + if (!fuzzerutil::PhiIdsOkForNewEdge(context, back_edge_block, merge_block, + loop_limiter_info.phi_id())) { + return false; + } + + // Augment OpPhi instructions at the loop merge with the given ids. + uint32_t phi_index = 0; + for (auto& inst : *merge_block) { + if (inst.opcode() != SpvOpPhi) { + break; + } + assert(phi_index < + static_cast(loop_limiter_info.phi_id().size()) && + "There should be at least one phi id per OpPhi instruction."); + inst.AddOperand( + {SPV_OPERAND_TYPE_ID, {loop_limiter_info.phi_id(phi_index)}}); + inst.AddOperand({SPV_OPERAND_TYPE_ID, {back_edge_block_id}}); + phi_index++; + } + + // Add the new edge, by changing OpBranch to OpBranchConditional. + // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3162): This + // could be a problem if the merge block was originally unreachable: it + // might now be dominated by other blocks that it appears earlier than in + // the module. + back_edge_block_terminator->SetOpcode(SpvOpBranchConditional); + back_edge_block_terminator->SetInOperands(opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}}, + {SPV_OPERAND_TYPE_ID, {loop_header->MergeBlockId()} + + }, + {SPV_OPERAND_TYPE_ID, {loop_header->id()}}})); + } + + // Update the module's id bound with respect to the various ids that + // have been used for loop limiter manipulation. + fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.load_id()); + fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.increment_id()); + fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.compare_id()); + fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.logical_op_id()); + } + return true; +} + +bool TransformationAddFunction::TryToTurnKillOrUnreachableIntoReturn( + opt::IRContext* context, opt::Function* added_function, + opt::Instruction* kill_or_unreachable_inst) const { + assert((kill_or_unreachable_inst->opcode() == SpvOpKill || + kill_or_unreachable_inst->opcode() == SpvOpUnreachable) && + "Precondition: instruction must be OpKill or OpUnreachable."); + + // Get the function's return type. + auto function_return_type_inst = + context->get_def_use_mgr()->GetDef(added_function->type_id()); + + if (function_return_type_inst->opcode() == SpvOpTypeVoid) { + // The function has void return type, so change this instruction to + // OpReturn. + kill_or_unreachable_inst->SetOpcode(SpvOpReturn); + } else { + // The function has non-void return type, so change this instruction + // to OpReturnValue, using the value id provided with the + // transformation. + + // We first check that the id, %id, provided with the transformation + // specifically to turn OpKill and OpUnreachable instructions into + // OpReturnValue %id has the same type as the function's return type. + if (context->get_def_use_mgr() + ->GetDef(message_.kill_unreachable_return_value_id()) + ->type_id() != function_return_type_inst->result_id()) { + return false; + } + kill_or_unreachable_inst->SetOpcode(SpvOpReturnValue); + kill_or_unreachable_inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {message_.kill_unreachable_return_value_id()}}}); + } + return true; +} + +bool TransformationAddFunction::TryToClampAccessChainIndices( + opt::IRContext* context, opt::Instruction* access_chain_inst) const { + assert((access_chain_inst->opcode() == SpvOpAccessChain || + access_chain_inst->opcode() == SpvOpInBoundsAccessChain) && + "Precondition: instruction must be OpAccessChain or " + "OpInBoundsAccessChain."); + + // Find the AccessChainClampingInfo associated with this access chain. + const protobufs::AccessChainClampingInfo* access_chain_clamping_info = + nullptr; + for (auto& clamping_info : message_.access_chain_clamping_info()) { + if (clamping_info.access_chain_id() == access_chain_inst->result_id()) { + access_chain_clamping_info = &clamping_info; + break; + } + } + if (!access_chain_clamping_info) { + // No access chain clamping information was found; the function cannot be + // made livesafe. + return false; + } + + // Check that there is a (compare_id, select_id) pair for every + // index associated with the instruction. + if (static_cast( + access_chain_clamping_info->compare_and_select_ids().size()) != + access_chain_inst->NumInOperands() - 1) { + return false; + } + + // Walk the access chain, clamping each index to be within bounds if it is + // not a constant. + auto base_object = context->get_def_use_mgr()->GetDef( + access_chain_inst->GetSingleWordInOperand(0)); + assert(base_object && "The base object must exist."); + auto pointer_type = + context->get_def_use_mgr()->GetDef(base_object->type_id()); + assert(pointer_type && pointer_type->opcode() == SpvOpTypePointer && + "The base object must have pointer type."); + auto should_be_composite_type = context->get_def_use_mgr()->GetDef( + pointer_type->GetSingleWordInOperand(1)); + + // Consider each index input operand in turn (operand 0 is the base object). + for (uint32_t index = 1; index < access_chain_inst->NumInOperands(); + index++) { + // We are going to turn: + // + // %result = OpAccessChain %type %object ... %index ... + // + // into: + // + // %t1 = OpULessThanEqual %bool %index %bound_minus_one + // %t2 = OpSelect %int_type %t1 %index %bound_minus_one + // %result = OpAccessChain %type %object ... %t2 ... + // + // ... unless %index is already a constant. + + // Get the bound for the composite being indexed into; e.g. the number of + // columns of matrix or the size of an array. + uint32_t bound = + GetBoundForCompositeIndex(context, *should_be_composite_type); + + // Get the instruction associated with the index and figure out its integer + // type. + const uint32_t index_id = access_chain_inst->GetSingleWordInOperand(index); + auto index_inst = context->get_def_use_mgr()->GetDef(index_id); + auto index_type_inst = + context->get_def_use_mgr()->GetDef(index_inst->type_id()); + assert(index_type_inst->opcode() == SpvOpTypeInt); + assert(index_type_inst->GetSingleWordInOperand(0) == 32); + opt::analysis::Integer* index_int_type = + context->get_type_mgr() + ->GetType(index_type_inst->result_id()) + ->AsInteger(); + + if (index_inst->opcode() != SpvOpConstant) { + // The index is non-constant so we need to clamp it. + assert(should_be_composite_type->opcode() != SpvOpTypeStruct && + "Access chain indices into structures are required to be " + "constants."); + opt::analysis::IntConstant bound_minus_one(index_int_type, {bound - 1}); + if (!context->get_constant_mgr()->FindConstant(&bound_minus_one)) { + // We do not have an integer constant whose value is |bound| -1. + return false; + } + + opt::analysis::Bool bool_type; + uint32_t bool_type_id = context->get_type_mgr()->GetId(&bool_type); + if (!bool_type_id) { + // Bool type is not declared; we cannot do a comparison. + return false; + } + + uint32_t bound_minus_one_id = + context->get_constant_mgr() + ->GetDefiningInstruction(&bound_minus_one) + ->result_id(); + + uint32_t compare_id = + access_chain_clamping_info->compare_and_select_ids(index - 1).first(); + uint32_t select_id = + access_chain_clamping_info->compare_and_select_ids(index - 1) + .second(); + std::vector> new_instructions; + + // Compare the index with the bound via an instruction of the form: + // %t1 = OpULessThanEqual %bool %index %bound_minus_one + new_instructions.push_back(MakeUnique( + context, SpvOpULessThanEqual, bool_type_id, compare_id, + opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_ID, {index_inst->result_id()}}, + {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}}))); + + // Select the index if in-bounds, otherwise one less than the bound: + // %t2 = OpSelect %int_type %t1 %index %bound_minus_one + new_instructions.push_back(MakeUnique( + context, SpvOpSelect, index_type_inst->result_id(), select_id, + opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_ID, {compare_id}}, + {SPV_OPERAND_TYPE_ID, {index_inst->result_id()}}, + {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}}))); + + // Add the new instructions before the access chain + access_chain_inst->InsertBefore(std::move(new_instructions)); + + // Replace %index with %t2. + access_chain_inst->SetInOperand(index, {select_id}); + fuzzerutil::UpdateModuleIdBound(context, compare_id); + fuzzerutil::UpdateModuleIdBound(context, select_id); + } else { + // TODO(afd): At present the SPIR-V spec is not clear on whether + // statically out-of-bounds indices mean that a module is invalid (so + // that it should be rejected by the validator), or that such accesses + // yield undefined results. Via the following assertion, we assume that + // functions added to the module do not feature statically out-of-bounds + // accesses. + // Assert that the index is smaller (unsigned) than this value. + // Return false if it is not (to keep compilers happy). + if (index_inst->GetSingleWordInOperand(0) >= bound) { + assert(false && + "The function has a statically out-of-bounds access; " + "this should not occur."); + return false; + } + } + should_be_composite_type = + FollowCompositeIndex(context, *should_be_composite_type, index_id); + } + return true; +} + +uint32_t TransformationAddFunction::GetBoundForCompositeIndex( + opt::IRContext* context, const opt::Instruction& composite_type_inst) { + switch (composite_type_inst.opcode()) { + case SpvOpTypeArray: + return fuzzerutil::GetArraySize(composite_type_inst, context); + case SpvOpTypeMatrix: + case SpvOpTypeVector: + return composite_type_inst.GetSingleWordInOperand(1); + case SpvOpTypeStruct: { + return fuzzerutil::GetNumberOfStructMembers(composite_type_inst); + } + default: + assert(false && "Unknown composite type."); + return 0; + } +} + +opt::Instruction* TransformationAddFunction::FollowCompositeIndex( + opt::IRContext* context, const opt::Instruction& composite_type_inst, + uint32_t index_id) { + uint32_t sub_object_type_id; + switch (composite_type_inst.opcode()) { + case SpvOpTypeArray: + sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0); + break; + case SpvOpTypeMatrix: + case SpvOpTypeVector: + sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0); + break; + case SpvOpTypeStruct: { + auto index_inst = context->get_def_use_mgr()->GetDef(index_id); + assert(index_inst->opcode() == SpvOpConstant); + assert( + context->get_def_use_mgr()->GetDef(index_inst->type_id())->opcode() == + SpvOpTypeInt); + assert(context->get_def_use_mgr() + ->GetDef(index_inst->type_id()) + ->GetSingleWordInOperand(0) == 32); + uint32_t index_value = index_inst->GetSingleWordInOperand(0); + sub_object_type_id = + composite_type_inst.GetSingleWordInOperand(index_value); + break; + } + default: + assert(false && "Unknown composite type."); + sub_object_type_id = 0; + break; + } + assert(sub_object_type_id && "No sub-object found."); + return context->get_def_use_mgr()->GetDef(sub_object_type_id); +} + } // namespace fuzz } // namespace spvtools diff --git a/source/fuzz/transformation_add_function.h b/source/fuzz/transformation_add_function.h index fee2732a90..848b799fca 100644 --- a/source/fuzz/transformation_add_function.h +++ b/source/fuzz/transformation_add_function.h @@ -28,26 +28,56 @@ class TransformationAddFunction : public Transformation { explicit TransformationAddFunction( const protobufs::TransformationAddFunction& message); + // Creates a transformation to add a non live-safe function. explicit TransformationAddFunction( const std::vector& instructions); + // Creates a transformation to add a live-safe function. + TransformationAddFunction( + const std::vector& instructions, + uint32_t loop_limiter_variable_id, uint32_t loop_limit_constant_id, + const std::vector& loop_limiters, + uint32_t kill_unreachable_return_value_id, + const std::vector& + access_chain_clampers); + // - |message_.instruction| must correspond to a sufficiently well-formed // sequence of instructions that a function can be created from them + // - If |message_.is_livesafe| holds then |message_| must contain suitable + // ingredients to make the function livesafe, and the function must only + // invoke other livesafe functions // - Adding the created function to the module must lead to a valid module. bool IsApplicable(opt::IRContext* context, const FactManager& fact_manager) const override; - // Adds the function defined by |message_.instruction| to the module + // Adds the function defined by |message_.instruction| to the module, making + // it livesafe if |message_.is_livesafe| holds. void Apply(opt::IRContext* context, FactManager* fact_manager) const override; protobufs::Transformation ToMessage() const override; + // Helper method that returns the bound for indexing into a composite of type + // |composite_type_inst|, i.e. the number of fields of a struct, the size of + // an array, the number of components of a vector, or the number of columns of + // a matrix. + static uint32_t GetBoundForCompositeIndex( + opt::IRContext* context, const opt::Instruction& composite_type_inst); + + // Helper method that, given composite type |composite_type_inst|, returns the + // type of the sub-object at index |index_id|, which is required to be in- + // bounds. + static opt::Instruction* FollowCompositeIndex( + opt::IRContext* context, const opt::Instruction& composite_type_inst, + uint32_t index_id); + private: // Attempts to create a function from the series of instructions in - // |message_.instruction| and add it to |context|. Returns false if this is - // not possible due to the messages not respecting the basic structure of a - // function, e.g. if there is no OpFunction instruction or no blocks; in this - // case |context| is left in an indeterminate state. + // |message_.instruction| and add it to |context|. + // + // Returns false if adding the function is not possible due to the messages + // not respecting the basic structure of a function, e.g. if there is no + // OpFunction instruction or no blocks; in this case |context| is left in an + // indeterminate state. // // Otherwise returns true. Whether |context| is valid after addition of the // function depends on the contents of |message_.instruction|. @@ -61,6 +91,30 @@ class TransformationAddFunction : public Transformation { // to add the function. bool TryToAddFunction(opt::IRContext* context) const; + // Should only be called if |message_.is_livesafe| holds. Attempts to make + // the function livesafe (see FactFunctionIsLivesafe for a definition). + // Returns false if this is not possible, due to |message_| or |context| not + // containing sufficient ingredients (such as types and fresh ids) to add + // the instrumentation necessary to make the function livesafe. + bool TryToMakeFunctionLivesafe(opt::IRContext* context, + const FactManager& fact_manager) const; + + // A helper for TryToMakeFunctionLivesafe that tries to add loop-limiting + // logic. + bool TryToAddLoopLimiters(opt::IRContext* context, + opt::Function* added_function) const; + + // A helper for TryToMakeFunctionLivesafe that tries to replace OpKill and + // OpUnreachable instructions into return instructions. + bool TryToTurnKillOrUnreachableIntoReturn( + opt::IRContext* context, opt::Function* added_function, + opt::Instruction* kill_or_unreachable_inst) const; + + // A helper for TryToMakeFunctionLivesafe that tries to clamp access chain + // indices so that they are guaranteed to be in-bounds. + bool TryToClampAccessChainIndices(opt::IRContext* context, + opt::Instruction* access_chain_inst) const; + protobufs::TransformationAddFunction message_; }; diff --git a/source/fuzz/transformation_outline_function.cpp b/source/fuzz/transformation_outline_function.cpp index 1b308c4daf..c097e6cfe6 100644 --- a/source/fuzz/transformation_outline_function.cpp +++ b/source/fuzz/transformation_outline_function.cpp @@ -254,11 +254,21 @@ bool TransformationOutlineFunction::IsApplicable( if (input_id_to_fresh_id_map.count(id) == 0) { return false; } - // Furthermore, no region input id is allowed to be the result of an access - // chain. This is because region input ids will become function parameters, - // and it is not legal to pass an access chain as a function parameter. - if (context->get_def_use_mgr()->GetDef(id)->opcode() == SpvOpAccessChain) { - return false; + // Furthermore, if the input id has pointer type it must be an OpVariable + // or OpFunctionParameter. + auto input_id_inst = context->get_def_use_mgr()->GetDef(id); + if (context->get_def_use_mgr() + ->GetDef(input_id_inst->type_id()) + ->opcode() == SpvOpTypePointer) { + switch (input_id_inst->opcode()) { + case SpvOpFunctionParameter: + case SpvOpVariable: + // These are OK. + break; + default: + // Anything else is not OK. + return false; + } } } diff --git a/test/fuzz/fuzzer_pass_donate_modules_test.cpp b/test/fuzz/fuzzer_pass_donate_modules_test.cpp index 7342dd426a..0d202b7a98 100644 --- a/test/fuzz/fuzzer_pass_donate_modules_test.cpp +++ b/test/fuzz/fuzzer_pass_donate_modules_test.cpp @@ -195,14 +195,15 @@ TEST(FuzzerPassDonateModulesTest, BasicDonation) { FactManager fact_manager; - FuzzerContext fuzzer_context(MakeUnique(0).get(), 100); + auto prng = MakeUnique(0); + FuzzerContext fuzzer_context(prng.get(), 100); protobufs::TransformationSequence transformation_sequence; FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), &fact_manager, &fuzzer_context, &transformation_sequence, {}); - fuzzer_pass.DonateSingleModule(donor_context.get()); + fuzzer_pass.DonateSingleModule(donor_context.get(), false); // We just check that the result is valid. Checking to what it should be // exactly equal to would be very fragile. @@ -276,7 +277,7 @@ TEST(FuzzerPassDonateModulesTest, DonationWithUniforms) { &fuzzer_context, &transformation_sequence, {}); - fuzzer_pass.DonateSingleModule(donor_context.get()); + fuzzer_pass.DonateSingleModule(donor_context.get(), false); ASSERT_TRUE(IsValid(env, recipient_context.get())); @@ -397,7 +398,7 @@ TEST(FuzzerPassDonateModulesTest, DonationWithInputAndOutputVariables) { &fuzzer_context, &transformation_sequence, {}); - fuzzer_pass.DonateSingleModule(donor_context.get()); + fuzzer_pass.DonateSingleModule(donor_context.get(), false); ASSERT_TRUE(IsValid(env, recipient_context.get())); @@ -483,7 +484,7 @@ TEST(FuzzerPassDonateModulesTest, DonateFunctionTypeWithDifferentPointers) { &fuzzer_context, &transformation_sequence, {}); - fuzzer_pass.DonateSingleModule(donor_context.get()); + fuzzer_pass.DonateSingleModule(donor_context.get(), false); // We just check that the result is valid. Checking to what it should be // exactly equal to would be very fragile. @@ -660,7 +661,7 @@ TEST(FuzzerPassDonateModulesTest, Miscellaneous1) { &fuzzer_context, &transformation_sequence, {}); - fuzzer_pass.DonateSingleModule(donor_context.get()); + fuzzer_pass.DonateSingleModule(donor_context.get(), false); // We just check that the result is valid. Checking to what it should be // exactly equal to would be very fragile. diff --git a/test/fuzz/transformation_add_dead_break_test.cpp b/test/fuzz/transformation_add_dead_break_test.cpp index 1dd0c9d4e9..d60fc1fc52 100644 --- a/test/fuzz/transformation_add_dead_break_test.cpp +++ b/test/fuzz/transformation_add_dead_break_test.cpp @@ -1948,9 +1948,6 @@ TEST(TransformationAddDeadBreakTest, PhiInstructions) { // Not applicable because two OpPhis (not just one) need to be updated at 20 ASSERT_FALSE(TransformationAddDeadBreak(23, 20, true, {13}) .IsApplicable(context.get(), fact_manager)); - // Not applicable because only two OpPhis (not three) need to be updated at 20 - ASSERT_FALSE(TransformationAddDeadBreak(23, 20, true, {13, 21, 22}) - .IsApplicable(context.get(), fact_manager)); // Not applicable because the given ids do not have types that match the // OpPhis at 20, in order ASSERT_FALSE(TransformationAddDeadBreak(23, 20, true, {21, 13}) diff --git a/test/fuzz/transformation_add_function_test.cpp b/test/fuzz/transformation_add_function_test.cpp index 66130be2bc..3bc7620510 100644 --- a/test/fuzz/transformation_add_function_test.cpp +++ b/test/fuzz/transformation_add_function_test.cpp @@ -20,6 +20,44 @@ namespace spvtools { namespace fuzz { namespace { +protobufs::AccessChainClampingInfo MakeAccessClampingInfo( + uint32_t access_chain_id, + const std::vector>& compare_and_select_ids) { + protobufs::AccessChainClampingInfo result; + result.set_access_chain_id(access_chain_id); + for (auto& compare_and_select_id : compare_and_select_ids) { + auto pair = result.add_compare_and_select_ids(); + pair->set_first(compare_and_select_id.first); + pair->set_second(compare_and_select_id.second); + } + return result; +} + +std::vector GetInstructionsForFunction( + spv_target_env env, const MessageConsumer& consumer, + const std::string& donor, uint32_t function_id) { + std::vector result; + const auto donor_context = + BuildModule(env, consumer, donor, kFuzzAssembleOption); + assert(IsValid(env, donor_context.get()) && "The given donor must be valid."); + for (auto& function : *donor_context->module()) { + if (function.result_id() == function_id) { + function.ForEachInst([&result](opt::Instruction* inst) { + opt::Instruction::OperandList input_operands; + for (uint32_t i = 0; i < inst->NumInOperands(); i++) { + input_operands.push_back(inst->GetInOperand(i)); + } + result.push_back(MakeInstructionMessage(inst->opcode(), inst->type_id(), + inst->result_id(), + input_operands)); + }); + break; + } + } + assert(!result.empty() && "The required function should have been found."); + return result; +} + TEST(TransformationAddFunctionTest, BasicTest) { std::string shader = R"( OpCapability Shader @@ -190,6 +228,12 @@ TEST(TransformationAddFunctionTest, BasicTest) { OpFunctionEnd )"; ASSERT_TRUE(IsEqual(env, after_transformation1, context.get())); + ASSERT_TRUE(fact_manager.BlockIsDead(14)); + ASSERT_TRUE(fact_manager.BlockIsDead(21)); + ASSERT_TRUE(fact_manager.BlockIsDead(22)); + ASSERT_TRUE(fact_manager.BlockIsDead(23)); + ASSERT_TRUE(fact_manager.BlockIsDead(24)); + ASSERT_TRUE(fact_manager.BlockIsDead(25)); TransformationAddFunction transformation2(std::vector( {MakeInstructionMessage( @@ -320,6 +364,7 @@ TEST(TransformationAddFunctionTest, BasicTest) { OpFunctionEnd )"; ASSERT_TRUE(IsEqual(env, after_transformation2, context.get())); + ASSERT_TRUE(fact_manager.BlockIsDead(16)); } TEST(TransformationAddFunctionTest, InapplicableTransformations) { @@ -442,6 +487,2261 @@ TEST(TransformationAddFunctionTest, InapplicableTransformations) { .IsApplicable(context.get(), fact_manager)); } +TEST(TransformationAddFunctionTest, LoopLimiters) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 0 + %7 = OpTypePointer Function %6 + %8 = OpConstant %6 0 + %9 = OpConstant %6 1 + %10 = OpConstant %6 5 + %11 = OpTypeBool + %12 = OpConstantTrue %11 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + + std::vector instructions; + instructions.push_back(MakeInstructionMessage( + SpvOpFunction, 2, 30, + {{SPV_OPERAND_TYPE_FUNCTION_CONTROL, {SpvFunctionControlMaskNone}}, + {SPV_OPERAND_TYPE_TYPE_ID, {3}}})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 31, {})); + instructions.push_back( + MakeInstructionMessage(SpvOpBranch, 0, 0, {{SPV_OPERAND_TYPE_ID, {20}}})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 20, {})); + instructions.push_back(MakeInstructionMessage( + SpvOpLoopMerge, 0, 0, + {{SPV_OPERAND_TYPE_ID, {21}}, + {SPV_OPERAND_TYPE_ID, {22}}, + {SPV_OPERAND_TYPE_LOOP_CONTROL, {SpvLoopControlMaskNone}}})); + instructions.push_back(MakeInstructionMessage(SpvOpBranchConditional, 0, 0, + {{SPV_OPERAND_TYPE_ID, {12}}, + {SPV_OPERAND_TYPE_ID, {23}}, + {SPV_OPERAND_TYPE_ID, {21}}})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 23, {})); + instructions.push_back(MakeInstructionMessage( + SpvOpLoopMerge, 0, 0, + {{SPV_OPERAND_TYPE_ID, {25}}, + {SPV_OPERAND_TYPE_ID, {26}}, + {SPV_OPERAND_TYPE_LOOP_CONTROL, {SpvLoopControlMaskNone}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpBranch, 0, 0, {{SPV_OPERAND_TYPE_ID, {28}}})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 28, {})); + instructions.push_back(MakeInstructionMessage(SpvOpBranchConditional, 0, 0, + {{SPV_OPERAND_TYPE_ID, {12}}, + {SPV_OPERAND_TYPE_ID, {26}}, + {SPV_OPERAND_TYPE_ID, {25}}})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 26, {})); + instructions.push_back( + MakeInstructionMessage(SpvOpBranch, 0, 0, {{SPV_OPERAND_TYPE_ID, {23}}})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 25, {})); + instructions.push_back(MakeInstructionMessage( + SpvOpLoopMerge, 0, 0, + {{SPV_OPERAND_TYPE_ID, {24}}, + {SPV_OPERAND_TYPE_ID, {27}}, + {SPV_OPERAND_TYPE_LOOP_CONTROL, {SpvLoopControlMaskNone}}})); + instructions.push_back(MakeInstructionMessage(SpvOpBranchConditional, 0, 0, + {{SPV_OPERAND_TYPE_ID, {12}}, + {SPV_OPERAND_TYPE_ID, {24}}, + {SPV_OPERAND_TYPE_ID, {27}}})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 27, {})); + instructions.push_back( + MakeInstructionMessage(SpvOpBranch, 0, 0, {{SPV_OPERAND_TYPE_ID, {25}}})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 24, {})); + instructions.push_back( + MakeInstructionMessage(SpvOpBranch, 0, 0, {{SPV_OPERAND_TYPE_ID, {22}}})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 22, {})); + instructions.push_back( + MakeInstructionMessage(SpvOpBranch, 0, 0, {{SPV_OPERAND_TYPE_ID, {20}}})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 21, {})); + instructions.push_back(MakeInstructionMessage(SpvOpReturn, 0, 0, {})); + instructions.push_back(MakeInstructionMessage(SpvOpFunctionEnd, 0, 0, {})); + + FactManager fact_manager; + + const auto context1 = BuildModule(env, consumer, shader, kFuzzAssembleOption); + const auto context2 = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context1.get())); + + TransformationAddFunction add_dead_function(instructions); + ASSERT_TRUE(add_dead_function.IsApplicable(context1.get(), fact_manager)); + add_dead_function.Apply(context1.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context1.get())); + // The added function should not be deemed livesafe. + ASSERT_FALSE(fact_manager.FunctionIsLivesafe(30)); + + std::string added_as_dead_code = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 0 + %7 = OpTypePointer Function %6 + %8 = OpConstant %6 0 + %9 = OpConstant %6 1 + %10 = OpConstant %6 5 + %11 = OpTypeBool + %12 = OpConstantTrue %11 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %30 = OpFunction %2 None %3 + %31 = OpLabel + OpBranch %20 + %20 = OpLabel + OpLoopMerge %21 %22 None + OpBranchConditional %12 %23 %21 + %23 = OpLabel + OpLoopMerge %25 %26 None + OpBranch %28 + %28 = OpLabel + OpBranchConditional %12 %26 %25 + %26 = OpLabel + OpBranch %23 + %25 = OpLabel + OpLoopMerge %24 %27 None + OpBranchConditional %12 %24 %27 + %27 = OpLabel + OpBranch %25 + %24 = OpLabel + OpBranch %22 + %22 = OpLabel + OpBranch %20 + %21 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, added_as_dead_code, context1.get())); + + protobufs::LoopLimiterInfo loop_limiter1; + loop_limiter1.set_loop_header_id(20); + loop_limiter1.set_load_id(101); + loop_limiter1.set_increment_id(102); + loop_limiter1.set_compare_id(103); + loop_limiter1.set_logical_op_id(104); + + protobufs::LoopLimiterInfo loop_limiter2; + loop_limiter2.set_loop_header_id(23); + loop_limiter2.set_load_id(105); + loop_limiter2.set_increment_id(106); + loop_limiter2.set_compare_id(107); + loop_limiter2.set_logical_op_id(108); + + protobufs::LoopLimiterInfo loop_limiter3; + loop_limiter3.set_loop_header_id(25); + loop_limiter3.set_load_id(109); + loop_limiter3.set_increment_id(110); + loop_limiter3.set_compare_id(111); + loop_limiter3.set_logical_op_id(112); + + std::vector loop_limiters = { + loop_limiter1, loop_limiter2, loop_limiter3}; + + TransformationAddFunction add_livesafe_function(instructions, 100, 10, + loop_limiters, 0, {}); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context2.get(), fact_manager)); + add_livesafe_function.Apply(context2.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context2.get())); + // The added function should indeed be deemed livesafe. + ASSERT_TRUE(fact_manager.FunctionIsLivesafe(30)); + std::string added_as_livesafe_code = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 0 + %7 = OpTypePointer Function %6 + %8 = OpConstant %6 0 + %9 = OpConstant %6 1 + %10 = OpConstant %6 5 + %11 = OpTypeBool + %12 = OpConstantTrue %11 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %30 = OpFunction %2 None %3 + %31 = OpLabel + %100 = OpVariable %7 Function %8 + OpBranch %20 + %20 = OpLabel + OpLoopMerge %21 %22 None + OpBranchConditional %12 %23 %21 + %23 = OpLabel + OpLoopMerge %25 %26 None + OpBranch %28 + %28 = OpLabel + OpBranchConditional %12 %26 %25 + %26 = OpLabel + %105 = OpLoad %6 %100 + %106 = OpIAdd %6 %105 %9 + OpStore %100 %106 + %107 = OpUGreaterThanEqual %11 %105 %10 + OpBranchConditional %107 %25 %23 + %25 = OpLabel + OpLoopMerge %24 %27 None + OpBranchConditional %12 %24 %27 + %27 = OpLabel + %109 = OpLoad %6 %100 + %110 = OpIAdd %6 %109 %9 + OpStore %100 %110 + %111 = OpUGreaterThanEqual %11 %109 %10 + OpBranchConditional %111 %24 %25 + %24 = OpLabel + OpBranch %22 + %22 = OpLabel + %101 = OpLoad %6 %100 + %102 = OpIAdd %6 %101 %9 + OpStore %100 %102 + %103 = OpUGreaterThanEqual %11 %101 %10 + OpBranchConditional %103 %21 %20 + %21 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, added_as_livesafe_code, context2.get())); +} + +TEST(TransformationAddFunctionTest, KillAndUnreachableInVoidFunction) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFunction %2 %7 + %13 = OpConstant %6 2 + %14 = OpTypeBool + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + + std::vector instructions; + + instructions.push_back(MakeInstructionMessage( + SpvOpFunction, 2, 10, + {{SPV_OPERAND_TYPE_FUNCTION_CONTROL, {SpvFunctionControlMaskNone}}, + {SPV_OPERAND_TYPE_TYPE_ID, {8}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpFunctionParameter, 7, 9, {})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 11, {})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 6, 12, {{SPV_OPERAND_TYPE_ID, {9}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpIEqual, 14, 15, + {{SPV_OPERAND_TYPE_ID, {12}}, {SPV_OPERAND_TYPE_ID, {13}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpSelectionMerge, 0, 0, + {{SPV_OPERAND_TYPE_ID, {17}}, + {SPV_OPERAND_TYPE_SELECTION_CONTROL, {SpvSelectionControlMaskNone}}})); + instructions.push_back(MakeInstructionMessage(SpvOpBranchConditional, 0, 0, + {{SPV_OPERAND_TYPE_ID, {15}}, + {SPV_OPERAND_TYPE_ID, {16}}, + {SPV_OPERAND_TYPE_ID, {17}}})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 16, {})); + instructions.push_back(MakeInstructionMessage(SpvOpUnreachable, 0, 0, {})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 17, {})); + instructions.push_back(MakeInstructionMessage(SpvOpKill, 0, 0, {})); + instructions.push_back(MakeInstructionMessage(SpvOpFunctionEnd, 0, 0, {})); + + FactManager fact_manager; + + const auto context1 = BuildModule(env, consumer, shader, kFuzzAssembleOption); + const auto context2 = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context1.get())); + + TransformationAddFunction add_dead_function(instructions); + ASSERT_TRUE(add_dead_function.IsApplicable(context1.get(), fact_manager)); + add_dead_function.Apply(context1.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context1.get())); + // The added function should not be deemed livesafe. + ASSERT_FALSE(fact_manager.FunctionIsLivesafe(10)); + + std::string added_as_dead_code = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFunction %2 %7 + %13 = OpConstant %6 2 + %14 = OpTypeBool + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %8 + %9 = OpFunctionParameter %7 + %11 = OpLabel + %12 = OpLoad %6 %9 + %15 = OpIEqual %14 %12 %13 + OpSelectionMerge %17 None + OpBranchConditional %15 %16 %17 + %16 = OpLabel + OpUnreachable + %17 = OpLabel + OpKill + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, added_as_dead_code, context1.get())); + + TransformationAddFunction add_livesafe_function(instructions, 0, 0, {}, 0, + {}); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context2.get(), fact_manager)); + add_livesafe_function.Apply(context2.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context2.get())); + // The added function should indeed be deemed livesafe. + ASSERT_TRUE(fact_manager.FunctionIsLivesafe(10)); + std::string added_as_livesafe_code = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFunction %2 %7 + %13 = OpConstant %6 2 + %14 = OpTypeBool + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %2 None %8 + %9 = OpFunctionParameter %7 + %11 = OpLabel + %12 = OpLoad %6 %9 + %15 = OpIEqual %14 %12 %13 + OpSelectionMerge %17 None + OpBranchConditional %15 %16 %17 + %16 = OpLabel + OpReturn + %17 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, added_as_livesafe_code, context2.get())); +} + +TEST(TransformationAddFunctionTest, KillAndUnreachableInNonVoidFunction) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFunction %2 %7 + %50 = OpTypeFunction %6 %7 + %13 = OpConstant %6 2 + %14 = OpTypeBool + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + + std::vector instructions; + + instructions.push_back(MakeInstructionMessage( + SpvOpFunction, 6, 10, + {{SPV_OPERAND_TYPE_FUNCTION_CONTROL, {SpvFunctionControlMaskNone}}, + {SPV_OPERAND_TYPE_TYPE_ID, {50}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpFunctionParameter, 7, 9, {})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 11, {})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 6, 12, {{SPV_OPERAND_TYPE_ID, {9}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpIEqual, 14, 15, + {{SPV_OPERAND_TYPE_ID, {12}}, {SPV_OPERAND_TYPE_ID, {13}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpSelectionMerge, 0, 0, + {{SPV_OPERAND_TYPE_ID, {17}}, + {SPV_OPERAND_TYPE_SELECTION_CONTROL, {SpvSelectionControlMaskNone}}})); + instructions.push_back(MakeInstructionMessage(SpvOpBranchConditional, 0, 0, + {{SPV_OPERAND_TYPE_ID, {15}}, + {SPV_OPERAND_TYPE_ID, {16}}, + {SPV_OPERAND_TYPE_ID, {17}}})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 16, {})); + instructions.push_back(MakeInstructionMessage(SpvOpUnreachable, 0, 0, {})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 17, {})); + instructions.push_back(MakeInstructionMessage(SpvOpKill, 0, 0, {})); + instructions.push_back(MakeInstructionMessage(SpvOpFunctionEnd, 0, 0, {})); + + FactManager fact_manager; + + const auto context1 = BuildModule(env, consumer, shader, kFuzzAssembleOption); + const auto context2 = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context1.get())); + + TransformationAddFunction add_dead_function(instructions); + ASSERT_TRUE(add_dead_function.IsApplicable(context1.get(), fact_manager)); + add_dead_function.Apply(context1.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context1.get())); + // The added function should not be deemed livesafe. + ASSERT_FALSE(fact_manager.FunctionIsLivesafe(10)); + + std::string added_as_dead_code = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFunction %2 %7 + %50 = OpTypeFunction %6 %7 + %13 = OpConstant %6 2 + %14 = OpTypeBool + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %6 None %50 + %9 = OpFunctionParameter %7 + %11 = OpLabel + %12 = OpLoad %6 %9 + %15 = OpIEqual %14 %12 %13 + OpSelectionMerge %17 None + OpBranchConditional %15 %16 %17 + %16 = OpLabel + OpUnreachable + %17 = OpLabel + OpKill + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, added_as_dead_code, context1.get())); + + TransformationAddFunction add_livesafe_function(instructions, 0, 0, {}, 13, + {}); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context2.get(), fact_manager)); + add_livesafe_function.Apply(context2.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context2.get())); + // The added function should indeed be deemed livesafe. + ASSERT_TRUE(fact_manager.FunctionIsLivesafe(10)); + std::string added_as_livesafe_code = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %8 = OpTypeFunction %2 %7 + %50 = OpTypeFunction %6 %7 + %13 = OpConstant %6 2 + %14 = OpTypeBool + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %10 = OpFunction %6 None %50 + %9 = OpFunctionParameter %7 + %11 = OpLabel + %12 = OpLoad %6 %9 + %15 = OpIEqual %14 %12 %13 + OpSelectionMerge %17 None + OpBranchConditional %15 %16 %17 + %16 = OpLabel + OpReturnValue %13 + %17 = OpLabel + OpReturnValue %13 + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, added_as_livesafe_code, context2.get())); +} + +TEST(TransformationAddFunctionTest, ClampedAccessChains) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %100 = OpTypeBool + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %15 = OpTypeInt 32 0 + %102 = OpTypePointer Function %15 + %8 = OpTypeFunction %2 %7 %102 %7 + %16 = OpConstant %15 5 + %17 = OpTypeArray %6 %16 + %18 = OpTypeArray %17 %16 + %19 = OpTypePointer Private %18 + %20 = OpVariable %19 Private + %21 = OpConstant %6 0 + %23 = OpTypePointer Private %6 + %26 = OpTypePointer Function %17 + %29 = OpTypePointer Private %17 + %33 = OpConstant %6 4 + %200 = OpConstant %15 4 + %35 = OpConstant %15 10 + %36 = OpTypeArray %6 %35 + %37 = OpTypePointer Private %36 + %38 = OpVariable %37 Private + %54 = OpTypeFloat 32 + %55 = OpTypeVector %54 4 + %56 = OpTypePointer Private %55 + %57 = OpVariable %56 Private + %59 = OpTypeVector %54 3 + %60 = OpTypeMatrix %59 2 + %61 = OpTypePointer Private %60 + %62 = OpVariable %61 Private + %64 = OpTypePointer Private %54 + %69 = OpConstant %54 2 + %71 = OpConstant %6 1 + %72 = OpConstant %6 2 + %201 = OpConstant %15 2 + %73 = OpConstant %6 3 + %202 = OpConstant %15 3 + %203 = OpConstant %6 1 + %204 = OpConstant %6 9 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + + std::vector instructions; + + instructions.push_back(MakeInstructionMessage( + SpvOpFunction, 2, 12, + {{SPV_OPERAND_TYPE_FUNCTION_CONTROL, {SpvFunctionControlMaskNone}}, + {SPV_OPERAND_TYPE_TYPE_ID, {8}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpFunctionParameter, 7, 9, {})); + instructions.push_back( + MakeInstructionMessage(SpvOpFunctionParameter, 102, 10, {})); + instructions.push_back( + MakeInstructionMessage(SpvOpFunctionParameter, 7, 11, {})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 13, {})); + + instructions.push_back(MakeInstructionMessage( + SpvOpVariable, 7, 14, + {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpVariable, 26, 27, + {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 6, 22, {{SPV_OPERAND_TYPE_ID, {11}}})); + instructions.push_back(MakeInstructionMessage(SpvOpAccessChain, 23, 24, + {{SPV_OPERAND_TYPE_ID, {20}}, + {SPV_OPERAND_TYPE_ID, {21}}, + {SPV_OPERAND_TYPE_ID, {22}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 6, 25, {{SPV_OPERAND_TYPE_ID, {24}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {14}}, {SPV_OPERAND_TYPE_ID, {25}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 15, 28, {{SPV_OPERAND_TYPE_ID, {10}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpAccessChain, 29, 30, + {{SPV_OPERAND_TYPE_ID, {20}}, {SPV_OPERAND_TYPE_ID, {28}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 17, 31, {{SPV_OPERAND_TYPE_ID, {30}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {27}}, {SPV_OPERAND_TYPE_ID, {31}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 6, 32, {{SPV_OPERAND_TYPE_ID, {9}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpInBoundsAccessChain, 7, 34, + {{SPV_OPERAND_TYPE_ID, {27}}, {SPV_OPERAND_TYPE_ID, {32}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {34}}, {SPV_OPERAND_TYPE_ID, {33}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 6, 39, {{SPV_OPERAND_TYPE_ID, {9}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpAccessChain, 23, 40, + {{SPV_OPERAND_TYPE_ID, {38}}, {SPV_OPERAND_TYPE_ID, {33}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 6, 41, {{SPV_OPERAND_TYPE_ID, {40}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpInBoundsAccessChain, 23, 42, + {{SPV_OPERAND_TYPE_ID, {38}}, {SPV_OPERAND_TYPE_ID, {39}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {42}}, {SPV_OPERAND_TYPE_ID, {41}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 15, 43, {{SPV_OPERAND_TYPE_ID, {10}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 6, 44, {{SPV_OPERAND_TYPE_ID, {11}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 6, 45, {{SPV_OPERAND_TYPE_ID, {9}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 15, 46, {{SPV_OPERAND_TYPE_ID, {10}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpIAdd, 6, 47, + {{SPV_OPERAND_TYPE_ID, {45}}, {SPV_OPERAND_TYPE_ID, {46}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpAccessChain, 23, 48, + {{SPV_OPERAND_TYPE_ID, {38}}, {SPV_OPERAND_TYPE_ID, {47}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 6, 49, {{SPV_OPERAND_TYPE_ID, {48}}})); + instructions.push_back(MakeInstructionMessage(SpvOpInBoundsAccessChain, 23, + 50, + {{SPV_OPERAND_TYPE_ID, {20}}, + {SPV_OPERAND_TYPE_ID, {43}}, + {SPV_OPERAND_TYPE_ID, {44}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 6, 51, {{SPV_OPERAND_TYPE_ID, {50}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpIAdd, 6, 52, + {{SPV_OPERAND_TYPE_ID, {51}}, {SPV_OPERAND_TYPE_ID, {49}}})); + instructions.push_back(MakeInstructionMessage(SpvOpAccessChain, 23, 53, + {{SPV_OPERAND_TYPE_ID, {20}}, + {SPV_OPERAND_TYPE_ID, {43}}, + {SPV_OPERAND_TYPE_ID, {44}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {53}}, {SPV_OPERAND_TYPE_ID, {52}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 15, 58, {{SPV_OPERAND_TYPE_ID, {10}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 6, 63, {{SPV_OPERAND_TYPE_ID, {11}}})); + instructions.push_back(MakeInstructionMessage(SpvOpAccessChain, 64, 65, + {{SPV_OPERAND_TYPE_ID, {62}}, + {SPV_OPERAND_TYPE_ID, {21}}, + {SPV_OPERAND_TYPE_ID, {63}}})); + instructions.push_back(MakeInstructionMessage(SpvOpAccessChain, 64, 101, + {{SPV_OPERAND_TYPE_ID, {62}}, + {SPV_OPERAND_TYPE_ID, {45}}, + {SPV_OPERAND_TYPE_ID, {46}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 54, 66, {{SPV_OPERAND_TYPE_ID, {65}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpAccessChain, 64, 67, + {{SPV_OPERAND_TYPE_ID, {57}}, {SPV_OPERAND_TYPE_ID, {58}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {67}}, {SPV_OPERAND_TYPE_ID, {66}}})); + instructions.push_back( + MakeInstructionMessage(SpvOpLoad, 6, 68, {{SPV_OPERAND_TYPE_ID, {9}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpInBoundsAccessChain, 64, 70, + {{SPV_OPERAND_TYPE_ID, {57}}, {SPV_OPERAND_TYPE_ID, {68}}})); + instructions.push_back(MakeInstructionMessage( + SpvOpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {70}}, {SPV_OPERAND_TYPE_ID, {69}}})); + instructions.push_back(MakeInstructionMessage(SpvOpReturn, 0, 0, {})); + instructions.push_back(MakeInstructionMessage(SpvOpFunctionEnd, 0, 0, {})); + + FactManager fact_manager; + + const auto context1 = BuildModule(env, consumer, shader, kFuzzAssembleOption); + const auto context2 = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context1.get())); + + TransformationAddFunction add_dead_function(instructions); + ASSERT_TRUE(add_dead_function.IsApplicable(context1.get(), fact_manager)); + add_dead_function.Apply(context1.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context1.get())); + + std::string added_as_dead_code = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %100 = OpTypeBool + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %15 = OpTypeInt 32 0 + %102 = OpTypePointer Function %15 + %8 = OpTypeFunction %2 %7 %102 %7 + %16 = OpConstant %15 5 + %17 = OpTypeArray %6 %16 + %18 = OpTypeArray %17 %16 + %19 = OpTypePointer Private %18 + %20 = OpVariable %19 Private + %21 = OpConstant %6 0 + %23 = OpTypePointer Private %6 + %26 = OpTypePointer Function %17 + %29 = OpTypePointer Private %17 + %33 = OpConstant %6 4 + %200 = OpConstant %15 4 + %35 = OpConstant %15 10 + %36 = OpTypeArray %6 %35 + %37 = OpTypePointer Private %36 + %38 = OpVariable %37 Private + %54 = OpTypeFloat 32 + %55 = OpTypeVector %54 4 + %56 = OpTypePointer Private %55 + %57 = OpVariable %56 Private + %59 = OpTypeVector %54 3 + %60 = OpTypeMatrix %59 2 + %61 = OpTypePointer Private %60 + %62 = OpVariable %61 Private + %64 = OpTypePointer Private %54 + %69 = OpConstant %54 2 + %71 = OpConstant %6 1 + %72 = OpConstant %6 2 + %201 = OpConstant %15 2 + %73 = OpConstant %6 3 + %202 = OpConstant %15 3 + %203 = OpConstant %6 1 + %204 = OpConstant %6 9 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %8 + %9 = OpFunctionParameter %7 + %10 = OpFunctionParameter %102 + %11 = OpFunctionParameter %7 + %13 = OpLabel + %14 = OpVariable %7 Function + %27 = OpVariable %26 Function + %22 = OpLoad %6 %11 + %24 = OpAccessChain %23 %20 %21 %22 + %25 = OpLoad %6 %24 + OpStore %14 %25 + %28 = OpLoad %15 %10 + %30 = OpAccessChain %29 %20 %28 + %31 = OpLoad %17 %30 + OpStore %27 %31 + %32 = OpLoad %6 %9 + %34 = OpInBoundsAccessChain %7 %27 %32 + OpStore %34 %33 + %39 = OpLoad %6 %9 + %40 = OpAccessChain %23 %38 %33 + %41 = OpLoad %6 %40 + %42 = OpInBoundsAccessChain %23 %38 %39 + OpStore %42 %41 + %43 = OpLoad %15 %10 + %44 = OpLoad %6 %11 + %45 = OpLoad %6 %9 + %46 = OpLoad %15 %10 + %47 = OpIAdd %6 %45 %46 + %48 = OpAccessChain %23 %38 %47 + %49 = OpLoad %6 %48 + %50 = OpInBoundsAccessChain %23 %20 %43 %44 + %51 = OpLoad %6 %50 + %52 = OpIAdd %6 %51 %49 + %53 = OpAccessChain %23 %20 %43 %44 + OpStore %53 %52 + %58 = OpLoad %15 %10 + %63 = OpLoad %6 %11 + %65 = OpAccessChain %64 %62 %21 %63 + %101 = OpAccessChain %64 %62 %45 %46 + %66 = OpLoad %54 %65 + %67 = OpAccessChain %64 %57 %58 + OpStore %67 %66 + %68 = OpLoad %6 %9 + %70 = OpInBoundsAccessChain %64 %57 %68 + OpStore %70 %69 + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, added_as_dead_code, context1.get())); + + std::vector access_chain_clamping_info; + access_chain_clamping_info.push_back( + MakeAccessClampingInfo(24, {{1001, 2001}, {1002, 2002}})); + access_chain_clamping_info.push_back( + MakeAccessClampingInfo(30, {{1003, 2003}})); + access_chain_clamping_info.push_back( + MakeAccessClampingInfo(34, {{1004, 2004}})); + access_chain_clamping_info.push_back( + MakeAccessClampingInfo(40, {{1005, 2005}})); + access_chain_clamping_info.push_back( + MakeAccessClampingInfo(42, {{1006, 2006}})); + access_chain_clamping_info.push_back( + MakeAccessClampingInfo(48, {{1007, 2007}})); + access_chain_clamping_info.push_back( + MakeAccessClampingInfo(50, {{1008, 2008}, {1009, 2009}})); + access_chain_clamping_info.push_back( + MakeAccessClampingInfo(53, {{1010, 2010}, {1011, 2011}})); + access_chain_clamping_info.push_back( + MakeAccessClampingInfo(65, {{1012, 2012}, {1013, 2013}})); + access_chain_clamping_info.push_back( + MakeAccessClampingInfo(101, {{1014, 2014}, {1015, 2015}})); + access_chain_clamping_info.push_back( + MakeAccessClampingInfo(67, {{1016, 2016}})); + access_chain_clamping_info.push_back( + MakeAccessClampingInfo(70, {{1017, 2017}})); + + TransformationAddFunction add_livesafe_function(instructions, 0, 0, {}, 13, + access_chain_clamping_info); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context2.get(), fact_manager)); + add_livesafe_function.Apply(context2.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context2.get())); + std::string added_as_livesafe_code = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %100 = OpTypeBool + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypePointer Function %6 + %15 = OpTypeInt 32 0 + %102 = OpTypePointer Function %15 + %8 = OpTypeFunction %2 %7 %102 %7 + %16 = OpConstant %15 5 + %17 = OpTypeArray %6 %16 + %18 = OpTypeArray %17 %16 + %19 = OpTypePointer Private %18 + %20 = OpVariable %19 Private + %21 = OpConstant %6 0 + %23 = OpTypePointer Private %6 + %26 = OpTypePointer Function %17 + %29 = OpTypePointer Private %17 + %33 = OpConstant %6 4 + %200 = OpConstant %15 4 + %35 = OpConstant %15 10 + %36 = OpTypeArray %6 %35 + %37 = OpTypePointer Private %36 + %38 = OpVariable %37 Private + %54 = OpTypeFloat 32 + %55 = OpTypeVector %54 4 + %56 = OpTypePointer Private %55 + %57 = OpVariable %56 Private + %59 = OpTypeVector %54 3 + %60 = OpTypeMatrix %59 2 + %61 = OpTypePointer Private %60 + %62 = OpVariable %61 Private + %64 = OpTypePointer Private %54 + %69 = OpConstant %54 2 + %71 = OpConstant %6 1 + %72 = OpConstant %6 2 + %201 = OpConstant %15 2 + %73 = OpConstant %6 3 + %202 = OpConstant %15 3 + %203 = OpConstant %6 1 + %204 = OpConstant %6 9 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %12 = OpFunction %2 None %8 + %9 = OpFunctionParameter %7 + %10 = OpFunctionParameter %102 + %11 = OpFunctionParameter %7 + %13 = OpLabel + %14 = OpVariable %7 Function + %27 = OpVariable %26 Function + %22 = OpLoad %6 %11 + %1002 = OpULessThanEqual %100 %22 %33 + %2002 = OpSelect %6 %1002 %22 %33 + %24 = OpAccessChain %23 %20 %21 %2002 + %25 = OpLoad %6 %24 + OpStore %14 %25 + %28 = OpLoad %15 %10 + %1003 = OpULessThanEqual %100 %28 %200 + %2003 = OpSelect %15 %1003 %28 %200 + %30 = OpAccessChain %29 %20 %2003 + %31 = OpLoad %17 %30 + OpStore %27 %31 + %32 = OpLoad %6 %9 + %1004 = OpULessThanEqual %100 %32 %33 + %2004 = OpSelect %6 %1004 %32 %33 + %34 = OpInBoundsAccessChain %7 %27 %2004 + OpStore %34 %33 + %39 = OpLoad %6 %9 + %40 = OpAccessChain %23 %38 %33 + %41 = OpLoad %6 %40 + %1006 = OpULessThanEqual %100 %39 %204 + %2006 = OpSelect %6 %1006 %39 %204 + %42 = OpInBoundsAccessChain %23 %38 %2006 + OpStore %42 %41 + %43 = OpLoad %15 %10 + %44 = OpLoad %6 %11 + %45 = OpLoad %6 %9 + %46 = OpLoad %15 %10 + %47 = OpIAdd %6 %45 %46 + %1007 = OpULessThanEqual %100 %47 %204 + %2007 = OpSelect %6 %1007 %47 %204 + %48 = OpAccessChain %23 %38 %2007 + %49 = OpLoad %6 %48 + %1008 = OpULessThanEqual %100 %43 %200 + %2008 = OpSelect %15 %1008 %43 %200 + %1009 = OpULessThanEqual %100 %44 %33 + %2009 = OpSelect %6 %1009 %44 %33 + %50 = OpInBoundsAccessChain %23 %20 %2008 %2009 + %51 = OpLoad %6 %50 + %52 = OpIAdd %6 %51 %49 + %1010 = OpULessThanEqual %100 %43 %200 + %2010 = OpSelect %15 %1010 %43 %200 + %1011 = OpULessThanEqual %100 %44 %33 + %2011 = OpSelect %6 %1011 %44 %33 + %53 = OpAccessChain %23 %20 %2010 %2011 + OpStore %53 %52 + %58 = OpLoad %15 %10 + %63 = OpLoad %6 %11 + %1013 = OpULessThanEqual %100 %63 %72 + %2013 = OpSelect %6 %1013 %63 %72 + %65 = OpAccessChain %64 %62 %21 %2013 + %1014 = OpULessThanEqual %100 %45 %71 + %2014 = OpSelect %6 %1014 %45 %71 + %1015 = OpULessThanEqual %100 %46 %201 + %2015 = OpSelect %15 %1015 %46 %201 + %101 = OpAccessChain %64 %62 %2014 %2015 + %66 = OpLoad %54 %65 + %1016 = OpULessThanEqual %100 %58 %202 + %2016 = OpSelect %15 %1016 %58 %202 + %67 = OpAccessChain %64 %57 %2016 + OpStore %67 %66 + %68 = OpLoad %6 %9 + %1017 = OpULessThanEqual %100 %68 %73 + %2017 = OpSelect %6 %1017 %68 %73 + %70 = OpInBoundsAccessChain %64 %57 %2017 + OpStore %70 %69 + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, added_as_livesafe_code, context2.get())); +} + +TEST(TransformationAddFunctionTest, LivesafeCanCallLivesafe) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + + std::vector instructions; + + instructions.push_back(MakeInstructionMessage( + SpvOpFunction, 2, 8, + {{SPV_OPERAND_TYPE_FUNCTION_CONTROL, {SpvFunctionControlMaskNone}}, + {SPV_OPERAND_TYPE_TYPE_ID, {3}}})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 9, {})); + instructions.push_back(MakeInstructionMessage(SpvOpFunctionCall, 2, 11, + {{SPV_OPERAND_TYPE_ID, {6}}})); + instructions.push_back(MakeInstructionMessage(SpvOpReturn, 0, 0, {})); + instructions.push_back(MakeInstructionMessage(SpvOpFunctionEnd, 0, 0, {})); + + FactManager fact_manager1; + FactManager fact_manager2; + + // Mark function 6 as livesafe. + fact_manager2.AddFactFunctionIsLivesafe(6); + + const auto context1 = BuildModule(env, consumer, shader, kFuzzAssembleOption); + const auto context2 = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context1.get())); + + TransformationAddFunction add_dead_function(instructions); + ASSERT_TRUE(add_dead_function.IsApplicable(context1.get(), fact_manager1)); + add_dead_function.Apply(context1.get(), &fact_manager1); + ASSERT_TRUE(IsValid(env, context1.get())); + + std::string added_as_live_or_dead_code = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %11 = OpFunctionCall %2 %6 + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, added_as_live_or_dead_code, context1.get())); + + TransformationAddFunction add_livesafe_function(instructions, 0, 0, {}, 0, + {}); + ASSERT_TRUE( + add_livesafe_function.IsApplicable(context2.get(), fact_manager2)); + add_livesafe_function.Apply(context2.get(), &fact_manager2); + ASSERT_TRUE(IsValid(env, context2.get())); + ASSERT_TRUE(IsEqual(env, added_as_live_or_dead_code, context2.get())); +} + +TEST(TransformationAddFunctionTest, LivesafeOnlyCallsLivesafe) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + OpKill + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + + std::vector instructions; + + instructions.push_back(MakeInstructionMessage( + SpvOpFunction, 2, 8, + {{SPV_OPERAND_TYPE_FUNCTION_CONTROL, {SpvFunctionControlMaskNone}}, + {SPV_OPERAND_TYPE_TYPE_ID, {3}}})); + instructions.push_back(MakeInstructionMessage(SpvOpLabel, 0, 9, {})); + instructions.push_back(MakeInstructionMessage(SpvOpFunctionCall, 2, 11, + {{SPV_OPERAND_TYPE_ID, {6}}})); + instructions.push_back(MakeInstructionMessage(SpvOpReturn, 0, 0, {})); + instructions.push_back(MakeInstructionMessage(SpvOpFunctionEnd, 0, 0, {})); + + FactManager fact_manager; + + const auto context1 = BuildModule(env, consumer, shader, kFuzzAssembleOption); + const auto context2 = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context1.get())); + + TransformationAddFunction add_dead_function(instructions); + ASSERT_TRUE(add_dead_function.IsApplicable(context1.get(), fact_manager)); + add_dead_function.Apply(context1.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context1.get())); + + std::string added_as_dead_code = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + OpKill + OpFunctionEnd + %8 = OpFunction %2 None %3 + %9 = OpLabel + %11 = OpFunctionCall %2 %6 + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, added_as_dead_code, context1.get())); + + TransformationAddFunction add_livesafe_function(instructions, 0, 0, {}, 0, + {}); + ASSERT_FALSE( + add_livesafe_function.IsApplicable(context2.get(), fact_manager)); +} + +TEST(TransformationAddFunctionTest, + LoopLimitersBackEdgeBlockEndsWithConditional1) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %22 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %22 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %10 = OpVariable %9 Function + OpStore %10 %11 + OpBranch %12 + %12 = OpLabel + OpLoopMerge %14 %15 None + OpBranch %15 + %15 = OpLabel + %17 = OpLoad %8 %10 + %20 = OpSLessThan %19 %17 %18 + %21 = OpLoad %8 %10 + %23 = OpIAdd %8 %21 %22 + OpStore %10 %23 + OpBranchConditional %20 %12 %14 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + + FactManager fact_manager; + + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + // Make a sequence of instruction messages corresponding to function %8 in + // |donor|. + std::vector instructions = + GetInstructionsForFunction(env, consumer, donor, 6); + + protobufs::LoopLimiterInfo loop_limiter_info; + loop_limiter_info.set_loop_header_id(12); + loop_limiter_info.set_load_id(102); + loop_limiter_info.set_increment_id(103); + loop_limiter_info.set_compare_id(104); + loop_limiter_info.set_logical_op_id(105); + TransformationAddFunction add_livesafe_function(instructions, 100, 32, + {loop_limiter_info}, 0, {}); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), fact_manager)); + add_livesafe_function.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %22 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %100 = OpVariable %29 Function %30 + %10 = OpVariable %9 Function + OpStore %10 %11 + OpBranch %12 + %12 = OpLabel + OpLoopMerge %14 %15 None + OpBranch %15 + %15 = OpLabel + %17 = OpLoad %8 %10 + %20 = OpSLessThan %19 %17 %18 + %21 = OpLoad %8 %10 + %23 = OpIAdd %8 %21 %22 + OpStore %10 %23 + %102 = OpLoad %28 %100 + %103 = OpIAdd %28 %102 %31 + OpStore %100 %103 + %104 = OpULessThan %19 %102 %32 + %105 = OpLogicalAnd %19 %20 %104 + OpBranchConditional %105 %12 %14 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, expected, context.get())); +} + +TEST(TransformationAddFunctionTest, + LoopLimitersBackEdgeBlockEndsWithConditional2) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %22 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %22 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %10 = OpVariable %9 Function + OpStore %10 %11 + OpBranch %12 + %12 = OpLabel + OpLoopMerge %14 %15 None + OpBranch %15 + %15 = OpLabel + %17 = OpLoad %8 %10 + %20 = OpSLessThan %19 %17 %18 + %21 = OpLoad %8 %10 + %23 = OpIAdd %8 %21 %22 + OpStore %10 %23 + %50 = OpLogicalNot %19 %20 + OpBranchConditional %50 %14 %12 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + + FactManager fact_manager; + + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + // Make a sequence of instruction messages corresponding to function %8 in + // |donor|. + std::vector instructions = + GetInstructionsForFunction(env, consumer, donor, 6); + + protobufs::LoopLimiterInfo loop_limiter_info; + loop_limiter_info.set_loop_header_id(12); + loop_limiter_info.set_load_id(102); + loop_limiter_info.set_increment_id(103); + loop_limiter_info.set_compare_id(104); + loop_limiter_info.set_logical_op_id(105); + TransformationAddFunction add_livesafe_function(instructions, 100, 32, + {loop_limiter_info}, 0, {}); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), fact_manager)); + add_livesafe_function.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %22 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %100 = OpVariable %29 Function %30 + %10 = OpVariable %9 Function + OpStore %10 %11 + OpBranch %12 + %12 = OpLabel + OpLoopMerge %14 %15 None + OpBranch %15 + %15 = OpLabel + %17 = OpLoad %8 %10 + %20 = OpSLessThan %19 %17 %18 + %21 = OpLoad %8 %10 + %23 = OpIAdd %8 %21 %22 + OpStore %10 %23 + %50 = OpLogicalNot %19 %20 + %102 = OpLoad %28 %100 + %103 = OpIAdd %28 %102 %31 + OpStore %100 %103 + %104 = OpUGreaterThanEqual %19 %102 %32 + %105 = OpLogicalOr %19 %50 %104 + OpBranchConditional %105 %14 %12 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, expected, context.get())); +} + +TEST(TransformationAddFunctionTest, LoopLimitersHeaderIsBackEdgeBlock) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %22 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %22 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %10 = OpVariable %9 Function + OpStore %10 %11 + OpBranch %12 + %12 = OpLabel + %17 = OpLoad %8 %10 + %20 = OpSLessThan %19 %17 %18 + %21 = OpLoad %8 %10 + %23 = OpIAdd %8 %21 %22 + OpStore %10 %23 + %50 = OpLogicalNot %19 %20 + OpLoopMerge %14 %12 None + OpBranchConditional %50 %14 %12 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + + FactManager fact_manager; + + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + // Make a sequence of instruction messages corresponding to function %8 in + // |donor|. + std::vector instructions = + GetInstructionsForFunction(env, consumer, donor, 6); + + protobufs::LoopLimiterInfo loop_limiter_info; + loop_limiter_info.set_loop_header_id(12); + loop_limiter_info.set_load_id(102); + loop_limiter_info.set_increment_id(103); + loop_limiter_info.set_compare_id(104); + loop_limiter_info.set_logical_op_id(105); + TransformationAddFunction add_livesafe_function(instructions, 100, 32, + {loop_limiter_info}, 0, {}); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), fact_manager)); + add_livesafe_function.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %22 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %100 = OpVariable %29 Function %30 + %10 = OpVariable %9 Function + OpStore %10 %11 + OpBranch %12 + %12 = OpLabel + %17 = OpLoad %8 %10 + %20 = OpSLessThan %19 %17 %18 + %21 = OpLoad %8 %10 + %23 = OpIAdd %8 %21 %22 + OpStore %10 %23 + %50 = OpLogicalNot %19 %20 + %102 = OpLoad %28 %100 + %103 = OpIAdd %28 %102 %31 + OpStore %100 %103 + %104 = OpUGreaterThanEqual %19 %102 %32 + %105 = OpLogicalOr %19 %50 %104 + OpLoopMerge %14 %12 None + OpBranchConditional %105 %14 %12 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, expected, context.get())); +} + +TEST(TransformationAddFunctionTest, InfiniteLoop1) { + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %22 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %22 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %10 = OpVariable %9 Function + OpStore %10 %11 + OpBranch %12 + %12 = OpLabel + OpLoopMerge %14 %12 None + OpBranch %12 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + + FactManager fact_manager; + + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + // Make a sequence of instruction messages corresponding to function %8 in + // |donor|. + std::vector instructions = + GetInstructionsForFunction(env, consumer, donor, 6); + + protobufs::LoopLimiterInfo loop_limiter_info; + loop_limiter_info.set_loop_header_id(12); + loop_limiter_info.set_load_id(102); + loop_limiter_info.set_increment_id(103); + loop_limiter_info.set_compare_id(104); + loop_limiter_info.set_logical_op_id(105); + TransformationAddFunction add_livesafe_function(instructions, 100, 32, + {loop_limiter_info}, 0, {}); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), fact_manager)); + add_livesafe_function.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %22 = OpConstant %8 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %100 = OpVariable %29 Function %30 + %10 = OpVariable %9 Function + OpStore %10 %11 + OpBranch %12 + %12 = OpLabel + %102 = OpLoad %28 %100 + %103 = OpIAdd %28 %102 %31 + OpStore %100 %103 + %104 = OpUGreaterThanEqual %19 %102 %32 + OpLoopMerge %14 %12 None + OpBranchConditional %104 %14 %12 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, expected, context.get())); +} + +TEST(TransformationAddFunctionTest, UnreachableContinueConstruct) { + // This captures the case where the loop's continue construct is statically + // unreachable. In this case the loop cannot iterate and so we do not add + // a loop limiter. (The reason we do not just add one anyway is that + // detecting which block would be the back-edge block is difficult in the + // absence of reliable dominance information.) + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %23 = OpConstant %8 1 + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %23 = OpConstant %8 1 + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %10 = OpVariable %9 Function + OpStore %10 %11 + OpBranch %12 + %12 = OpLabel + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %17 = OpLoad %8 %10 + %20 = OpSLessThan %19 %17 %18 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + OpBranch %14 + %15 = OpLabel + %22 = OpLoad %8 %10 + %24 = OpIAdd %8 %22 %23 + OpStore %10 %24 + OpBranch %12 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + + FactManager fact_manager; + + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + // Make a sequence of instruction messages corresponding to function %8 in + // |donor|. + std::vector instructions = + GetInstructionsForFunction(env, consumer, donor, 6); + + protobufs::LoopLimiterInfo loop_limiter_info; + loop_limiter_info.set_loop_header_id(12); + loop_limiter_info.set_load_id(102); + loop_limiter_info.set_increment_id(103); + loop_limiter_info.set_compare_id(104); + loop_limiter_info.set_logical_op_id(105); + TransformationAddFunction add_livesafe_function(instructions, 100, 32, + {loop_limiter_info}, 0, {}); + ASSERT_TRUE(add_livesafe_function.IsApplicable(context.get(), fact_manager)); + add_livesafe_function.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %8 = OpTypeInt 32 1 + %9 = OpTypePointer Function %8 + %11 = OpConstant %8 0 + %18 = OpConstant %8 10 + %19 = OpTypeBool + %23 = OpConstant %8 1 + %26 = OpConstantTrue %19 + %27 = OpConstantFalse %19 + %28 = OpTypeInt 32 0 + %29 = OpTypePointer Function %28 + %30 = OpConstant %28 0 + %31 = OpConstant %28 1 + %32 = OpConstant %28 5 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %6 = OpFunction %2 None %3 + %7 = OpLabel + %100 = OpVariable %29 Function %30 + %10 = OpVariable %9 Function + OpStore %10 %11 + OpBranch %12 + %12 = OpLabel + OpLoopMerge %14 %15 None + OpBranch %16 + %16 = OpLabel + %17 = OpLoad %8 %10 + %20 = OpSLessThan %19 %17 %18 + OpBranchConditional %20 %13 %14 + %13 = OpLabel + OpBranch %14 + %15 = OpLabel + %22 = OpLoad %8 %10 + %24 = OpIAdd %8 %22 %23 + OpStore %10 %24 + OpBranch %12 + %14 = OpLabel + OpReturn + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, expected, context.get())); +} + +TEST(TransformationAddFunctionTest, LoopLimitersAndOpPhi1) { + // This captures the scenario where breaking a loop due to a loop limiter + // requires patching up OpPhi instructions occurring at the loop merge block. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %50 = OpTypeInt 32 0 + %51 = OpConstant %50 0 + %52 = OpConstant %50 1 + %53 = OpTypePointer Function %50 + %7 = OpTypeFunction %6 + %10 = OpTypePointer Function %6 + %12 = OpConstant %6 0 + %19 = OpConstant %6 100 + %20 = OpTypeBool + %23 = OpConstant %6 20 + %28 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %7 = OpTypeFunction %6 + %10 = OpTypePointer Function %6 + %12 = OpConstant %6 0 + %19 = OpConstant %6 100 + %20 = OpTypeBool + %23 = OpConstant %6 20 + %28 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %36 = OpFunctionCall %6 %8 + OpReturn + OpFunctionEnd + %8 = OpFunction %6 None %7 + %9 = OpLabel + %11 = OpVariable %10 Function + OpStore %11 %12 + OpBranch %13 + %13 = OpLabel + %37 = OpPhi %6 %12 %9 %32 %16 + OpLoopMerge %15 %16 None + OpBranch %17 + %17 = OpLabel + %21 = OpSLessThan %20 %37 %19 + OpBranchConditional %21 %14 %15 + %14 = OpLabel + %24 = OpSGreaterThan %20 %37 %23 + OpSelectionMerge %26 None + OpBranchConditional %24 %25 %26 + %25 = OpLabel + %29 = OpIAdd %6 %37 %28 + OpStore %11 %29 + OpBranch %15 + %26 = OpLabel + OpBranch %16 + %16 = OpLabel + %32 = OpIAdd %6 %37 %28 + OpStore %11 %32 + OpBranch %13 + %15 = OpLabel + %38 = OpPhi %6 %37 %17 %29 %25 + OpReturnValue %38 + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + + FactManager fact_manager; + + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + // Make a sequence of instruction messages corresponding to function %8 in + // |donor|. + std::vector instructions = + GetInstructionsForFunction(env, consumer, donor, 8); + + protobufs::LoopLimiterInfo loop_limiter_info; + loop_limiter_info.set_loop_header_id(13); + loop_limiter_info.set_load_id(102); + loop_limiter_info.set_increment_id(103); + loop_limiter_info.set_compare_id(104); + loop_limiter_info.set_logical_op_id(105); + + TransformationAddFunction no_op_phi_data(instructions, 100, 28, + {loop_limiter_info}, 0, {}); + // The loop limiter info is not good enough; it does not include ids to patch + // up the OpPhi at the loop merge. + ASSERT_FALSE(no_op_phi_data.IsApplicable(context.get(), fact_manager)); + + // Add a phi id for the new edge from the loop back edge block to the loop + // merge. + loop_limiter_info.add_phi_id(28); + TransformationAddFunction with_op_phi_data(instructions, 100, 28, + {loop_limiter_info}, 0, {}); + ASSERT_TRUE(with_op_phi_data.IsApplicable(context.get(), fact_manager)); + with_op_phi_data.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %50 = OpTypeInt 32 0 + %51 = OpConstant %50 0 + %52 = OpConstant %50 1 + %53 = OpTypePointer Function %50 + %7 = OpTypeFunction %6 + %10 = OpTypePointer Function %6 + %12 = OpConstant %6 0 + %19 = OpConstant %6 100 + %20 = OpTypeBool + %23 = OpConstant %6 20 + %28 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %6 None %7 + %9 = OpLabel + %100 = OpVariable %53 Function %51 + %11 = OpVariable %10 Function + OpStore %11 %12 + OpBranch %13 + %13 = OpLabel + %37 = OpPhi %6 %12 %9 %32 %16 + OpLoopMerge %15 %16 None + OpBranch %17 + %17 = OpLabel + %21 = OpSLessThan %20 %37 %19 + OpBranchConditional %21 %14 %15 + %14 = OpLabel + %24 = OpSGreaterThan %20 %37 %23 + OpSelectionMerge %26 None + OpBranchConditional %24 %25 %26 + %25 = OpLabel + %29 = OpIAdd %6 %37 %28 + OpStore %11 %29 + OpBranch %15 + %26 = OpLabel + OpBranch %16 + %16 = OpLabel + %32 = OpIAdd %6 %37 %28 + OpStore %11 %32 + %102 = OpLoad %50 %100 + %103 = OpIAdd %50 %102 %52 + OpStore %100 %103 + %104 = OpUGreaterThanEqual %20 %102 %28 + OpBranchConditional %104 %15 %13 + %15 = OpLabel + %38 = OpPhi %6 %37 %17 %29 %25 %28 %16 + OpReturnValue %38 + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, expected, context.get())); +} + +TEST(TransformationAddFunctionTest, LoopLimitersAndOpPhi2) { + // This captures the scenario where the loop merge block already has an OpPhi + // with the loop back edge block as a predecessor. + + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %50 = OpTypeInt 32 0 + %51 = OpConstant %50 0 + %52 = OpConstant %50 1 + %53 = OpTypePointer Function %50 + %7 = OpTypeFunction %6 + %10 = OpTypePointer Function %6 + %12 = OpConstant %6 0 + %19 = OpConstant %6 100 + %20 = OpTypeBool + %60 = OpConstantTrue %20 + %23 = OpConstant %6 20 + %28 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::string donor = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %50 = OpTypeInt 32 0 + %51 = OpConstant %50 0 + %52 = OpConstant %50 1 + %53 = OpTypePointer Function %50 + %7 = OpTypeFunction %6 + %10 = OpTypePointer Function %6 + %12 = OpConstant %6 0 + %19 = OpConstant %6 100 + %20 = OpTypeBool + %60 = OpConstantTrue %20 + %23 = OpConstant %6 20 + %28 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %6 None %7 + %9 = OpLabel + %11 = OpVariable %10 Function + OpStore %11 %12 + OpBranch %13 + %13 = OpLabel + %37 = OpPhi %6 %12 %9 %32 %16 + OpLoopMerge %15 %16 None + OpBranch %17 + %17 = OpLabel + %21 = OpSLessThan %20 %37 %19 + OpBranchConditional %21 %14 %15 + %14 = OpLabel + %24 = OpSGreaterThan %20 %37 %23 + OpSelectionMerge %26 None + OpBranchConditional %24 %25 %26 + %25 = OpLabel + %29 = OpIAdd %6 %37 %28 + OpStore %11 %29 + OpBranch %15 + %26 = OpLabel + OpBranch %16 + %16 = OpLabel + %32 = OpIAdd %6 %37 %28 + OpStore %11 %32 + OpBranchConditional %60 %15 %13 + %15 = OpLabel + %38 = OpPhi %6 %37 %17 %29 %25 %23 %16 + OpReturnValue %38 + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_4; + const auto consumer = nullptr; + + FactManager fact_manager; + + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + // Make a sequence of instruction messages corresponding to function %8 in + // |donor|. + std::vector instructions = + GetInstructionsForFunction(env, consumer, donor, 8); + + protobufs::LoopLimiterInfo loop_limiter_info; + loop_limiter_info.set_loop_header_id(13); + loop_limiter_info.set_load_id(102); + loop_limiter_info.set_increment_id(103); + loop_limiter_info.set_compare_id(104); + loop_limiter_info.set_logical_op_id(105); + + TransformationAddFunction transformation(instructions, 100, 28, + {loop_limiter_info}, 0, {}); + ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager)); + transformation.Apply(context.get(), &fact_manager); + ASSERT_TRUE(IsValid(env, context.get())); + std::string expected = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeInt 32 1 + %50 = OpTypeInt 32 0 + %51 = OpConstant %50 0 + %52 = OpConstant %50 1 + %53 = OpTypePointer Function %50 + %7 = OpTypeFunction %6 + %10 = OpTypePointer Function %6 + %12 = OpConstant %6 0 + %19 = OpConstant %6 100 + %20 = OpTypeBool + %60 = OpConstantTrue %20 + %23 = OpConstant %6 20 + %28 = OpConstant %6 1 + %4 = OpFunction %2 None %3 + %5 = OpLabel + OpReturn + OpFunctionEnd + %8 = OpFunction %6 None %7 + %9 = OpLabel + %100 = OpVariable %53 Function %51 + %11 = OpVariable %10 Function + OpStore %11 %12 + OpBranch %13 + %13 = OpLabel + %37 = OpPhi %6 %12 %9 %32 %16 + OpLoopMerge %15 %16 None + OpBranch %17 + %17 = OpLabel + %21 = OpSLessThan %20 %37 %19 + OpBranchConditional %21 %14 %15 + %14 = OpLabel + %24 = OpSGreaterThan %20 %37 %23 + OpSelectionMerge %26 None + OpBranchConditional %24 %25 %26 + %25 = OpLabel + %29 = OpIAdd %6 %37 %28 + OpStore %11 %29 + OpBranch %15 + %26 = OpLabel + OpBranch %16 + %16 = OpLabel + %32 = OpIAdd %6 %37 %28 + OpStore %11 %32 + %102 = OpLoad %50 %100 + %103 = OpIAdd %50 %102 %52 + OpStore %100 %103 + %104 = OpUGreaterThanEqual %20 %102 %28 + %105 = OpLogicalOr %20 %60 %104 + OpBranchConditional %105 %15 %13 + %15 = OpLabel + %38 = OpPhi %6 %37 %17 %29 %25 %23 %16 + OpReturnValue %38 + OpFunctionEnd + )"; + ASSERT_TRUE(IsEqual(env, expected, context.get())); +} + } // namespace } // namespace fuzz } // namespace spvtools diff --git a/test/fuzz/transformation_outline_function_test.cpp b/test/fuzz/transformation_outline_function_test.cpp index 5cd1437675..91e17337f4 100644 --- a/test/fuzz/transformation_outline_function_test.cpp +++ b/test/fuzz/transformation_outline_function_test.cpp @@ -1656,6 +1656,64 @@ TEST(TransformationOutlineFunctionTest, DoNotOutlineRegionThatUsesAccessChain) { ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); } +TEST(TransformationOutlineFunctionTest, + DoNotOutlineRegionThatUsesCopiedObject) { + // Copying a variable leads to a pointer, but one that cannot be passed as a + // function parameter, as it is not a memory object. + std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %4 "main" + OpExecutionMode %4 OriginUpperLeft + OpSource ESSL 310 + OpName %4 "main" + %2 = OpTypeVoid + %3 = OpTypeFunction %2 + %6 = OpTypeFloat 32 + %7 = OpTypeVector %6 4 + %8 = OpTypePointer Function %7 + %9 = OpTypePointer Function %6 + %18 = OpTypeInt 32 0 + %19 = OpConstant %18 0 + %4 = OpFunction %2 None %3 + %5 = OpLabel + %10 = OpVariable %8 Function + OpBranch %11 + %11 = OpLabel + %20 = OpCopyObject %8 %10 + OpBranch %13 + %13 = OpLabel + %12 = OpAccessChain %9 %20 %19 + %14 = OpLoad %6 %12 + OpBranch %15 + %15 = OpLabel + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_5; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + + TransformationOutlineFunction transformation( + /*entry_block*/ 13, + /*exit_block*/ 15, + /*new_function_struct_return_type_id*/ 200, + /*new_function_type_id*/ 201, + /*new_function_id*/ 202, + /*new_function_region_entry_block*/ 204, + /*new_caller_result_id*/ 205, + /*new_callee_result_id*/ 206, + /*input_id_to_fresh_id*/ {{20, 207}}, + /*output_id_to_fresh_id*/ {}); + + ASSERT_FALSE(transformation.IsApplicable(context.get(), fact_manager)); +} + TEST(TransformationOutlineFunctionTest, Miscellaneous1) { // This tests outlining of some non-trivial code.