Skip to content

Commit

Permalink
Merge pull request #79 from xataio/add-snapshot-store
Browse files Browse the repository at this point in the history
Add snapshot store
  • Loading branch information
eminano authored Oct 15, 2024
2 parents b4c9619 + 8849ad3 commit 801ad51
Show file tree
Hide file tree
Showing 8 changed files with 623 additions and 4 deletions.
6 changes: 4 additions & 2 deletions internal/postgres/mocks/mock_pg_querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ import (
type Querier struct {
QueryRowFn func(ctx context.Context, query string, args ...any) postgres.Row
QueryFn func(ctx context.Context, query string, args ...any) (postgres.Rows, error)
ExecFn func(context.Context, string, ...any) (postgres.CommandTag, error)
ExecFn func(context.Context, uint, string, ...any) (postgres.CommandTag, error)
ExecInTxFn func(context.Context, func(tx postgres.Tx) error) error
CloseFn func(context.Context) error
execCalls uint
}

func (m *Querier) QueryRow(ctx context.Context, query string, args ...any) postgres.Row {
Expand All @@ -25,7 +26,8 @@ func (m *Querier) Query(ctx context.Context, query string, args ...any) (postgre
}

func (m *Querier) Exec(ctx context.Context, query string, args ...any) (postgres.CommandTag, error) {
return m.ExecFn(ctx, query, args...)
m.execCalls++
return m.ExecFn(ctx, m.execCalls, query, args...)
}

func (m *Querier) ExecInTx(ctx context.Context, fn func(tx postgres.Tx) error) error {
Expand Down
56 changes: 56 additions & 0 deletions internal/postgres/mocks/mock_rows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// SPDX-License-Identifier: Apache-2.0

package mocks

import (
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)

type Rows struct {
CloseFn func()
ErrFn func() error
FieldDescriptionsFn func() []pgconn.FieldDescription
NextFn func(i uint) bool
ScanFn func(dest ...any) error
ValuesFn func() ([]any, error)
RawValuesFn func() [][]byte
nextCalls uint
}

func (m *Rows) Close() {
m.CloseFn()
}

func (m *Rows) Err() error {
return m.ErrFn()
}

func (m *Rows) CommandTag() pgconn.CommandTag {
return pgconn.CommandTag{}
}

func (m *Rows) FieldDescriptions() []pgconn.FieldDescription {
return m.FieldDescriptionsFn()
}

func (m *Rows) Next() bool {
m.nextCalls++
return m.NextFn(m.nextCalls)
}

func (m *Rows) Scan(dest ...any) error {
return m.ScanFn(dest...)
}

func (m *Rows) Values() ([]any, error) {
return m.ValuesFn()
}

func (m *Rows) RawValues() [][]byte {
return m.RawValuesFn()
}

func (m *Rows) Conn() *pgx.Conn {
return &pgx.Conn{}
}
4 changes: 2 additions & 2 deletions pkg/schemalog/postgres/pg_schemalog_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func TestStore_Ack(t *testing.T) {
{
name: "ok",
querier: &pgmocks.Querier{
ExecFn: func(_ context.Context, query string, args ...any) (pglib.CommandTag, error) {
ExecFn: func(_ context.Context, _ uint, query string, args ...any) (pglib.CommandTag, error) {
require.Len(t, args, 2)
require.Equal(t, args[0], testID.String())
require.Equal(t, args[1], testSchema)
Expand All @@ -133,7 +133,7 @@ func TestStore_Ack(t *testing.T) {
{
name: "error - executing update query",
querier: &pgmocks.Querier{
ExecFn: func(_ context.Context, query string, args ...any) (pglib.CommandTag, error) {
ExecFn: func(_ context.Context, _ uint, query string, args ...any) (pglib.CommandTag, error) {
return pglib.CommandTag{CommandTag: pgconn.NewCommandTag("")}, errTest
},
},
Expand Down
32 changes: 32 additions & 0 deletions pkg/snapshot/snapshot.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// SPDX-License-Identifier: Apache-2.0

package snapshot

type Snapshot struct {
SchemaName string
TableName string
IdentityColumnNames []string
Status Status
Error error
}

type Status string

const (
StatusRequested = Status("requested")
StatusInProgress = Status("in progress")
StatusCompleted = Status("completed")
)

func (s *Snapshot) IsValid() bool {
return s != nil && s.SchemaName != "" && s.TableName != "" && len(s.IdentityColumnNames) > 0
}

func (s *Snapshot) MarkCompleted(err error) {
s.Status = StatusCompleted
s.Error = err
}

func (s *Snapshot) MarkInProgress() {
s.Status = StatusInProgress
}
27 changes: 27 additions & 0 deletions pkg/snapshot/store/mocks/mock_snapshot_store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-License-Identifier: Apache-2.0

package mocks

import (
"context"

"github.com/xataio/pgstream/pkg/snapshot"
)

type Store struct {
CreateSnapshotRequestFn func(context.Context, *snapshot.Snapshot) error
UpdateSnapshotRequestFn func(context.Context, *snapshot.Snapshot) error
GetSnapshotRequestsFn func(ctx context.Context, status snapshot.Status) ([]*snapshot.Snapshot, error)
}

func (m *Store) CreateSnapshotRequest(ctx context.Context, s *snapshot.Snapshot) error {
return m.CreateSnapshotRequestFn(ctx, s)
}

func (m *Store) UpdateSnapshotRequest(ctx context.Context, s *snapshot.Snapshot) error {
return m.UpdateSnapshotRequestFn(ctx, s)
}

func (m *Store) GetSnapshotRequests(ctx context.Context, status snapshot.Status) ([]*snapshot.Snapshot, error) {
return m.GetSnapshotRequestsFn(ctx, status)
}
116 changes: 116 additions & 0 deletions pkg/snapshot/store/postgres/pg_snapshot_store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// SPDX-License-Identifier: Apache-2.0

package postgres

import (
"context"
"fmt"

"github.com/lib/pq"
"github.com/xataio/pgstream/internal/postgres"
"github.com/xataio/pgstream/pkg/snapshot"
"github.com/xataio/pgstream/pkg/snapshot/store"
)

type Store struct {
conn postgres.Querier
}

const queryLimit = 1000

func NewStore(ctx context.Context, url string) (*Store, error) {
conn, err := postgres.NewConnPool(ctx, url)
if err != nil {
return nil, err
}

s := &Store{
conn: conn,
}

// create snapshots table if it doesn't exist
if err := s.createTable(ctx); err != nil {
return nil, fmt.Errorf("creating snapshots table: %w", err)
}

return s, nil
}

func (s *Store) Close() error {
return s.conn.Close(context.Background())
}

func (s *Store) CreateSnapshotRequest(ctx context.Context, req *snapshot.Snapshot) error {
query := fmt.Sprintf(`INSERT INTO %s (schema_name, table_name, identity_column_names, created_at, updated_at, status)
VALUES($1, $2, $3,'now()','now()','requested')`, snapshotsTable())
_, err := s.conn.Exec(ctx, query, req.SchemaName, req.TableName, pq.StringArray(req.IdentityColumnNames))
if err != nil {
return fmt.Errorf("error creating snapshot request: %w", err)
}
return nil
}

func (s *Store) UpdateSnapshotRequest(ctx context.Context, req *snapshot.Snapshot) error {
errStr := ""
if req.Error != nil {
errStr = req.Error.Error()
}
query := fmt.Sprintf(`UPDATE %s SET status = '%s', error = '%s', updated_at = 'now()'
WHERE schema_name = '%s' and table_name = '%s' and status != 'completed'`, snapshotsTable(), req.Status, errStr, req.SchemaName, req.TableName)
_, err := s.conn.Exec(ctx, query)
if err != nil {
return fmt.Errorf("error updating snapshot request: %w", err)
}
return nil
}

func (s *Store) GetSnapshotRequests(ctx context.Context, status snapshot.Status) ([]*snapshot.Snapshot, error) {
query := fmt.Sprintf(`SELECT schema_name,table_name,identity_column_names,status FROM %s
WHERE status = '%s' ORDER BY req_id ASC LIMIT %d`, snapshotsTable(), status, queryLimit)
rows, err := s.conn.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("error getting snapshot requests: %w", err)
}
defer rows.Close()

snapshots := []*snapshot.Snapshot{}
for rows.Next() {
snapshot := &snapshot.Snapshot{}
if err := rows.Scan(&snapshot.SchemaName, &snapshot.TableName, &snapshot.IdentityColumnNames, &snapshot.Status); err != nil {
return nil, fmt.Errorf("scanning snapshot row: %w", err)
}

snapshots = append(snapshots, snapshot)
}

return snapshots, nil
}

func (s *Store) createTable(ctx context.Context) error {
createQuery := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s(
req_id SERIAL PRIMARY KEY,
schema_name TEXT,
table_name TEXT,
identity_column_names TEXT[],
created_at TIMESTAMP WITH TIME ZONE,
updated_at TIMESTAMP WITH TIME ZONE,
status TEXT CHECK (status IN ('requested', 'in progress', 'completed')),
error TEXT )`, snapshotsTable())
_, err := s.conn.Exec(ctx, createQuery)
if err != nil {
return fmt.Errorf("error creating snapshots postgres table: %w", err)
}

uniqueIndexQuery := fmt.Sprintf(`CREATE UNIQUE INDEX IF NOT EXISTS schema_table_status_unique_index
ON %s(schema_name,table_name) WHERE status != 'completed'`, snapshotsTable())
_, err = s.conn.Exec(ctx, uniqueIndexQuery)
if err != nil {
return fmt.Errorf("error creating unique index on snapshots postgres table: %w", err)
}

return err
}

func snapshotsTable() string {
return fmt.Sprintf("%s.%s", pq.QuoteIdentifier(store.SchemaName), pq.QuoteIdentifier(store.TableName))
}
Loading

0 comments on commit 801ad51

Please sign in to comment.