Skip to content

Commit

Permalink
feat: some improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
gfyrag committed Dec 12, 2022
1 parent f9066d6 commit 6f5b11b
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 17 deletions.
8 changes: 6 additions & 2 deletions pgtesting/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@ import (
)
func TestMain(m *testing.M) {
pgtesting.CreatePostgresServer()
if err := pgtesting.CreatePostgresServer(); err != nil {
log.Fatal(err)
}
code := m.Run()
pgtesting.DestroyPostgresServer()
if err := pgtesting.DestroyPostgresServer(); err != nil {
log.Fatal(err)
}
os.Exit(code)
}
Expand Down
115 changes: 100 additions & 15 deletions pgtesting/pkg/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package pgtesting
import (
"context"
"fmt"
"log"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -42,21 +41,29 @@ func (s *pgServer) NewDatabase(t *testing.T) *pgDatabase {
_, err := s.conn.Exec(context.Background(), fmt.Sprintf(`CREATE DATABASE "%s"`, databaseName))
require.NoError(t, err)

t.Cleanup(func() {
s.lock.Lock()
defer s.lock.Unlock()

_, _ = s.conn.Exec(context.Background(), fmt.Sprintf(`DROP DATABASE "%s"`, databaseName))
})

return &pgDatabase{
url: s.dsn(databaseName),
}
}

func (s *pgServer) Close() {
func (s *pgServer) Close() error {
if s.conn == nil {
return
return nil
}
if err := s.conn.Close(context.Background()); err != nil {
log.Fatal("error closing connection: ", err)
return err
}
if err := s.destroy(); err != nil {
log.Fatal("error destroying pg server: ", err)
return err
}
return nil
}

var srv *pgServer
Expand All @@ -65,11 +72,86 @@ func NewPostgresDatabase(t *testing.T) *pgDatabase {
return srv.NewDatabase(t)
}

func DestroyPostgresServer() {
srv.Close()
func DestroyPostgresServer() error {
return srv.Close()
}

type config struct {
initialDatabaseName string
initialUserPassword string
initialUsername string
statusCheckInterval time.Duration
maximumWaitingTime time.Duration
context context.Context
}

func (c config) validate() error {
if c.statusCheckInterval == 0 {
return errors.New("status check interval must be greater than 0")
}
if c.initialUsername == "" {
return errors.New("initial username must be defined")
}
if c.initialUserPassword == "" {
return errors.New("initial user password must be defined")
}
if c.initialDatabaseName == "" {
return errors.New("initial database name must be defined")
}
return nil
}

type option func(opts *config)

func WithInitialDatabaseName(name string) option {
return func(opts *config) {
opts.initialDatabaseName = name
}
}

func WithInitialUser(username, pwd string) option {
return func(opts *config) {
opts.initialUserPassword = pwd
opts.initialUsername = username
}
}

func WithStatusCheckInterval(d time.Duration) option {
return func(opts *config) {
opts.statusCheckInterval = d
}
}

func CreatePostgresServer() error {
func WithMaximumWaitingTime(d time.Duration) option {
return func(opts *config) {
opts.maximumWaitingTime = d
}
}

func WithContext(ctx context.Context) option {
return func(opts *config) {
opts.context = ctx
}
}

var defaultOptions = []option{
WithStatusCheckInterval(200 * time.Millisecond),
WithInitialUser("root", "root"),
WithMaximumWaitingTime(5 * time.Second),
WithInitialDatabaseName("formance"),
WithContext(context.Background()),
}

func CreatePostgresServer(opts ...option) error {

cfg := config{}
for _, opt := range append(defaultOptions, opts...) {
opt(&cfg)
}

if err := cfg.validate(); err != nil {
return errors.Wrap(err, "validating config")
}

pool, err := dockertest.NewPool("")
if err != nil {
Expand All @@ -80,9 +162,9 @@ func CreatePostgresServer() error {
Repository: "postgres",
Tag: "15-alpine",
Env: []string{
"POSTGRES_USER=root",
"POSTGRES_PASSWORD=root",
"POSTGRES_DB=formance",
fmt.Sprintf("POSTGRES_USER=%s", cfg.initialUsername),
fmt.Sprintf("POSTGRES_PASSWORD=%s", cfg.initialUserPassword),
fmt.Sprintf("POSTGRES_DB=%s", cfg.initialDatabaseName),
},
Entrypoint: nil,
Cmd: []string{"-c", "superuser-reserved-connections=0"},
Expand All @@ -99,12 +181,15 @@ func CreatePostgresServer() error {
}

try := time.Duration(0)
delay := 200 * time.Millisecond
for try*delay < 5*time.Second {
srv.conn, err = pgx.Connect(context.Background(), srv.dsn("formance"))
for try*cfg.statusCheckInterval < cfg.maximumWaitingTime {
srv.conn, err = pgx.Connect(context.Background(), srv.dsn(cfg.initialDatabaseName))
if err != nil {
try++
<-time.After(delay)
select {
case <-cfg.context.Done():
return cfg.context.Err()
case <-time.After(cfg.statusCheckInterval):
}
continue
}
return nil
Expand Down

0 comments on commit 6f5b11b

Please sign in to comment.