Skip to content

Commit

Permalink
test: fix api unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Elias Van Ootegem <elias@vega.xyz>
  • Loading branch information
EVODelavega committed Jun 12, 2024
1 parent dfdd3ba commit 0eedcd2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 21 deletions.
20 changes: 2 additions & 18 deletions datanode/api/trading_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package api_test
import (
"context"
"fmt"
"io"
"net"
"testing"
"time"
Expand All @@ -40,8 +39,6 @@ import (

"github.com/golang/mock/gomock"
"github.com/golang/protobuf/proto"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
Expand Down Expand Up @@ -98,9 +95,8 @@ func getTestGRPCServer(t *testing.T, ctx context.Context) (tidy func(), conn *gr

conf.CandlesV2.CandleStore.DefaultCandleIntervals = ""

sqlConn := &sqlstore.ConnectionSource{
Connection: dummyConnection{},
}
sqlConn := &sqlstore.ConnectionSource{}
sqlConn.ToggleTest() // ensure calls to query and copyTo do not fail

bro, err := broker.New(ctx, logging.NewTestLogger(), conf.Broker, "", eventSource)
if err != nil {
Expand Down Expand Up @@ -257,18 +253,6 @@ func getTestGRPCServer(t *testing.T, ctx context.Context) (tidy func(), conn *gr
return tidy, conn, mockCoreServiceClient, err
}

type dummyConnection struct {
sqlstore.Connection
}

func (d dummyConnection) Query(context.Context, string, ...interface{}) (pgx.Rows, error) {
return nil, pgx.ErrNoRows
}

func (d dummyConnection) CopyTo(context.Context, io.Writer, string, ...any) (pgconn.CommandTag, error) {
return pgconn.CommandTag{}, nil
}

func TestSubmitTransaction(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down
18 changes: 15 additions & 3 deletions datanode/sqlstore/connection_tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ import (
)

type ConnectionSource struct {
pool *pgxpool.Pool
log *logging.Logger
pool *pgxpool.Pool
log *logging.Logger
isTest bool
}

type wrappedTx struct {
Expand All @@ -36,7 +37,6 @@ type (
)

func NewTransactionalConnectionSource(ctx context.Context, log *logging.Logger, connConfig ConnectionConfig) (*ConnectionSource, error) {
// func NewConnSource(ctx context.Context, log *logging.Logger, connConfig ConnectionConfig) (*ConnectionSource, error) {
pool, err := CreateConnectionPool(ctx, connConfig)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
Expand All @@ -47,6 +47,10 @@ func NewTransactionalConnectionSource(ctx context.Context, log *logging.Logger,
}, nil
}

func (c *ConnectionSource) ToggleTest() {
c.isTest = true
}

func (c *ConnectionSource) WithConnection(ctx context.Context) (context.Context, error) {
poolConn, err := c.pool.Acquire(ctx)
if err != nil {
Expand Down Expand Up @@ -141,6 +145,10 @@ func (c *ConnectionSource) Commit(ctx context.Context) error {
}

func (c *ConnectionSource) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
// this is nasty, but required for the API tests currently.
if c.isTest && c.pool == nil {
return nil, pgx.ErrNoRows
}
if tx, ok := ctx.Value(txKey{}).(*wrappedTx); ok {
return tx.tx.Query(ctx, sql, args...)
}
Expand Down Expand Up @@ -191,6 +199,10 @@ func (c *ConnectionSource) CopyFrom(ctx context.Context, tableName pgx.Identifie
}

func (c *ConnectionSource) CopyTo(ctx context.Context, w io.Writer, sql string, args ...any) (pgconn.CommandTag, error) {
// this is nasty, but required for the API tests currently.
if c.isTest && c.pool == nil {
return pgconn.CommandTag{}, nil
}
var err error
sql, err = SanitizeSql(sql, args...)
if err != nil {
Expand Down

0 comments on commit 0eedcd2

Please sign in to comment.