Skip to content

Commit

Permalink
Spirv OpUndef (#769)
Browse files Browse the repository at this point in the history
  • Loading branch information
natgavrilenko authored Nov 6, 2024
1 parent d2ccf4f commit 4d03388
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ private Set<Class<?>> getChildVisitors() {
VisitorOpsFunction.class,
VisitorOpsLogical.class,
VisitorOpsMemory.class,
VisitorOpsMisc.class,
VisitorOpsSetting.class,
VisitorOpsType.class
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package com.dat3m.dartagnan.parsers.program.visitors.spirv;

import com.dat3m.dartagnan.exception.ParsingException;
import com.dat3m.dartagnan.expression.Expression;
import com.dat3m.dartagnan.expression.Type;
import com.dat3m.dartagnan.expression.type.VoidType;
import com.dat3m.dartagnan.parsers.SpirvBaseVisitor;
import com.dat3m.dartagnan.parsers.SpirvParser;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.builders.ProgramBuilder;

import java.util.Set;

public class VisitorOpsMisc extends SpirvBaseVisitor<Expression> {

private final ProgramBuilder builder;

public VisitorOpsMisc(ProgramBuilder builder) {
this.builder = builder;
}

@Override
public Expression visitOpUndef(SpirvParser.OpUndefContext ctx) {
String id = ctx.idResult().getText();
Type type = builder.getType(ctx.idResultType().getText());
if (!(type instanceof VoidType)) {
Expression expression = builder.makeUndefinedValue(type);
return builder.addExpression(id, expression);
}
throw new ParsingException("Illegal definition '%s': " +
"OpUndef cannot have void type", id);
}

public Set<String> getSupportedOps() {
return Set.of(
"OpUndef"
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@ public class MemoryTransformer extends ExprTransformer {
// Thread / Subgroup / Workgroup / QueueFamily / Device
private static final List<String> namePrefixes = List.of("T", "S", "W", "Q", "D");

private final Program program;
private final Function function;
private final BuiltIn builtIn;
private final List<? extends Map<MemoryObject, MemoryObject>> scopeMapping;
private final Map<MemoryObject, ScopedPointerVariable> pointerMapping;
private final List<IntUnaryOperator> scopeIdProvider;
private final List<IntUnaryOperator> namePrefixIdxProvider;
private Map<Register, Register> registerMapping;
private Map<NonDetValue, NonDetValue> nonDetMapping;
private int tid;

public MemoryTransformer(ThreadGrid grid, Function function, BuiltIn builtIn, Set<ScopedPointerVariable> variables) {
this.program = function.getProgram();
this.function = function;
this.builtIn = builtIn;
this.scopeMapping = Stream.generate(() -> new HashMap<MemoryObject, MemoryObject>()).limit(namePrefixes.size()).toList();
Expand Down Expand Up @@ -65,13 +68,19 @@ public void setThread(Thread thread) {
builtIn.setThreadId(tid);
registerMapping = function.getRegisters().stream().collect(
toMap(r -> r, r -> thread.getOrNewRegister(r.getName(), r.getType())));
nonDetMapping = new HashMap<>();
}

@Override
public Expression visitRegister(Register register) {
return registerMapping.get(register);
}

@Override
public Expression visitNonDetValue(NonDetValue nonDetValue) {
return nonDetMapping.computeIfAbsent(nonDetValue, x -> (NonDetValue) program.newConstant(x.getType()));
}

@Override
public Expression visitMemoryObject(MemoryObject memObj) {
String storageClass = pointerMapping.get(memObj).getScopeId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ public static Iterable<Object[]> data() throws IOException {
{"uninitialized-forall.spv.dis", 1, FAIL},
{"uninitialized-private-exists.spv.dis", 1, PASS},
{"uninitialized-private-forall.spv.dis", 1, FAIL},
{"undef-exists.spv.dis", 1, PASS},
{"undef-forall.spv.dis", 1, FAIL},
{"read-write.spv.dis", 1, PASS},
{"vector-init.spv.dis", 1, PASS},
{"vector.spv.dis", 1, PASS},
Expand Down
33 changes: 33 additions & 0 deletions dartagnan/src/test/resources/spirv/basic/undef-exists.spv.dis
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
; @Input: %out = {0, 0}
; @Output: exists (%out[0] == %out[1])
; @Config: 2, 1, 1
; SPIR-V
; Version: 1.0
; Schema: 0
OpCapability Shader
OpCapability VulkanMemoryModel
OpMemoryModel Logical Vulkan
OpEntryPoint GLCompute %main "main" %ids
OpSource GLSL 450
OpDecorate %ids BuiltIn GlobalInvocationId
%void = OpTypeVoid
%bool = OpTypeBool
%func = OpTypeFunction %void
%uint = OpTypeInt 32 0
%v3uint = OpTypeVector %uint 3
%v2uint = OpTypeVector %uint 2
%ptr_uint = OpTypePointer Private %uint
%ptr_v3uint = OpTypePointer Input %v3uint
%ptr_v2uint = OpTypePointer Output %v2uint
%c0 = OpConstant %uint 0
%ids = OpVariable %ptr_v3uint Input
%out = OpVariable %ptr_v2uint Output
%undef = OpUndef %uint
%main = OpFunction %void None %func
%label = OpLabel
%id_ptr = OpAccessChain %ptr_uint %ids %c0
%id = OpLoad %uint %id_ptr
%ptr_out = OpAccessChain %ptr_uint %out %id
OpStore %ptr_out %undef
OpReturn
OpFunctionEnd
33 changes: 33 additions & 0 deletions dartagnan/src/test/resources/spirv/basic/undef-forall.spv.dis
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
; @Input: %out = {0, 0}
; @Output: forall (%out[0] == %out[1])
; @Config: 2, 1, 1
; SPIR-V
; Version: 1.0
; Schema: 0
OpCapability Shader
OpCapability VulkanMemoryModel
OpMemoryModel Logical Vulkan
OpEntryPoint GLCompute %main "main" %ids
OpSource GLSL 450
OpDecorate %ids BuiltIn GlobalInvocationId
%void = OpTypeVoid
%bool = OpTypeBool
%func = OpTypeFunction %void
%uint = OpTypeInt 32 0
%v3uint = OpTypeVector %uint 3
%v2uint = OpTypeVector %uint 2
%ptr_uint = OpTypePointer Private %uint
%ptr_v3uint = OpTypePointer Input %v3uint
%ptr_v2uint = OpTypePointer Output %v2uint
%c0 = OpConstant %uint 0
%ids = OpVariable %ptr_v3uint Input
%out = OpVariable %ptr_v2uint Output
%undef = OpUndef %uint
%main = OpFunction %void None %func
%label = OpLabel
%id_ptr = OpAccessChain %ptr_uint %ids %c0
%id = OpLoad %uint %id_ptr
%ptr_out = OpAccessChain %ptr_uint %out %id
OpStore %ptr_out %undef
OpReturn
OpFunctionEnd

0 comments on commit 4d03388

Please sign in to comment.