Skip to content

Commit

Permalink
Merge pull request #700 from dolthub/fulghum/csv
Browse files Browse the repository at this point in the history
Feature: `COPY FROM STDIN` support for CSV files
  • Loading branch information
fulghum authored Sep 19, 2024
2 parents 200b5e9 + d46078c commit 1c13450
Show file tree
Hide file tree
Showing 22 changed files with 3,012 additions and 1,276 deletions.
176 changes: 176 additions & 0 deletions core/dataloader/csvdataloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package dataloader

import (
"bufio"
"fmt"
"io"
"strings"

"github.com/dolthub/dolt/go/libraries/doltcore/table"
"github.com/dolthub/go-mysql-server/sql"
"github.com/sirupsen/logrus"

"github.com/dolthub/doltgresql/server/types"
)

// CsvDataLoader is an implementation of DataLoader that reads data from chunks of CSV files and inserts them into a table.
type CsvDataLoader struct {
results LoadDataResults
partialRecord string
rowInserter sql.RowInserter
colTypes []types.DoltgresType
sch sql.Schema
removeHeader bool
}

var _ DataLoader = (*CsvDataLoader)(nil)

// NewCsvDataLoader creates a new DataLoader instance that will insert records from chunks of CSV data into |table|. If
// |header| is true, the first line of the data will be treated as a header and ignored.
func NewCsvDataLoader(ctx *sql.Context, table sql.InsertableTable, header bool) (*CsvDataLoader, error) {
colTypes, err := getColumnTypes(table.Schema())
if err != nil {
return nil, err
}

rowInserter := table.Inserter(ctx)
rowInserter.StatementBegin(ctx)

return &CsvDataLoader{
rowInserter: rowInserter,
colTypes: colTypes,
sch: table.Schema(),
removeHeader: header,
}, nil
}

// LoadChunk implements the DataLoader interface
func (cdl *CsvDataLoader) LoadChunk(ctx *sql.Context, data *bufio.Reader) error {
combinedReader := newStringPrefixReader(cdl.partialRecord, data)
cdl.partialRecord = ""

reader, err := newCsvReader(combinedReader)
if err != nil {
return err
}

for {
// Read the next record from the data
if cdl.removeHeader {
_, err := reader.readLine()
cdl.removeHeader = false
if err != nil {
return err
}
}

record, err := reader.ReadSqlRow()
if err != nil {
if ple, ok := err.(*partialLineError); ok {
cdl.partialRecord = ple.partialLine
break
}

// csvReader will return a BadRow error if it encounters an input line without the
// correct number of columns. If we see the end of data marker, then break out of the
// loop and return from this function without returning an error.
if _, ok := err.(*table.BadRow); ok {
if len(record) == 1 && record[0] == "\\." {
break
}
}

if err != io.EOF {
return err
}

recordValues := make([]string, 0, len(record))
for _, v := range record {
recordValues = append(recordValues, fmt.Sprintf("%v", v))
}
cdl.partialRecord = strings.Join(recordValues, ",")
break
}

// If we see the end of data marker, then break out of the loop. Normally this will happen in the code
// above when we receive a BadRow error, since there won't be enough values, but if a table only has
// one column, we won't get a BadRow error, and we'll handle the end of data marker here.
if len(record) == 1 && record[0] == "\\." {
break
}

if len(record) > len(cdl.colTypes) {
return fmt.Errorf("extra data after last expected column")
} else if len(record) < len(cdl.colTypes) {
return fmt.Errorf(`missing data for column "%s"`, cdl.sch[len(record)].Name)
}

// Cast the values using I/O input
row := make(sql.Row, len(cdl.colTypes))
for i := range cdl.colTypes {
if record[i] == nil {
row[i] = nil
} else {
row[i], err = cdl.colTypes[i].IoInput(ctx, fmt.Sprintf("%v", record[i]))
if err != nil {
return err
}
}
}

// Insert the row
if err = cdl.rowInserter.Insert(ctx, row); err != nil {
return err
}
cdl.results.RowsLoaded += 1
}

return nil
}

