From e9f6667291b68ef5d82b4a193fdd84c8ef06a2cf Mon Sep 17 00:00:00 2001 From: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Date: Wed, 31 Jul 2024 15:07:15 -0500 Subject: [PATCH] GH-43443: [Go] [IPC] Infer schema from first record if not specified (#43484) ### Rationale for this change Fixes: #43443 Makes usage of the IPC writer and any writers that use it such the flight writer simpler. ### What changes are included in this PR? - Infer schema from first record if schema is not specified - IPC and Flight tests ### Are these changes tested? Yes ### Are there any user-facing changes? Any `ipc.Writer` that does not specify the optional argument `ipc.WithSchema` will no longer return an error as long as the incoming stream of records has a consistent schema. * GitHub Issue: #43443 Authored-by: Joel Lubinitsky Signed-off-by: Joel Lubinitsky --- go/arrow/flight/flight_test.go | 35 ++++++++++++++++++++++++++++++++++ go/arrow/ipc/writer.go | 8 ++++++-- go/arrow/ipc/writer_test.go | 19 ++++++++++++++++++ 3 files changed, 60 insertions(+), 2 deletions(-) diff --git a/go/arrow/flight/flight_test.go b/go/arrow/flight/flight_test.go index fe896f39a2b21..a03d839e9484d 100755 --- a/go/arrow/flight/flight_test.go +++ b/go/arrow/flight/flight_test.go @@ -23,11 +23,13 @@ import ( "io" "testing" + "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" "github.com/apache/arrow/go/v18/arrow/flight" "github.com/apache/arrow/go/v18/arrow/internal/arrdata" "github.com/apache/arrow/go/v18/arrow/ipc" "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" @@ -449,3 +451,36 @@ func TestReaderError(t *testing.T) { t.Fatal("should have errored") } } + +func TestWriterInferSchema(t *testing.T) { + recs, ok := arrdata.Records["primitives"] + require.True(t, ok) + + fs := flightStreamWriter{} + w := flight.NewRecordWriter(&fs) + + for _, rec := range recs { + require.NoError(t, w.Write(rec)) + } + + require.NoError(t, w.Close()) +} + +func TestWriterInconsistentSchema(t *testing.T) { + recs, ok := arrdata.Records["primitives"] + require.True(t, ok) + + schema := arrow.NewSchema([]arrow.Field{{Name: "unknown", Type: arrow.PrimitiveTypes.Int8}}, nil) + fs := flightStreamWriter{} + w := flight.NewRecordWriter(&fs, ipc.WithSchema(schema)) + + require.ErrorContains(t, w.Write(recs[0]), "arrow/ipc: tried to write record batch with different schema") + require.NoError(t, w.Close()) +} + +type flightStreamWriter struct{} + +// Send implements flight.DataStreamWriter. +func (f *flightStreamWriter) Send(data *flight.FlightData) error { return nil } + +var _ flight.DataStreamWriter = (*flightStreamWriter)(nil) diff --git a/go/arrow/ipc/writer.go b/go/arrow/ipc/writer.go index ca4f77d35e17f..02c67635bb2fd 100644 --- a/go/arrow/ipc/writer.go +++ b/go/arrow/ipc/writer.go @@ -159,15 +159,19 @@ func (w *Writer) Write(rec arrow.Record) (err error) { } }() + incomingSchema := rec.Schema() + if !w.started { + if w.schema == nil { + w.schema = incomingSchema + } err := w.start() if err != nil { return err } } - schema := rec.Schema() - if schema == nil || !schema.Equal(w.schema) { + if incomingSchema == nil || !incomingSchema.Equal(w.schema) { return errInconsistentSchema } diff --git a/go/arrow/ipc/writer_test.go b/go/arrow/ipc/writer_test.go index e5683243e4546..60d811e68e87e 100644 --- a/go/arrow/ipc/writer_test.go +++ b/go/arrow/ipc/writer_test.go @@ -235,3 +235,22 @@ func TestWriteWithCompressionAndMinSavings(t *testing.T) { } } } + +func TestWriterInferSchema(t *testing.T) { + bldr := array.NewRecordBuilder(memory.DefaultAllocator, arrow.NewSchema([]arrow.Field{{Name: "col", Type: arrow.PrimitiveTypes.Int8}}, nil)) + bldr.Field(0).(*array.Int8Builder).AppendValues([]int8{1, 2, 3, 4, 5}, nil) + rec := bldr.NewRecord() + defer rec.Release() + + var buf bytes.Buffer + w := NewWriter(&buf) + + require.NoError(t, w.Write(rec)) + require.NoError(t, w.Close()) + + r, err := NewReader(&buf) + require.NoError(t, err) + defer r.Release() + + require.True(t, r.Schema().Equal(rec.Schema())) +}