diff --git a/row_group_test.go b/row_group_test.go index 4474438c..cb85a976 100644 --- a/row_group_test.go +++ b/row_group_test.go @@ -2,6 +2,7 @@ package parquet_test import ( "bytes" + "io" "reflect" "sort" "testing" @@ -421,3 +422,27 @@ func TestMergeRowGroups(t *testing.T) { }) } } + +func TestWriteRowGroupClosesRows(t *testing.T) { + var rows []*wrappedRows + rg := wrappedRowGroup{ + RowGroup: newPeopleFile([]Person{{}}), + rowsCallback: func(r parquet.Rows) parquet.Rows { + wrapped := &wrappedRows{Rows: r} + rows = append(rows, wrapped) + return wrapped + }, + } + writer := parquet.NewWriter(io.Discard) + if _, err := writer.WriteRowGroup(rg); err != nil { + t.Fatal(err) + } + if err := writer.Close(); err != nil { + t.Fatal(err) + } + for _, r := range rows { + if !r.closed { + t.Fatal("rows not closed") + } + } +} diff --git a/writer.go b/writer.go index 00ac1dec..ef76ef4f 100644 --- a/writer.go +++ b/writer.go @@ -179,7 +179,9 @@ func (w *Writer) WriteRowGroup(rowGroup RowGroup) (int64, error) { return 0, err } w.writer.configureBloomFilters(rowGroup.ColumnChunks()) - n, err := CopyRows(w.writer, rowGroup.Rows()) + rows := rowGroup.Rows() + defer rows.Close() + n, err := CopyRows(w.writer, rows) if err != nil { return n, err }