diff --git a/cmd/trade.go b/cmd/trade.go index f8da3e33e..8db4d7511 100644 --- a/cmd/trade.go +++ b/cmd/trade.go @@ -539,7 +539,7 @@ func runTradeCmd(options inputs) { } var e error - db, e = database.ConnectInitializedDatabase(botConfig.PostgresDbConfig, upgradeScripts) + db, e = database.ConnectInitializedDatabase(botConfig.PostgresDbConfig, upgradeScripts, version) if e != nil { logger.Fatal(l, fmt.Errorf("problem encountered while initializing the db: %s", e)) } diff --git a/support/database/schema.go b/support/database/schema.go index 479608149..679bfe2df 100644 --- a/support/database/schema.go +++ b/support/database/schema.go @@ -10,6 +10,7 @@ import ( tables */ const SqlDbVersionTableCreate = "CREATE TABLE IF NOT EXISTS db_version (version INTEGER NOT NULL, date_completed_utc TIMESTAMP WITHOUT TIME ZONE NOT NULL, num_scripts INTEGER NOT NULL, time_elapsed_millis BIGINT NOT NULL, PRIMARY KEY (version))" +const SqlDbVersionTableAlter1 = "ALTER TABLE db_version ADD COLUMN code_version_string TEXT" /* queries diff --git a/support/database/upgrade.go b/support/database/upgrade.go index 1ccc28824..74c25359d 100644 --- a/support/database/upgrade.go +++ b/support/database/upgrade.go @@ -12,7 +12,8 @@ import ( ) // sqlDbVersionTableInsertTemplate inserts into the db_version table -const sqlDbVersionTableInsertTemplate = "INSERT INTO db_version (version, date_completed_utc, num_scripts, time_elapsed_millis) VALUES (%d, '%s', %d, %d)" +const sqlDbVersionTableInsertTemplate1 = "INSERT INTO db_version (version, date_completed_utc, num_scripts, time_elapsed_millis) VALUES (%d, '%s', %d, %d)" +const sqlDbVersionTableInsertTemplate2 = "INSERT INTO db_version (version, date_completed_utc, num_scripts, time_elapsed_millis, code_version_string) VALUES (%d, '%s', %d, %d, '%s')" // UpgradeScript encapsulates a script to be run to upgrade the database from one version to the next type UpgradeScript struct { @@ -31,8 +32,13 @@ func MakeUpgradeScript(version uint32, command string, moreCommands ...string) * } } +var UpgradeScripts = []*UpgradeScript{ + MakeUpgradeScript(1, SqlDbVersionTableCreate), + MakeUpgradeScript(2, SqlDbVersionTableAlter1), +} + // ConnectInitializedDatabase creates a database with the required metadata tables -func ConnectInitializedDatabase(postgresDbConfig *postgresdb.Config, upgradeScripts []*UpgradeScript) (*sql.DB, error) { +func ConnectInitializedDatabase(postgresDbConfig *postgresdb.Config, upgradeScripts []*UpgradeScript, codeVersionString string) (*sql.DB, error) { dbCreated, e := postgresdb.CreateDatabaseIfNotExists(postgresDbConfig) if e != nil { if strings.Contains(e.Error(), "connect: connection refused") { @@ -53,7 +59,7 @@ func ConnectInitializedDatabase(postgresDbConfig *postgresdb.Config, upgradeScri // don't defer db.Close() here becuase we want it open for the life of the application for now log.Printf("creating db schema and running upgrade scripts ...\n") - e = runUpgradeScripts(db, upgradeScripts) + e = runUpgradeScripts(db, upgradeScripts, codeVersionString) if e != nil { return nil, fmt.Errorf("could not run upgrade scripts: %s", e) } @@ -62,16 +68,20 @@ func ConnectInitializedDatabase(postgresDbConfig *postgresdb.Config, upgradeScri return db, nil } -func runUpgradeScripts(db *sql.DB, scripts []*UpgradeScript) error { - currentDbVersion, e := QueryDbVersion(db) - if e != nil { - if !strings.Contains(e.Error(), "relation \"db_version\" does not exist") { - return fmt.Errorf("could not fetch current db version: %s", e) - } - currentDbVersion = 0 - } +func runUpgradeScripts(db *sql.DB, scripts []*UpgradeScript, codeVersionString string) error { + // save feature flags for the db_version table here + hasCodeVersionString := false for _, script := range scripts { + // fetch the db version inside the for loop because it constantly gets updated + currentDbVersion, e := QueryDbVersion(db) + if e != nil { + if !strings.Contains(e.Error(), "relation \"db_version\" does not exist") { + return fmt.Errorf("could not fetch current db version: %s", e) + } + currentDbVersion = 0 + } + if script.version <= currentDbVersion { log.Printf(" skipping upgrade script for version %d because current db version (%d) is equal or ahead\n", script.version, currentDbVersion) continue @@ -95,17 +105,34 @@ func runUpgradeScripts(db *sql.DB, scripts []*UpgradeScript) error { endTimeMillis := time.Now().UnixNano() / int64(time.Millisecond) elapsedMillis := endTimeMillis - startTimeMillis + // update feature flags here where required after running a script so we don't need to hard-code version numbers which can be different for different consumers of this API + for _, command := range script.commands { + if command == SqlDbVersionTableAlter1 { + // if we have run this alter table command it means the database version has the code_version_string feature + hasCodeVersionString = true + } + } + // add entry to db_version table - sqlInsertDbVersion := fmt.Sprintf(sqlDbVersionTableInsertTemplate, + sqlInsertDbVersion := fmt.Sprintf(sqlDbVersionTableInsertTemplate1, script.version, startTime.Format(postgresdb.TimestampFormatString), len(script.commands), elapsedMillis, ) + if hasCodeVersionString { + sqlInsertDbVersion = fmt.Sprintf(sqlDbVersionTableInsertTemplate2, + script.version, + startTime.Format(postgresdb.TimestampFormatString), + len(script.commands), + elapsedMillis, + codeVersionString, + ) + } _, e = db.Exec(sqlInsertDbVersion) if e != nil { // duplicate insert should return an error - return fmt.Errorf("could not execute sql insert values statement in db_version table for db version %d (%s): %s", script.version, sqlInsertDbVersion, e) + return fmt.Errorf("could not add an entry to the db_version table for upgrade script (db_version=%d) for current db version %d (%s): %s", script.version, currentDbVersion, sqlInsertDbVersion, e) } // commit transaction diff --git a/support/database/upgrade_test.go b/support/database/upgrade_test.go new file mode 100644 index 000000000..0c58ce6de --- /dev/null +++ b/support/database/upgrade_test.go @@ -0,0 +1,293 @@ +package database + +import ( + "database/sql" + "fmt" + "os" + "reflect" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/stellar/kelp/support/postgresdb" +) + +func preTest(t *testing.T) (*sql.DB, string) { + // use ToLower because the database name cannot be uppercase in postgres + dbname := fmt.Sprintf("test_database_%s_%d", strings.ToLower(t.Name()), time.Now().UnixNano()) + postgresDbConfig := &postgresdb.Config{ + Host: "localhost", + Port: 5432, + DbName: dbname, + User: os.Getenv("POSTGRES_USER"), + SSLEnable: false, + } + + // create empty database + db, e := ConnectInitializedDatabase(postgresDbConfig, []*UpgradeScript{}, "") + if e != nil { + panic(e) + } + + return db, dbname +} + +// execWithNewManagedConnection creates a new connection which is "managed" (i.e. closing is handled) so the passed in function can just use it to do their business +func execWithNewManagedConnection(fn func(db *sql.DB)) { + postgresDbConfig := &postgresdb.Config{ + Host: "localhost", + Port: 5432, + DbName: "postgres", + User: os.Getenv("POSTGRES_USER"), + SSLEnable: false, + } + // connect to the db + db, e := sql.Open("postgres", postgresDbConfig.MakeConnectStringWithoutDB()) + if e != nil { + panic(e) + } + // defer closing this new connection + defer db.Close() + + // delegate to passed in function + fn(db) +} + +func dropDatabaseWithNewConnection(dbname string) { + execWithNewManagedConnection(func(db *sql.DB) { + _, e := db.Exec(fmt.Sprintf("DROP DATABASE %s", dbname)) + if e != nil { + panic(e) + } + }) +} + +func postTestWithDbClose(db *sql.DB, dbname string) { + // defer statements are executed in LIFO order + + // second delete the database (internally creates a new db connection and then deletes it) + defer dropDatabaseWithNewConnection(dbname) + + // first close the existing db connection + defer db.Close() +} + +func checkDatabaseExistsWithNewConnection(dbname string) bool { + hasDatabase := false + execWithNewManagedConnection(func(db *sql.DB) { + rows, e := db.Query(fmt.Sprintf("SELECT datname FROM pg_database WHERE datname = '%s'", dbname)) + if e != nil { + panic(e) + } + + hasDatabase = rows.Next() + }) + return hasDatabase +} + +func getNumTablesInDb(db *sql.DB) int { + // run the query -- note that we need to be connected to the database of interest + tablesQueryResult, e := db.Query("select COUNT(*) from pg_stat_user_tables") + if e != nil { + panic(e) + } + defer tablesQueryResult.Close() // remembering to defer closing the query + + tablesQueryResult.Next() // remembering to call Next() before Scan() + var count int + e = tablesQueryResult.Scan(&count) + if e != nil { + panic(e) + } + + return count +} + +func checkTableExists(db *sql.DB, tableName string) bool { + tablesQueryResult, e := db.Query(fmt.Sprintf("select tablename from pg_catalog.pg_tables where tablename = '%s'", tableName)) + if e != nil { + panic(e) + } + defer tablesQueryResult.Close() // remembering to defer closing the query + + return tablesQueryResult.Next() +} + +type tableColumn struct { + columnName string // `db:"column_name"` + ordinalPosition int // `db:"ordinal_position"` + columnDefault interface{} // `db:"column_default"` + isNullable string // `db:"is_nullable"` // uses "YES" / "NO" instead of a boolean + dataType string // `db:"data_type"` + characterMaximumLength interface{} // `db:"character_maximum_length"` +} + +func assertTableColumnsEqual(t *testing.T, want *tableColumn, actual *tableColumn) { + assert.Equal(t, want.columnName, actual.columnName) + assert.Equal(t, want.ordinalPosition, actual.ordinalPosition) + assert.Equal(t, want.columnDefault, actual.columnDefault) + assert.Equal(t, want.isNullable, actual.isNullable) + assert.Equal(t, want.dataType, actual.dataType) + assert.Equal(t, want.characterMaximumLength, actual.characterMaximumLength) +} + +func getTableSchema(db *sql.DB, tableName string) []tableColumn { + schemaQueryResult, e := db.Query(fmt.Sprintf("SELECT column_name, ordinal_position, column_default, is_nullable, data_type, character_maximum_length FROM information_schema.columns WHERE table_schema = 'public' AND table_name = '%s'", tableName)) + if e != nil { + panic(e) + } + defer schemaQueryResult.Close() // remembering to defer closing the query + + items := []tableColumn{} + for schemaQueryResult.Next() { // remembering to call Next() before Scan() + var item tableColumn + e = schemaQueryResult.Scan(&item.columnName, &item.ordinalPosition, &item.columnDefault, &item.isNullable, &item.dataType, &item.characterMaximumLength) + if e != nil { + panic(e) + } + + items = append(items, item) + } + + return items +} + +func queryAllRows(db *sql.DB, tableName string) [][]interface{} { + queryResult, e := db.Query(fmt.Sprintf("SELECT * FROM %s", tableName)) + if e != nil { + panic(e) + } + defer queryResult.Close() // remembering to defer closing the query + + allRows := [][]interface{}{} + for queryResult.Next() { // remembering to call Next() before Scan() + // we want to generically query the table + colTypes, e := queryResult.ColumnTypes() + if e != nil { + panic(e) + } + columnValues := []interface{}{} + for i := 0; i < len(colTypes); i++ { + columnValues = append(columnValues, new(interface{})) + } + + e = queryResult.Scan(columnValues...) + if e != nil { + panic(e) + } + + allRows = append(allRows, columnValues) + } + + return allRows +} + +func TestCurrentClassTestInfra(t *testing.T) { + // run the preTest + db, dbname := preTest(t) + // assert state of preTest + assert.NotNil(t, db) + assert.Equal(t, strings.ToLower(dbname), dbname) + assert.True(t, checkDatabaseExistsWithNewConnection(dbname)) + assert.Equal(t, 0, getNumTablesInDb(db)) + + // run the postTest + postTestWithDbClose(db, dbname) + // assert state after the postTest + assert.False(t, checkDatabaseExistsWithNewConnection(dbname)) +} + +func TestUpgradeScripts(t *testing.T) { + // run the preTest and defer running the postTest + db, dbname := preTest(t) + defer postTestWithDbClose(db, dbname) + + // run the upgrade scripts + codeVersionString := "someCodeVersion" + runUpgradeScripts(db, UpgradeScripts, codeVersionString) + + // assert current state of the database + assert.Equal(t, 1, getNumTablesInDb(db)) + assert.True(t, checkTableExists(db, "db_version")) + + // check schema of db_version table + columns := getTableSchema(db, "db_version") + assert.Equal(t, 5, len(columns), fmt.Sprintf("%v", columns)) + assertTableColumnsEqual(t, &tableColumn{ + columnName: "version", + ordinalPosition: 1, + columnDefault: nil, + isNullable: "NO", + dataType: "integer", + characterMaximumLength: nil, + }, &columns[0]) + assertTableColumnsEqual(t, &tableColumn{ + columnName: "date_completed_utc", + ordinalPosition: 2, + columnDefault: nil, + isNullable: "NO", + dataType: "timestamp without time zone", + characterMaximumLength: nil, + }, &columns[1]) + assertTableColumnsEqual(t, &tableColumn{ + columnName: "num_scripts", + ordinalPosition: 3, + columnDefault: nil, + isNullable: "NO", + dataType: "integer", + characterMaximumLength: nil, + }, &columns[2]) + assertTableColumnsEqual(t, &tableColumn{ + columnName: "time_elapsed_millis", + ordinalPosition: 4, + columnDefault: nil, + isNullable: "NO", + dataType: "bigint", + characterMaximumLength: nil, + }, &columns[3]) + assertTableColumnsEqual(t, &tableColumn{ + columnName: "code_version_string", + ordinalPosition: 5, + columnDefault: nil, + isNullable: "YES", + dataType: "text", + characterMaximumLength: nil, + }, &columns[4]) + + // check entries of db_version table + allRows := queryAllRows(db, "db_version") + assert.Equal(t, 2, len(allRows)) + // first code_version_string is nil becuase the field was not supported at the time when the upgrade script was run, and only in version 2 of + // the database do we add the field. See UpgradeScripts and runUpgradeScripts() for more details + validateDBVersionRow(t, allRows[0], 1, time.Now(), 1, 10, nil) + validateDBVersionRow(t, allRows[1], 2, time.Now(), 1, 10, &codeVersionString) +} + +func validateDBVersionRow( + t *testing.T, + actualRow []interface{}, + wantVersion int, + wantDateCompletedUTC time.Time, + wantNumScripts int, + wantTimeElapsedMillis int, + wantCodeVersionString *string, +) { + // first check length + if assert.Equal(t, 5, len(actualRow)) { + assert.Equal(t, fmt.Sprintf("%d", wantVersion), fmt.Sprintf("%v", reflect.ValueOf(actualRow[0]).Elem())) + assert.Equal(t, wantDateCompletedUTC.Format("20060102"), reflect.ValueOf(actualRow[1]).Elem().Interface().(time.Time).Format("20060102")) + assert.Equal(t, fmt.Sprintf("%v", wantNumScripts), fmt.Sprintf("%v", reflect.ValueOf(actualRow[2]).Elem())) + elapsed, e := strconv.Atoi(fmt.Sprintf("%v", reflect.ValueOf(actualRow[3]).Elem())) + if assert.Nil(t, e) { + assert.LessOrEqual(t, elapsed, wantTimeElapsedMillis) + } + if wantCodeVersionString == nil { + assert.Equal(t, "", fmt.Sprintf("%v", reflect.ValueOf(actualRow[4]).Elem())) + } else { + assert.Equal(t, *wantCodeVersionString, fmt.Sprintf("%v", reflect.ValueOf(actualRow[4]).Elem())) + } + } +}