diff --git a/js/src/vector.ts b/js/src/vector.ts index f36c691e1bd27..6c2bbbb86a7d2 100644 --- a/js/src/vector.ts +++ b/js/src/vector.ts @@ -399,6 +399,9 @@ export class DictionaryVector extends Vector; constructor(data: Data>, view: View> = new DictionaryView(data.dictionary, new IntVector(data.indices))) { super(data as Data, view); + if (view instanceof ValidityView) { + view = (view as any).view; + } if (data instanceof DictionaryData && view instanceof DictionaryView) { this.indices = view.indices; this.dictionary = data.dictionary; diff --git a/js/test/unit/vector-tests.ts b/js/test/unit/vector-tests.ts index e2be229834f8e..3eb3fbe0195b3 100644 --- a/js/test/unit/vector-tests.ts +++ b/js/test/unit/vector-tests.ts @@ -17,14 +17,15 @@ import { TextEncoder } from 'text-encoding-utf-8'; import Arrow from '../Arrow'; -import { type, TypedArray, TypedArrayConstructor } from '../../src/Arrow'; +import { type, TypedArray, TypedArrayConstructor, Vector } from '../../src/Arrow'; +import { packBools } from '../../src/util/bit' const utf8Encoder = new TextEncoder('utf-8'); -const { BoolData, FlatData, FlatListData } = Arrow.data; -const { IntVector, FloatVector, BoolVector, Utf8Vector } = Arrow.vector; +const { BoolData, FlatData, FlatListData, DictionaryData } = Arrow.data; +const { IntVector, FloatVector, BoolVector, Utf8Vector, DictionaryVector } = Arrow.vector; const { - Utf8, Bool, + Dictionary, Utf8, Bool, Float16, Float32, Float64, Int8, Int16, Int32, Int64, Uint8, Uint16, Uint32, Uint64, @@ -310,6 +311,54 @@ describe(`Utf8Vector`, () => { let offset = 0; const offsets = Uint32Array.of(0, ...values.map((d) => { offset += d.length; return offset; })); const vector = new Utf8Vector(new FlatListData(new Utf8(), n, null, offsets, utf8Encoder.encode(values.join('')))); + basicVectorTests(vector, values, ['abc', '123']); + describe(`sliced`, () => { + basicVectorTests(vector.slice(1,3), values.slice(1,3), ['foo', 'abc']); + }); +}); + +describe(`DictionaryVector`, () => { + const dictionary = ['foo', 'bar', 'baz']; + const extras = ['abc', '123']; // values to search for that should NOT be found + let offset = 0; + const offsets = Uint32Array.of(0, ...dictionary.map((d) => { offset += d.length; return offset; })); + const dictionary_vec = new Utf8Vector(new FlatListData(new Utf8(), dictionary.length, null, offsets, utf8Encoder.encode(dictionary.join('')))); + + const indices = Array.from({length: 50}, () => Math.random() * 3 | 0); + + describe(`index with nullCount == 0`, () => { + const indices_data = new FlatData(new Int32(), indices.length, new Uint8Array(0), indices); + + const values = Array.from(indices).map((d) => dictionary[d]); + const vector = new DictionaryVector(new DictionaryData(new Dictionary(dictionary_vec.type, indices_data.type), dictionary_vec, indices_data)); + + basicVectorTests(vector, values, extras); + + describe(`sliced`, () => { + basicVectorTests(vector.slice(10, 20), values.slice(10,20), extras); + }) + }); + + describe(`index with nullCount > 0`, () => { + const validity = Array.from({length: indices.length}, () => Math.random() > 0.2 ? true : false); + const indices_data = new FlatData(new Int32(), indices.length, packBools(validity), indices, 0, validity.reduce((acc, d) => acc + (d ? 0 : 1), 0)); + const values = Array.from(indices).map((d, i) => validity[i] ? dictionary[d] : null); + const vector = new DictionaryVector(new DictionaryData(new Dictionary(dictionary_vec.type, indices_data.type), dictionary_vec, indices_data)); + + basicVectorTests(vector, values, ['abc', '123']); + describe(`sliced`, () => { + basicVectorTests(vector.slice(10, 20), values.slice(10,20), extras); + }); + }); +}); + +// Creates some basic tests for the given vector. +// Verifies that: +// - `get` and the native iterator return the same data as `values` +// - `indexOf` returns the same indices as `values` +function basicVectorTests(vector: Vector, values: any[], extras: any[]) { + const n = values.length; + test(`gets expected values`, () => { let i = -1; while (++i < n) { @@ -325,14 +374,14 @@ describe(`Utf8Vector`, () => { } }); test(`indexOf returns expected values`, () => { - let testValues = values.concat(['abc', '12345']); + let testValues = values.concat(extras); for (const value of testValues) { const expected = values.indexOf(value); expect(vector.indexOf(value)).toEqual(expected); } }); -}); +} function toMap(entries: Record, keys: string[]) { return keys.reduce((map, key) => {