diff --git a/cmd/internal/batch_writer.go b/cmd/internal/batch_writer.go index ad8bcee..f35d9ca 100644 --- a/cmd/internal/batch_writer.go +++ b/cmd/internal/batch_writer.go @@ -9,6 +9,9 @@ import ( "time" ) +const MaxObjectsInBatch int = 19000 +const MaxBatchRequestSize int = 20 * 1024 * 1024 + type BatchWriter interface { Flush(stream *Stream) error Send(record *Record, stream *Stream) error @@ -48,50 +51,45 @@ func (h *httpBatchWriter) Flush(stream *Stream) error { return nil } - h.logger.Info(fmt.Sprintf("flushing [%v] messages for stream %q", len(h.messages), stream.Name)) + batches := getBatchMessages(h.messages, stream, MaxObjectsInBatch, MaxBatchRequestSize) + h.logger.Info(fmt.Sprintf("flushing [%v] messages for stream %q in [%v] batches", len(h.messages), stream.Name, len(batches))) + for _, batch := range batches { - batch := ImportBatch{ - Table: stream.Name, - Schema: stream.Schema, - Messages: h.messages, - PrimaryKeys: stream.KeyProperties, - } + b, err := json.Marshal(batch) + if err != nil { + return err + } - b, err := json.Marshal(batch) - if err != nil { - return err - } + stitch, err := retryablehttp.NewRequest("POST", h.apiURL+"/v2/import/batch", bytes.NewBuffer(b)) + if err != nil { + return err + } + stitch.Header.Set("Content-Type", "application/json") + stitch.Header.Set("Authorization", "Bearer "+h.apiToken) - stitch, err := retryablehttp.NewRequest("POST", h.apiURL+"/v2/import/batch", bytes.NewBuffer(b)) - if err != nil { - return err - } - stitch.Header.Set("Content-Type", "application/json") - stitch.Header.Set("Authorization", "Bearer "+h.apiToken) + stitchResponse, err := h.client.Do(stitch) + if err != nil { + return err + } - stitchResponse, err := h.client.Do(stitch) - if err != nil { - return err - } + defer stitchResponse.Body.Close() - defer stitchResponse.Body.Close() + if stitchResponse.StatusCode > 203 { + body, err := ioutil.ReadAll(stitchResponse.Body) + if err != nil { + return err + } + return fmt.Errorf("server request failed with %s", body) + } - if stitchResponse.StatusCode > 203 { - body, err := ioutil.ReadAll(stitchResponse.Body) - if err != nil { + var resp BatchResponse + decoder := json.NewDecoder(stitchResponse.Body) + if err := decoder.Decode(&resp); err != nil { return err } - return fmt.Errorf("server request failed with %s", body) - } - var resp BatchResponse - decoder := json.NewDecoder(stitchResponse.Body) - if err := decoder.Decode(&resp); err != nil { - return err + h.logger.Info(fmt.Sprintf("Server response status : %q, message : %q", resp.Status, resp.Message)) } - - h.logger.Info(fmt.Sprintf("Server response status : %q, message : %q", resp.Status, resp.Message)) - h.messages = h.messages[:0] return nil @@ -106,6 +104,46 @@ func (h *httpBatchWriter) Send(record *Record, stream *Stream) error { return nil } +// getBatchMessages accepts a list of import messages +// and returns a slice of ImportBatch that can be safely uploaded. +// The rules are: +// 1. There cannot be more than 20,000 records in the request. +// 2. The size of the serialized JSON cannot be more than 20 MB. +func getBatchMessages(messages []ImportMessage, stream *Stream, maxObjectsInBatch int, maxBatchSerializedSize int) []ImportBatch { + var batches []ImportBatch + allocated := 0 + unallocated := len(messages) + + for unallocated > 0 { + batch := ImportBatch{ + Table: stream.Name, + Schema: stream.Schema, + Messages: messages[allocated:], + PrimaryKeys: stream.KeyProperties, + } + + // reduce the size of the batch until it is an acceptable size. + for batch.SizeOf() > maxBatchSerializedSize || len(batch.Messages) > maxObjectsInBatch { + // keep halving the number of messages until the batch is an acceptable size. + batch.Messages = batch.Messages[0:(len(batch.Messages) / 2)] + } + + allocated += len(batch.Messages) + unallocated -= len(batch.Messages) + batches = append(batches, batch) + } + + return batches +} + +func (imb *ImportBatch) SizeOf() int { + b, err := json.Marshal(imb) + if err != nil { + return 0 + } + return len(b) +} + func createImportMessage(record *Record) ImportMessage { now := time.Now() return ImportMessage{ diff --git a/cmd/internal/batch_writer_test.go b/cmd/internal/batch_writer_test.go new file mode 100644 index 0000000..c89aa67 --- /dev/null +++ b/cmd/internal/batch_writer_test.go @@ -0,0 +1,31 @@ +package internal + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestCanSplitIntoBatches(t *testing.T) { + var messages []ImportMessage + + n := 1 + for n <= 20 { + messages = append(messages, ImportMessage{ + Action: "upsert", + }) + + n++ + } + + stream := &Stream{} + + batches := getBatchMessages(messages, stream, 1, 100*1024*1024) + assert.Equal(t, len(messages), len(batches)) + totalMessages := 0 + for _, batch := range batches { + assert.Equal(t, 1, len(batch.Messages)) + totalMessages += len(batch.Messages) + } + + assert.Equal(t, len(messages), totalMessages) +}