Skip to content

Commit

Permalink
ccl/sqlproxyccl: add postgres interceptors for message forwarding
Browse files Browse the repository at this point in the history
Informs cockroachdb#76000.

This commit implements postgres interceptors, namely FrontendInterceptor and
BackendInterceptor, as described in the sqlproxy connection migration RFC.
These interceptors will be used as building blocks for the forwarder component
that we will be adding in a later PR. Since the forwarder component has not
been added, a simple proxy test (i.e. TestSimpleProxy) has been added to
illustrate how the frontend and backend interceptors can be used within the
proxy.

Release note: None
  • Loading branch information
jaylim-crl authored and RajivTS committed Mar 6, 2022
1 parent edf6e7a commit 1ed6790
Show file tree
Hide file tree
Showing 11 changed files with 1,552 additions and 0 deletions.
1 change: 1 addition & 0 deletions pkg/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ ALL_TESTS = [
"//pkg/ccl/spanconfigccl/spanconfigsqlwatcherccl:spanconfigsqlwatcherccl_test",
"//pkg/ccl/sqlproxyccl/denylist:denylist_test",
"//pkg/ccl/sqlproxyccl/idle:idle_test",
"//pkg/ccl/sqlproxyccl/interceptor:interceptor_test",
"//pkg/ccl/sqlproxyccl/tenant:tenant_test",
"//pkg/ccl/sqlproxyccl/throttler:throttler_test",
"//pkg/ccl/sqlproxyccl:sqlproxyccl_test",
Expand Down
38 changes: 38 additions & 0 deletions pkg/ccl/sqlproxyccl/interceptor/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "interceptor",
srcs = [
"backend_interceptor.go",
"base.go",
"chunkreader.go",
"frontend_interceptor.go",
],
importpath = "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor",
visibility = ["//visibility:public"],
deps = [
"//pkg/sql/pgwire/pgwirebase",
"//pkg/util",
"@com_github_cockroachdb_errors//:errors",
"@com_github_jackc_pgproto3_v2//:pgproto3",
],
)

go_test(
name = "interceptor_test",
srcs = [
"backend_interceptor_test.go",
"base_test.go",
"chunkreader_test.go",
"frontend_interceptor_test.go",
"interceptor_test.go",
],
embed = [":interceptor"],
deps = [
"//pkg/sql/pgwire/pgwirebase",
"//pkg/util/leaktest",
"@com_github_cockroachdb_errors//:errors",
"@com_github_jackc_pgproto3_v2//:pgproto3",
"@com_github_stretchr_testify//require",
],
)
71 changes: 71 additions & 0 deletions pkg/ccl/sqlproxyccl/interceptor/backend_interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright 2022 The Cockroach Authors.
//
// Licensed as a CockroachDB Enterprise file under the Cockroach Community
// License (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt

package interceptor

import (
"io"

"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase"
"github.com/jackc/pgproto3/v2"
)

// BackendInterceptor is a server int/erceptor for the Postgres backend protocol.
type BackendInterceptor pgInterceptor

// NewBackendInterceptor creates a BackendInterceptor. bufSize must be at least
// the size of a pgwire message header.
func NewBackendInterceptor(src io.Reader, dst io.Writer, bufSize int) (*BackendInterceptor, error) {
pgi, err := newPgInterceptor(src, dst, bufSize)
if err != nil {
return nil, err
}
return (*BackendInterceptor)(pgi), nil
}

// PeekMsg returns the header of the current pgwire message without advancing
// the interceptor.
//
// See pgInterceptor.PeekMsg for more information.
func (bi *BackendInterceptor) PeekMsg() (typ pgwirebase.ClientMessageType, size int, err error) {
byteType, size, err := (*pgInterceptor)(bi).PeekMsg()
return pgwirebase.ClientMessageType(byteType), size, err
}

// WriteMsg writes the given bytes to the writer dst.
//
// See pgInterceptor.WriteMsg for more information.
func (bi *BackendInterceptor) WriteMsg(data pgproto3.FrontendMessage) (n int, err error) {
return (*pgInterceptor)(bi).WriteMsg(data.Encode(nil))
}

