diff --git a/src/reaper/database.go b/src/reaper/database.go index 1a73576c3c..1029d24268 100644 --- a/src/reaper/database.go +++ b/src/reaper/database.go @@ -988,7 +988,7 @@ func CloseResourcesForRemovedMonitoredDBs(metricsWriter *sinks.MultiWriter, curr // or to be ignored due to current instance state for roleChangedDB := range shutDownDueToRoleChange { - if db := currentDBs.GetDatabase(roleChangedDB); db != nil { + if db := currentDBs.GetMonitoredDatabase(roleChangedDB); db != nil { db.Conn.Close() } _ = metricsWriter.SyncMetrics(roleChangedDB, "", "remove") diff --git a/src/reaper/reaper.go b/src/reaper/reaper.go index 2bf4500732..b7b6d5893c 100644 --- a/src/reaper/reaper.go +++ b/src/reaper/reaper.go @@ -78,7 +78,7 @@ func (r *Reaper) Reap(mainContext context.Context) (err error) { continue } } - if monitoredDbs, err = monitoredDbs.Expand(); err != nil { + if monitoredDbs, err = monitoredDbs.ResolveDatabases(); err != nil { logger.Error(err) continue } diff --git a/src/sources/postgres.go b/src/sources/postgres.go index f54672a97d..a7edf30487 100644 --- a/src/sources/postgres.go +++ b/src/sources/postgres.go @@ -1,5 +1,8 @@ package sources +// This file contains the implementation of the ReaderWriter interface for the PostgreSQL database. +// Monitored sources are stored in the `pgwatch3.source` table in the configuration database. + import ( "context" @@ -30,14 +33,14 @@ func (r *dbSourcesReaderWriter) WriteMonitoredDatabases(dbs MonitoredDatabases) } defer func() { _ = tx.Rollback(context.Background()) }() for _, md := range dbs { - if err = updateDatabase(tx, md); err != nil { + if err = r.updateDatabase(tx, md); err != nil { return err } } return tx.Commit(context.Background()) } -func updateDatabase(conn db.PgxIface, md *MonitoredDatabase) (err error) { +func (r *dbSourcesReaderWriter) updateDatabase(conn db.PgxIface, md *MonitoredDatabase) (err error) { sql := `insert into pgwatch3.source( name, "group", dbtype, connstr, config, config_standby, preset_config, preset_config_standby, is_superuser, include_pattern, exclude_pattern, custom_tags, host_config, only_if_master) @@ -54,7 +57,7 @@ host_config = $13, only_if_master = $14` } func (r *dbSourcesReaderWriter) UpdateDatabase(md *MonitoredDatabase) error { - return updateDatabase(r.configDb, md) + return r.updateDatabase(r.configDb, md) } func (r *dbSourcesReaderWriter) DeleteDatabase(name string) error { diff --git a/src/sources/postgres_test.go b/src/sources/postgres_test.go new file mode 100644 index 0000000000..3f954e4209 --- /dev/null +++ b/src/sources/postgres_test.go @@ -0,0 +1,152 @@ +package sources_test + +import ( + "errors" + "testing" + + "github.com/pashagolub/pgxmock/v3" + "github.com/stretchr/testify/assert" + + "github.com/cybertec-postgresql/pgwatch3/sources" +) + +func TestNewPostgresSourcesReaderWriter(t *testing.T) { + a := assert.New(t) + conn, err := pgxmock.NewPool() + a.NoError(err) + conn.ExpectPing() + + pgrw, err := sources.NewPostgresSourcesReaderWriter(ctx, conn) + a.NoError(err) + a.NotNil(t, pgrw) + a.NoError(conn.ExpectationsWereMet()) +} + +func TestGetMonitoredDatabases(t *testing.T) { + a := assert.New(t) + conn, err := pgxmock.NewPool() + a.NoError(err) + conn.ExpectPing() + conn.ExpectQuery(`select \/\* pgwatch3_generated \*\/`).WillReturnRows(pgxmock.NewRows([]string{ + "name", "group", "dbtype", "connstr", "config", "config_standby", "preset_config", + "preset_config_standby", "is_superuser", "include_pattern", "exclude_pattern", + "custom_tags", "host_config", "only_if_master", "is_enabled", + }).AddRow( + "db1", "group1", sources.Kind("postgres"), "postgres://user:pass@localhost:5432/db1", + map[string]float64{"metric": 60}, map[string]float64{"standby_metric": 60}, "exhaustive", "exhaustive", + true, ".*", `\_.+`, map[string]string{"tag": "value"}, nil, true, true, + )) + pgrw, err := sources.NewPostgresSourcesReaderWriter(ctx, conn) + a.NoError(err) + + dbs, err := pgrw.GetMonitoredDatabases() + a.NoError(err) + a.Len(dbs, 1) + a.NoError(conn.ExpectationsWereMet()) + + // check failed query + conn.ExpectQuery(`select \/\* pgwatch3_generated \*\/`).WillReturnError(errors.New("failed query")) + dbs, err = pgrw.GetMonitoredDatabases() + a.Error(err) + a.Nil(dbs) + a.NoError(conn.ExpectationsWereMet()) +} +func TestDeleteDatabase(t *testing.T) { + a := assert.New(t) + conn, err := pgxmock.NewPool() + a.NoError(err) + conn.ExpectPing() + conn.ExpectExec(`delete from pgwatch3\.source where name = \$1`).WithArgs("db1").WillReturnResult(pgxmock.NewResult("DELETE", 1)) + pgrw, err := sources.NewPostgresSourcesReaderWriter(ctx, conn) + a.NoError(err) + + err = pgrw.DeleteDatabase("db1") + a.NoError(err) + a.NoError(conn.ExpectationsWereMet()) +} + +func TestUpdateDatabase(t *testing.T) { + a := assert.New(t) + conn, err := pgxmock.NewPool() + a.NoError(err) + + md := &sources.MonitoredDatabase{} + conn.ExpectPing() + conn.ExpectExec(`insert into pgwatch3\.source`).WithArgs( + md.DBUniqueName, md.Group, md.Kind, + md.ConnStr, md.Metrics, md.MetricsStandby, md.PresetMetrics, md.PresetMetricsStandby, + md.IsSuperuser, md.IncludePattern, md.ExcludePattern, md.CustomTags, + md.HostConfig, md.OnlyIfMaster, + ).WillReturnResult(pgxmock.NewResult("UPDATE", 1)) + + pgrw, err := sources.NewPostgresSourcesReaderWriter(ctx, conn) + a.NoError(err) + err = pgrw.UpdateDatabase(md) + a.NoError(err) + a.NoError(conn.ExpectationsWereMet()) +} + +func TestWriteMonitoredDatabases(t *testing.T) { + var ( + pgrw sources.ReaderWriter + err error + ) + a := assert.New(t) + conn, err := pgxmock.NewPool() + a.NoError(err) + md := &sources.MonitoredDatabase{} + mds := sources.MonitoredDatabases{md} + + t.Run("happy path", func(*testing.T) { + conn.ExpectPing() + conn.ExpectBegin() + conn.ExpectExec(`truncate pgwatch3\.source`).WillReturnResult(pgxmock.NewResult("TRUNCATE", 1)) + conn.ExpectExec(`insert into pgwatch3\.source`).WithArgs( + md.DBUniqueName, md.Group, md.Kind, + md.ConnStr, md.Metrics, md.MetricsStandby, md.PresetMetrics, md.PresetMetricsStandby, + md.IsSuperuser, md.IncludePattern, md.ExcludePattern, md.CustomTags, + md.HostConfig, md.OnlyIfMaster, + ).WillReturnResult(pgxmock.NewResult("INSERT", 1)) + conn.ExpectCommit() + conn.ExpectRollback() // deferred rollback + + pgrw, err = sources.NewPostgresSourcesReaderWriter(ctx, conn) + a.NoError(err) + err = pgrw.WriteMonitoredDatabases(mds) + a.NoError(err) + a.NoError(conn.ExpectationsWereMet()) + }) + + t.Run("failed transaction begin", func(*testing.T) { + conn.ExpectBegin().WillReturnError(errors.New("failed transaction begin")) + + err = pgrw.WriteMonitoredDatabases(mds) + a.Error(err) + a.NoError(conn.ExpectationsWereMet()) + }) + + t.Run("failed truncate", func(*testing.T) { + conn.ExpectBegin() + conn.ExpectExec(`truncate pgwatch3\.source`).WillReturnError(errors.New("failed truncate")) + + err = pgrw.WriteMonitoredDatabases(mds) + a.Error(err) + a.NoError(conn.ExpectationsWereMet()) + }) + + t.Run("failed insert", func(*testing.T) { + conn.ExpectBegin() + conn.ExpectExec(`truncate pgwatch3\.source`).WillReturnResult(pgxmock.NewResult("TRUNCATE", 1)) + conn.ExpectExec(`insert into pgwatch3\.source`).WithArgs( + md.DBUniqueName, md.Group, md.Kind, + md.ConnStr, md.Metrics, md.MetricsStandby, md.PresetMetrics, md.PresetMetricsStandby, + md.IsSuperuser, md.IncludePattern, md.ExcludePattern, md.CustomTags, + md.HostConfig, md.OnlyIfMaster, + ).WillReturnError(errors.New("failed insert")) + conn.ExpectRollback() + + err = pgrw.WriteMonitoredDatabases(mds) + a.Error(err) + a.NoError(conn.ExpectationsWereMet()) + }) +} diff --git a/src/sources/patroni.go b/src/sources/resolver.go similarity index 80% rename from src/sources/patroni.go rename to src/sources/resolver.go index e8ce7a959f..1ae1605313 100644 --- a/src/sources/patroni.go +++ b/src/sources/resolver.go @@ -1,5 +1,9 @@ package sources +// This file contains the implemendation of Patroni and PostgrSQL resolvers for continuous monitoring. +// Patroni resolver will return the list of databases from the Patroni cluster. +// Postgres resolver will return the list of databases from the given Postgres instance. + import ( "context" "crypto/tls" @@ -21,6 +25,37 @@ import ( client "go.etcd.io/etcd/client/v3" ) +// ResolveDatabases() updates list of monitored objects from continuous monitoring sources, e.g. patroni +func (mds MonitoredDatabases) ResolveDatabases() (MonitoredDatabases, error) { + resolvedDbs := make(MonitoredDatabases, 0, len(mds)) + for _, md := range mds { + if !md.IsEnabled { + continue + } + dbs, err := md.ResolveDatabases() + if err != nil { + return nil, err + } + if len(dbs) == 0 { + resolvedDbs = append(resolvedDbs, md) + continue + } + resolvedDbs = append(resolvedDbs, dbs...) + } + return resolvedDbs, nil +} + +// ResolveDatabases() return a slice of found databases for continuous monitoring sources, e.g. patroni +func (md *MonitoredDatabase) ResolveDatabases() (MonitoredDatabases, error) { + switch md.Kind { + case SourcePatroni, SourcePatroniContinuous, SourcePatroniNamespace: + return ResolveDatabasesFromPatroni(md) + case SourcePostgresContinuous: + return ResolveDatabasesFromPostgres(md) + } + return nil, nil +} + type PatroniClusterMember struct { Scope string Name string @@ -321,3 +356,44 @@ func ResolveDatabasesFromPatroni(ce *MonitoredDatabase) ([]*MonitoredDatabase, e return md, err } + +// "resolving" reads all the DB names from the given host/port, additionally matching/not matching specified regex patterns +func ResolveDatabasesFromPostgres(md *MonitoredDatabase) (resolvedDbs MonitoredDatabases, err error) { + var ( + c db.PgxPoolIface + dbname string + rows pgx.Rows + ) + c, err = db.New(context.TODO(), md.ConnStr) + if err != nil { + return + } + defer c.Close() + + sql := `select /* pgwatch3_generated */ + quote_ident(datname)::text as datname_escaped + from pg_database + where not datistemplate + and datallowconn + and has_database_privilege (datname, 'CONNECT') + and case when length(trim($1)) > 0 then datname ~ $1 else true end + and case when length(trim($2)) > 0 then not datname ~ $2 else true end` + + if rows, err = c.Query(context.TODO(), sql, md.IncludePattern, md.ExcludePattern); err != nil { + return nil, err + } + for rows.Next() { + if err = rows.Scan(&dbname); err != nil { + return nil, err + } + rdb := md.Clone() + rdb.DBUniqueName += "_" + dbname + rdb.SetDatabaseName(dbname) + resolvedDbs = append(resolvedDbs, rdb) + } + + if err := rows.Err(); err != nil { + return nil, err + } + return +} diff --git a/src/sources/resolver_test.go b/src/sources/resolver_test.go new file mode 100644 index 0000000000..85cea7375f --- /dev/null +++ b/src/sources/resolver_test.go @@ -0,0 +1,45 @@ +package sources_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/cybertec-postgresql/pgwatch3/sources" + testcontainers "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" +) + +func TestMonitoredDatabase_ResolveDatabasesFromPostgres(t *testing.T) { + pgContainer, err := postgres.RunContainer(ctx, + testcontainers.WithImage("docker.io/postgres:16-alpine"), + postgres.WithDatabase("mydatabase"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2). + WithStartupTimeout(5*time.Second)), + ) + assert.NoError(t, err) + defer func() { assert.NoError(t, pgContainer.Terminate(ctx)) }() + + // Create a new MonitoredDatabase instance + md := &sources.MonitoredDatabase{DBUniqueName: "continuous", Kind: sources.SourcePostgresContinuous} + md.ConnStr, err = pgContainer.ConnectionString(ctx) + assert.NoError(t, err) + + // Call the ResolveDatabasesFromPostgres method + dbs, err := sources.ResolveDatabasesFromPostgres(md) + assert.NoError(t, err) + assert.True(t, len(dbs) == 2) //postgres and mydatabase + + // check the "continuous_mydatabase" + db := dbs.GetMonitoredDatabase(md.DBUniqueName + "_mydatabase") + assert.NotNil(t, db) + assert.Equal(t, "mydatabase", db.GetDatabaseName()) + + //check unexpected database + db = dbs.GetMonitoredDatabase(md.DBUniqueName + "_unexpected") + assert.Nil(t, db) +} diff --git a/src/sources/types.go b/src/sources/types.go index 845b2ced73..a29d97a9e9 100644 --- a/src/sources/types.go +++ b/src/sources/types.go @@ -104,83 +104,11 @@ func (md *MonitoredDatabase) IsPostgresSource() bool { return md.Kind != SourcePgBouncer && md.Kind != SourcePgPool } -// ExpandDatabases() return a slice of found databases for continuous monitoring sources, e.g. patroni -func (md *MonitoredDatabase) ExpandDatabases() (MonitoredDatabases, error) { - switch md.Kind { - case SourcePatroni, SourcePatroniContinuous, SourcePatroniNamespace: - return ResolveDatabasesFromPatroni(md) - case SourcePostgresContinuous: - return md.ResolveDatabasesFromPostgres() - } - return nil, nil -} - -// "resolving" reads all the DB names from the given host/port, additionally matching/not matching specified regex patterns -func (md *MonitoredDatabase) ResolveDatabasesFromPostgres() (resolvedDbs MonitoredDatabases, err error) { - var ( - c db.PgxPoolIface - dbname string - rows pgx.Rows - ) - c, err = db.New(context.TODO(), md.ConnStr) - if err != nil { - return - } - defer c.Close() - - sql := `select /* pgwatch3_generated */ - quote_ident(datname)::text as datname_escaped - from pg_database - where not datistemplate - and datallowconn - and has_database_privilege (datname, 'CONNECT') - and case when length(trim($1)) > 0 then datname ~ $1 else true end - and case when length(trim($2)) > 0 then not datname ~ $2 else true end` - - if rows, err = c.Query(context.TODO(), sql, md.IncludePattern, md.ExcludePattern); err != nil { - return nil, err - } - for rows.Next() { - if err = rows.Scan(&dbname); err != nil { - return nil, err - } - rdb := md.Clone() - rdb.DBUniqueName += "_" + dbname - rdb.SetDatabaseName(dbname) - resolvedDbs = append(resolvedDbs, rdb) - } - - if err := rows.Err(); err != nil { - return nil, err - } - return -} - type MonitoredDatabases []*MonitoredDatabase -// Expand() updates list of monitored objects from continuous monitoring sources, e.g. patroni -func (mds MonitoredDatabases) Expand() (MonitoredDatabases, error) { - resolvedDbs := make(MonitoredDatabases, 0, len(mds)) - for _, md := range mds { - if !md.IsEnabled { - continue - } - dbs, err := md.ExpandDatabases() - if err != nil { - return nil, err - } - if len(dbs) == 0 { - resolvedDbs = append(resolvedDbs, md) - continue - } - resolvedDbs = append(resolvedDbs, dbs...) - } - return resolvedDbs, nil -} - -func (mds MonitoredDatabases) GetDatabase(name string) *MonitoredDatabase { +func (mds MonitoredDatabases) GetMonitoredDatabase(DBUniqueName string) *MonitoredDatabase { for _, md := range mds { - if md.DBUniqueName == name { + if md.DBUniqueName == DBUniqueName { return md } } @@ -192,11 +120,11 @@ func (mds MonitoredDatabases) SyncFromReader(r Reader) (MonitoredDatabases, erro if err != nil { return nil, err } - if newMDs, err = newMDs.Expand(); err != nil { + if newMDs, err = newMDs.ResolveDatabases(); err != nil { return nil, err } for _, newMD := range newMDs { - if md := mds.GetDatabase(newMD.DBUniqueName); md != nil { + if md := mds.GetMonitoredDatabase(newMD.DBUniqueName); md != nil { newMD.Conn = md.Conn } } diff --git a/src/sources/types_test.go b/src/sources/types_test.go index 3f2193693d..f350efccd4 100644 --- a/src/sources/types_test.go +++ b/src/sources/types_test.go @@ -88,26 +88,3 @@ func TestMonitoredDatabase_IsPostgresSource(t *testing.T) { md.Kind = sources.SourcePatroni assert.True(t, md.IsPostgresSource(), "IsPostgresSource() = false, want true") } - -func TestMonitoredDatabase_ResolveDatabasesFromPostgres(t *testing.T) { - pgContainer, err := postgres.RunContainer(ctx, - testcontainers.WithImage("docker.io/postgres:16-alpine"), - postgres.WithDatabase("mydatabase"), - testcontainers.WithWaitStrategy( - wait.ForLog("database system is ready to accept connections"). - WithOccurrence(2). - WithStartupTimeout(5*time.Second)), - ) - assert.NoError(t, err) - defer func() { assert.NoError(t, pgContainer.Terminate(ctx)) }() - - // Create a new MonitoredDatabase instance - md := &sources.MonitoredDatabase{} - md.ConnStr, err = pgContainer.ConnectionString(ctx) - assert.NoError(t, err) - - // Call the ResolveDatabasesFromPostgres method - dbs, err := md.ResolveDatabasesFromPostgres() - assert.NoError(t, err) - assert.True(t, len(dbs) == 2) //postgres and mydatabase -} diff --git a/src/sources/yaml.go b/src/sources/yaml.go index a207945b39..ed17c4c741 100644 --- a/src/sources/yaml.go +++ b/src/sources/yaml.go @@ -1,5 +1,7 @@ package sources +// This file contains the implementation of the ReaderWriter interface for the YAML file. + import ( "context" "io/fs"