-
-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #700 from dolthub/fulghum/csv
Feature: `COPY FROM STDIN` support for CSV files
- Loading branch information
Showing
22 changed files
with
3,012 additions
and
1,276 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
// Copyright 2024 Dolthub, Inc. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package dataloader | ||
|
||
import ( | ||
"bufio" | ||
"fmt" | ||
"io" | ||
"strings" | ||
|
||
"github.com/dolthub/dolt/go/libraries/doltcore/table" | ||
"github.com/dolthub/go-mysql-server/sql" | ||
"github.com/sirupsen/logrus" | ||
|
||
"github.com/dolthub/doltgresql/server/types" | ||
) | ||
|
||
// CsvDataLoader is an implementation of DataLoader that reads data from chunks of CSV files and inserts them into a table. | ||
type CsvDataLoader struct { | ||
results LoadDataResults | ||
partialRecord string | ||
rowInserter sql.RowInserter | ||
colTypes []types.DoltgresType | ||
sch sql.Schema | ||
removeHeader bool | ||
} | ||
|
||
var _ DataLoader = (*CsvDataLoader)(nil) | ||
|
||
// NewCsvDataLoader creates a new DataLoader instance that will insert records from chunks of CSV data into |table|. If | ||
// |header| is true, the first line of the data will be treated as a header and ignored. | ||
func NewCsvDataLoader(ctx *sql.Context, table sql.InsertableTable, header bool) (*CsvDataLoader, error) { | ||
colTypes, err := getColumnTypes(table.Schema()) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
rowInserter := table.Inserter(ctx) | ||
rowInserter.StatementBegin(ctx) | ||
|
||
return &CsvDataLoader{ | ||
rowInserter: rowInserter, | ||
colTypes: colTypes, | ||
sch: table.Schema(), | ||
removeHeader: header, | ||
}, nil | ||
} | ||
|
||
// LoadChunk implements the DataLoader interface | ||
func (cdl *CsvDataLoader) LoadChunk(ctx *sql.Context, data *bufio.Reader) error { | ||
combinedReader := newStringPrefixReader(cdl.partialRecord, data) | ||
cdl.partialRecord = "" | ||
|
||
reader, err := newCsvReader(combinedReader) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
for { | ||
// Read the next record from the data | ||
if cdl.removeHeader { | ||
_, err := reader.readLine() | ||
cdl.removeHeader = false | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
|
||
record, err := reader.ReadSqlRow() | ||
if err != nil { | ||
if ple, ok := err.(*partialLineError); ok { | ||
cdl.partialRecord = ple.partialLine | ||
break | ||
} | ||
|
||
// csvReader will return a BadRow error if it encounters an input line without the | ||
// correct number of columns. If we see the end of data marker, then break out of the | ||
// loop and return from this function without returning an error. | ||
if _, ok := err.(*table.BadRow); ok { | ||
if len(record) == 1 && record[0] == "\\." { | ||
break | ||
} | ||
} | ||
|
||
if err != io.EOF { | ||
return err | ||
} | ||
|
||
recordValues := make([]string, 0, len(record)) | ||
for _, v := range record { | ||
recordValues = append(recordValues, fmt.Sprintf("%v", v)) | ||
} | ||
cdl.partialRecord = strings.Join(recordValues, ",") | ||
break | ||
} | ||
|
||
// If we see the end of data marker, then break out of the loop. Normally this will happen in the code | ||
// above when we receive a BadRow error, since there won't be enough values, but if a table only has | ||
// one column, we won't get a BadRow error, and we'll handle the end of data marker here. | ||
if len(record) == 1 && record[0] == "\\." { | ||
break | ||
} | ||
|
||
if len(record) > len(cdl.colTypes) { | ||
return fmt.Errorf("extra data after last expected column") | ||
} else if len(record) < len(cdl.colTypes) { | ||
return fmt.Errorf(`missing data for column "%s"`, cdl.sch[len(record)].Name) | ||
} | ||
|
||
// Cast the values using I/O input | ||
row := make(sql.Row, len(cdl.colTypes)) | ||
for i := range cdl.colTypes { | ||
if record[i] == nil { | ||
row[i] = nil | ||
} else { | ||
row[i], err = cdl.colTypes[i].IoInput(ctx, fmt.Sprintf("%v", record[i])) | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
} | ||
|
||
// Insert the row | ||
if err = cdl.rowInserter.Insert(ctx, row); err != nil { | ||
return err | ||
} | ||
cdl.results.RowsLoaded += 1 | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// Abort implements the DataLoader interface | ||
func (cdl *CsvDataLoader) Abort(ctx *sql.Context) error { | ||
defer func() { | ||
if closeErr := cdl.rowInserter.Close(ctx); closeErr != nil { | ||
logrus.Warnf("error closing rowInserter: %v", closeErr) | ||
} | ||
}() | ||
|
||
return cdl.rowInserter.DiscardChanges(ctx, nil) | ||
} | ||
|
||
// Finish implements the DataLoader interface | ||
func (cdl *CsvDataLoader) Finish(ctx *sql.Context) (*LoadDataResults, error) { | ||
defer func() { | ||
if closeErr := cdl.rowInserter.Close(ctx); closeErr != nil { | ||
logrus.Warnf("error closing rowInserter: %v", closeErr) | ||
} | ||
}() | ||
|
||
// If there is partial data from the last chunk that hasn't been inserted, return an error. | ||
if cdl.partialRecord != "" { | ||
return nil, fmt.Errorf("partial record (%s) found at end of data load", cdl.partialRecord) | ||
} | ||
|
||
err := cdl.rowInserter.StatementComplete(ctx) | ||
if err != nil { | ||
err = cdl.rowInserter.DiscardChanges(ctx, err) | ||
return nil, err | ||
} | ||
|
||
return &cdl.results, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
// Copyright 2024 Dolthub, Inc. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package dataloader | ||
|
||
import ( | ||
"bufio" | ||
"bytes" | ||
"context" | ||
"io" | ||
"testing" | ||
|
||
"github.com/dolthub/go-mysql-server/memory" | ||
"github.com/dolthub/go-mysql-server/sql" | ||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/dolthub/doltgresql/server/types" | ||
) | ||
|
||
// TestCsvDataLoader tests the CsvDataLoader implementation. | ||
func TestCsvDataLoader(t *testing.T) { | ||
db := memory.NewDatabase("mydb") | ||
provider := memory.NewDBProvider(db) | ||
|
||
ctx := &sql.Context{ | ||
Context: context.Background(), | ||
Session: memory.NewSession(sql.NewBaseSession(), provider), | ||
} | ||
|
||
pkSchema := sql.NewPrimaryKeySchema(sql.Schema{ | ||
{Name: "pk", Type: types.Int64, Source: "source1"}, | ||
{Name: "c1", Type: types.Int64, Source: "source1"}, | ||
{Name: "c2", Type: types.VarChar, Source: "source1"}, | ||
}, 0) | ||
|
||
// Tests that a basic CSV document can be loaded as a single chunk. | ||
t.Run("basic case", func(t *testing.T) { | ||
table := memory.NewTable(db, "myTable", pkSchema, nil) | ||
dataLoader, err := NewCsvDataLoader(ctx, table, false) | ||
require.NoError(t, err) | ||
|
||
// Load all the data as a single chunk | ||
reader := bytes.NewReader([]byte("1,100,bar\n2,200,bash\n")) | ||
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader)) | ||
require.NoError(t, err) | ||
results, err := dataLoader.Finish(ctx) | ||
require.NoError(t, err) | ||
require.EqualValues(t, 2, results.RowsLoaded) | ||
|
||
// Assert that the table contains the expected data | ||
assertRows(t, ctx, table, [][]any{ | ||
{int64(1), int64(100), "bar"}, | ||
{int64(2), int64(200), "bash"}, | ||
}) | ||
}) | ||
|
||
// Tests when a CSV record is split across two chunks of data, and the | ||
// partial record must be buffered and prepended to the next chunk. | ||
t.Run("record split across two chunks", func(t *testing.T) { | ||
table := memory.NewTable(db, "myTable", pkSchema, nil) | ||
dataLoader, err := NewCsvDataLoader(ctx, table, false) | ||
require.NoError(t, err) | ||
|
||
// Load the first chunk | ||
reader := bytes.NewReader([]byte("1,100,ba")) | ||
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader)) | ||
require.NoError(t, err) | ||
|
||
// Load the second chunk | ||
reader = bytes.NewReader([]byte("r\n2,200,bash\n")) | ||
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader)) | ||
require.NoError(t, err) | ||
|
||
// Finish | ||
results, err := dataLoader.Finish(ctx) | ||
require.NoError(t, err) | ||
require.EqualValues(t, 2, results.RowsLoaded) | ||
|
||
// Assert that the table contains the expected data | ||
assertRows(t, ctx, table, [][]any{ | ||
{int64(1), int64(100), "bar"}, | ||
{int64(2), int64(200), "bash"}, | ||
}) | ||
}) | ||
|
||
// Tests when a CSV record is split across two chunks of data, and a | ||
// header row is present. | ||
t.Run("record split across two chunks, with header", func(t *testing.T) { | ||
table := memory.NewTable(db, "myTable", pkSchema, nil) | ||
dataLoader, err := NewCsvDataLoader(ctx, table, true) | ||
require.NoError(t, err) | ||
|
||
// Load the first chunk | ||
reader := bytes.NewReader([]byte("pk,c1,c2\n1,100,ba")) | ||
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader)) | ||
require.NoError(t, err) | ||
|
||
// Load the second chunk | ||
reader = bytes.NewReader([]byte("r\n2,200,bash\n")) | ||
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader)) | ||
require.NoError(t, err) | ||
|
||
// Finish | ||
results, err := dataLoader.Finish(ctx) | ||
require.NoError(t, err) | ||
require.EqualValues(t, 2, results.RowsLoaded) | ||
|
||
// Assert that the table contains the expected data | ||
assertRows(t, ctx, table, [][]any{ | ||
{int64(1), int64(100), "bar"}, | ||
{int64(2), int64(200), "bash"}, | ||
}) | ||
}) | ||
|
||
// Tests a CSV record that contains a quoted newline character and is split | ||
// across two chunks. | ||
t.Run("quoted newlines across two chunks", func(t *testing.T) { | ||
table := memory.NewTable(db, "myTable", pkSchema, nil) | ||
dataLoader, err := NewCsvDataLoader(ctx, table, false) | ||
require.NoError(t, err) | ||
|
||
// Load the first chunk | ||
reader := bytes.NewReader([]byte("1,100,\"baz\nbar\n")) | ||
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader)) | ||
require.NoError(t, err) | ||
|
||
// Load the second chunk | ||
reader = bytes.NewReader([]byte("bash\"\n2,200,bash\n")) | ||
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader)) | ||
require.NoError(t, err) | ||
|
||
// Finish | ||
results, err := dataLoader.Finish(ctx) | ||
require.NoError(t, err) | ||
require.EqualValues(t, 2, results.RowsLoaded) | ||
|
||
// Assert that the table contains the expected data | ||
assertRows(t, ctx, table, [][]any{ | ||
{int64(1), int64(100), "baz\nbar\nbash"}, | ||
{int64(2), int64(200), "bash"}, | ||
}) | ||
}) | ||
|
||
// Test that calling Abort() does not insert any data into the table. | ||
t.Run("abort cancels data load", func(t *testing.T) { | ||
table := memory.NewTable(db, "myTable", pkSchema, nil) | ||
dataLoader, err := NewCsvDataLoader(ctx, table, false) | ||
require.NoError(t, err) | ||
|
||
// Load the first chunk | ||
reader := bytes.NewReader([]byte("1,100,bazbar\n")) | ||
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader)) | ||
require.NoError(t, err) | ||
|
||
// Load the second chunk | ||
reader = bytes.NewReader([]byte("2,200,bash\n")) | ||
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader)) | ||
require.NoError(t, err) | ||
|
||
// Abort | ||
err = dataLoader.Abort(ctx) | ||
require.NoError(t, err) | ||
|
||
// Assert that the table does not contain any of the data from the CSV load | ||
assertRows(t, ctx, table, [][]any{}) | ||
}) | ||
} | ||
|
||
// assertRows asserts that the rows in |table| match |expectedRows| and fails the test if the | ||
// rows do not exactly match. | ||
func assertRows(t *testing.T, ctx *sql.Context, table *memory.Table, expectedRows [][]any) { | ||
partitions, err := table.Partitions(ctx) | ||
require.NoError(t, err) | ||
|
||
expectedRowsIdx := 0 | ||
|
||
for { | ||
partition, err := partitions.Next(ctx) | ||
if err == io.EOF { | ||
break | ||
} | ||
require.NoError(t, err) | ||
rows := table.GetPartition(string(partition.Key())) | ||
for _, row := range rows { | ||
if len(expectedRows) <= expectedRowsIdx { | ||
t.Fatalf("Expected %d rows, got more", len(expectedRows)) | ||
} | ||
|
||
if len(expectedRows[expectedRowsIdx]) != len(row) { | ||
t.Fatalf("Expected row length %d, got %d. expectedRows: %v, rows: %v", | ||
len(expectedRows), len(row), expectedRows, rows) | ||
} | ||
for i := range len(row) { | ||
if expectedRows[expectedRowsIdx][i] != row[i] { | ||
t.Fatalf("Expected row %v, got %v. expectedRows: %v, rows: %v", | ||
expectedRows[expectedRowsIdx], row, expectedRows, rows) | ||
} | ||
} | ||
|
||
expectedRowsIdx += 1 | ||
} | ||
} | ||
|
||
if len(expectedRows) != expectedRowsIdx { | ||
t.Fatalf("Expected %d rows, got %d", len(expectedRows), expectedRowsIdx) | ||
} | ||
} |
Oops, something went wrong.