// Abort implements the DataLoader interface
func (cdl *CsvDataLoader) Abort(ctx *sql.Context) error {
defer func() {
if closeErr := cdl.rowInserter.Close(ctx); closeErr != nil {
logrus.Warnf("error closing rowInserter: %v", closeErr)
}
}()

return cdl.rowInserter.DiscardChanges(ctx, nil)
}

// Finish implements the DataLoader interface
func (cdl *CsvDataLoader) Finish(ctx *sql.Context) (*LoadDataResults, error) {
defer func() {
if closeErr := cdl.rowInserter.Close(ctx); closeErr != nil {
logrus.Warnf("error closing rowInserter: %v", closeErr)
}
}()

// If there is partial data from the last chunk that hasn't been inserted, return an error.
if cdl.partialRecord != "" {
return nil, fmt.Errorf("partial record (%s) found at end of data load", cdl.partialRecord)
}

err := cdl.rowInserter.StatementComplete(ctx)
if err != nil {
err = cdl.rowInserter.DiscardChanges(ctx, err)
return nil, err
}

return &cdl.results, nil
}
218 changes: 218 additions & 0 deletions core/dataloader/csvdataloader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package dataloader

import (
"bufio"
"bytes"
"context"
"io"
"testing"

"github.com/dolthub/go-mysql-server/memory"
"github.com/dolthub/go-mysql-server/sql"
"github.com/stretchr/testify/require"

"github.com/dolthub/doltgresql/server/types"
)

// TestCsvDataLoader tests the CsvDataLoader implementation.
func TestCsvDataLoader(t *testing.T) {
db := memory.NewDatabase("mydb")
provider := memory.NewDBProvider(db)

ctx := &sql.Context{
Context: context.Background(),
Session: memory.NewSession(sql.NewBaseSession(), provider),
}

pkSchema := sql.NewPrimaryKeySchema(sql.Schema{
{Name: "pk", Type: types.Int64, Source: "source1"},
{Name: "c1", Type: types.Int64, Source: "source1"},
{Name: "c2", Type: types.VarChar, Source: "source1"},
}, 0)

// Tests that a basic CSV document can be loaded as a single chunk.
t.Run("basic case", func(t *testing.T) {
table := memory.NewTable(db, "myTable", pkSchema, nil)
dataLoader, err := NewCsvDataLoader(ctx, table, false)
require.NoError(t, err)

// Load all the data as a single chunk
reader := bytes.NewReader([]byte("1,100,bar\n2,200,bash\n"))
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader))
require.NoError(t, err)
results, err := dataLoader.Finish(ctx)
require.NoError(t, err)
require.EqualValues(t, 2, results.RowsLoaded)

// Assert that the table contains the expected data
assertRows(t, ctx, table, [][]any{
{int64(1), int64(100), "bar"},
{int64(2), int64(200), "bash"},
})
})

// Tests when a CSV record is split across two chunks of data, and the
// partial record must be buffered and prepended to the next chunk.
t.Run("record split across two chunks", func(t *testing.T) {
table := memory.NewTable(db, "myTable", pkSchema, nil)
dataLoader, err := NewCsvDataLoader(ctx, table, false)
require.NoError(t, err)

// Load the first chunk
reader := bytes.NewReader([]byte("1,100,ba"))
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader))
require.NoError(t, err)

// Load the second chunk
reader = bytes.NewReader([]byte("r\n2,200,bash\n"))
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader))
require.NoError(t, err)

// Finish
results, err := dataLoader.Finish(ctx)
require.NoError(t, err)
require.EqualValues(t, 2, results.RowsLoaded)

// Assert that the table contains the expected data
assertRows(t, ctx, table, [][]any{
{int64(1), int64(100), "bar"},
{int64(2), int64(200), "bash"},
})
})