// ReadMsg decodes the current pgwire message and returns a FrontendMessage.
// This also advances the interceptor to the next message.
//
// See pgInterceptor.ReadMsg for more information.
func (bi *BackendInterceptor) ReadMsg() (msg pgproto3.FrontendMessage, err error) {
msgBytes, err := (*pgInterceptor)(bi).ReadMsg()
if err != nil {
return nil, err
}
// errPanicWriter is used here because Receive must not Write.
return pgproto3.NewBackend(newChunkReader(msgBytes), &errPanicWriter{}).Receive()
}

// ForwardMsg sends the current pgwire message to the destination without any
// decoding, and advances the interceptor to the next message.
//
// See pgInterceptor.ForwardMsg for more information.
func (bi *BackendInterceptor) ForwardMsg() (n int, err error) {
return (*pgInterceptor)(bi).ForwardMsg()
}

// Close closes the interceptor, and prevents further operations on it.
func (bi *BackendInterceptor) Close() {
(*pgInterceptor)(bi).Close()
}
117 changes: 117 additions & 0 deletions pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright 2022 The Cockroach Authors.
//
// Licensed as a CockroachDB Enterprise file under the Cockroach Community
// License (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt

package interceptor_test

import (
"bytes"
"testing"

"github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/jackc/pgproto3/v2"
"github.com/stretchr/testify/require"
)

// TestBackendInterceptor tests the BackendInterceptor. Note that the tests
// here are shallow. For detailed ones, see the tests for the internal
// interceptor in base_test.go.
func TestBackendInterceptor(t *testing.T) {
defer leaktest.AfterTest(t)()

q := (&pgproto3.Query{String: "SELECT 1"}).Encode(nil)

t.Run("bufSize too small", func(t *testing.T) {
bi, err := interceptor.NewBackendInterceptor(nil /* src */, nil /* dst */, 1)
require.Error(t, err)
require.Nil(t, bi)
})

t.Run("PeekMsg returns the right message type", func(t *testing.T) {
src := new(bytes.Buffer)
_, err := src.Write(q)
require.NoError(t, err)

bi, err := interceptor.NewBackendInterceptor(src, nil /* dst */, 16)
require.NoError(t, err)
require.NotNil(t, bi)

typ, size, err := bi.PeekMsg()
require.NoError(t, err)
require.Equal(t, pgwirebase.ClientMsgSimpleQuery, typ)
require.Equal(t, 9, size)

bi.Close()
typ, size, err = bi.PeekMsg()
require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error())
require.Equal(t, pgwirebase.ClientMessageType(0), typ)
require.Equal(t, 0, size)
})

t.Run("WriteMsg writes data to dst", func(t *testing.T) {
dst := new(bytes.Buffer)
bi, err := interceptor.NewBackendInterceptor(nil /* src */, dst, 10)
require.NoError(t, err)
require.NotNil(t, bi)

// This is a backend interceptor, so writing goes to the server.
toSend := &pgproto3.Query{String: "SELECT 1"}
n, err := bi.WriteMsg(toSend)
require.NoError(t, err)
require.Equal(t, 14, n)
require.Equal(t, 14, dst.Len())

bi.Close()
n, err = bi.WriteMsg(toSend)
require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error())
require.Equal(t, 0, n)
})

t.Run("ReadMsg decodes the message correctly", func(t *testing.T) {
src := new(bytes.Buffer)
_, err := src.Write(q)
require.NoError(t, err)

bi, err := interceptor.NewBackendInterceptor(src, nil /* dst */, 16)
require.NoError(t, err)
require.NotNil(t, bi)

msg, err := bi.ReadMsg()
require.NoError(t, err)
rmsg, ok := msg.(*pgproto3.Query)
require.True(t, ok)
require.Equal(t, "SELECT 1", rmsg.String)

bi.Close()
msg, err = bi.ReadMsg()
require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error())
require.Nil(t, msg)
})

t.Run("ForwardMsg forwards data to dst", func(t *testing.T) {
src := new(bytes.Buffer)
_, err := src.Write(q)
require.NoError(t, err)
dst := new(bytes.Buffer)

bi, err := interceptor.NewBackendInterceptor(src, dst, 16)
require.NoError(t, err)
require.NotNil(t, bi)

n, err := bi.ForwardMsg()
require.NoError(t, err)
require.Equal(t, 14, n)
require.Equal(t, 14, dst.Len())

bi.Close()
n, err = bi.ForwardMsg()
require.EqualError(t, err, interceptor.ErrInterceptorClosed.Error())
require.Equal(t, 0, n)
})
}
Loading

0 comments on commit 1ed6790

Please sign in to comment.