Skip to content

Commit

Permalink
Adding test for #41728
Browse files Browse the repository at this point in the history
  • Loading branch information
FiV0 committed May 20, 2024
1 parent cc951e6 commit 934466e
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@

import static org.junit.Assert.*;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.nio.ByteBuffer;
import java.util.*;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.util.hash.ArrowBufHasher;
import org.apache.arrow.vector.complex.AbstractStructVector;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.StructVector;
Expand All @@ -35,15 +34,19 @@
import org.apache.arrow.vector.holders.ComplexHolder;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.ArrowType.Struct;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.TestExtensionType;
import org.apache.arrow.vector.util.TransferPair;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import javax.xml.crypto.dsig.TransformService;

public class TestStructVector {

private BufferAllocator allocator;
Expand Down Expand Up @@ -82,6 +85,45 @@ public void testMakeTransferPair() {
}
}

@Test
public void testStructVectorWithExtensionTypes() {
TestExtensionType.UuidType uuidType = new TestExtensionType.UuidType();
Field uuidField = new Field("struct_child", FieldType.nullable(uuidType), null);
Field structField = new Field("struct", FieldType.nullable(new ArrowType.Struct()), List.of(uuidField));
// throws
StructVector s1 = new StructVector(structField, allocator, null);
// doesn't throw
StructVector s2 = (StructVector) structField.createVector(allocator);
s1.close();
s2.close();
}

@Test
public void testStructVectorTransferPairWithExtensionType() {
TestExtensionType.UuidType uuidType = new TestExtensionType.UuidType();
Field uuidField = new Field("uuid_child", FieldType.nullable(uuidType), null);
Field structField = new Field("struct", FieldType.nullable(new ArrowType.Struct()), List.of(uuidField));

StructVector s1 = (StructVector) structField.createVector(allocator);
TestExtensionType.UuidVector uuidVector = s1.addOrGet("uuid_child", FieldType.nullable(uuidType), TestExtensionType.UuidVector.class);
s1.setValueCount(1);
uuidVector.set(0, new UUID(1, 2));
s1.setIndexDefined(0);

TransferPair tp = s1.getTransferPair(structField, allocator);
final StructVector toVector = (StructVector) tp.getTo();
assertEquals(s1.getField(), toVector.getField());
assertEquals(s1.getField().getChildren().get(0), toVector.getField().getChildren().get(0));
// also fails but probably another issue
// assertEquals(s1.getValueCount(), toVector.getValueCount());
// assertEquals(s1, toVector);

s1.close();
toVector.close();
}



@Test
public void testAllocateAfterReAlloc() throws Exception {
Map<String, String> metadata = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,21 @@
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.Collections;
import java.util.Map;
import java.util.UUID;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.memory.util.hash.ArrowBufHasher;
import org.apache.arrow.vector.ExtensionTypeVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.FixedSizeBinaryVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.compare.Range;
import org.apache.arrow.vector.compare.RangeEqualsVisitor;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowFileWriter;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType;
import org.apache.arrow.vector.util.TransferPair;
import org.apache.arrow.vector.util.VectorBatchAppender;
import org.apache.arrow.vector.validate.ValidateVectorVisitor;
import org.junit.Assert;
Expand Down Expand Up @@ -85,21 +83,21 @@ public void roundtripUuid() throws IOException {
final ArrowFileReader reader = new ArrowFileReader(channel, allocator)) {
reader.loadNextBatch();
final VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot();
Assert.assertEquals(root.getSchema(), readerRoot.getSchema());
assertEquals(root.getSchema(), readerRoot.getSchema());

final Field field = readerRoot.getSchema().getFields().get(0);
final UuidType expectedType = new UuidType();
Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME),
assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME),
expectedType.extensionName());
Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA),
assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA),
expectedType.serialize());

final ExtensionTypeVector deserialized = (ExtensionTypeVector) readerRoot.getFieldVectors().get(0);
Assert.assertEquals(vector.getValueCount(), deserialized.getValueCount());
assertEquals(vector.getValueCount(), deserialized.getValueCount());
for (int i = 0; i < vector.getValueCount(); i++) {
Assert.assertEquals(vector.isNull(i), deserialized.isNull(i));
assertEquals(vector.isNull(i), deserialized.isNull(i));
if (!vector.isNull(i)) {
Assert.assertEquals(vector.getObject(i), deserialized.getObject(i));
assertEquals(vector.getObject(i), deserialized.getObject(i));
}
}
}
Expand Down Expand Up @@ -138,23 +136,23 @@ public void readUnderlyingType() throws IOException {
final ArrowFileReader reader = new ArrowFileReader(channel, allocator)) {
reader.loadNextBatch();
final VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot();
Assert.assertEquals(1, readerRoot.getSchema().getFields().size());
Assert.assertEquals("a", readerRoot.getSchema().getFields().get(0).getName());
Assert.assertTrue(readerRoot.getSchema().getFields().get(0).getType() instanceof ArrowType.FixedSizeBinary);
Assert.assertEquals(16,
assertEquals(1, readerRoot.getSchema().getFields().size());
assertEquals("a", readerRoot.getSchema().getFields().get(0).getName());
assertTrue(readerRoot.getSchema().getFields().get(0).getType() instanceof ArrowType.FixedSizeBinary);
assertEquals(16,
((ArrowType.FixedSizeBinary) readerRoot.getSchema().getFields().get(0).getType()).getByteWidth());

final Field field = readerRoot.getSchema().getFields().get(0);
final UuidType expectedType = new UuidType();
Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME),
assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME),
expectedType.extensionName());
Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA),
assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA),
expectedType.serialize());

