Skip to content

Commit

Permalink
Simplify and optimize tunnel reader
Browse files Browse the repository at this point in the history
  • Loading branch information
robinbraemer committed Feb 25, 2022
1 parent 25836fd commit 98f4e78
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 69 deletions.
6 changes: 3 additions & 3 deletions api/buf.lock
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ deps:
owner: googleapis
repository: googleapis
branch: main
commit: bdfe2a6dbea44b2aadd202aef3b02139
digest: b1-fsLQE8NjJb-wMtT1cRL6GOXR5WX__qNZRUzlZUo0VBs=
create_time: 2022-02-20T15:03:24.321088Z
commit: 0b64ae0918a6421b8deeffb20bd7c58a
digest: b1-rcRLiZYvmis9EDBO1iRVlWakV5U988o2oNLyldwzo6Q=
create_time: 2022-02-24T15:04:40.484238Z
98 changes: 52 additions & 46 deletions tunnel_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"errors"
"io"
"net"
"os"
"time"
)
Expand All @@ -25,41 +26,40 @@ func newTunnelReader(
ctx context.Context,
readFn readFn,
) TunnelReader {
r := &tunnelReader{
readChan := make(chan interface{})
go readLoop(ctx, readFn, readChan)
return &tunnelReader{
ctx: ctx,
deadline: newDeadline(),
readNext: make(chan chan struct{}),
readChan: readChan,
}
go func() {
var data []byte
for {
select {
case res := <-r.readNext:
data, r.err = readFn()
_, _ = r.buf.Write(data)
select {
case res <- struct{}{}:
case <-r.timeout.Done(): // read retryable
case <-ctx.Done():
return // stop read loop
}
if errors.Is(r.err, io.EOF) {
return // stop read loop
}
case <-ctx.Done():
return // stop read loop
}
}

func readLoop(ctx context.Context, rd readFn, out chan<- interface{}) {
var v interface{}
var err error
for {
v, err = rd()
if err != nil {
v = err
}
select {
case out <- v:
case <-ctx.Done():
return
}
if errors.Is(err, io.EOF) {
return
}
}()
return r
}
}

type readFn func() ([]byte, error)

type tunnelReader struct {
ctx context.Context
*deadline
readNext chan chan struct{}
readChan <-chan interface{}
buf bytes.Buffer
err error
}
Expand All @@ -72,35 +72,41 @@ func (t *tunnelReader) Read(p []byte) (int, error) {
if t.buf.Len() != 0 {
return t.buf.Read(p)
}
// try read more data, can be more or less than len(p)
res := make(chan struct{})
// trigger next read
// Check reader is already closed or timed out
select {
case t.readNext <- res:
case <-t.timeout.Done():
return 0, os.ErrDeadlineExceeded
return t.timedOut()
case <-t.ctx.Done():
return 0, t.ctx.Err()
return t.ctxDone()
default:
}
// wait until we can read from buf
// receive read
select {
case <-res:
err := t.err
if t.err != nil {
if !errors.Is(t.err, io.EOF) {
t.err = nil
}
return 0, err
case v := <-t.readChan:
if t.err, _ = v.(error); t.err != nil {
return 0, t.err
}
b := v.([]byte)
if len(b) > len(p) {
// buffer last bytes for next Read
t.buf.Write(b[len(p):])
return copy(p, b[:len(p)]), nil
}
// successful read
return copy(p, b), nil
case <-t.ctx.Done():
return t.ctxDone()
case <-t.timeout.Done():
return t.timedOut()
}
}

func (t *tunnelReader) ctxDone() (int, error) {
if errors.Is(t.ctx.Err(), context.DeadlineExceeded) {
return 0, os.ErrDeadlineExceeded
case <-t.ctx.Done():
return 0, t.ctx.Err()
}
// It returns the number of bytes read (0 <= n <= len(p)) and any error encountered.
// If some data is available but not len(p) bytes,
// Read conventionally returns what is available instead of waiting for more,
// as specified by the io.Reader interface.
return t.buf.Read(p)
return 0, net.ErrClosed
}

func (t *tunnelReader) timedOut() (int, error) {
return 0, os.ErrDeadlineExceeded
}
64 changes: 45 additions & 19 deletions tunnel_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,79 @@ package connect

import (
"context"
"github.com/stretchr/testify/require"
"io"
"os"
"testing"
"time"

"github.com/stretchr/testify/require"
)

type out struct {
data []byte
err error
}

func newReader(ctx context.Context, o ...out) TunnelReader {
i := 0
return newTunnelReader(ctx, func() ([]byte, error) {
r := o[i]
i++
return r.data, r.err
})
}

func TestTunnelReader_Read(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
var c int
r := newTunnelReader(ctx, func() ([]byte, error) {
if c == 1 {
time.Sleep(time.Hour) // block
return nil, nil
}
c++
return []byte("hello"), nil
})

r := newReader(ctx,
out{data: []byte("hello")},
out{err: io.EOF},
)

b := make([]byte, 2)
n, err := r.Read(b)
require.NoError(t, err)
require.Equal(t, len(b), n)
require.Equal(t, "he", string(b))

b = make([]byte, 3)
b = make([]byte, 3+10)
n, err = r.Read(b)
require.NoError(t, err)
require.Equal(t, len(b), n)
require.Equal(t, "llo", string(b))
require.Equal(t, 3, n)
require.Equal(t, "llo", string(b[:3]))

b = make([]byte, 1)
n, err = r.Read(b)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.NotNil(t, err)
require.ErrorIs(t, err, io.EOF)
require.Equal(t, []byte{0}, b)
require.Equal(t, 0, n)

n, err = r.Read(b)
require.NotNil(t, err)
require.ErrorIs(t, err, io.EOF)
require.Equal(t, 0, n)
}

func TestTunnelReader_ReadDeadline(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
r := newTunnelReader(ctx, func() ([]byte, error) {
return []byte("hello"), nil
})

r := newReader(ctx,
out{data: []byte("hello")},
out{err: io.EOF},
)

err := r.SetDeadline(time.Now().Add(time.Second / 2))
require.NoError(t, err)

time.Sleep(time.Second)
b := make([]byte, 5)
n, err := r.Read(b)
require.Empty(t, n)
require.NotNil(t, err)
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
require.Equal(t, []byte{0, 0, 0, 0, 0}, b)
require.Equal(t, 0, n)
Expand All @@ -60,16 +84,18 @@ func TestTunnelReader_ReadDeadline(t *testing.T) {
err = r.SetDeadline(time.Now().Add(time.Second * 1))
require.NoError(t, err)

b = make([]byte, 5)
b = make([]byte, 10)
n, err = r.Read(b)
require.NoError(t, err)
require.Equal(t, "hello", string(b))
require.Equal(t, "hello", string(b[:5]))
require.Equal(t, 5, n)

time.Sleep(time.Millisecond * 1300)

b = make([]byte, 5)
n, err = r.Read(b)
require.Empty(t, n)
require.NotNil(t, err)
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
require.Equal(t, []byte{0, 0, 0, 0, 0}, b)
require.Equal(t, 0, n)
Expand Down
5 changes: 4 additions & 1 deletion tunnel_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package connect

import (
"context"
"github.com/stretchr/testify/require"
"os"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestTunnelWriter_Write(t *testing.T) {
Expand Down Expand Up @@ -50,6 +51,7 @@ func TestTunnelWriter_WriteDeadline(t *testing.T) {

time.Sleep(time.Second)
n, err := w.Write(b)
require.NotNil(t, err)
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
require.Equal(t, 0, n)
time.Sleep(time.Millisecond * 50)
Expand All @@ -65,6 +67,7 @@ func TestTunnelWriter_WriteDeadline(t *testing.T) {
time.Sleep(time.Millisecond * 1300)

n, err = w.Write(b)
require.NotNil(t, err)
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
require.Equal(t, 0, n)
}

0 comments on commit 98f4e78

Please sign in to comment.