diff --git a/components/home.go b/components/home.go index 5c6c1ce..948ef01 100644 --- a/components/home.go +++ b/components/home.go @@ -236,12 +236,11 @@ func (home *Home) rightWrapperInputCapture(event *tcell.EventKey) *tcell.EventKe if tab != nil { table := tab.Content - if ((table.Menu != nil && table.Menu.GetSelectedOption() == 1) || table.Menu == nil) && !table.Pagination.GetIsFirstPage() && !table.GetIsLoading() { + if ((table.Menu != nil && table.Menu.GetSelectedOption() == 1) || + table.Menu == nil) && !table.Pagination.GetIsFirstPage() && !table.GetIsLoading() { table.Pagination.SetOffset(table.Pagination.GetOffset() - table.Pagination.GetLimit()) table.FetchRecords(nil) - } - } case commands.PageNext: @@ -250,7 +249,8 @@ func (home *Home) rightWrapperInputCapture(event *tcell.EventKey) *tcell.EventKe if tab != nil { table := tab.Content - if ((table.Menu != nil && table.Menu.GetSelectedOption() == 1) || table.Menu == nil) && !table.Pagination.GetIsLastPage() && !table.GetIsLoading() { + if ((table.Menu != nil && table.Menu.GetSelectedOption() == 1) || + table.Menu == nil) && !table.Pagination.GetIsLastPage() && !table.GetIsLoading() { table.Pagination.SetOffset(table.Pagination.GetOffset() + table.Pagination.GetLimit()) table.FetchRecords(nil) } diff --git a/components/pagination.go b/components/pagination.go index 78bb8cf..31d1228 100644 --- a/components/pagination.go +++ b/components/pagination.go @@ -40,7 +40,6 @@ func NewPagination() *Pagination { TotalRecords: 0, }, } - } func (pagination *Pagination) GetOffset() int { @@ -67,13 +66,16 @@ func (pagination *Pagination) SetTotalRecords(total int) { pagination.state.TotalRecords = total offset := pagination.GetOffset() - limit := pagination.GetLimit() + offset + if offset < total { + offset++ + } + limit := pagination.GetLimit() + offset if limit > total { limit = total } - pagination.textView.SetText(fmt.Sprintf("%d-%d of %d rows", offset+1, limit, total)) + pagination.textView.SetText(fmt.Sprintf("%d-%d of %d rows", offset, limit, total)) } func (pagination *Pagination) SetLimit(limit int) { diff --git a/drivers/postgres.go b/drivers/postgres.go index d4a28a3..29b9fb7 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -30,13 +30,14 @@ func (db *Postgres) TestConnection(urlstr string) error { return db.Connect(urlstr) } -func (db *Postgres) Connect(urlstr string) (err error) { +func (db *Postgres) Connect(urlstr string) error { db.SetProvider(DriverPostgres) - db.Connection, err = dburl.Open(urlstr) + connection, err := dburl.Open(urlstr) if err != nil { return err } + db.Connection = connection err = db.Connection.Ping() if err != nil { @@ -45,30 +46,29 @@ func (db *Postgres) Connect(urlstr string) (err error) { db.Urlstr = urlstr - // get current database - + // Get the current database. rows := db.Connection.QueryRow("SELECT current_database();") database := "" - err = rows.Scan(&database) - - db.CurrentDatabase = database - db.PreviousDatabase = database if err != nil { return err } + db.CurrentDatabase = database + db.PreviousDatabase = database + return nil } -func (db *Postgres) GetDatabases() (databases []string, err error) { +func (db *Postgres) GetDatabases() ([]string, error) { rows, err := db.Connection.Query("SELECT datname FROM pg_database;") if err != nil { return nil, err } defer rows.Close() + var databases []string for rows.Next() { var database string err := rows.Scan(&database) @@ -84,9 +84,7 @@ func (db *Postgres) GetDatabases() (databases []string, err error) { return databases, nil } -func (db *Postgres) GetTables(database string) (tables map[string][]string, err error) { - tables = make(map[string][]string) - +func (db *Postgres) GetTables(database string) (map[string][]string, error) { logger.Info("GetTables", map[string]any{"database": database}) if database == "" { @@ -94,7 +92,7 @@ func (db *Postgres) GetTables(database string) (tables map[string][]string, err } if database != db.CurrentDatabase { - err = db.SwitchDatabase(database) + err := db.SwitchDatabase(database) if err != nil { return nil, err } @@ -113,6 +111,7 @@ func (db *Postgres) GetTables(database string) (tables map[string][]string, err } defer rows.Close() + tables := make(map[string][]string) for rows.Next() { var ( tableName string @@ -131,11 +130,10 @@ func (db *Postgres) GetTables(database string) (tables map[string][]string, err return tables, nil } -func (db *Postgres) GetTableColumns(database, table string) (results [][]string, err error) { +func (db *Postgres) GetTableColumns(database, table string) ([][]string, error) { if database == "" { return nil, errors.New("database name is required") } - if table == "" { return nil, errors.New("table name is required") } @@ -147,7 +145,7 @@ func (db *Postgres) GetTableColumns(database, table string) (results [][]string, } if database != db.CurrentDatabase { - err = db.SwitchDatabase(database) + err := db.SwitchDatabase(database) if err != nil { return nil, err } @@ -170,13 +168,12 @@ func (db *Postgres) GetTableColumns(database, table string) (results [][]string, } defer rows.Close() - columns, columnsError := rows.Columns() - if columnsError != nil { - err = columnsError + columns, err := rows.Columns() + if err != nil { + return nil, err } - results = append(results, columns) - + results := [][]string{columns} for rows.Next() { rowValues := make([]interface{}, len(columns)) @@ -200,26 +197,24 @@ func (db *Postgres) GetTableColumns(database, table string) (results [][]string, return nil, err } - return + return results, nil } -func (db *Postgres) GetConstraints(database, table string) (constraints [][]string, err error) { +func (db *Postgres) GetConstraints(database, table string) ([][]string, error) { if database == "" { return nil, errors.New("database name is required") } - if table == "" { return nil, errors.New("table name is required") } splitTableString := strings.Split(table, ".") - if len(splitTableString) == 1 { return nil, errors.New("table must be in the format schema.table") } if database != db.CurrentDatabase { - err = db.SwitchDatabase(database) + err := db.SwitchDatabase(database) if err != nil { return nil, err } @@ -235,7 +230,7 @@ func (db *Postgres) GetConstraints(database, table string) (constraints [][]stri tableName := splitTableString[1] rows, err := db.Connection.Query(fmt.Sprintf(` - SELECT + SELECT tc.constraint_name, kcu.column_name, tc.constraint_type @@ -255,13 +250,12 @@ func (db *Postgres) GetConstraints(database, table string) (constraints [][]stri } defer rows.Close() - columns, columnsError := rows.Columns() - if columnsError != nil { - err = columnsError + columns, err := rows.Columns() + if err != nil { + return nil, err } - constraints = append(constraints, columns) - + constraints := [][]string{columns} for rows.Next() { rowValues := make([]interface{}, len(columns)) for i := range columns { @@ -283,26 +277,24 @@ func (db *Postgres) GetConstraints(database, table string) (constraints [][]stri return nil, err } - return + return constraints, nil } -func (db *Postgres) GetForeignKeys(database, table string) (foreignKeys [][]string, err error) { +func (db *Postgres) GetForeignKeys(database, table string) ([][]string, error) { if database == "" { return nil, errors.New("database name is required") } - if table == "" { return nil, errors.New("table name is required") } splitTableString := strings.Split(table, ".") - if len(splitTableString) == 1 { return nil, errors.New("table must be in the format schema.table") } if database != db.CurrentDatabase { - err = db.SwitchDatabase(database) + err := db.SwitchDatabase(database) if err != nil { return nil, err } @@ -318,7 +310,7 @@ func (db *Postgres) GetForeignKeys(database, table string) (foreignKeys [][]stri tableName := splitTableString[1] rows, err := db.Connection.Query(fmt.Sprintf(` - SELECT + SELECT tc.constraint_name, kcu.column_name, ccu.table_name AS foreign_table_name, @@ -339,13 +331,12 @@ func (db *Postgres) GetForeignKeys(database, table string) (foreignKeys [][]stri } defer rows.Close() - columns, columnsError := rows.Columns() - if columnsError != nil { - err = columnsError + columns, err := rows.Columns() + if err != nil { + return nil, err } - foreignKeys = append(foreignKeys, columns) - + foreignKeys := [][]string{columns} for rows.Next() { rowValues := make([]interface{}, len(columns)) for i := range columns { @@ -367,26 +358,24 @@ func (db *Postgres) GetForeignKeys(database, table string) (foreignKeys [][]stri return nil, err } - return + return foreignKeys, nil } -func (db *Postgres) GetIndexes(database, table string) (indexes [][]string, err error) { +func (db *Postgres) GetIndexes(database, table string) ([][]string, error) { if database == "" { return nil, errors.New("database name is required") } - if table == "" { return nil, errors.New("table name is required") } splitTableString := strings.Split(table, ".") - if len(splitTableString) == 1 { return nil, errors.New("table must be in the format schema.table") } if database != db.CurrentDatabase { - err = db.SwitchDatabase(database) + err := db.SwitchDatabase(database) if err != nil { return nil, err } @@ -402,7 +391,7 @@ func (db *Postgres) GetIndexes(database, table string) (indexes [][]string, err tableName := splitTableString[1] rows, err := db.Connection.Query(fmt.Sprintf(` - SELECT + SELECT i.relname AS index_name, a.attname AS column_name, am.amname AS type @@ -432,13 +421,12 @@ func (db *Postgres) GetIndexes(database, table string) (indexes [][]string, err } defer rows.Close() - columns, columnsError := rows.Columns() - if columnsError != nil { - err = columnsError + columns, err := rows.Columns() + if err != nil { + return nil, err } - indexes = append(indexes, columns) - + indexes := [][]string{columns} for rows.Next() { rowValues := make([]interface{}, len(columns)) for i := range columns { @@ -460,26 +448,24 @@ func (db *Postgres) GetIndexes(database, table string) (indexes [][]string, err return nil, err } - return + return indexes, nil } -func (db *Postgres) GetRecords(database, table, where, sort string, offset, limit int) (records [][]string, totalRecords int, err error) { +func (db *Postgres) GetRecords(database, table, where, sort string, offset, limit int) ([][]string, int, error) { if database == "" { return nil, 0, errors.New("database name is required") } - if table == "" { return nil, 0, errors.New("table name is required") } splitTableString := strings.Split(table, ".") - if len(splitTableString) == 1 { return nil, 0, errors.New("table must be in the format schema.table") } if database != db.CurrentDatabase { - err = db.SwitchDatabase(database) + err := db.SwitchDatabase(database) if err != nil { return nil, 0, err } @@ -496,10 +482,6 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi formattedTableName := db.formatTableName(tableSchema, tableName) - if limit == 0 { - limit = DefaultRowLimit - } - query := "SELECT * FROM " query += formattedTableName @@ -513,6 +495,10 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi query += " LIMIT $1 OFFSET $2" + if limit == 0 { + limit = DefaultRowLimit + } + paginatedRows, err := db.Connection.Query(query, limit, offset) if err != nil { return nil, 0, err @@ -524,8 +510,7 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi return nil, 0, columnsError } - records = append(records, columns) - + records := [][]string{columns} for paginatedRows.Next() { nullStringSlice := make([]sql.NullString, len(columns)) @@ -552,7 +537,6 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi } records = append(records, row) - } if err := paginatedRows.Err(); err != nil { @@ -565,49 +549,49 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi countQuery := "SELECT COUNT(*) FROM " countQuery += formattedTableName + + if where != "" { + countQuery += fmt.Sprintf(" %s", where) + } + row := db.Connection.QueryRow(countQuery) + + var totalRecords int if err := row.Scan(&totalRecords); err != nil { return nil, 0, err } - return + return records, totalRecords, nil } -func (db *Postgres) UpdateRecord(database, table, column, value, primaryKeyColumnName, primaryKeyValue string) (err error) { +func (db *Postgres) UpdateRecord(database, table, column, value, primaryKeyColumnName, primaryKeyValue string) error { if database == "" { return errors.New("database name is required") } - if table == "" { return errors.New("table name is required") } - if column == "" { return errors.New("column name is required") } - if value == "" { return errors.New("value is required") } - if primaryKeyColumnName == "" { return errors.New("primary key column name is required") } - if primaryKeyValue == "" { return errors.New("primary key value is required") } splitTableString := strings.Split(table, ".") - if len(splitTableString) == 1 { return errors.New("table must be in the format schema.table") } switchDatabaseOnError := false - if database != db.CurrentDatabase { - err = db.SwitchDatabase(database) + err := db.SwitchDatabase(database) if err != nil { return err } @@ -623,8 +607,7 @@ func (db *Postgres) UpdateRecord(database, table, column, value, primaryKeyColum query += formattedTableName query += fmt.Sprintf(" SET \"%s\" = $1 WHERE \"%s\" = $2", column, primaryKeyColumnName) - _, err = db.Connection.Exec(query, value, primaryKeyValue) - + _, err := db.Connection.Exec(query, value, primaryKeyValue) if err != nil && switchDatabaseOnError { err = db.SwitchDatabase(db.PreviousDatabase) } @@ -632,33 +615,28 @@ func (db *Postgres) UpdateRecord(database, table, column, value, primaryKeyColum return err } -func (db *Postgres) DeleteRecord(database, table, primaryKeyColumnName, primaryKeyValue string) (err error) { +func (db *Postgres) DeleteRecord(database, table, primaryKeyColumnName, primaryKeyValue string) error { if database == "" { return errors.New("database name is required") } - if table == "" { return errors.New("table name is required") } - if primaryKeyColumnName == "" { return errors.New("primary key column name is required") } - if primaryKeyValue == "" { return errors.New("primary key value is required") } splitTableString := strings.Split(table, ".") - if len(splitTableString) == 1 { return errors.New("table must be in the format schema.table") } switchDatabaseOnError := false - if database != db.CurrentDatabase { - err = db.SwitchDatabase(database) + err := db.SwitchDatabase(database) if err != nil { return err } @@ -674,8 +652,7 @@ func (db *Postgres) DeleteRecord(database, table, primaryKeyColumnName, primaryK query += formattedTableName query += fmt.Sprintf(" WHERE \"%s\" = $1", primaryKeyColumnName) - _, err = db.Connection.Exec(query, primaryKeyValue) - + _, err := db.Connection.Exec(query, primaryKeyValue) if err != nil && switchDatabaseOnError { err = db.SwitchDatabase(db.PreviousDatabase) } @@ -692,11 +669,10 @@ func (db *Postgres) ExecuteDMLStatement(query string) (result string, err error) if err != nil { return result, err } - return fmt.Sprintf("%d rows affected", rowsAffected), nil } -func (db *Postgres) ExecuteQuery(query string) (results [][]string, err error) { +func (db *Postgres) ExecuteQuery(query string) ([][]string, error) { rows, err := db.Connection.Query(query) if err != nil { return nil, err @@ -708,8 +684,7 @@ func (db *Postgres) ExecuteQuery(query string) (results [][]string, err error) { return nil, err } - results = append(results, columns) - + results := [][]string{columns} for rows.Next() { rowValues := make([]interface{}, len(columns)) for i := range columns { @@ -732,10 +707,10 @@ func (db *Postgres) ExecuteQuery(query string) (results [][]string, err error) { return nil, err } - return + return results, nil } -func (db *Postgres) ExecutePendingChanges(changes []models.DBDMLChange) (err error) { +func (db *Postgres) ExecutePendingChanges(changes []models.DBDMLChange) error { var queries []models.Query for _, change := range changes { @@ -847,20 +822,19 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DBDMLChange) (err err queries = append(queries, newQuery) } } + return queriesInTransaction(db.Connection, queries) } -func (db *Postgres) GetPrimaryKeyColumnNames(database, table string) (primaryKeyColumnName []string, err error) { +func (db *Postgres) GetPrimaryKeyColumnNames(database, table string) ([]string, error) { if database == "" { return nil, errors.New("database name is required") } - if table == "" { return nil, errors.New("table name is required") } splitTableString := strings.Split(table, ".") - if len(splitTableString) != 2 { return nil, errors.New("table must be in the format schema.table") } @@ -869,7 +843,7 @@ func (db *Postgres) GetPrimaryKeyColumnNames(database, table string) (primaryKey tableName := splitTableString[1] if database != db.CurrentDatabase { - err = db.SwitchDatabase(database) + err := db.SwitchDatabase(database) if err != nil { return nil, err } @@ -900,6 +874,7 @@ func (db *Postgres) GetPrimaryKeyColumnNames(database, table string) (primaryKey defer row.Close() + var primaryKeyColumnName []string for row.Next() { var colName string err = row.Scan(&colName)