Skip to content

Commit

Permalink
Support explicitly defined offsets in aggregate type
Browse files Browse the repository at this point in the history
  • Loading branch information
natgavrilenko committed Nov 5, 2024
1 parent fa5a25c commit 550134e
Show file tree
Hide file tree
Showing 30 changed files with 426 additions and 299 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,7 @@ public Expression makeFloatCast(Expression operand, FloatType targetType, boolea
// -----------------------------------------------------------------------------------------------------------------
// Aggregates

public Expression makeConstruct(List<Expression> arguments) {
final AggregateType type = types.getAggregateType(arguments.stream().map(Expression::getType).toList());
public Expression makeConstruct(Type type, List<Expression> arguments) {
return new ConstructExpr(type, arguments);
}

Expand Down Expand Up @@ -302,11 +301,11 @@ public Expression makeGeneralZero(Type type) {
}
return makeArray(arrayType.getElementType(), zeroes, true);
} else if (type instanceof AggregateType structType) {
List<Expression> zeroes = new ArrayList<>(structType.getDirectFields().size());
for (Type fieldType : structType.getDirectFields()) {
zeroes.add(makeGeneralZero(fieldType));
List<Expression> zeroes = new ArrayList<>(structType.getTypeOffsets().size());
for (TypeOffset typeOffset : structType.getTypeOffsets()) {
zeroes.add(makeGeneralZero(typeOffset.type()));
}
return makeConstruct(zeroes);
return makeConstruct(structType, zeroes);
} else if (type instanceof IntegerType intType) {
return makeZero(intType);
} else if (type instanceof BooleanType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.dat3m.dartagnan.expression.base.NaryExpressionBase;
import com.dat3m.dartagnan.expression.type.AggregateType;
import com.dat3m.dartagnan.expression.type.ArrayType;
import com.dat3m.dartagnan.expression.type.TypeOffset;

import java.util.List;
import java.util.stream.Collectors;
Expand All @@ -20,7 +21,8 @@ public ConstructExpr(Type type, List<Expression> arguments) {
checkArgument(type instanceof AggregateType || type instanceof ArrayType,
"Non-constructible type %s.", type);
checkArgument(!(type instanceof AggregateType a) ||
arguments.stream().map(Expression::getType).toList().equals(a.getDirectFields()),
arguments.stream().map(Expression::getType).toList()
.equals(a.getTypeOffsets().stream().map(TypeOffset::type).toList()),
"Arguments do not match the constructor signature.");
checkArgument(!(type instanceof ArrayType a) ||
!a.hasKnownNumElements() ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
import com.dat3m.dartagnan.expression.base.UnaryExpressionBase;
import com.dat3m.dartagnan.expression.type.AggregateType;
import com.dat3m.dartagnan.expression.type.ArrayType;
import com.dat3m.dartagnan.expression.type.TypeOffset;
import com.google.common.base.Preconditions;

import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;

public final class ExtractExpr extends UnaryExpressionBase<Type, ExpressionKind.Other> {
Expand All @@ -25,13 +28,14 @@ private static Type extractType(Expression expr, int index) {
Preconditions.checkArgument(exprType instanceof AggregateType || exprType instanceof ArrayType,
"Cannot extract from a non-aggregate expression: (%s)[%d].", expr, index);
if (exprType instanceof AggregateType aggregateType) {
return aggregateType.getDirectFields().get(index);
} else {
final ArrayType arrayType = (ArrayType) exprType;
checkArgument(0 <= index && (!arrayType.hasKnownNumElements() || index < arrayType.getNumElements()),
"Index %s out of bounds [0,%s].", index, arrayType.getNumElements() - 1);
return arrayType.getElementType();
final List<TypeOffset> typeOffsets = aggregateType.getTypeOffsets();
checkArgument(0 <= index && index < typeOffsets.size());
return typeOffsets.get(index).type();
}
final ArrayType arrayType = (ArrayType) exprType;
checkArgument(0 <= index && (!arrayType.hasKnownNumElements() || index < arrayType.getNumElements()),
"Index %s out of bounds [0,%s].", index, arrayType.getNumElements() - 1);
return arrayType.getElementType();
}

public int getFieldIndex() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public Expression visitConstructExpression(ConstructExpr construct) {
for (final Expression argument : construct.getOperands()) {
arguments.add(argument.accept(this));
}
return expressions.makeConstruct(arguments);
return expressions.makeConstruct(construct.getType(), arguments);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,61 @@

import com.dat3m.dartagnan.expression.Type;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public final class AggregateType implements Type {
import static com.dat3m.dartagnan.expression.type.TypeFactory.paddedSize;

private final List<Type> fields;
public class AggregateType implements Type {

AggregateType(List<Type> directFields) {
this.fields = List.copyOf(directFields);
private static final TypeFactory types = TypeFactory.getInstance();

private final List<TypeOffset> directFields;

AggregateType(List<Type> fields) {
this(fields, computeDefaultOffsets(fields));
}

AggregateType(List<Type> fields, List<Integer> offsets) {
this.directFields = IntStream.range(0, fields.size()).boxed().map(i -> new TypeOffset(fields.get(i), offsets.get(i))).toList();
}

private static List<Integer> computeDefaultOffsets(List<Type> fields) {
List<Integer> offsets = new ArrayList<>();
int maxAlignment = 0;
int maxOffset = 0;
if (!fields.isEmpty()) {
maxAlignment = types.getAlignment(fields.get(0));
maxOffset = types.getMemorySizeInBytes(fields.get(0));
offsets.add(0);
}
for (int i = 1; i < fields.size(); i++) {
maxAlignment = Integer.max(types.getAlignment(fields.get(i)), maxAlignment);
maxOffset = paddedSize(maxOffset, maxAlignment);
offsets.add(maxOffset);
maxOffset += types.getMemorySizeInBytes(fields.get(i));
}
return offsets;
}

public List<Type> getDirectFields() {
return fields;
public List<TypeOffset> getTypeOffsets() {
return directFields;
}

@Override
public int hashCode() {
return fields.hashCode();
return directFields.hashCode();
}

@Override
public boolean equals(Object obj) {
return this == obj || obj instanceof AggregateType o && fields.equals(o.fields);
return this == obj || obj instanceof AggregateType o && directFields.equals(o.directFields);
}

@Override
public String toString() {
return fields.stream().map(Type::toString).collect(Collectors.joining(", ", "{ ", " }"));
return directFields.stream().map(f -> f.offset() + ": " + f.type()).collect(Collectors.joining(", ", "{ ", " }"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import com.dat3m.dartagnan.expression.Type;
import com.dat3m.dartagnan.utils.Normalizer;
import com.google.common.math.IntMath;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.math.RoundingMode;
import java.util.*;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
Expand Down Expand Up @@ -74,6 +75,22 @@ public AggregateType getAggregateType(List<Type> fields) {
return typeNormalizer.normalize(new AggregateType(fields));
}

public AggregateType getAggregateType(List<Type> fields, List<Integer> offsets) {
checkNotNull(fields);
checkNotNull(offsets);
checkArgument(fields.stream().noneMatch(t -> t == voidType), "Void fields are not allowed");
checkArgument(fields.size() == offsets.size(), "The number of offsets does not match the number of fields");
checkArgument(offsets.stream().noneMatch(o -> o < 0), "Offset cannot be negative");
checkArgument(offsets.isEmpty() || offsets.get(0) == 0, "The first offset must be zero");
checkArgument(IntStream.range(1, offsets.size()).boxed().allMatch(
i -> offsets.get(i) >= offsets.get(i - 1) + Integer.max(0, getMemorySizeInBytes(fields.get(i - 1)))),
"Offset is too small");
checkArgument(IntStream.range(0, offsets.size() - 1).boxed().allMatch(
i -> getMemorySizeInBytes(fields.get(i)) > 0),
"Only the last element of a structure can have unknown size");
return typeNormalizer.normalize(new AggregateType(fields, offsets));
}

public ArrayType getArrayType(Type element) {
return typeNormalizer.normalize(new ArrayType(element, -1));
}
Expand All @@ -92,7 +109,50 @@ public IntegerType getByteType() {
}

public int getMemorySizeInBytes(Type type) {
return TypeLayout.of(type).totalSizeInBytes();
if (type instanceof BooleanType) {
return 1;
}
if (type instanceof IntegerType integerType) {
return IntMath.divide(integerType.getBitWidth(), 8, RoundingMode.CEILING);
}
if (type instanceof FloatType floatType) {
return IntMath.divide(floatType.getBitWidth(), 8, RoundingMode.CEILING);
}
if (type instanceof ArrayType arrayType) {
Type elType = arrayType.getElementType();
return getMemorySizeInBytes(elType) * arrayType.getNumElements();
}
if (type instanceof AggregateType aType) {
List<TypeOffset> typeOffsets = aType.getTypeOffsets();
if (typeOffsets.isEmpty()) {
return 0;
}
TypeOffset lastTypeOffset = typeOffsets.get(typeOffsets.size() - 1);
int unpaddedSize = lastTypeOffset.offset() + getMemorySizeInBytes(lastTypeOffset.type());
return paddedSize(unpaddedSize, getAlignment(type));
}
throw new UnsupportedOperationException("Cannot compute memory layout of type " + type);
}

public int getAlignment(Type type) {
if (type instanceof BooleanType || type instanceof IntegerType || type instanceof FloatType) {
return getMemorySizeInBytes(type);
}
if (type instanceof ArrayType arrayType) {
return getMemorySizeInBytes(arrayType.getElementType());
}
if (type instanceof AggregateType aType) {
return aType.getTypeOffsets().stream().map(o -> getAlignment(o.type())).max(Integer::compare).orElseThrow();
}
throw new UnsupportedOperationException("Cannot compute memory layout of type " + type);
}

public static int paddedSize(int size, int alignment) {
int mod = size % alignment;
if (mod > 0) {
return size + alignment - mod;
}
return size;
}

public int getMemorySizeInBits(Type type) {
Expand All @@ -119,16 +179,13 @@ public Map<Integer, Type> decomposeIntoPrimitives(Type type) {
}
}
} else if (type instanceof AggregateType aggregateType) {
final List<Type> fields = aggregateType.getDirectFields();
for (int i = 0; i < fields.size(); i++) {
final int offset = getOffsetInBytes(aggregateType, i);
final Map<Integer, Type> innerDecomposition = decomposeIntoPrimitives(fields.get(i));
for (TypeOffset typeOffset : aggregateType.getTypeOffsets()) {
final Map<Integer, Type> innerDecomposition = decomposeIntoPrimitives(typeOffset.type());
if (innerDecomposition == null) {
return null;
}

for (Map.Entry<Integer, Type> entry : innerDecomposition.entrySet()) {
decomposition.put(entry.getKey() + offset, entry.getValue());
decomposition.put(typeOffset.offset() + entry.getKey(), entry.getValue());
}
}
} else {
Expand All @@ -147,12 +204,7 @@ public static boolean isStaticType(Type type) {
return aType.hasKnownNumElements() && isStaticType(aType.getElementType());
}
if (type instanceof AggregateType aType) {
for (Type elType : aType.getDirectFields()) {
if (!isStaticType(elType)) {
return false;
}
}
return true;
return aType.getTypeOffsets().stream().allMatch(o -> isStaticType(o.type()));
}
throw new UnsupportedOperationException("Cannot compute if type '" + type + "' is static");
}
Expand All @@ -162,12 +214,15 @@ public static boolean isStaticTypeOf(Type staticType, Type runtimeType) {
return true;
}
if (staticType instanceof AggregateType aStaticType && runtimeType instanceof AggregateType aRuntimeType) {
int size = aStaticType.getDirectFields().size();
if (size != aRuntimeType.getDirectFields().size()) {
int size = aStaticType.getTypeOffsets().size();
if (size != aRuntimeType.getTypeOffsets().size()) {
return false;
}
for (int i = 0; i < size; i++) {
if (!isStaticTypeOf(aStaticType.getDirectFields().get(i), aRuntimeType.getDirectFields().get(i))) {
TypeOffset staticTypeOffset = aStaticType.getTypeOffsets().get(i);
TypeOffset runtimeTypeOffset = aRuntimeType.getTypeOffsets().get(i);
if (staticTypeOffset.offset() != runtimeTypeOffset.offset()
|| !isStaticTypeOf(staticTypeOffset.type(), runtimeTypeOffset.type())) {
return false;
}
}
Expand Down

This file was deleted.

Loading

0 comments on commit 550134e

Please sign in to comment.