From 213a8ce6d5c299a4c255330b41a2f5dfa05fec80 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Thu, 21 Mar 2019 18:57:23 +0000 Subject: [PATCH] SpirvShader: Implement OpSwitch Tests: dEQP-VK.spirv_assembly.instruction.compute.* Tests: dEQP-VK.spirv_assembly.instruction.graphics.* Bug: b/128527271 Change-Id: I7ba31ca504a582a4d36d25ef2747fb1c1607bade Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/27775 Presubmit-Ready: Ben Clayton Tested-by: Ben Clayton Reviewed-by: Nicolas Capens Kokoro-Presubmit: kokoro --- src/Pipeline/SpirvShader.cpp | 38 +++ src/Pipeline/SpirvShader.hpp | 1 + tests/VulkanUnitTests/unittests.cpp | 452 ++++++++++++++++++++++++++++ 3 files changed, 491 insertions(+) diff --git a/src/Pipeline/SpirvShader.cpp b/src/Pipeline/SpirvShader.cpp index c226f1312bff3..cc81c9fbcd144 100644 --- a/src/Pipeline/SpirvShader.cpp +++ b/src/Pipeline/SpirvShader.cpp @@ -1179,6 +1179,8 @@ namespace sw case Block::Simple: case Block::StructuredBranchConditional: case Block::UnstructuredBranchConditional: + case Block::StructuredSwitch: + case Block::UnstructuredSwitch: if (id != mainBlockId) { // Emit all preceeding blocks and set the activeLaneMask. @@ -1404,6 +1406,9 @@ namespace sw case spv::OpBranchConditional: return EmitBranchConditional(insn, state); + case spv::OpSwitch: + return EmitSwitch(insn, state); + case spv::OpUnreachable: return EmitUnreachable(insn, state); @@ -2638,6 +2643,39 @@ namespace sw return EmitResult::Terminator; } + SpirvShader::EmitResult SpirvShader::EmitSwitch(InsnIterator insn, EmitState *state) const + { + auto block = getBlock(state->currentBlock); + ASSERT(block.branchInstruction == insn); + + auto selId = Object::ID(block.branchInstruction.word(1)); + + auto sel = GenericValue(this, state->routine, selId); + ASSERT_MSG(getType(getObject(selId).type).sizeInComponents == 1, "Selector must be a scalar"); + + auto numCases = (block.branchInstruction.wordCount() - 3) / 2; + + // TODO: Optimize for case where all lanes take same path. + + SIMD::Int defaultLaneMask = state->activeLaneMask(); + + // Gather up the case label matches and calculate defaultLaneMask. + std::vector> caseLabelMatches; + caseLabelMatches.reserve(numCases); + for (uint32_t i = 0; i < numCases; i++) + { + auto label = block.branchInstruction.word(i * 2 + 3); + auto caseBlockId = Block::ID(block.branchInstruction.word(i * 2 + 4)); + auto caseLabelMatch = CmpEQ(sel.Int(0), SIMD::Int(label)); + state->addOutputActiveLaneMaskEdge(caseBlockId, caseLabelMatch); + defaultLaneMask &= ~caseLabelMatch; + } + + auto defaultBlockId = Block::ID(block.branchInstruction.word(2)); + state->addOutputActiveLaneMaskEdge(defaultBlockId, defaultLaneMask); + + return EmitResult::Terminator; + } SpirvShader::EmitResult SpirvShader::EmitUnreachable(InsnIterator insn, EmitState *state) const { diff --git a/src/Pipeline/SpirvShader.hpp b/src/Pipeline/SpirvShader.hpp index 8268c710dfbbb..6d727459b1646 100644 --- a/src/Pipeline/SpirvShader.hpp +++ b/src/Pipeline/SpirvShader.hpp @@ -602,6 +602,7 @@ namespace sw EmitResult EmitAll(InsnIterator insn, EmitState *state) const; EmitResult EmitBranch(InsnIterator insn, EmitState *state) const; EmitResult EmitBranchConditional(InsnIterator insn, EmitState *state) const; + EmitResult EmitSwitch(InsnIterator insn, EmitState *state) const; EmitResult EmitUnreachable(InsnIterator insn, EmitState *state) const; EmitResult EmitReturn(InsnIterator insn, EmitState *state) const; EmitResult EmitPhi(InsnIterator insn, EmitState *state) const; diff --git a/tests/VulkanUnitTests/unittests.cpp b/tests/VulkanUnitTests/unittests.cpp index 2dca716ea92df..d2bfcc51c6c30 100644 --- a/tests/VulkanUnitTests/unittests.cpp +++ b/tests/VulkanUnitTests/unittests.cpp @@ -917,3 +917,455 @@ TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalPhi) test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 1 : 2; }); } +TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchEmptyCases) +{ + std::stringstream src; + src << + "OpCapability Shader\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint GLCompute %1 \"main\" %2\n" + "OpExecutionMode %1 LocalSize " << + GetParam().localSizeX << " " << + GetParam().localSizeY << " " << + GetParam().localSizeZ << "\n" << + "OpDecorate %3 ArrayStride 4\n" + "OpMemberDecorate %4 0 Offset 0\n" + "OpDecorate %4 BufferBlock\n" + "OpDecorate %5 DescriptorSet 0\n" + "OpDecorate %5 Binding 1\n" + "OpDecorate %2 BuiltIn GlobalInvocationId\n" + "OpDecorate %6 DescriptorSet 0\n" + "OpDecorate %6 Binding 0\n" + "%7 = OpTypeVoid\n" + "%8 = OpTypeFunction %7\n" // void() + "%9 = OpTypeInt 32 1\n" // int32 + "%10 = OpTypeInt 32 0\n" // uint32 + "%11 = OpTypeBool\n" + "%3 = OpTypeRuntimeArray %9\n" // int32[] + "%4 = OpTypeStruct %3\n" // struct{ int32[] } + "%12 = OpTypePointer Uniform %4\n" // struct{ int32[] }* + "%5 = OpVariable %12 Uniform\n" // struct{ int32[] }* in + "%13 = OpConstant %9 0\n" // int32(0) + "%14 = OpConstant %9 2\n" // int32(2) + "%15 = OpConstant %10 0\n" // uint32(0) + "%16 = OpTypeVector %10 3\n" // vec4 + "%17 = OpTypePointer Input %16\n" // vec4* + "%2 = OpVariable %17 Input\n" // gl_GlobalInvocationId + "%18 = OpTypePointer Input %10\n" // uint32* + "%6 = OpVariable %12 Uniform\n" // struct{ int32[] }* out + "%19 = OpTypePointer Uniform %9\n" // int32* + "%1 = OpFunction %7 None %8\n" // -- Function begin -- + "%20 = OpLabel\n" + "%21 = OpAccessChain %18 %2 %15\n" // &gl_GlobalInvocationId.x + "%22 = OpLoad %10 %21\n" // gl_GlobalInvocationId.x + "%23 = OpAccessChain %19 %6 %13 %22\n" // &in.arr[gl_GlobalInvocationId.x] + "%24 = OpLoad %9 %23\n" // in.arr[gl_GlobalInvocationId.x] + "%25 = OpAccessChain %19 %5 %13 %22\n" // &out.arr[gl_GlobalInvocationId.x] + // Start of branch logic + // %24 = in value + "%26 = OpSMod %9 %24 %14\n" // in % 2 + "OpSelectionMerge %27 None\n" + "OpSwitch %26 %27 0 %28 1 %29\n" + "%28 = OpLabel\n" // (in % 2) == 0 + "OpBranch %27\n" + "%29 = OpLabel\n" // (in % 2) == 1 + "OpBranch %27\n" + "%27 = OpLabel\n" + // %26 = out value + // End of branch logic + "OpStore %25 %26\n" // use SSA value from previous block + "OpReturn\n" + "OpFunctionEnd\n"; + + test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i%2; }); +} + +TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchStore) +{ + std::stringstream src; + src << + "OpCapability Shader\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint GLCompute %1 \"main\" %2\n" + "OpExecutionMode %1 LocalSize " << + GetParam().localSizeX << " " << + GetParam().localSizeY << " " << + GetParam().localSizeZ << "\n" << + "OpDecorate %3 ArrayStride 4\n" + "OpMemberDecorate %4 0 Offset 0\n" + "OpDecorate %4 BufferBlock\n" + "OpDecorate %5 DescriptorSet 0\n" + "OpDecorate %5 Binding 1\n" + "OpDecorate %2 BuiltIn GlobalInvocationId\n" + "OpDecorate %6 DescriptorSet 0\n" + "OpDecorate %6 Binding 0\n" + "%7 = OpTypeVoid\n" + "%8 = OpTypeFunction %7\n" // void() + "%9 = OpTypeInt 32 1\n" // int32 + "%10 = OpTypeInt 32 0\n" // uint32 + "%11 = OpTypeBool\n" + "%3 = OpTypeRuntimeArray %9\n" // int32[] + "%4 = OpTypeStruct %3\n" // struct{ int32[] } + "%12 = OpTypePointer Uniform %4\n" // struct{ int32[] }* + "%5 = OpVariable %12 Uniform\n" // struct{ int32[] }* in + "%13 = OpConstant %9 0\n" // int32(0) + "%14 = OpConstant %9 1\n" // int32(1) + "%15 = OpConstant %9 2\n" // int32(2) + "%16 = OpConstant %10 0\n" // uint32(0) + "%17 = OpTypeVector %10 3\n" // vec4 + "%18 = OpTypePointer Input %17\n" // vec4* + "%2 = OpVariable %18 Input\n" // gl_GlobalInvocationId + "%19 = OpTypePointer Input %10\n" // uint32* + "%6 = OpVariable %12 Uniform\n" // struct{ int32[] }* out + "%20 = OpTypePointer Uniform %9\n" // int32* + "%1 = OpFunction %7 None %8\n" // -- Function begin -- + "%21 = OpLabel\n" + "%22 = OpAccessChain %19 %2 %16\n" // &gl_GlobalInvocationId.x + "%23 = OpLoad %10 %22\n" // gl_GlobalInvocationId.x + "%24 = OpAccessChain %20 %6 %13 %23\n" // &in.arr[gl_GlobalInvocationId.x] + "%25 = OpLoad %9 %24\n" // in.arr[gl_GlobalInvocationId.x] + "%26 = OpAccessChain %20 %5 %13 %23\n" // &out.arr[gl_GlobalInvocationId.x] + // Start of branch logic + // %25 = in value + "%27 = OpSMod %9 %25 %15\n" // in % 2 + "OpSelectionMerge %28 None\n" + "OpSwitch %27 %28 0 %29 1 %30\n" + "%29 = OpLabel\n" // (in % 2) == 0 + "OpStore %26 %15\n" // write 2 + "OpBranch %28\n" + "%30 = OpLabel\n" // (in % 2) == 1 + "OpStore %26 %14\n" // write 1 + "OpBranch %28\n" + "%28 = OpLabel\n" + // End of branch logic + "OpReturn\n" + "OpFunctionEnd\n"; + + test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 2 : 1; }); +} + +TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchCaseReturn) +{ + std::stringstream src; + src << + "OpCapability Shader\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint GLCompute %1 \"main\" %2\n" + "OpExecutionMode %1 LocalSize " << + GetParam().localSizeX << " " << + GetParam().localSizeY << " " << + GetParam().localSizeZ << "\n" << + "OpDecorate %3 ArrayStride 4\n" + "OpMemberDecorate %4 0 Offset 0\n" + "OpDecorate %4 BufferBlock\n" + "OpDecorate %5 DescriptorSet 0\n" + "OpDecorate %5 Binding 1\n" + "OpDecorate %2 BuiltIn GlobalInvocationId\n" + "OpDecorate %6 DescriptorSet 0\n" + "OpDecorate %6 Binding 0\n" + "%7 = OpTypeVoid\n" + "%8 = OpTypeFunction %7\n" // void() + "%9 = OpTypeInt 32 1\n" // int32 + "%10 = OpTypeInt 32 0\n" // uint32 + "%11 = OpTypeBool\n" + "%3 = OpTypeRuntimeArray %9\n" // int32[] + "%4 = OpTypeStruct %3\n" // struct{ int32[] } + "%12 = OpTypePointer Uniform %4\n" // struct{ int32[] }* + "%5 = OpVariable %12 Uniform\n" // struct{ int32[] }* in + "%13 = OpConstant %9 0\n" // int32(0) + "%14 = OpConstant %9 1\n" // int32(1) + "%15 = OpConstant %9 2\n" // int32(2) + "%16 = OpConstant %10 0\n" // uint32(0) + "%17 = OpTypeVector %10 3\n" // vec4 + "%18 = OpTypePointer Input %17\n" // vec4* + "%2 = OpVariable %18 Input\n" // gl_GlobalInvocationId + "%19 = OpTypePointer Input %10\n" // uint32* + "%6 = OpVariable %12 Uniform\n" // struct{ int32[] }* out + "%20 = OpTypePointer Uniform %9\n" // int32* + "%1 = OpFunction %7 None %8\n" // -- Function begin -- + "%21 = OpLabel\n" + "%22 = OpAccessChain %19 %2 %16\n" // &gl_GlobalInvocationId.x + "%23 = OpLoad %10 %22\n" // gl_GlobalInvocationId.x + "%24 = OpAccessChain %20 %6 %13 %23\n" // &in.arr[gl_GlobalInvocationId.x] + "%25 = OpLoad %9 %24\n" // in.arr[gl_GlobalInvocationId.x] + "%26 = OpAccessChain %20 %5 %13 %23\n" // &out.arr[gl_GlobalInvocationId.x] + // Start of branch logic + // %25 = in value + "%27 = OpSMod %9 %25 %15\n" // in % 2 + "OpSelectionMerge %28 None\n" + "OpSwitch %27 %28 0 %29 1 %30\n" + "%29 = OpLabel\n" // (in % 2) == 0 + "OpBranch %28\n" + "%30 = OpLabel\n" // (in % 2) == 1 + "OpReturn\n" + "%28 = OpLabel\n" + "OpStore %26 %14\n" // write 1 + // End of branch logic + "OpReturn\n" + "OpFunctionEnd\n"; + + test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 0 : 1; }); +} + +TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchDefaultReturn) +{ + std::stringstream src; + src << + "OpCapability Shader\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint GLCompute %1 \"main\" %2\n" + "OpExecutionMode %1 LocalSize " << + GetParam().localSizeX << " " << + GetParam().localSizeY << " " << + GetParam().localSizeZ << "\n" << + "OpDecorate %3 ArrayStride 4\n" + "OpMemberDecorate %4 0 Offset 0\n" + "OpDecorate %4 BufferBlock\n" + "OpDecorate %5 DescriptorSet 0\n" + "OpDecorate %5 Binding 1\n" + "OpDecorate %2 BuiltIn GlobalInvocationId\n" + "OpDecorate %6 DescriptorSet 0\n" + "OpDecorate %6 Binding 0\n" + "%7 = OpTypeVoid\n" + "%8 = OpTypeFunction %7\n" // void() + "%9 = OpTypeInt 32 1\n" // int32 + "%10 = OpTypeInt 32 0\n" // uint32 + "%11 = OpTypeBool\n" + "%3 = OpTypeRuntimeArray %9\n" // int32[] + "%4 = OpTypeStruct %3\n" // struct{ int32[] } + "%12 = OpTypePointer Uniform %4\n" // struct{ int32[] }* + "%5 = OpVariable %12 Uniform\n" // struct{ int32[] }* in + "%13 = OpConstant %9 0\n" // int32(0) + "%14 = OpConstant %9 1\n" // int32(1) + "%15 = OpConstant %9 2\n" // int32(2) + "%16 = OpConstant %10 0\n" // uint32(0) + "%17 = OpTypeVector %10 3\n" // vec4 + "%18 = OpTypePointer Input %17\n" // vec4* + "%2 = OpVariable %18 Input\n" // gl_GlobalInvocationId + "%19 = OpTypePointer Input %10\n" // uint32* + "%6 = OpVariable %12 Uniform\n" // struct{ int32[] }* out + "%20 = OpTypePointer Uniform %9\n" // int32* + "%1 = OpFunction %7 None %8\n" // -- Function begin -- + "%21 = OpLabel\n" + "%22 = OpAccessChain %19 %2 %16\n" // &gl_GlobalInvocationId.x + "%23 = OpLoad %10 %22\n" // gl_GlobalInvocationId.x + "%24 = OpAccessChain %20 %6 %13 %23\n" // &in.arr[gl_GlobalInvocationId.x] + "%25 = OpLoad %9 %24\n" // in.arr[gl_GlobalInvocationId.x] + "%26 = OpAccessChain %20 %5 %13 %23\n" // &out.arr[gl_GlobalInvocationId.x] + // Start of branch logic + // %25 = in value + "%27 = OpSMod %9 %25 %15\n" // in % 2 + "OpSelectionMerge %28 None\n" + "OpSwitch %27 %29 1 %30\n" + "%30 = OpLabel\n" // (in % 2) == 1 + "OpBranch %28\n" + "%29 = OpLabel\n" // (in % 2) != 1 + "OpReturn\n" + "%28 = OpLabel\n" // merge + "OpStore %26 %14\n" // write 1 + // End of branch logic + "OpReturn\n" + "OpFunctionEnd\n"; + + test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 0; }); +} + +TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchCaseFallthrough) +{ + std::stringstream src; + src << + "OpCapability Shader\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint GLCompute %1 \"main\" %2\n" + "OpExecutionMode %1 LocalSize " << + GetParam().localSizeX << " " << + GetParam().localSizeY << " " << + GetParam().localSizeZ << "\n" << + "OpDecorate %3 ArrayStride 4\n" + "OpMemberDecorate %4 0 Offset 0\n" + "OpDecorate %4 BufferBlock\n" + "OpDecorate %5 DescriptorSet 0\n" + "OpDecorate %5 Binding 1\n" + "OpDecorate %2 BuiltIn GlobalInvocationId\n" + "OpDecorate %6 DescriptorSet 0\n" + "OpDecorate %6 Binding 0\n" + "%7 = OpTypeVoid\n" + "%8 = OpTypeFunction %7\n" // void() + "%9 = OpTypeInt 32 1\n" // int32 + "%10 = OpTypeInt 32 0\n" // uint32 + "%11 = OpTypeBool\n" + "%3 = OpTypeRuntimeArray %9\n" // int32[] + "%4 = OpTypeStruct %3\n" // struct{ int32[] } + "%12 = OpTypePointer Uniform %4\n" // struct{ int32[] }* + "%5 = OpVariable %12 Uniform\n" // struct{ int32[] }* in + "%13 = OpConstant %9 0\n" // int32(0) + "%14 = OpConstant %9 1\n" // int32(1) + "%15 = OpConstant %9 2\n" // int32(2) + "%16 = OpConstant %10 0\n" // uint32(0) + "%17 = OpTypeVector %10 3\n" // vec4 + "%18 = OpTypePointer Input %17\n" // vec4* + "%2 = OpVariable %18 Input\n" // gl_GlobalInvocationId + "%19 = OpTypePointer Input %10\n" // uint32* + "%6 = OpVariable %12 Uniform\n" // struct{ int32[] }* out + "%20 = OpTypePointer Uniform %9\n" // int32* + "%1 = OpFunction %7 None %8\n" // -- Function begin -- + "%21 = OpLabel\n" + "%22 = OpAccessChain %19 %2 %16\n" // &gl_GlobalInvocationId.x + "%23 = OpLoad %10 %22\n" // gl_GlobalInvocationId.x + "%24 = OpAccessChain %20 %6 %13 %23\n" // &in.arr[gl_GlobalInvocationId.x] + "%25 = OpLoad %9 %24\n" // in.arr[gl_GlobalInvocationId.x] + "%26 = OpAccessChain %20 %5 %13 %23\n" // &out.arr[gl_GlobalInvocationId.x] + // Start of branch logic + // %25 = in value + "%27 = OpSMod %9 %25 %15\n" // in % 2 + "OpSelectionMerge %28 None\n" + "OpSwitch %27 %29 0 %30 1 %31\n" + "%30 = OpLabel\n" // (in % 2) == 0 + "%32 = OpIAdd %9 %27 %14\n" // generate an intermediate + "OpStore %26 %32\n" // write a value (overwritten later) + "OpBranch %31\n" // fallthrough + "%31 = OpLabel\n" // (in % 2) == 1 + "OpStore %26 %15\n" // write 2 + "OpBranch %28\n" + "%29 = OpLabel\n" // unreachable + "OpUnreachable\n" + "%28 = OpLabel\n" // merge + // End of branch logic + "OpReturn\n" + "OpFunctionEnd\n"; + + test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return 2; }); +} + +TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchDefaultFallthrough) +{ + std::stringstream src; + src << + "OpCapability Shader\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint GLCompute %1 \"main\" %2\n" + "OpExecutionMode %1 LocalSize " << + GetParam().localSizeX << " " << + GetParam().localSizeY << " " << + GetParam().localSizeZ << "\n" << + "OpDecorate %3 ArrayStride 4\n" + "OpMemberDecorate %4 0 Offset 0\n" + "OpDecorate %4 BufferBlock\n" + "OpDecorate %5 DescriptorSet 0\n" + "OpDecorate %5 Binding 1\n" + "OpDecorate %2 BuiltIn GlobalInvocationId\n" + "OpDecorate %6 DescriptorSet 0\n" + "OpDecorate %6 Binding 0\n" + "%7 = OpTypeVoid\n" + "%8 = OpTypeFunction %7\n" // void() + "%9 = OpTypeInt 32 1\n" // int32 + "%10 = OpTypeInt 32 0\n" // uint32 + "%11 = OpTypeBool\n" + "%3 = OpTypeRuntimeArray %9\n" // int32[] + "%4 = OpTypeStruct %3\n" // struct{ int32[] } + "%12 = OpTypePointer Uniform %4\n" // struct{ int32[] }* + "%5 = OpVariable %12 Uniform\n" // struct{ int32[] }* in + "%13 = OpConstant %9 0\n" // int32(0) + "%14 = OpConstant %9 1\n" // int32(1) + "%15 = OpConstant %9 2\n" // int32(2) + "%16 = OpConstant %10 0\n" // uint32(0) + "%17 = OpTypeVector %10 3\n" // vec4 + "%18 = OpTypePointer Input %17\n" // vec4* + "%2 = OpVariable %18 Input\n" // gl_GlobalInvocationId + "%19 = OpTypePointer Input %10\n" // uint32* + "%6 = OpVariable %12 Uniform\n" // struct{ int32[] }* out + "%20 = OpTypePointer Uniform %9\n" // int32* + "%1 = OpFunction %7 None %8\n" // -- Function begin -- + "%21 = OpLabel\n" + "%22 = OpAccessChain %19 %2 %16\n" // &gl_GlobalInvocationId.x + "%23 = OpLoad %10 %22\n" // gl_GlobalInvocationId.x + "%24 = OpAccessChain %20 %6 %13 %23\n" // &in.arr[gl_GlobalInvocationId.x] + "%25 = OpLoad %9 %24\n" // in.arr[gl_GlobalInvocationId.x] + "%26 = OpAccessChain %20 %5 %13 %23\n" // &out.arr[gl_GlobalInvocationId.x] + // Start of branch logic + // %25 = in value + "%27 = OpSMod %9 %25 %15\n" // in % 2 + "OpSelectionMerge %28 None\n" + "OpSwitch %27 %29 0 %30 1 %31\n" + "%30 = OpLabel\n" // (in % 2) == 0 + "%32 = OpIAdd %9 %27 %14\n" // generate an intermediate + "OpStore %26 %32\n" // write a value (overwritten later) + "OpBranch %29\n" // fallthrough + "%29 = OpLabel\n" // default + "%33 = OpIAdd %9 %27 %14\n" // generate an intermediate + "OpStore %26 %33\n" // write a value (overwritten later) + "OpBranch %31\n" // fallthrough + "%31 = OpLabel\n" // (in % 2) == 1 + "OpStore %26 %15\n" // write 2 + "OpBranch %28\n" + "%28 = OpLabel\n" // merge + // End of branch logic + "OpReturn\n" + "OpFunctionEnd\n"; + + test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return 2; }); +} + +TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchPhi) +{ + std::stringstream src; + src << + "OpCapability Shader\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint GLCompute %1 \"main\" %2\n" + "OpExecutionMode %1 LocalSize " << + GetParam().localSizeX << " " << + GetParam().localSizeY << " " << + GetParam().localSizeZ << "\n" << + "OpDecorate %3 ArrayStride 4\n" + "OpMemberDecorate %4 0 Offset 0\n" + "OpDecorate %4 BufferBlock\n" + "OpDecorate %5 DescriptorSet 0\n" + "OpDecorate %5 Binding 1\n" + "OpDecorate %2 BuiltIn GlobalInvocationId\n" + "OpDecorate %6 DescriptorSet 0\n" + "OpDecorate %6 Binding 0\n" + "%7 = OpTypeVoid\n" + "%8 = OpTypeFunction %7\n" // void() + "%9 = OpTypeInt 32 1\n" // int32 + "%10 = OpTypeInt 32 0\n" // uint32 + "%11 = OpTypeBool\n" + "%3 = OpTypeRuntimeArray %9\n" // int32[] + "%4 = OpTypeStruct %3\n" // struct{ int32[] } + "%12 = OpTypePointer Uniform %4\n" // struct{ int32[] }* + "%5 = OpVariable %12 Uniform\n" // struct{ int32[] }* in + "%13 = OpConstant %9 0\n" // int32(0) + "%14 = OpConstant %9 1\n" // int32(1) + "%15 = OpConstant %9 2\n" // int32(2) + "%16 = OpConstant %10 0\n" // uint32(0) + "%17 = OpTypeVector %10 3\n" // vec4 + "%18 = OpTypePointer Input %17\n" // vec4* + "%2 = OpVariable %18 Input\n" // gl_GlobalInvocationId + "%19 = OpTypePointer Input %10\n" // uint32* + "%6 = OpVariable %12 Uniform\n" // struct{ int32[] }* out + "%20 = OpTypePointer Uniform %9\n" // int32* + "%1 = OpFunction %7 None %8\n" // -- Function begin -- + "%21 = OpLabel\n" + "%22 = OpAccessChain %19 %2 %16\n" // &gl_GlobalInvocationId.x + "%23 = OpLoad %10 %22\n" // gl_GlobalInvocationId.x + "%24 = OpAccessChain %20 %6 %13 %23\n" // &in.arr[gl_GlobalInvocationId.x] + "%25 = OpLoad %9 %24\n" // in.arr[gl_GlobalInvocationId.x] + "%26 = OpAccessChain %20 %5 %13 %23\n" // &out.arr[gl_GlobalInvocationId.x] + // Start of branch logic + // %25 = in value + "%27 = OpSMod %9 %25 %15\n" // in % 2 + "OpSelectionMerge %28 None\n" + "OpSwitch %27 %29 1 %30\n" + "%30 = OpLabel\n" // (in % 2) == 1 + "OpBranch %28\n" + "%29 = OpLabel\n" // (in % 2) != 1 + "OpBranch %28\n" + "%28 = OpLabel\n" // merge + "%31 = OpPhi %9 %14 %30 %15 %29\n" // (in % 2) == 1 ? 1 : 2 + "OpStore %26 %31\n" + // End of branch logic + "OpReturn\n" + "OpFunctionEnd\n"; + + test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 2; }); +}