final FixedSizeBinaryVector deserialized = (FixedSizeBinaryVector) readerRoot.getFieldVectors().get(0);
Assert.assertEquals(vector.getValueCount(), deserialized.getValueCount());
assertEquals(vector.getValueCount(), deserialized.getValueCount());
for (int i = 0; i < vector.getValueCount(); i++) {
Assert.assertEquals(vector.isNull(i), deserialized.isNull(i));
assertEquals(vector.isNull(i), deserialized.isNull(i));
if (!vector.isNull(i)) {
final UUID uuid = vector.getObject(i);
final ByteBuffer bb = ByteBuffer.allocate(16);
Expand Down Expand Up @@ -210,26 +208,26 @@ public void roundtripLocation() throws IOException {
final ArrowFileReader reader = new ArrowFileReader(channel, allocator)) {
reader.loadNextBatch();
final VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot();
Assert.assertEquals(root.getSchema(), readerRoot.getSchema());
assertEquals(root.getSchema(), readerRoot.getSchema());

final Field field = readerRoot.getSchema().getFields().get(0);
final LocationType expectedType = new LocationType();
Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME),
assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME),
expectedType.extensionName());
Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA),
assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA),
expectedType.serialize());

final ExtensionTypeVector deserialized = (ExtensionTypeVector) readerRoot.getFieldVectors().get(0);
Assert.assertTrue(deserialized instanceof LocationVector);
Assert.assertEquals("location", deserialized.getName());
assertTrue(deserialized instanceof LocationVector);
assertEquals("location", deserialized.getName());
StructVector deserStruct = (StructVector) deserialized.getUnderlyingVector();
Assert.assertNotNull(deserStruct.getChild("Latitude"));
Assert.assertNotNull(deserStruct.getChild("Longitude"));
Assert.assertEquals(vector.getValueCount(), deserialized.getValueCount());
assertEquals(vector.getValueCount(), deserialized.getValueCount());
for (int i = 0; i < vector.getValueCount(); i++) {
Assert.assertEquals(vector.isNull(i), deserialized.isNull(i));
assertEquals(vector.isNull(i), deserialized.isNull(i));
if (!vector.isNull(i)) {
Assert.assertEquals(vector.getObject(i), deserialized.getObject(i));
assertEquals(vector.getObject(i), deserialized.getObject(i));
}
}
}
Expand Down Expand Up @@ -278,11 +276,11 @@ public void testVectorCompare() {
}
}

static class UuidType extends ExtensionType {
public static class UuidType extends ExtensionType {

@Override
public ArrowType storageType() {
return new ArrowType.FixedSizeBinary(16);
return new FixedSizeBinary(16);
}

@Override
Expand Down Expand Up @@ -314,10 +312,17 @@ public FieldVector getNewVector(String name, FieldType fieldType, BufferAllocato
}
}

static class UuidVector extends ExtensionTypeVector<FixedSizeBinaryVector> {
public static class UuidVector extends ExtensionTypeVector<FixedSizeBinaryVector> {
private final Field field;

public UuidVector(String name, BufferAllocator allocator, FixedSizeBinaryVector underlyingVector) {
super(name, allocator, underlyingVector);
this.field = new Field(name, FieldType.nullable(new UuidType()), null);
}

@Override
public Field getField() {
return field;
}

@Override
Expand All @@ -342,6 +347,34 @@ public void set(int index, UUID uuid) {
bb.putLong(uuid.getLeastSignificantBits());
getUnderlyingVector().set(index, bb.array());
}

@Override
public TransferPair makeTransferPair(ValueVector to) {
ValueVector targetUnderlyingVector = ((UuidVector) to).getUnderlyingVector();
TransferPair tp = getUnderlyingVector().makeTransferPair(targetUnderlyingVector);

return new TransferPair() {
@Override
public void transfer() {
tp.transfer();
}

@Override
public void splitAndTransfer(int startIndex, int length) {
tp.splitAndTransfer(startIndex, length);
}

@Override
public ValueVector getTo() {
return to;
}

@Override
public void copyValueSafe(int fromIndex, int toIndex) {
tp.copyValueSafe(fromIndex, toIndex);
}
};
}
}

static class LocationType extends ExtensionType {
Expand Down Expand Up @@ -407,7 +440,7 @@ public int hashCode(int index, ArrowBufHasher hasher) {
}

@Override
public java.util.Map<String, ?> getObject(int index) {
public Map<String, ?> getObject(int index) {
return getUnderlyingVector().getObject(index);
}

Expand Down

0 comments on commit 934466e

Please sign in to comment.