diff --git a/source/opt/eliminate_dead_input_components_pass.cpp b/source/opt/eliminate_dead_input_components_pass.cpp index f383136d55..aa2776bbd3 100644 --- a/source/opt/eliminate_dead_input_components_pass.cpp +++ b/source/opt/eliminate_dead_input_components_pass.cpp @@ -56,21 +56,30 @@ Pass::Status EliminateDeadInputComponentsPass::Process() { continue; } const analysis::Array* arr_type = ptr_type->pointee_type()->AsArray(); - if (arr_type == nullptr) { + if (arr_type != nullptr) { + unsigned arr_len_id = arr_type->LengthId(); + Instruction* arr_len_inst = def_use_mgr->GetDef(arr_len_id); + if (arr_len_inst->opcode() != SpvOpConstant) { + continue; + } + // SPIR-V requires array size is >= 1, so this works for signed or + // unsigned size + unsigned original_max = + arr_len_inst->GetSingleWordInOperand(kConstantValueInIdx) - 1; + unsigned max_idx = FindMaxIndex(var, original_max); + if (max_idx != original_max) { + ChangeArrayLength(var, max_idx + 1); + modified = true; + } continue; } - unsigned arr_len_id = arr_type->LengthId(); - Instruction* arr_len_inst = def_use_mgr->GetDef(arr_len_id); - if (arr_len_inst->opcode() != SpvOpConstant) { - continue; - } - // SPIR-V requires array size is >= 1, so this works for signed or - // unsigned size - unsigned original_max = - arr_len_inst->GetSingleWordInOperand(kConstantValueInIdx) - 1; + const analysis::Struct* struct_type = ptr_type->pointee_type()->AsStruct(); + if (struct_type == nullptr) continue; + const auto elt_types = struct_type->element_types(); + unsigned original_max = static_cast(elt_types.size()) - 1; unsigned max_idx = FindMaxIndex(var, original_max); if (max_idx != original_max) { - ChangeArrayLength(var, max_idx + 1); + ChangeStructLength(var, max_idx + 1); modified = true; } } @@ -116,12 +125,13 @@ unsigned EliminateDeadInputComponentsPass::FindMaxIndex(Instruction& var, return seen_non_const_ac ? original_max : max; } -void EliminateDeadInputComponentsPass::ChangeArrayLength(Instruction& arr, +void EliminateDeadInputComponentsPass::ChangeArrayLength(Instruction& arr_var, unsigned length) { analysis::TypeManager* type_mgr = context()->get_type_mgr(); analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); - analysis::Pointer* ptr_type = type_mgr->GetType(arr.type_id())->AsPointer(); + analysis::Pointer* ptr_type = + type_mgr->GetType(arr_var.type_id())->AsPointer(); const analysis::Array* arr_ty = ptr_type->pointee_type()->AsArray(); assert(arr_ty && "expecting array type"); uint32_t length_id = const_mgr->GetUIntConst(length); @@ -131,15 +141,48 @@ void EliminateDeadInputComponentsPass::ChangeArrayLength(Instruction& arr, analysis::Pointer new_ptr_ty(reg_new_arr_ty, SpvStorageClassInput); analysis::Type* reg_new_ptr_ty = type_mgr->GetRegisteredType(&new_ptr_ty); uint32_t new_ptr_ty_id = type_mgr->GetTypeInstruction(reg_new_ptr_ty); - arr.SetResultType(new_ptr_ty_id); - def_use_mgr->AnalyzeInstUse(&arr); - // Move array OpVariable instruction after its new type to preserve order - USE_ASSERT(arr.GetSingleWordInOperand(kVariableStorageClassInIdx) != + arr_var.SetResultType(new_ptr_ty_id); + def_use_mgr->AnalyzeInstUse(&arr_var); + // Move arr_var after its new type to preserve order + USE_ASSERT(arr_var.GetSingleWordInOperand(kVariableStorageClassInIdx) != + SpvStorageClassFunction && + "cannot move Function variable"); + Instruction* new_ptr_ty_inst = def_use_mgr->GetDef(new_ptr_ty_id); + arr_var.RemoveFromList(); + arr_var.InsertAfter(new_ptr_ty_inst); +} + +void EliminateDeadInputComponentsPass::ChangeStructLength( + Instruction& struct_var, unsigned length) { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::Pointer* ptr_type = + type_mgr->GetType(struct_var.type_id())->AsPointer(); + const analysis::Struct* struct_ty = ptr_type->pointee_type()->AsStruct(); + assert(struct_ty && "expecting struct type"); + const auto orig_elt_types = struct_ty->element_types(); + std::vector new_elt_types; + for (unsigned u = 0; u < length; ++u) + new_elt_types.push_back(orig_elt_types[u]); + analysis::Struct new_struct_ty(new_elt_types); + analysis::Type* reg_new_struct_ty = + type_mgr->GetRegisteredType(&new_struct_ty); + uint32_t new_struct_ty_id = type_mgr->GetTypeInstruction(reg_new_struct_ty); + uint32_t old_struct_ty_id = type_mgr->GetTypeInstruction(struct_ty); + analysis::DecorationManager* deco_mgr = context()->get_decoration_mgr(); + deco_mgr->CloneDecorations(old_struct_ty_id, new_struct_ty_id); + analysis::Pointer new_ptr_ty(reg_new_struct_ty, SpvStorageClassInput); + analysis::Type* reg_new_ptr_ty = type_mgr->GetRegisteredType(&new_ptr_ty); + uint32_t new_ptr_ty_id = type_mgr->GetTypeInstruction(reg_new_ptr_ty); + struct_var.SetResultType(new_ptr_ty_id); + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + def_use_mgr->AnalyzeInstUse(&struct_var); + // Move struct_var after its new type to preserve order + USE_ASSERT(struct_var.GetSingleWordInOperand(kVariableStorageClassInIdx) != SpvStorageClassFunction && "cannot move Function variable"); Instruction* new_ptr_ty_inst = def_use_mgr->GetDef(new_ptr_ty_id); - arr.RemoveFromList(); - arr.InsertAfter(new_ptr_ty_inst); + struct_var.RemoveFromList(); + struct_var.InsertAfter(new_ptr_ty_inst); } } // namespace opt diff --git a/source/opt/eliminate_dead_input_components_pass.h b/source/opt/eliminate_dead_input_components_pass.h index b77857f4e9..a3a133c2bb 100644 --- a/source/opt/eliminate_dead_input_components_pass.h +++ b/source/opt/eliminate_dead_input_components_pass.h @@ -30,7 +30,10 @@ class EliminateDeadInputComponentsPass : public Pass { public: explicit EliminateDeadInputComponentsPass() {} - const char* name() const override { return "reduce-load-size"; } + const char* name() const override { + return "eliminate-dead-input-components"; + } + Status Process() override; // Return the mask of preserved Analyses. @@ -51,6 +54,9 @@ class EliminateDeadInputComponentsPass : public Pass { // Change the length of the array |inst| to |length| void ChangeArrayLength(Instruction& inst, unsigned length); + + // Change the length of the struct |struct_var| to |length| + void ChangeStructLength(Instruction& struct_var, unsigned length); }; } // namespace opt diff --git a/test/opt/eliminate_dead_input_components_test.cpp b/test/opt/eliminate_dead_input_components_test.cpp index b0098f733a..822914a860 100644 --- a/test/opt/eliminate_dead_input_components_test.cpp +++ b/test/opt/eliminate_dead_input_components_test.cpp @@ -399,6 +399,70 @@ TEST_F(ElimDeadInputComponentsTest, NoElimNonIndexedAccessChain) { SinglePassRunAndMatch(text, true); } +TEST_F(ElimDeadInputComponentsTest, ElimStructMember) { + // Should eliminate uv + // + // #version 450 + // + // in Vertex { + // vec4 Cd; + // vec2 uv; + // } iVert; + // + // out vec4 fragColor; + // + // void main() + // { + // vec4 color = vec4(iVert.Cd); + // fragColor = color; + // } + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %iVert %fragColor + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpName %main "main" + OpName %Vertex "Vertex" + OpMemberName %Vertex 0 "Cd" + OpMemberName %Vertex 1 "uv" + OpName %iVert "iVert" + OpName %fragColor "fragColor" + OpDecorate %Vertex Block + OpDecorate %iVert Location 0 + OpDecorate %fragColor Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %v2float = OpTypeVector %float 2 + %Vertex = OpTypeStruct %v4float %v2float +; CHECK: %Vertex = OpTypeStruct %v4float %v2float +; CHECK: [[sty:%\w+]] = OpTypeStruct %v4float +%_ptr_Input_Vertex = OpTypePointer Input %Vertex +; CHECK: [[pty:%\w+]] = OpTypePointer Input [[sty]] + %iVert = OpVariable %_ptr_Input_Vertex Input +; CHECK: %iVert = OpVariable [[pty]] Input + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float + %fragColor = OpVariable %_ptr_Output_v4float Output + %main = OpFunction %void None %3 + %5 = OpLabel + %17 = OpAccessChain %_ptr_Input_v4float %iVert %int_0 + %18 = OpLoad %v4float %17 + OpStore %fragColor %18 + OpReturn + OpFunctionEnd +)"; + + SetTargetEnv(SPV_ENV_VULKAN_1_3); + SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + SinglePassRunAndMatch(text, true); +} + } // namespace } // namespace opt } // namespace spvtools