Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically split up each batch into an acceptable request to the Stitch API #17

Merged
merged 1 commit into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 72 additions & 34 deletions cmd/internal/batch_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
"time"
)

const MaxObjectsInBatch int = 19000
const MaxBatchRequestSize int = 20 * 1024 * 1024

type BatchWriter interface {
Flush(stream *Stream) error
Send(record *Record, stream *Stream) error
Expand Down Expand Up @@ -48,50 +51,45 @@ func (h *httpBatchWriter) Flush(stream *Stream) error {
return nil
}

h.logger.Info(fmt.Sprintf("flushing [%v] messages for stream %q", len(h.messages), stream.Name))
batches := getBatchMessages(h.messages, stream, MaxObjectsInBatch, MaxBatchRequestSize)
h.logger.Info(fmt.Sprintf("flushing [%v] messages for stream %q in [%v] batches", len(h.messages), stream.Name, len(batches)))
for _, batch := range batches {

batch := ImportBatch{
Table: stream.Name,
Schema: stream.Schema,
Messages: h.messages,
PrimaryKeys: stream.KeyProperties,
}
b, err := json.Marshal(batch)
if err != nil {
return err
}

b, err := json.Marshal(batch)
if err != nil {
return err
}
stitch, err := retryablehttp.NewRequest("POST", h.apiURL+"/v2/import/batch", bytes.NewBuffer(b))
if err != nil {
return err
}
stitch.Header.Set("Content-Type", "application/json")
stitch.Header.Set("Authorization", "Bearer "+h.apiToken)

stitch, err := retryablehttp.NewRequest("POST", h.apiURL+"/v2/import/batch", bytes.NewBuffer(b))
if err != nil {
return err
}
stitch.Header.Set("Content-Type", "application/json")
stitch.Header.Set("Authorization", "Bearer "+h.apiToken)
stitchResponse, err := h.client.Do(stitch)
if err != nil {
return err
}

stitchResponse, err := h.client.Do(stitch)
if err != nil {
return err
}
defer stitchResponse.Body.Close()

defer stitchResponse.Body.Close()
if stitchResponse.StatusCode > 203 {
body, err := ioutil.ReadAll(stitchResponse.Body)
if err != nil {
return err
}
return fmt.Errorf("server request failed with %s", body)
}

if stitchResponse.StatusCode > 203 {
body, err := ioutil.ReadAll(stitchResponse.Body)
if err != nil {
var resp BatchResponse
decoder := json.NewDecoder(stitchResponse.Body)
if err := decoder.Decode(&resp); err != nil {
return err
}
return fmt.Errorf("server request failed with %s", body)
}

var resp BatchResponse
decoder := json.NewDecoder(stitchResponse.Body)
if err := decoder.Decode(&resp); err != nil {
return err
h.logger.Info(fmt.Sprintf("Server response status : %q, message : %q", resp.Status, resp.Message))
}

h.logger.Info(fmt.Sprintf("Server response status : %q, message : %q", resp.Status, resp.Message))

h.messages = h.messages[:0]

return nil
Expand All @@ -106,6 +104,46 @@ func (h *httpBatchWriter) Send(record *Record, stream *Stream) error {
return nil
}

// getBatchMessages accepts a list of import messages
// and returns a slice of ImportBatch that can be safely uploaded.
// The rules are:
// 1. There cannot be more than 20,000 records in the request.
// 2. The size of the serialized JSON cannot be more than 20 MB.
func getBatchMessages(messages []ImportMessage, stream *Stream, maxObjectsInBatch int, maxBatchSerializedSize int) []ImportBatch {
var batches []ImportBatch
allocated := 0
unallocated := len(messages)

for unallocated > 0 {
batch := ImportBatch{
Table: stream.Name,
Schema: stream.Schema,
Messages: messages[allocated:],
PrimaryKeys: stream.KeyProperties,
}

// reduce the size of the batch until it is an acceptable size.
for batch.SizeOf() > maxBatchSerializedSize || len(batch.Messages) > maxObjectsInBatch {
// keep halving the number of messages until the batch is an acceptable size.
batch.Messages = batch.Messages[0:(len(batch.Messages) / 2)]
}

allocated += len(batch.Messages)
unallocated -= len(batch.Messages)
batches = append(batches, batch)
}

return batches
}

func (imb *ImportBatch) SizeOf() int {
b, err := json.Marshal(imb)
if err != nil {
return 0
}
return len(b)
}

func createImportMessage(record *Record) ImportMessage {
now := time.Now()
return ImportMessage{
Expand Down
31 changes: 31 additions & 0 deletions cmd/internal/batch_writer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package internal

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestCanSplitIntoBatches(t *testing.T) {
var messages []ImportMessage

n := 1
for n <= 20 {
messages = append(messages, ImportMessage{
Action: "upsert",
})

n++
}

stream := &Stream{}

batches := getBatchMessages(messages, stream, 1, 100*1024*1024)
assert.Equal(t, len(messages), len(batches))
totalMessages := 0
for _, batch := range batches {
assert.Equal(t, 1, len(batch.Messages))
totalMessages += len(batch.Messages)
}

assert.Equal(t, len(messages), totalMessages)
}