Skip to content

Commit

Permalink
Refactor http calls to support message payloads (#646)
Browse files Browse the repository at this point in the history
To enable client retries in the future, the http calls now operates directly
on message payloads instead of relying solely on `io.Writer`s. This
refactor replaces write operations with send operations at message
boundaries. It introduces a new interface, `messagePayload`, implemented
by `*bytes.Reader`s and `*envelopes`, enabling them to adopt sender
operations.
  • Loading branch information
emcfarlane authored Dec 6, 2023
1 parent 90ed22b commit cb84690
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 50 deletions.
54 changes: 50 additions & 4 deletions duplex_http_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,21 @@ func newDuplexHTTPCall(
}
}

// Write to the request body.
func (d *duplexHTTPCall) Write(data []byte) (int, error) {
// Send sends a message to the server.
func (d *duplexHTTPCall) Send(payload messsagePayload) (int64, error) {
isFirst := d.ensureRequestMade()
// Before we send any data, check if the context has been canceled.
if err := d.ctx.Err(); err != nil {
return 0, wrapIfContextError(err)
}
if isFirst && data == nil {
if isFirst && payload.Len() == 0 {
// On first write a nil Send is used to send request headers. Avoid
// writing a zero-length payload to avoid superfluous errors with close.
return 0, nil
}
// It's safe to write to this side of the pipe while net/http concurrently
// reads from the other side.
bytesWritten, err := d.requestBodyWriter.Write(data)
bytesWritten, err := payload.WriteTo(d.requestBodyWriter)
if err != nil && errors.Is(err, io.ErrClosedPipe) {
// Signal that the stream is closed with the more-typical io.EOF instead of
// io.ErrClosedPipe. This makes it easier for protocol-specific wrappers to
Expand Down Expand Up @@ -295,6 +295,52 @@ func (d *duplexHTTPCall) makeRequest() {
}
}

// messsagePayload is a sized and seekable message payload. The interface is
// implemented by [*bytes.Reader] and *envelope.
type messsagePayload interface {
io.Reader
io.WriterTo
io.Seeker
Len() int
}

// nopPayload is a message payload that does nothing. It's used to send headers
// to the server.
type nopPayload struct{}

var _ messsagePayload = nopPayload{}

func (nopPayload) Read([]byte) (int, error) {
return 0, io.EOF
}
func (nopPayload) WriteTo(io.Writer) (int64, error) {
return 0, nil
}
func (nopPayload) Seek(int64, int) (int64, error) {
return 0, nil
}
func (nopPayload) Len() int {
return 0
}

// messageSender sends a message payload. The interface is implemented by
// [*duplexHTTPCall] and writeSender.
type messageSender interface {
Send(messsagePayload) (int64, error)
}

// writeSender is a sender that writes to an [io.Writer]. Useful for wrapping
// [http.ResponseWriter].
type writeSender struct {
writer io.Writer
}

var _ messageSender = writeSender{}

func (w writeSender) Send(payload messsagePayload) (int64, error) {
return payload.WriteTo(w.writer)
}

// See: https://cs.opensource.google/go/go/+/refs/tags/go1.20.1:src/net/http/clone.go;l=22-33
func cloneURL(oldURL *url.URL) *url.URL {
if oldURL == nil {
Expand Down
93 changes: 82 additions & 11 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,84 @@ var errSpecialEnvelope = errorf(
// message length. gRPC and Connect interpret the bitwise flags differently, so
// envelope leaves their interpretation up to the caller.
type envelope struct {
Data *bytes.Buffer
Flags uint8
Data *bytes.Buffer
Flags uint8
offset int64
}

var _ messsagePayload = (*envelope)(nil)

func (e *envelope) IsSet(flag uint8) bool {
return e.Flags&flag == flag
}

// Read implements [io.Reader].
func (e *envelope) Read(data []byte) (readN int, err error) {
if e.offset < 5 {
prefix := makeEnvelopePrefix(e.Flags, e.Data.Len())
readN = copy(data, prefix[e.offset:])
e.offset += int64(readN)
if e.offset < 5 {
return readN, nil
}
data = data[readN:]
}
n := copy(data, e.Data.Bytes()[e.offset-5:])
e.offset += int64(n)
readN += n
if readN == 0 && e.offset == int64(e.Data.Len()+5) {
err = io.EOF
}
return readN, err
}

// WriteTo implements [io.WriterTo].
func (e *envelope) WriteTo(dst io.Writer) (wroteN int64, err error) {
if e.offset < 5 {
prefix := makeEnvelopePrefix(e.Flags, e.Data.Len())
prefixN, err := dst.Write(prefix[e.offset:])
e.offset += int64(prefixN)
wroteN += int64(prefixN)
if e.offset < 5 {
return wroteN, err
}
}
n, err := dst.Write(e.Data.Bytes()[e.offset-5:])
e.offset += int64(n)
wroteN += int64(n)
return wroteN, err
}

// Seek implements [io.Seeker]. Based on the implementation of [bytes.Reader].
func (e *envelope) Seek(offset int64, whence int) (int64, error) {
var abs int64
switch whence {
case io.SeekStart:
abs = offset
case io.SeekCurrent:
abs = e.offset + offset
case io.SeekEnd:
abs = int64(e.Data.Len()) + offset
default:
return 0, errors.New("connect.envelope.Seek: invalid whence")
}
if abs < 0 {
return 0, errors.New("connect.envelope.Seek: negative position")
}
e.offset = abs
return abs, nil
}

// Len returns the number of bytes of the unread portion of the envelope.
func (e *envelope) Len() int {
if length := int(int64(e.Data.Len()) + 5 - e.offset); length > 0 {
return length
}
return 0
}

type envelopeWriter struct {
writer io.Writer
sender messageSender
codec Codec
compressMinBytes int
compressionPool *compressionPool
Expand All @@ -59,7 +127,9 @@ type envelopeWriter struct {

func (w *envelopeWriter) Marshal(message any) *Error {
if message == nil {
if _, err := w.writer.Write(nil); err != nil {
// Send no-op message to create the request and send headers.
payload := nopPayload{}
if _, err := w.sender.Send(payload); err != nil {
if connectErr, ok := asError(err); ok {
return connectErr
}
Expand Down Expand Up @@ -137,18 +207,12 @@ func (w *envelopeWriter) marshal(message any) *Error {
}

func (w *envelopeWriter) write(env *envelope) *Error {
prefix := [5]byte{}
prefix[0] = env.Flags
binary.BigEndian.PutUint32(prefix[1:5], uint32(env.Data.Len()))
if _, err := w.writer.Write(prefix[:]); err != nil {
if _, err := w.sender.Send(env); err != nil {
if connectErr, ok := asError(err); ok {
return connectErr
}
return errorf(CodeUnknown, "write envelope: %w", err)
}
if _, err := io.Copy(w.writer, env.Data); err != nil {
return errorf(CodeUnknown, "write message: %w", err)
}
return nil
}

Expand Down Expand Up @@ -279,3 +343,10 @@ func (r *envelopeReader) Read(env *envelope) *Error {
env.Flags = prefixes[0]
return nil
}

func makeEnvelopePrefix(flags uint8, size int) [5]byte {
prefix := [5]byte{}
prefix[0] = flags
binary.BigEndian.PutUint32(prefix[1:5], uint32(size))
return prefix
}
90 changes: 67 additions & 23 deletions envelope_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,87 @@ package connect

import (
"bytes"
"encoding/binary"
"io"
"testing"

"connectrpc.com/connect/internal/assert"
)

func TestEnvelope_read(t *testing.T) {
func TestEnvelope(t *testing.T) {
t.Parallel()

head := [5]byte{}
payload := []byte(`{"number": 42}`)
binary.BigEndian.PutUint32(head[1:], uint32(len(payload)))

head := makeEnvelopePrefix(0, len(payload))
buf := &bytes.Buffer{}
buf.Write(head[:])
buf.Write(payload)

t.Run("full", func(t *testing.T) {
t.Run("read", func(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
reader: bytes.NewReader(buf.Bytes()),
}
assert.Nil(t, rdr.Read(env))
assert.Equal(t, payload, env.Data.Bytes())
t.Run("full", func(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
reader: bytes.NewReader(buf.Bytes()),
}
assert.Nil(t, rdr.Read(env))
assert.Equal(t, payload, env.Data.Bytes())
})
t.Run("byteByByte", func(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
reader: byteByByteReader{
reader: bytes.NewReader(buf.Bytes()),
},
}
assert.Nil(t, rdr.Read(env))
assert.Equal(t, payload, env.Data.Bytes())
})
})
t.Run("byteByByte", func(t *testing.T) {
t.Run("write", func(t *testing.T) {
t.Parallel()
env := &envelope{Data: &bytes.Buffer{}}
rdr := envelopeReader{
reader: byteByByteReader{
reader: bytes.NewReader(buf.Bytes()),
},
}
assert.Nil(t, rdr.Read(env))
assert.Equal(t, payload, env.Data.Bytes())
t.Run("full", func(t *testing.T) {
t.Parallel()
dst := &bytes.Buffer{}
wtr := envelopeWriter{
sender: writeSender{writer: dst},
}
env := &envelope{Data: bytes.NewBuffer(payload)}
err := wtr.Write(env)
assert.Nil(t, err)
assert.Equal(t, buf.Bytes(), dst.Bytes())
})
t.Run("partial", func(t *testing.T) {
t.Parallel()
dst := &bytes.Buffer{}
env := &envelope{Data: bytes.NewBuffer(payload)}
_, err := io.CopyN(dst, env, 2)
assert.Nil(t, err)
_, err = env.WriteTo(dst)
assert.Nil(t, err)
assert.Equal(t, buf.Bytes(), dst.Bytes())
})
})
t.Run("seek", func(t *testing.T) {
t.Parallel()
t.Run("start", func(t *testing.T) {
t.Parallel()
dst1 := &bytes.Buffer{}
dst2 := &bytes.Buffer{}
env := &envelope{Data: bytes.NewBuffer(payload)}
_, err := io.CopyN(dst1, env, 2)
assert.Nil(t, err)
assert.Equal(t, env.Len(), len(payload)+3)
_, err = env.Seek(0, io.SeekStart)
assert.Nil(t, err)
assert.Equal(t, env.Len(), len(payload)+5)
_, err = io.CopyN(dst2, env, 2)
assert.Nil(t, err)
assert.Equal(t, dst1.Bytes(), dst2.Bytes())
_, err = env.WriteTo(dst2)
assert.Nil(t, err)
assert.Equal(t, dst2.Bytes(), buf.Bytes())
assert.Equal(t, env.Len(), 0)
})
})
}

Expand Down
2 changes: 1 addition & 1 deletion error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (w *ErrorWriter) writeConnectStreaming(response http.ResponseWriter, err er
response.WriteHeader(http.StatusOK)
marshaler := &connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
writer: response,
sender: writeSender{writer: response},
bufferPool: w.bufferPool,
},
}
Expand Down
13 changes: 7 additions & 6 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func (h *connectHandler) NewConn(
request: request,
responseWriter: responseWriter,
marshaler: connectUnaryMarshaler{
writer: responseWriter,
sender: writeSender{writer: responseWriter},
codec: codec,
compressMinBytes: h.CompressMinBytes,
compressionName: responseCompression,
Expand All @@ -280,7 +280,7 @@ func (h *connectHandler) NewConn(
responseWriter: responseWriter,
marshaler: connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
writer: responseWriter,
sender: writeSender{responseWriter},
codec: codec,
compressMinBytes: h.CompressMinBytes,
compressionPool: h.CompressionPools.Get(responseCompression),
Expand Down Expand Up @@ -375,7 +375,7 @@ func (c *connectClient) NewConn(
bufferPool: c.BufferPool,
marshaler: connectUnaryRequestMarshaler{
connectUnaryMarshaler: connectUnaryMarshaler{
writer: duplexCall,
sender: duplexCall,
codec: c.Codec,
compressMinBytes: c.CompressMinBytes,
compressionName: c.CompressionName,
Expand Down Expand Up @@ -415,7 +415,7 @@ func (c *connectClient) NewConn(
codec: c.Codec,
marshaler: connectStreamingMarshaler{
envelopeWriter: envelopeWriter{
writer: duplexCall,
sender: duplexCall,
codec: c.Codec,
compressMinBytes: c.CompressMinBytes,
compressionPool: c.CompressionPools.Get(c.CompressionName),
Expand Down Expand Up @@ -892,7 +892,7 @@ func (u *connectStreamingUnmarshaler) EndStreamError() *Error {
}

type connectUnaryMarshaler struct {
writer io.Writer
sender messageSender
codec Codec
compressMinBytes int
compressionName string
Expand Down Expand Up @@ -938,7 +938,8 @@ func (m *connectUnaryMarshaler) Marshal(message any) *Error {
}

func (m *connectUnaryMarshaler) write(data []byte) *Error {
if _, err := m.writer.Write(data); err != nil {
payload := bytes.NewReader(data)
if _, err := m.sender.Send(payload); err != nil {
if connectErr, ok := asError(err); ok {
return connectErr
}
Expand Down
2 changes: 1 addition & 1 deletion protocol_connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestConnectEndOfResponseCanonicalTrailers(t *testing.T) {
assert.Nil(t, err)

writer := envelopeWriter{
writer: &buffer,
sender: writeSender{writer: &buffer},
bufferPool: bufferPool,
}
err = writer.Write(&envelope{
Expand Down
Loading

0 comments on commit cb84690

Please sign in to comment.