diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorSpirv.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorSpirv.java index f8c54d890c..764233597d 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorSpirv.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorSpirv.java @@ -146,6 +146,7 @@ private Set> getChildVisitors() { VisitorOpsFunction.class, VisitorOpsLogical.class, VisitorOpsMemory.class, + VisitorOpsMisc.class, VisitorOpsSetting.class, VisitorOpsType.class ); diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsMisc.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsMisc.java new file mode 100644 index 0000000000..96b1de6a85 --- /dev/null +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsMisc.java @@ -0,0 +1,38 @@ +package com.dat3m.dartagnan.parsers.program.visitors.spirv; + +import com.dat3m.dartagnan.exception.ParsingException; +import com.dat3m.dartagnan.expression.Expression; +import com.dat3m.dartagnan.expression.Type; +import com.dat3m.dartagnan.expression.type.VoidType; +import com.dat3m.dartagnan.parsers.SpirvBaseVisitor; +import com.dat3m.dartagnan.parsers.SpirvParser; +import com.dat3m.dartagnan.parsers.program.visitors.spirv.builders.ProgramBuilder; + +import java.util.Set; + +public class VisitorOpsMisc extends SpirvBaseVisitor { + + private final ProgramBuilder builder; + + public VisitorOpsMisc(ProgramBuilder builder) { + this.builder = builder; + } + + @Override + public Expression visitOpUndef(SpirvParser.OpUndefContext ctx) { + String id = ctx.idResult().getText(); + Type type = builder.getType(ctx.idResultType().getText()); + if (!(type instanceof VoidType)) { + Expression expression = builder.makeUndefinedValue(type); + return builder.addExpression(id, expression); + } + throw new ParsingException("Illegal definition '%s': " + + "OpUndef cannot have void type", id); + } + + public Set getSupportedOps() { + return Set.of( + "OpUndef" + ); + } +} diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/utils/MemoryTransformer.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/utils/MemoryTransformer.java index 610ca8c89f..6a668a1030 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/utils/MemoryTransformer.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/utils/MemoryTransformer.java @@ -28,6 +28,7 @@ public class MemoryTransformer extends ExprTransformer { // Thread / Subgroup / Workgroup / QueueFamily / Device private static final List namePrefixes = List.of("T", "S", "W", "Q", "D"); + private final Program program; private final Function function; private final BuiltIn builtIn; private final List> scopeMapping; @@ -35,9 +36,11 @@ public class MemoryTransformer extends ExprTransformer { private final List scopeIdProvider; private final List namePrefixIdxProvider; private Map registerMapping; + private Map nonDetMapping; private int tid; public MemoryTransformer(ThreadGrid grid, Function function, BuiltIn builtIn, Set variables) { + this.program = function.getProgram(); this.function = function; this.builtIn = builtIn; this.scopeMapping = Stream.generate(() -> new HashMap()).limit(namePrefixes.size()).toList(); @@ -65,6 +68,7 @@ public void setThread(Thread thread) { builtIn.setThreadId(tid); registerMapping = function.getRegisters().stream().collect( toMap(r -> r, r -> thread.getOrNewRegister(r.getName(), r.getType()))); + nonDetMapping = new HashMap<>(); } @Override @@ -72,6 +76,11 @@ public Expression visitRegister(Register register) { return registerMapping.get(register); } + @Override + public Expression visitNonDetValue(NonDetValue nonDetValue) { + return nonDetMapping.computeIfAbsent(nonDetValue, x -> (NonDetValue) program.newConstant(x.getType())); + } + @Override public Expression visitMemoryObject(MemoryObject memObj) { String storageClass = pointerMapping.get(memObj).getScopeId(); diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/basic/SpirvAssertionsTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/basic/SpirvAssertionsTest.java index 24fa6af270..b1074afee3 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/basic/SpirvAssertionsTest.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/spirv/basic/SpirvAssertionsTest.java @@ -62,6 +62,8 @@ public static Iterable data() throws IOException { {"uninitialized-forall.spv.dis", 1, FAIL}, {"uninitialized-private-exists.spv.dis", 1, PASS}, {"uninitialized-private-forall.spv.dis", 1, FAIL}, + {"undef-exists.spv.dis", 1, PASS}, + {"undef-forall.spv.dis", 1, FAIL}, {"read-write.spv.dis", 1, PASS}, {"vector-init.spv.dis", 1, PASS}, {"vector.spv.dis", 1, PASS}, diff --git a/dartagnan/src/test/resources/spirv/basic/undef-exists.spv.dis b/dartagnan/src/test/resources/spirv/basic/undef-exists.spv.dis new file mode 100644 index 0000000000..00af2c9db1 --- /dev/null +++ b/dartagnan/src/test/resources/spirv/basic/undef-exists.spv.dis @@ -0,0 +1,33 @@ +; @Input: %out = {0, 0} +; @Output: exists (%out[0] == %out[1]) +; @Config: 2, 1, 1 +; SPIR-V +; Version: 1.0 +; Schema: 0 + OpCapability Shader + OpCapability VulkanMemoryModel + OpMemoryModel Logical Vulkan + OpEntryPoint GLCompute %main "main" %ids + OpSource GLSL 450 + OpDecorate %ids BuiltIn GlobalInvocationId + %void = OpTypeVoid + %bool = OpTypeBool + %func = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 + %v2uint = OpTypeVector %uint 2 + %ptr_uint = OpTypePointer Private %uint + %ptr_v3uint = OpTypePointer Input %v3uint + %ptr_v2uint = OpTypePointer Output %v2uint + %c0 = OpConstant %uint 0 + %ids = OpVariable %ptr_v3uint Input + %out = OpVariable %ptr_v2uint Output + %undef = OpUndef %uint + %main = OpFunction %void None %func + %label = OpLabel + %id_ptr = OpAccessChain %ptr_uint %ids %c0 + %id = OpLoad %uint %id_ptr + %ptr_out = OpAccessChain %ptr_uint %out %id + OpStore %ptr_out %undef + OpReturn + OpFunctionEnd diff --git a/dartagnan/src/test/resources/spirv/basic/undef-forall.spv.dis b/dartagnan/src/test/resources/spirv/basic/undef-forall.spv.dis new file mode 100644 index 0000000000..b17eed1423 --- /dev/null +++ b/dartagnan/src/test/resources/spirv/basic/undef-forall.spv.dis @@ -0,0 +1,33 @@ +; @Input: %out = {0, 0} +; @Output: forall (%out[0] == %out[1]) +; @Config: 2, 1, 1 +; SPIR-V +; Version: 1.0 +; Schema: 0 + OpCapability Shader + OpCapability VulkanMemoryModel + OpMemoryModel Logical Vulkan + OpEntryPoint GLCompute %main "main" %ids + OpSource GLSL 450 + OpDecorate %ids BuiltIn GlobalInvocationId + %void = OpTypeVoid + %bool = OpTypeBool + %func = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 + %v2uint = OpTypeVector %uint 2 + %ptr_uint = OpTypePointer Private %uint + %ptr_v3uint = OpTypePointer Input %v3uint + %ptr_v2uint = OpTypePointer Output %v2uint + %c0 = OpConstant %uint 0 + %ids = OpVariable %ptr_v3uint Input + %out = OpVariable %ptr_v2uint Output + %undef = OpUndef %uint + %main = OpFunction %void None %func + %label = OpLabel + %id_ptr = OpAccessChain %ptr_uint %ids %c0 + %id = OpLoad %uint %id_ptr + %ptr_out = OpAccessChain %ptr_uint %out %id + OpStore %ptr_out %undef + OpReturn + OpFunctionEnd