diff --git a/source/opt/desc_sroa.cpp b/source/opt/desc_sroa.cpp index 8da0c864fe..2c0f4829f2 100644 --- a/source/opt/desc_sroa.cpp +++ b/source/opt/desc_sroa.cpp @@ -54,9 +54,10 @@ Pass::Status DescriptorScalarReplacement::Process() { bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) { std::vector access_chain_work_list; std::vector load_work_list; + std::vector entry_point_work_list; bool failed = !get_def_use_mgr()->WhileEachUser( - var->result_id(), - [this, &access_chain_work_list, &load_work_list](Instruction* use) { + var->result_id(), [this, &access_chain_work_list, &load_work_list, + &entry_point_work_list](Instruction* use) { if (use->opcode() == spv::Op::OpName) { return true; } @@ -73,6 +74,9 @@ bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) { case spv::Op::OpLoad: load_work_list.push_back(use); return true; + case spv::Op::OpEntryPoint: + entry_point_work_list.push_back(use); + return true; default: context()->EmitErrorMessage( "Variable cannot be replaced: invalid instruction", use); @@ -95,6 +99,11 @@ bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) { return false; } } + for (Instruction* use : entry_point_work_list) { + if (!ReplaceEntryPoint(var, use)) { + return false; + } + } return true; } @@ -147,6 +156,42 @@ bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var, return true; } +bool DescriptorScalarReplacement::ReplaceEntryPoint(Instruction* var, + Instruction* use) { + // Build a new |OperandList| for |use| that removes |var| and adds its + // replacement variables. + Instruction::OperandList new_operands; + + // Copy all operands except |var|. + bool found = false; + for (uint32_t idx = 0; idx < use->NumOperands(); idx++) { + Operand& op = use->GetOperand(idx); + if (op.type == SPV_OPERAND_TYPE_ID && op.words[0] == var->result_id()) { + found = true; + } else { + new_operands.emplace_back(op); + } + } + + if (!found) { + context()->EmitErrorMessage( + "Variable cannot be replaced: invalid instruction", use); + return false; + } + + // Add all new replacement variables. + uint32_t num_replacement_vars = + descsroautil::GetNumberOfElementsForArrayOrStruct(context(), var); + for (uint32_t i = 0; i < num_replacement_vars; i++) { + new_operands.push_back( + {SPV_OPERAND_TYPE_ID, {GetReplacementVariable(var, i)}}); + } + + use->ReplaceOperands(new_operands); + context()->UpdateDefUse(use); + return true; +} + uint32_t DescriptorScalarReplacement::GetReplacementVariable(Instruction* var, uint32_t idx) { auto replacement_vars = replacement_variables_.find(var); diff --git a/source/opt/desc_sroa.h b/source/opt/desc_sroa.h index 6a24fd8714..901be3e98b 100644 --- a/source/opt/desc_sroa.h +++ b/source/opt/desc_sroa.h @@ -64,6 +64,11 @@ class DescriptorScalarReplacement : public Pass { // otherwise. bool ReplaceLoadedValue(Instruction* var, Instruction* value); + // Replaces the given composite variable |var| in the OpEntryPoint with the + // new replacement variables, one for each element of the array |var|. Returns + // |true| if successful, and |false| otherwise. + bool ReplaceEntryPoint(Instruction* var, Instruction* use); + // Replaces the given OpCompositeExtract |extract| and all of its references // with an OpLoad of a replacement variable. |var| is the variable with // composite type whose value is being used by |extract|. Assumes that diff --git a/test/opt/desc_sroa_test.cpp b/test/opt/desc_sroa_test.cpp index 7a118f988e..5c166d83f7 100644 --- a/test/opt/desc_sroa_test.cpp +++ b/test/opt/desc_sroa_test.cpp @@ -918,6 +918,74 @@ TEST_F(DescriptorScalarReplacementTest, DecorateStringForReflect) { SinglePassRunAndMatch(shader, true); } +TEST_F(DescriptorScalarReplacementTest, ExpandArrayInOpEntryPoint) { + const std::string text = R"(; SPIR-V +; Version: 1.6 +; Bound: 31 +; Schema: 0 + OpCapability Shader + OpMemoryModel Logical GLSL450 + +; CHECK: OpEntryPoint GLCompute %main "main" %output_0_ %output_1_ + + OpEntryPoint GLCompute %main "main" %output + OpExecutionMode %main LocalSize 1 1 1 + OpSource HLSL 670 + OpName %type_RWByteAddressBuffer "type.RWByteAddressBuffer" + OpName %output "output" + OpName %main "main" + OpName %src_main "src.main" + OpName %bb_entry "bb.entry" + +; CHECK: OpDecorate %output_1_ DescriptorSet 0 +; CHECK: OpDecorate %output_1_ Binding 1 +; CHECK: OpDecorate %output_0_ DescriptorSet 0 +; CHECK: OpDecorate %output_0_ Binding 0 + + OpDecorate %output DescriptorSet 0 + OpDecorate %output Binding 0 + + OpDecorate %_runtimearr_uint ArrayStride 4 + OpMemberDecorate %type_RWByteAddressBuffer 0 Offset 0 + OpDecorate %type_RWByteAddressBuffer Block + %int = OpTypeInt 32 1 + %int_1 = OpConstant %int 1 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_2 = OpConstant %uint 2 + %uint_32 = OpConstant %uint 32 +%_runtimearr_uint = OpTypeRuntimeArray %uint +%type_RWByteAddressBuffer = OpTypeStruct %_runtimearr_uint +%_arr_type_RWByteAddressBuffer_uint_2 = OpTypeArray %type_RWByteAddressBuffer %uint_2 +%_ptr_StorageBuffer__arr_type_RWByteAddressBuffer_uint_2 = OpTypePointer StorageBuffer %_arr_type_RWByteAddressBuffer_uint_2 + %void = OpTypeVoid + %23 = OpTypeFunction %void +%_ptr_StorageBuffer_type_RWByteAddressBuffer = OpTypePointer StorageBuffer %type_RWByteAddressBuffer +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint + +; CHECK: %output_1_ = OpVariable %_ptr_StorageBuffer_type_RWByteAddressBuffer StorageBuffer +; CHECK: %output_0_ = OpVariable %_ptr_StorageBuffer_type_RWByteAddressBuffer StorageBuffer + + %output = OpVariable %_ptr_StorageBuffer__arr_type_RWByteAddressBuffer_uint_2 StorageBuffer + + %main = OpFunction %void None %23 + %26 = OpLabel + %27 = OpFunctionCall %void %src_main + OpReturn + OpFunctionEnd + %src_main = OpFunction %void None %23 + %bb_entry = OpLabel + %28 = OpAccessChain %_ptr_StorageBuffer_type_RWByteAddressBuffer %output %int_1 + %29 = OpShiftRightLogical %uint %uint_0 %uint_2 + %30 = OpAccessChain %_ptr_StorageBuffer_uint %28 %uint_0 %29 + OpStore %30 %uint_32 + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(text, false); +} + } // namespace } // namespace opt } // namespace spvtools