// Tests when a CSV record is split across two chunks of data, and a
// header row is present.
t.Run("record split across two chunks, with header", func(t *testing.T) {
table := memory.NewTable(db, "myTable", pkSchema, nil)
dataLoader, err := NewCsvDataLoader(ctx, table, true)
require.NoError(t, err)

// Load the first chunk
reader := bytes.NewReader([]byte("pk,c1,c2\n1,100,ba"))
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader))
require.NoError(t, err)

// Load the second chunk
reader = bytes.NewReader([]byte("r\n2,200,bash\n"))
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader))
require.NoError(t, err)

// Finish
results, err := dataLoader.Finish(ctx)
require.NoError(t, err)
require.EqualValues(t, 2, results.RowsLoaded)

// Assert that the table contains the expected data
assertRows(t, ctx, table, [][]any{
{int64(1), int64(100), "bar"},
{int64(2), int64(200), "bash"},
})
})

// Tests a CSV record that contains a quoted newline character and is split
// across two chunks.
t.Run("quoted newlines across two chunks", func(t *testing.T) {
table := memory.NewTable(db, "myTable", pkSchema, nil)
dataLoader, err := NewCsvDataLoader(ctx, table, false)
require.NoError(t, err)

// Load the first chunk
reader := bytes.NewReader([]byte("1,100,\"baz\nbar\n"))
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader))
require.NoError(t, err)

// Load the second chunk
reader = bytes.NewReader([]byte("bash\"\n2,200,bash\n"))
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader))
require.NoError(t, err)

// Finish
results, err := dataLoader.Finish(ctx)
require.NoError(t, err)
require.EqualValues(t, 2, results.RowsLoaded)

// Assert that the table contains the expected data
assertRows(t, ctx, table, [][]any{
{int64(1), int64(100), "baz\nbar\nbash"},
{int64(2), int64(200), "bash"},
})
})

// Test that calling Abort() does not insert any data into the table.
t.Run("abort cancels data load", func(t *testing.T) {
table := memory.NewTable(db, "myTable", pkSchema, nil)
dataLoader, err := NewCsvDataLoader(ctx, table, false)
require.NoError(t, err)

// Load the first chunk
reader := bytes.NewReader([]byte("1,100,bazbar\n"))
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader))
require.NoError(t, err)

// Load the second chunk
reader = bytes.NewReader([]byte("2,200,bash\n"))
err = dataLoader.LoadChunk(ctx, bufio.NewReader(reader))
require.NoError(t, err)

// Abort
err = dataLoader.Abort(ctx)
require.NoError(t, err)

// Assert that the table does not contain any of the data from the CSV load
assertRows(t, ctx, table, [][]any{})
})
}

// assertRows asserts that the rows in |table| match |expectedRows| and fails the test if the
// rows do not exactly match.
func assertRows(t *testing.T, ctx *sql.Context, table *memory.Table, expectedRows [][]any) {
partitions, err := table.Partitions(ctx)
require.NoError(t, err)

expectedRowsIdx := 0

for {
partition, err := partitions.Next(ctx)
if err == io.EOF {
break
}
require.NoError(t, err)
rows := table.GetPartition(string(partition.Key()))
for _, row := range rows {
if len(expectedRows) <= expectedRowsIdx {
t.Fatalf("Expected %d rows, got more", len(expectedRows))
}

if len(expectedRows[expectedRowsIdx]) != len(row) {
t.Fatalf("Expected row length %d, got %d. expectedRows: %v, rows: %v",
len(expectedRows), len(row), expectedRows, rows)
}
for i := range len(row) {
if expectedRows[expectedRowsIdx][i] != row[i] {
t.Fatalf("Expected row %v, got %v. expectedRows: %v, rows: %v",
expectedRows[expectedRowsIdx], row, expectedRows, rows)
}
}

expectedRowsIdx += 1
}
}

if len(expectedRows) != expectedRowsIdx {
t.Fatalf("Expected %d rows, got %d", len(expectedRows), expectedRowsIdx)
}
}
Loading

0 comments on commit 1c13450

Please sign in to comment.