diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java index 68f5e14dabb9b..a5faecaa484b0 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java @@ -19,12 +19,11 @@ import static org.junit.Assert.*; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.nio.ByteBuffer; +import java.util.*; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.util.hash.ArrowBufHasher; import org.apache.arrow.vector.complex.AbstractStructVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.StructVector; @@ -35,15 +34,19 @@ import org.apache.arrow.vector.holders.ComplexHolder; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.Struct; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.TestExtensionType; import org.apache.arrow.vector.util.TransferPair; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import javax.xml.crypto.dsig.TransformService; + public class TestStructVector { private BufferAllocator allocator; @@ -82,6 +85,45 @@ public void testMakeTransferPair() { } } + @Test + public void testStructVectorWithExtensionTypes() { + TestExtensionType.UuidType uuidType = new TestExtensionType.UuidType(); + Field uuidField = new Field("struct_child", FieldType.nullable(uuidType), null); + Field structField = new Field("struct", FieldType.nullable(new ArrowType.Struct()), List.of(uuidField)); + // throws + StructVector s1 = new StructVector(structField, allocator, null); + // doesn't throw + StructVector s2 = (StructVector) structField.createVector(allocator); + s1.close(); + s2.close(); + } + + @Test + public void testStructVectorTransferPairWithExtensionType() { + TestExtensionType.UuidType uuidType = new TestExtensionType.UuidType(); + Field uuidField = new Field("uuid_child", FieldType.nullable(uuidType), null); + Field structField = new Field("struct", FieldType.nullable(new ArrowType.Struct()), List.of(uuidField)); + + StructVector s1 = (StructVector) structField.createVector(allocator); + TestExtensionType.UuidVector uuidVector = s1.addOrGet("uuid_child", FieldType.nullable(uuidType), TestExtensionType.UuidVector.class); + s1.setValueCount(1); + uuidVector.set(0, new UUID(1, 2)); + s1.setIndexDefined(0); + + TransferPair tp = s1.getTransferPair(structField, allocator); + final StructVector toVector = (StructVector) tp.getTo(); + assertEquals(s1.getField(), toVector.getField()); + assertEquals(s1.getField().getChildren().get(0), toVector.getField().getChildren().get(0)); + // also fails but probably another issue +// assertEquals(s1.getValueCount(), toVector.getValueCount()); +// assertEquals(s1, toVector); + + s1.close(); + toVector.close(); + } + + + @Test public void testAllocateAfterReAlloc() throws Exception { Map metadata = new HashMap<>(); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java b/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java index 872b2f3934b07..bb31e168685c2 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java @@ -32,16 +32,13 @@ import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.util.Collections; +import java.util.Map; import java.util.UUID; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.memory.util.hash.ArrowBufHasher; -import org.apache.arrow.vector.ExtensionTypeVector; -import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.FixedSizeBinaryVector; -import org.apache.arrow.vector.Float4Vector; -import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.*; import org.apache.arrow.vector.compare.Range; import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.complex.StructVector; @@ -49,6 +46,7 @@ import org.apache.arrow.vector.ipc.ArrowFileWriter; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; +import org.apache.arrow.vector.util.TransferPair; import org.apache.arrow.vector.util.VectorBatchAppender; import org.apache.arrow.vector.validate.ValidateVectorVisitor; import org.junit.Assert; @@ -85,21 +83,21 @@ public void roundtripUuid() throws IOException { final ArrowFileReader reader = new ArrowFileReader(channel, allocator)) { reader.loadNextBatch(); final VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot(); - Assert.assertEquals(root.getSchema(), readerRoot.getSchema()); + assertEquals(root.getSchema(), readerRoot.getSchema()); final Field field = readerRoot.getSchema().getFields().get(0); final UuidType expectedType = new UuidType(); - Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), + assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), expectedType.extensionName()); - Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), + assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), expectedType.serialize()); final ExtensionTypeVector deserialized = (ExtensionTypeVector) readerRoot.getFieldVectors().get(0); - Assert.assertEquals(vector.getValueCount(), deserialized.getValueCount()); + assertEquals(vector.getValueCount(), deserialized.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { - Assert.assertEquals(vector.isNull(i), deserialized.isNull(i)); + assertEquals(vector.isNull(i), deserialized.isNull(i)); if (!vector.isNull(i)) { - Assert.assertEquals(vector.getObject(i), deserialized.getObject(i)); + assertEquals(vector.getObject(i), deserialized.getObject(i)); } } } @@ -138,23 +136,23 @@ public void readUnderlyingType() throws IOException { final ArrowFileReader reader = new ArrowFileReader(channel, allocator)) { reader.loadNextBatch(); final VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot(); - Assert.assertEquals(1, readerRoot.getSchema().getFields().size()); - Assert.assertEquals("a", readerRoot.getSchema().getFields().get(0).getName()); - Assert.assertTrue(readerRoot.getSchema().getFields().get(0).getType() instanceof ArrowType.FixedSizeBinary); - Assert.assertEquals(16, + assertEquals(1, readerRoot.getSchema().getFields().size()); + assertEquals("a", readerRoot.getSchema().getFields().get(0).getName()); + assertTrue(readerRoot.getSchema().getFields().get(0).getType() instanceof ArrowType.FixedSizeBinary); + assertEquals(16, ((ArrowType.FixedSizeBinary) readerRoot.getSchema().getFields().get(0).getType()).getByteWidth()); final Field field = readerRoot.getSchema().getFields().get(0); final UuidType expectedType = new UuidType(); - Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), + assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), expectedType.extensionName()); - Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), + assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), expectedType.serialize()); final FixedSizeBinaryVector deserialized = (FixedSizeBinaryVector) readerRoot.getFieldVectors().get(0); - Assert.assertEquals(vector.getValueCount(), deserialized.getValueCount()); + assertEquals(vector.getValueCount(), deserialized.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { - Assert.assertEquals(vector.isNull(i), deserialized.isNull(i)); + assertEquals(vector.isNull(i), deserialized.isNull(i)); if (!vector.isNull(i)) { final UUID uuid = vector.getObject(i); final ByteBuffer bb = ByteBuffer.allocate(16); @@ -210,26 +208,26 @@ public void roundtripLocation() throws IOException { final ArrowFileReader reader = new ArrowFileReader(channel, allocator)) { reader.loadNextBatch(); final VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot(); - Assert.assertEquals(root.getSchema(), readerRoot.getSchema()); + assertEquals(root.getSchema(), readerRoot.getSchema()); final Field field = readerRoot.getSchema().getFields().get(0); final LocationType expectedType = new LocationType(); - Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), + assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), expectedType.extensionName()); - Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), + assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), expectedType.serialize()); final ExtensionTypeVector deserialized = (ExtensionTypeVector) readerRoot.getFieldVectors().get(0); - Assert.assertTrue(deserialized instanceof LocationVector); - Assert.assertEquals("location", deserialized.getName()); + assertTrue(deserialized instanceof LocationVector); + assertEquals("location", deserialized.getName()); StructVector deserStruct = (StructVector) deserialized.getUnderlyingVector(); Assert.assertNotNull(deserStruct.getChild("Latitude")); Assert.assertNotNull(deserStruct.getChild("Longitude")); - Assert.assertEquals(vector.getValueCount(), deserialized.getValueCount()); + assertEquals(vector.getValueCount(), deserialized.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { - Assert.assertEquals(vector.isNull(i), deserialized.isNull(i)); + assertEquals(vector.isNull(i), deserialized.isNull(i)); if (!vector.isNull(i)) { - Assert.assertEquals(vector.getObject(i), deserialized.getObject(i)); + assertEquals(vector.getObject(i), deserialized.getObject(i)); } } } @@ -278,11 +276,11 @@ public void testVectorCompare() { } } - static class UuidType extends ExtensionType { + public static class UuidType extends ExtensionType { @Override public ArrowType storageType() { - return new ArrowType.FixedSizeBinary(16); + return new FixedSizeBinary(16); } @Override @@ -314,10 +312,17 @@ public FieldVector getNewVector(String name, FieldType fieldType, BufferAllocato } } - static class UuidVector extends ExtensionTypeVector { + public static class UuidVector extends ExtensionTypeVector { + private final Field field; public UuidVector(String name, BufferAllocator allocator, FixedSizeBinaryVector underlyingVector) { super(name, allocator, underlyingVector); + this.field = new Field(name, FieldType.nullable(new UuidType()), null); + } + + @Override + public Field getField() { + return field; } @Override @@ -342,6 +347,34 @@ public void set(int index, UUID uuid) { bb.putLong(uuid.getLeastSignificantBits()); getUnderlyingVector().set(index, bb.array()); } + + @Override + public TransferPair makeTransferPair(ValueVector to) { + ValueVector targetUnderlyingVector = ((UuidVector) to).getUnderlyingVector(); + TransferPair tp = getUnderlyingVector().makeTransferPair(targetUnderlyingVector); + + return new TransferPair() { + @Override + public void transfer() { + tp.transfer(); + } + + @Override + public void splitAndTransfer(int startIndex, int length) { + tp.splitAndTransfer(startIndex, length); + } + + @Override + public ValueVector getTo() { + return to; + } + + @Override + public void copyValueSafe(int fromIndex, int toIndex) { + tp.copyValueSafe(fromIndex, toIndex); + } + }; + } } static class LocationType extends ExtensionType { @@ -407,7 +440,7 @@ public int hashCode(int index, ArrowBufHasher hasher) { } @Override - public java.util.Map getObject(int index) { + public Map getObject(int index) { return getUnderlyingVector().getObject(index); }