Skip to content

Commit

Permalink
Merge pull request #17 from planetscale/split-into-batches
Browse files Browse the repository at this point in the history
Automatically split up each batch into an acceptable request to the Stitch API
  • Loading branch information
Phani Raj authored Aug 10, 2022
2 parents af3bd8a + fa8a055 commit 3836343
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 34 deletions.
106 changes: 72 additions & 34 deletions cmd/internal/batch_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
"time"
)

const MaxObjectsInBatch int = 19000
const MaxBatchRequestSize int = 20 * 1024 * 1024

type BatchWriter interface {
Flush(stream *Stream) error
Send(record *Record, stream *Stream) error
Expand Down Expand Up @@ -48,50 +51,45 @@ func (h *httpBatchWriter) Flush(stream *Stream) error {
return nil
}

h.logger.Info(fmt.Sprintf("flushing [%v] messages for stream %q", len(h.messages), stream.Name))
batches := getBatchMessages(h.messages, stream, MaxObjectsInBatch, MaxBatchRequestSize)
h.logger.Info(fmt.Sprintf("flushing [%v] messages for stream %q in [%v] batches", len(h.messages), stream.Name, len(batches)))
for _, batch := range batches {

batch := ImportBatch{
Table: stream.Name,
Schema: stream.Schema,
Messages: h.messages,
PrimaryKeys: stream.KeyProperties,
}
b, err := json.Marshal(batch)
if err != nil {
return err
}

b, err := json.Marshal(batch)
if err != nil {
return err
}
stitch, err := retryablehttp.NewRequest("POST", h.apiURL+"/v2/import/batch", bytes.NewBuffer(b))
if err != nil {
return err
}
stitch.Header.Set("Content-Type", "application/json")
stitch.Header.Set("Authorization", "Bearer "+h.apiToken)

stitch, err := retryablehttp.NewRequest("POST", h.apiURL+"/v2/import/batch", bytes.NewBuffer(b))
if err != nil {
return err
}
stitch.Header.Set("Content-Type", "application/json")
stitch.Header.Set("Authorization", "Bearer "+h.apiToken)
stitchResponse, err := h.client.Do(stitch)
if err != nil {
return err
}

stitchResponse, err := h.client.Do(stitch)
if err != nil {
return err
}
defer stitchResponse.Body.Close()

defer stitchResponse.Body.Close()
if stitchResponse.StatusCode > 203 {
body, err := ioutil.ReadAll(stitchResponse.Body)
if err != nil {
return err
}
return fmt.Errorf("server request failed with %s", body)
}

if stitchResponse.StatusCode > 203 {
body, err := ioutil.ReadAll(stitchResponse.Body)
if err != nil {
var resp BatchResponse
decoder := json.NewDecoder(stitchResponse.Body)
if err := decoder.Decode(&resp); err != nil {
return err
}
return fmt.Errorf("server request failed with %s", body)
}

var resp BatchResponse
decoder := json.NewDecoder(stitchResponse.Body)
if err := decoder.Decode(&resp); err != nil {
return err
h.logger.Info(fmt.Sprintf("Server response status : %q, message : %q", resp.Status, resp.Message))
}

h.logger.Info(fmt.Sprintf("Server response status : %q, message : %q", resp.Status, resp.Message))

h.messages = h.messages[:0]

return nil
Expand All @@ -106,6 +104,46 @@ func (h *httpBatchWriter) Send(record *Record, stream *Stream) error {
return nil
}

// getBatchMessages accepts a list of import messages
// and returns a slice of ImportBatch that can be safely uploaded.
// The rules are:
// 1. There cannot be more than 20,000 records in the request.
// 2. The size of the serialized JSON cannot be more than 20 MB.
func getBatchMessages(messages []ImportMessage, stream *Stream, maxObjectsInBatch int, maxBatchSerializedSize int) []ImportBatch {
var batches []ImportBatch
allocated := 0
unallocated := len(messages)

for unallocated > 0 {
batch := ImportBatch{
Table: stream.Name,
Schema: stream.Schema,
Messages: messages[allocated:],
PrimaryKeys: stream.KeyProperties,
}

// reduce the size of the batch until it is an acceptable size.
for batch.SizeOf() > maxBatchSerializedSize || len(batch.Messages) > maxObjectsInBatch {
// keep halving the number of messages until the batch is an acceptable size.
batch.Messages = batch.Messages[0:(len(batch.Messages) / 2)]
}

allocated += len(batch.Messages)
unallocated -= len(batch.Messages)
batches = append(batches, batch)
}

return batches
}

func (imb *ImportBatch) SizeOf() int {
b, err := json.Marshal(imb)
if err != nil {
return 0
}
return len(b)
}

func createImportMessage(record *Record) ImportMessage {
now := time.Now()
return ImportMessage{
Expand Down
31 changes: 31 additions & 0 deletions cmd/internal/batch_writer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package internal

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestCanSplitIntoBatches(t *testing.T) {
var messages []ImportMessage

n := 1
for n <= 20 {
messages = append(messages, ImportMessage{
Action: "upsert",
})

n++
}

stream := &Stream{}

batches := getBatchMessages(messages, stream, 1, 100*1024*1024)
assert.Equal(t, len(messages), len(batches))
totalMessages := 0
for _, batch := range batches {
assert.Equal(t, 1, len(batch.Messages))
totalMessages += len(batch.Messages)
}

assert.Equal(t, len(messages), totalMessages)
}

0 comments on commit 3836343

Please sign in to comment.