From 857943db2dc5af22dcda4af71d03c9903979b52b Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Thu, 3 Oct 2024 16:39:40 -0400 Subject: [PATCH 01/10] feat: add menu modal to select a value --- app/Keymap.go | 1 + commands/commands.go | 4 +++ components/ResultsTable.go | 19 ++++++++++ components/SetValueList.go | 71 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 95 insertions(+) create mode 100644 components/SetValueList.go diff --git a/app/Keymap.go b/app/Keymap.go index f80a57d..935faa6 100644 --- a/app/Keymap.go +++ b/app/Keymap.go @@ -96,6 +96,7 @@ var Keymaps = KeymapSystem{ Bind{Key: Key{Char: 'o'}, Cmd: cmd.AppendNewRow, Description: "Append new row"}, Bind{Key: Key{Char: 'J'}, Cmd: cmd.SortDesc, Description: "Sort descending"}, Bind{Key: Key{Char: 'K'}, Cmd: cmd.SortAsc, Description: "Sort ascending"}, + Bind{Key: Key{Char: 'C'}, Cmd: cmd.SetValue, Description: "Toggle value menu to put values like NULL, EMPTY or DEFAULT"}, // Tabs Bind{Key: Key{Char: '['}, Cmd: cmd.TabPrev, Description: "Switch to previous tab"}, Bind{Key: Key{Char: ']'}, Cmd: cmd.TabNext, Description: "Switch to next tab"}, diff --git a/commands/commands.go b/commands/commands.go index 7929380..26e7834 100644 --- a/commands/commands.go +++ b/commands/commands.go @@ -60,6 +60,7 @@ const ( PreviousFoundNode TreeCollapseAll ExpandAll + SetValue // Connection NewConnection @@ -179,6 +180,9 @@ func (c Command) String() string { return "TreeCollapseAll" case ExpandAll: return "ExpandAll" + case SetValue: + return "SetValue" } + return "Unknown" } diff --git a/components/ResultsTable.go b/components/ResultsTable.go index 6b0b652..c54e30b 100644 --- a/components/ResultsTable.go +++ b/components/ResultsTable.go @@ -367,6 +367,19 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event } } + } else if command == commands.SetValue { + table.SetIsEditing(true) + table.SetInputCapture(nil) + + cell := table.GetCell(selectedRowIndex, selectedColumnIndex) + x, y, width := cell.GetLastPosition() + + list := NewSetValueList() + list.SetRect(x, y, width, 7) + + list.OnFinish(table.FinishSettingValue) + + list.Show(x, y, width) } if len(table.GetRecords()) > 0 { @@ -1224,3 +1237,9 @@ func (table *ResultsTable) search() { table.SetInputCapture(nil) } + +func (table *ResultsTable) FinishSettingValue() { + table.SetIsEditing(false) + table.SetInputCapture(table.tableInputCapture) + App.SetFocus(table) +} diff --git a/components/SetValueList.go b/components/SetValueList.go new file mode 100644 index 0000000..8bedc92 --- /dev/null +++ b/components/SetValueList.go @@ -0,0 +1,71 @@ +package components + +import ( + "github.com/gdamore/tcell/v2" + "github.com/jorgerojas26/lazysql/app" + "github.com/jorgerojas26/lazysql/commands" + "github.com/rivo/tview" +) + +type SetValueList struct { + *tview.List +} + +type value struct { + value string + key rune +} + +var VALUES = []value{ + {value: "NULL", key: 'n'}, + {value: "EMPTY", key: 'e'}, + {value: "DEFAULT", key: 'd'}, +} + +func NewSetValueList() *SetValueList { + list := tview.NewList() + list.SetBorder(true) + + for _, value := range VALUES { + list.AddItem(value.value, "", value.key, nil) + } + + return &SetValueList{List: list} +} + +func (list *SetValueList) OnFinish(callback func()) { + list.SetDoneFunc(func() { + list.Hide() + callback() + }) + + list.SetSelectedFunc(func(int, string, string, rune) { + list.Hide() + callback() + }) + + list.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + command := app.Keymaps.Group(app.TableGroup).Resolve(event) + + if command == commands.SetValue { + list.Hide() + callback() + return nil + } + + return event + }) +} + +func (list *SetValueList) Show(x, y, width int) { + list.SetRect(x, y, width, len(VALUES)*2+1) + MainPages.AddPage("setValue", list, false, true) + App.SetFocus(list) + App.ForceDraw() +} + +func (list *SetValueList) Hide() { + MainPages.RemovePage("setValue") + App.SetFocus(list) + App.ForceDraw() +} From eb4f81f8b7d63a66bbd2676cb48aee9099de5022 Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Fri, 11 Oct 2024 00:55:05 -0400 Subject: [PATCH 02/10] feat: null, empty and default italics style --- components/ResultsTable.go | 49 +++++++++++++++++++++++++++++++------- components/SetValueList.go | 17 +++++++++---- drivers/mysql.go | 18 ++++++++++---- drivers/postgres.go | 19 +++++++++++---- drivers/sqlite.go | 19 +++++++++++---- 5 files changed, 97 insertions(+), 25 deletions(-) diff --git a/components/ResultsTable.go b/components/ResultsTable.go index c54e30b..18152b7 100644 --- a/components/ResultsTable.go +++ b/components/ResultsTable.go @@ -190,11 +190,17 @@ func (table *ResultsTable) AddRows(rows [][]string) { for i, row := range rows { for j, cell := range row { tableCell := tview.NewTableCell(cell) + tableCell.SetTextColor(tview.Styles.PrimaryTextColor) + + if cell == "EMPTY&" || cell == "NULL&" { + tableCell.SetText(strings.Replace(cell, "&", "", 1)) + tableCell.SetStyle(table.GetItalicStyle()) + tableCell.SetReference(cell) + } + tableCell.SetSelectable(i > 0) tableCell.SetExpansion(1) - tableCell.SetTextColor(tview.Styles.PrimaryTextColor) - table.SetCell(i, j, tableCell) } } @@ -228,7 +234,7 @@ func (table *ResultsTable) AddInsertedRows() { tableCell.SetExpansion(1) tableCell.SetReference(inserts[i].PrimaryKeyValue) - tableCell.SetTextColor(tcell.ColorWhite.TrueColor()) + tableCell.SetTextColor(tview.Styles.PrimaryTextColor) tableCell.SetBackgroundColor(InsertColor) table.SetCell(rowIndex, j, tableCell) @@ -242,7 +248,7 @@ func (table *ResultsTable) AppendNewRow(cells []models.CellValue, index int, UUI tableCell.SetExpansion(1) tableCell.SetReference(UUID) tableCell.SetTextColor(tview.Styles.PrimaryTextColor) - tableCell.SetBackgroundColor(tcell.ColorDarkGreen) + tableCell.SetBackgroundColor(InsertColor) switch cell.Type { case models.Null: @@ -346,7 +352,7 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event for i, insertedRow := range *table.state.listOfDbChanges { cellReference := table.GetCell(selectedRowIndex, 0).GetReference() - if cellReference != nil && insertedRow.PrimaryKeyValue == cellReference.(string) { + if cellReference != nil && insertedRow.PrimaryKeyValue == cellReference { isAnInsertedRow = true indexOfInsertedRow = i } @@ -377,7 +383,13 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event list := NewSetValueList() list.SetRect(x, y, width, 7) - list.OnFinish(table.FinishSettingValue) + list.OnFinish(func(selection models.CellValueType, value string) { + table.FinishSettingValue() + + if selection >= 0 { + table.AppendNewChange(models.DmlUpdateType, table.Tree.GetSelectedDatabase(), table.Tree.GetSelectedTable(), selectedRowIndex, selectedColumnIndex, models.CellValue{Type: selection, Value: value, Column: table.GetColumnNameByIndex(selectedColumnIndex)}) + } + }) list.Show(x, y, width) } @@ -424,7 +436,13 @@ func (table *ResultsTable) UpdateRowsColor(headerColor tcell.Color, rowColor tce if i == 0 && headerColor != 0 { cell.SetTextColor(headerColor) } else { - cell.SetTextColor(rowColor) + cellReference := cell.GetReference() + + if cellReference != nil && cellReference == "EMPTY&" || cellReference == "NULL&" || cellReference == "DEFAULT&" { + cell.SetStyle(table.GetItalicStyle()) + } else { + cell.SetTextColor(rowColor) + } } } } @@ -516,6 +534,7 @@ func (table *ResultsTable) subscribeToEditorChanges() { if strings.Contains(queryLower, "select") { table.SetLoading(true) App.Draw() + rows, err := table.DBDriver.ExecuteQuery(query) table.Pagination.SetTotalRecords(len(rows)) table.Pagination.SetLimit(len(rows)) @@ -927,7 +946,7 @@ func (table *ResultsTable) AppendNewChange(changeType models.DmlType, databaseNa tableCell := table.GetCell(rowIndex, colIndex) tableCellReference := tableCell.GetReference() - isAnInsertedRow := tableCellReference != nil + isAnInsertedRow := tableCellReference != nil && tableCellReference.(string) != "NULL&" && tableCellReference.(string) != "EMPTY&" && tableCellReference.(string) != "DEFAULT&" if isAnInsertedRow { table.MutateInsertedRowCell(tableCellReference.(string), value) @@ -936,6 +955,15 @@ func (table *ResultsTable) AppendNewChange(changeType models.DmlType, databaseNa primaryKeyValue, primaryKeyColumnName := table.GetPrimaryKeyValue(rowIndex) + if changeType == models.DmlUpdateType { + switch value.Type { + case models.Null, models.Empty, models.Default: + tableCell.SetText(value.Value.(string)) + tableCell.SetStyle(tcell.StyleDefault.Italic(true)) + tableCell.SetReference(value.Value.(string) + "&") + } + } + for i, dmlChange := range *table.state.listOfDbChanges { if dmlChange.Table == tableName && dmlChange.Type == changeType && dmlChange.PrimaryKeyValue == primaryKeyValue { dmlChangeAlreadyExists = true @@ -984,6 +1012,7 @@ func (table *ResultsTable) AppendNewChange(changeType models.DmlType, databaseNa case models.DmlDeleteType: table.SetRowColor(rowIndex, DeleteColor) case models.DmlUpdateType: + tableCell.SetStyle(tcell.StyleDefault.Background(ChangeColor)) table.SetCellColor(rowIndex, colIndex, ChangeColor) } @@ -1243,3 +1272,7 @@ func (table *ResultsTable) FinishSettingValue() { table.SetInputCapture(table.tableInputCapture) App.SetFocus(table) } + +func (table *ResultsTable) GetItalicStyle() tcell.Style { + return tcell.StyleDefault.Foreground(tview.Styles.InverseTextColor).Italic(true) +} diff --git a/components/SetValueList.go b/components/SetValueList.go index 8bedc92..45de811 100644 --- a/components/SetValueList.go +++ b/components/SetValueList.go @@ -4,6 +4,7 @@ import ( "github.com/gdamore/tcell/v2" "github.com/jorgerojas26/lazysql/app" "github.com/jorgerojas26/lazysql/commands" + "github.com/jorgerojas26/lazysql/models" "github.com/rivo/tview" ) @@ -33,15 +34,21 @@ func NewSetValueList() *SetValueList { return &SetValueList{List: list} } -func (list *SetValueList) OnFinish(callback func()) { +func (list *SetValueList) OnFinish(callback func(selection models.CellValueType, value string)) { list.SetDoneFunc(func() { list.Hide() - callback() + callback(-1, "") }) - list.SetSelectedFunc(func(int, string, string, rune) { + list.SetSelectedFunc(func(_ int, _ string, _ string, shortcut rune) { list.Hide() - callback() + if shortcut == 'n' { + callback(models.Null, "NULL") + } else if shortcut == 'e' { + callback(models.Empty, "EMPTY") + } else if shortcut == 'd' { + callback(models.Default, "DEFAULT") + } }) list.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { @@ -49,7 +56,7 @@ func (list *SetValueList) OnFinish(callback func()) { if command == commands.SetValue { list.Hide() - callback() + callback(-1, "") return nil } diff --git a/drivers/mysql.go b/drivers/mysql.go index 1c47412..d042b3e 100644 --- a/drivers/mysql.go +++ b/drivers/mysql.go @@ -372,9 +372,11 @@ func (db *MySQL) GetRecords(database, table, where, sort string, offset, limit i paginatedResults = append(paginatedResults, columns) for paginatedRows.Next() { + nullStringSlice := make([]sql.NullString, len(columns)) + rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) + for i := range nullStringSlice { + rowValues[i] = &nullStringSlice[i] } err = paginatedRows.Scan(rowValues...) @@ -383,8 +385,16 @@ func (db *MySQL) GetRecords(database, table, where, sort string, offset, limit i } var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) + for _, col := range nullStringSlice { + if col.Valid { + if col.String == "" { + row = append(row, "EMPTY&") + } else { + row = append(row, col.String) + } + } else { + row = append(row, "NULL&") + } } paginatedResults = append(paginatedResults, row) diff --git a/drivers/postgres.go b/drivers/postgres.go index a3fede9..270fd02 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -588,16 +588,27 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi records = append(records, columns) for paginatedRows.Next() { + nullStringSlice := make([]sql.NullString, len(columns)) + + // Create a slice of interface{} to hold pointers to the sql.NullString slice rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) + for i := range nullStringSlice { + rowValues[i] = &nullStringSlice[i] } err = paginatedRows.Scan(rowValues...) var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) + for _, col := range nullStringSlice { + if col.Valid { + if col.String == "" { + row = append(row, "EMPTY&") + } else { + row = append(row, col.String) + } + } else { + row = append(row, "NULL&") + } } records = append(records, row) diff --git a/drivers/sqlite.go b/drivers/sqlite.go index 24fc476..d85205c 100644 --- a/drivers/sqlite.go +++ b/drivers/sqlite.go @@ -362,9 +362,12 @@ func (db *SQLite) GetRecords(_, table, where, sort string, offset, limit int) (p paginatedResults = append(paginatedResults, columns) for paginatedRows.Next() { + nullStringSlice := make([]sql.NullString, len(columns)) + rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) + + for i := range nullStringSlice { + rowValues[i] = &nullStringSlice[i] } err = paginatedRows.Scan(rowValues...) @@ -373,8 +376,16 @@ func (db *SQLite) GetRecords(_, table, where, sort string, offset, limit int) (p } var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) + for _, col := range nullStringSlice { + if col.Valid { + if col.String == "" { + row = append(row, "EMPTY&") + } else { + row = append(row, col.String) + } + } else { + row = append(row, "NULL&") + } } paginatedResults = append(paginatedResults, row) From f13c8243bf83c2c36d65da3f4fd91b88536a58fe Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Fri, 11 Oct 2024 00:57:02 -0400 Subject: [PATCH 03/10] feat: menu modal fixed width --- components/ResultsTable.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/components/ResultsTable.go b/components/ResultsTable.go index 18152b7..457efca 100644 --- a/components/ResultsTable.go +++ b/components/ResultsTable.go @@ -391,7 +391,7 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event } }) - list.Show(x, y, width) + list.Show(x, y, 30) } if len(table.GetRecords()) > 0 { From 2088995ccd163bd63ecbd76aa3b1232d0b367f27 Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Fri, 11 Oct 2024 01:04:39 -0400 Subject: [PATCH 04/10] fix: cell background color on set value --- components/ResultsTable.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/components/ResultsTable.go b/components/ResultsTable.go index 457efca..1355fdb 100644 --- a/components/ResultsTable.go +++ b/components/ResultsTable.go @@ -378,10 +378,9 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event table.SetInputCapture(nil) cell := table.GetCell(selectedRowIndex, selectedColumnIndex) - x, y, width := cell.GetLastPosition() + x, y, _ := cell.GetLastPosition() list := NewSetValueList() - list.SetRect(x, y, width, 7) list.OnFinish(func(selection models.CellValueType, value string) { table.FinishSettingValue() @@ -438,7 +437,7 @@ func (table *ResultsTable) UpdateRowsColor(headerColor tcell.Color, rowColor tce } else { cellReference := cell.GetReference() - if cellReference != nil && cellReference == "EMPTY&" || cellReference == "NULL&" || cellReference == "DEFAULT&" { + if cellReference != nil && (cellReference == "EMPTY&" || cellReference == "NULL&" || cellReference == "DEFAULT&") && (cell.BackgroundColor != DeleteColor && cell.BackgroundColor != ChangeColor && cell.BackgroundColor != InsertColor) { cell.SetStyle(table.GetItalicStyle()) } else { cell.SetTextColor(rowColor) From e46fa6b8a38095015e73ec34a9e4f649053f4bde Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Fri, 11 Oct 2024 22:54:13 -0400 Subject: [PATCH 05/10] feat: support NULL, DEFAULT and EMPTY values --- components/ResultsTable.go | 14 ++++----- components/SetValueList.go | 31 ++++++++++++------ drivers/mysql.go | 15 +++++---- drivers/postgres.go | 26 +++++++++++----- drivers/sqlite.go | 64 +++++++++++++++++++++++++------------- 5 files changed, 98 insertions(+), 52 deletions(-) diff --git a/components/ResultsTable.go b/components/ResultsTable.go index 1355fdb..f10e856 100644 --- a/components/ResultsTable.go +++ b/components/ResultsTable.go @@ -192,7 +192,7 @@ func (table *ResultsTable) AddRows(rows [][]string) { tableCell := tview.NewTableCell(cell) tableCell.SetTextColor(tview.Styles.PrimaryTextColor) - if cell == "EMPTY&" || cell == "NULL&" { + if cell == "EMPTY&" || cell == "NULL&" || cell == "DEFAULT&" { tableCell.SetText(strings.Replace(cell, "&", "", 1)) tableCell.SetStyle(table.GetItalicStyle()) tableCell.SetReference(cell) @@ -248,16 +248,16 @@ func (table *ResultsTable) AppendNewRow(cells []models.CellValue, index int, UUI tableCell.SetExpansion(1) tableCell.SetReference(UUID) tableCell.SetTextColor(tview.Styles.PrimaryTextColor) - tableCell.SetBackgroundColor(InsertColor) switch cell.Type { - case models.Null: - case models.Default: - case models.String: - tableCell.SetText("") + case models.Null, models.Empty, models.Default: + tableCell.SetText(strings.Replace(cell.Value.(string), "&", "", 1)) + tableCell.SetStyle(table.GetItalicStyle()) + // tableCell.SetText("") tableCell.SetTextColor(tview.Styles.InverseTextColor) } + tableCell.SetBackgroundColor(InsertColor) table.SetCell(index, i, tableCell) } @@ -380,7 +380,7 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event cell := table.GetCell(selectedRowIndex, selectedColumnIndex) x, y, _ := cell.GetLastPosition() - list := NewSetValueList() + list := NewSetValueList(table.DBDriver.GetProvider()) list.OnFinish(func(selection models.CellValueType, value string) { table.FinishSettingValue() diff --git a/components/SetValueList.go b/components/SetValueList.go index 45de811..ecf7e7d 100644 --- a/components/SetValueList.go +++ b/components/SetValueList.go @@ -2,10 +2,11 @@ package components import ( "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" + "github.com/jorgerojas26/lazysql/app" "github.com/jorgerojas26/lazysql/commands" "github.com/jorgerojas26/lazysql/models" - "github.com/rivo/tview" ) type SetValueList struct { @@ -17,16 +18,25 @@ type value struct { key rune } -var VALUES = []value{ - {value: "NULL", key: 'n'}, - {value: "EMPTY", key: 'e'}, - {value: "DEFAULT", key: 'd'}, -} +var VALUES = []value{} -func NewSetValueList() *SetValueList { +func NewSetValueList(dbProvider string) *SetValueList { list := tview.NewList() list.SetBorder(true) + if dbProvider == "sqlite3" { + VALUES = []value{ + {value: "NULL", key: 'n'}, + {value: "EMPTY", key: 'e'}, + } + } else { + VALUES = []value{ + {value: "NULL", key: 'n'}, + {value: "EMPTY", key: 'e'}, + {value: "DEFAULT", key: 'd'}, + } + } + for _, value := range VALUES { list.AddItem(value.value, "", value.key, nil) } @@ -42,11 +52,12 @@ func (list *SetValueList) OnFinish(callback func(selection models.CellValueType, list.SetSelectedFunc(func(_ int, _ string, _ string, shortcut rune) { list.Hide() - if shortcut == 'n' { + switch shortcut { + case 'n': callback(models.Null, "NULL") - } else if shortcut == 'e' { + case 'e': callback(models.Empty, "EMPTY") - } else if shortcut == 'd' { + case 'd': callback(models.Default, "DEFAULT") } }) diff --git a/drivers/mysql.go b/drivers/mysql.go index d042b3e..26ada57 100644 --- a/drivers/mysql.go +++ b/drivers/mysql.go @@ -489,9 +489,14 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) valuesPlaceholder := []string{} for _, cell := range change.Values { + columnNames = append(columnNames, cell.Column) + switch cell.Type { - case models.Empty, models.Null, models.String: - columnNames = append(columnNames, cell.Column) + case models.Default: + valuesPlaceholder = append(valuesPlaceholder, "DEFAULT") + case models.Null: + valuesPlaceholder = append(valuesPlaceholder, "NULL") + default: valuesPlaceholder = append(valuesPlaceholder, "?") } } @@ -500,8 +505,6 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) switch cell.Type { case models.Empty: values = append(values, "") - case models.Null: - values = append(values, sql.NullString{}) case models.String: values = append(values, cell.Value) } @@ -525,9 +528,9 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) for i, column := range columnNames { if i == 0 { - queryStr += fmt.Sprintf(" SET `%s` = ?", column) + queryStr += fmt.Sprintf(" SET `%s` = %s", column, valuesPlaceholder[i]) } else { - queryStr += fmt.Sprintf(", `%s` = ?", column) + queryStr += fmt.Sprintf(", `%s` = %s", column, valuesPlaceholder[i]) } } diff --git a/drivers/postgres.go b/drivers/postgres.go index 270fd02..a636b62 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -590,7 +590,6 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi for paginatedRows.Next() { nullStringSlice := make([]sql.NullString, len(columns)) - // Create a slice of interface{} to hold pointers to the sql.NullString slice rowValues := make([]interface{}, len(columns)) for i := range nullStringSlice { rowValues[i] = &nullStringSlice[i] @@ -800,9 +799,14 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err placeholderIndex := 1 for _, cell := range change.Values { + columnNames = append(columnNames, cell.Column) + switch cell.Type { - case models.Empty, models.Null, models.String: - columnNames = append(columnNames, cell.Column) + case models.Default: + valuesPlaceholder = append(valuesPlaceholder, "DEFAULT") + case models.Null: + valuesPlaceholder = append(valuesPlaceholder, "NULL") + default: valuesPlaceholder = append(valuesPlaceholder, fmt.Sprintf("$%d", placeholderIndex)) placeholderIndex++ } @@ -812,8 +816,6 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err switch cell.Type { case models.Empty: values = append(values, "") - case models.Null: - values = append(values, sql.NullString{}) case models.String: values = append(values, cell.Value) } @@ -844,9 +846,9 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err for i, column := range columnNames { if i == 0 { - queryStr += fmt.Sprintf(" SET \"%s\" = $1", column) + queryStr += fmt.Sprintf(" SET \"%s\" = %s", column, valuesPlaceholder[i]) } else { - queryStr += fmt.Sprintf(", \"%s\" = $%d", column, i+1) + queryStr += fmt.Sprintf(", \"%s\" = %s", column, valuesPlaceholder[i]) } } @@ -854,7 +856,15 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err copy(args, values) - queryStr += fmt.Sprintf(" WHERE \"%s\" = $%d", change.PrimaryKeyColumnName, len(columnNames)+1) + wherePlaceholder := 0 + + for _, placeholder := range valuesPlaceholder { + if strings.Contains(placeholder, "$") { + wherePlaceholder++ + } + } + + queryStr += fmt.Sprintf(" WHERE \"%s\" = $%d", change.PrimaryKeyColumnName, wherePlaceholder+1) args = append(args, change.PrimaryKeyValue) newQuery := models.Query{ diff --git a/drivers/sqlite.go b/drivers/sqlite.go index d85205c..f2f47e2 100644 --- a/drivers/sqlite.go +++ b/drivers/sqlite.go @@ -516,28 +516,28 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error values := []interface{}{} valuesPlaceholder := []string{} - for _, cell := range change.Values { - switch cell.Type { - case models.Empty, models.Null, models.String: - columnNames = append(columnNames, cell.Column) - valuesPlaceholder = append(valuesPlaceholder, "?") + switch change.Type { + case models.DmlInsertType: + for _, cell := range change.Values { + switch cell.Type { + case models.Null: + columnNames = append(columnNames, cell.Column) + valuesPlaceholder = append(valuesPlaceholder, "NULL") + case models.Empty, models.String: + columnNames = append(columnNames, cell.Column) + valuesPlaceholder = append(valuesPlaceholder, "?") + } } - } - logger.Info("Column names", map[string]any{"columnNames": columnNames}) - - for _, cell := range change.Values { - switch cell.Type { - case models.Empty: - values = append(values, "") - case models.Null: - values = append(values, sql.NullString{}) - case models.String: - values = append(values, cell.Value) + + for _, cell := range change.Values { + switch cell.Type { + case models.Empty: + values = append(values, "") + case models.String: + values = append(values, cell.Value) + } } - } - switch change.Type { - case models.DmlInsertType: queryStr := "INSERT INTO " queryStr += db.formatTableName(change.Table) queryStr += fmt.Sprintf(" (%s) VALUES (%s)", strings.Join(columnNames, ", "), strings.Join(valuesPlaceholder, ", ")) @@ -549,14 +549,36 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error query = append(query, newQuery) case models.DmlUpdateType: + + for _, cell := range change.Values { + switch cell.Type { + case models.Null: + columnNames = append(columnNames, cell.Column) + valuesPlaceholder = append(valuesPlaceholder, "NULL") + case models.Empty, models.String: + columnNames = append(columnNames, cell.Column) + valuesPlaceholder = append(valuesPlaceholder, "?") + /// Leaves "DEFAULT" type out because it's not supported by sqlite + } + } + + for _, cell := range change.Values { + switch cell.Type { + case models.Empty: + values = append(values, "") + case models.String: + values = append(values, cell.Value) + } + } + queryStr := "UPDATE " queryStr += db.formatTableName(change.Table) for i, column := range columnNames { if i == 0 { - queryStr += fmt.Sprintf(" SET `%s` = ?", column) + queryStr += fmt.Sprintf(" SET `%s` = %s", column, valuesPlaceholder[i]) } else { - queryStr += fmt.Sprintf(", `%s` = ?", column) + queryStr += fmt.Sprintf(", `%s` = %s", column, valuesPlaceholder[i]) } } From 517685d152a5c287fd7dbc2095adf4bdda29dfd5 Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Sat, 12 Oct 2024 00:27:39 -0400 Subject: [PATCH 06/10] Merge branch 'main' into set-value-modal --- app/App.go | 13 +- app/Keymap.go | 15 ++ commands/commands.go | 15 ++ components/ConfirmationModal.go | 6 +- components/ConnectionForm.go | 29 +-- components/ConnectionPage.go | 4 +- components/ConnectionSelection.go | 30 +-- components/ConnectionsTable.go | 3 +- components/HelpModal.go | 10 +- components/HelpStatus.go | 2 +- components/Home.go | 48 ++--- components/Pages.go | 6 +- components/ResultTableFilter.go | 37 ++-- components/ResultsTable.go | 276 ++++++++++++++++++++------ components/ResultsTableMenu.go | 34 ++-- components/SQLEditor.go | 12 +- components/Sidebar.go | 319 ++++++++++++++++++++++++++++++ components/TabbedMenu.go | 10 +- components/Tree.go | 82 ++++---- components/constants.go | 75 ++++++- drivers/constants.go | 7 + drivers/mysql.go | 139 ++++--------- drivers/postgres.go | 307 ++++++++++++---------------- drivers/sqlite.go | 130 ++++-------- drivers/utils.go | 34 ++++ drivers/utils_test.go | 121 ++++++++++++ go.mod | 20 +- go.sum | 3 + models/models.go | 17 +- 29 files changed, 1200 insertions(+), 604 deletions(-) create mode 100644 components/Sidebar.go create mode 100644 drivers/utils.go create mode 100644 drivers/utils_test.go diff --git a/app/App.go b/app/App.go index 213f09a..a0d2c4c 100644 --- a/app/App.go +++ b/app/App.go @@ -7,6 +7,15 @@ import ( var App = tview.NewApplication() +type Theme struct { + SidebarTitleBorderColor string + tview.Theme +} + +var Styles = Theme{ + SidebarTitleBorderColor: "#666A7E", +} + func init() { theme := tview.Theme{ PrimitiveBackgroundColor: tcell.ColorDefault, @@ -14,12 +23,14 @@ func init() { MoreContrastBackgroundColor: tcell.ColorGreen, BorderColor: tcell.ColorWhite, TitleColor: tcell.ColorWhite, - GraphicsColor: tcell.ColorWhite, + GraphicsColor: tcell.ColorGray, PrimaryTextColor: tcell.ColorDefault.TrueColor(), SecondaryTextColor: tcell.ColorYellow, TertiaryTextColor: tcell.ColorGreen, InverseTextColor: tcell.ColorWhite, ContrastSecondaryTextColor: tcell.ColorBlack, } + + Styles.Theme = theme tview.Styles = theme } diff --git a/app/Keymap.go b/app/Keymap.go index 935faa6..aa846df 100644 --- a/app/Keymap.go +++ b/app/Keymap.go @@ -44,6 +44,7 @@ const ( TableGroup = "table" EditorGroup = "editor" ConnectionGroup = "connection" + SidebarGroup = "sidebar" ) // Define a global KeymapSystem object with default keybinds @@ -111,11 +112,25 @@ var Keymaps = KeymapSystem{ Bind{Key: Key{Char: '3'}, Cmd: cmd.ConstraintsMenu, Description: "Switch to constraints menu"}, Bind{Key: Key{Char: '4'}, Cmd: cmd.ForeignKeysMenu, Description: "Switch to foreign keys menu"}, Bind{Key: Key{Char: '5'}, Cmd: cmd.IndexesMenu, Description: "Switch to indexes menu"}, + // Sidebar + Bind{Key: Key{Char: 'S'}, Cmd: cmd.ToggleSidebar, Description: "Toggle sidebar"}, + Bind{Key: Key{Char: 's'}, Cmd: cmd.FocusSidebar, Description: "Focus sidebar"}, }, EditorGroup: { Bind{Key: Key{Code: tcell.KeyCtrlR}, Cmd: cmd.Execute, Description: "Execute query"}, Bind{Key: Key{Code: tcell.KeyEscape}, Cmd: cmd.UnfocusEditor, Description: "Unfocus editor"}, Bind{Key: Key{Code: tcell.KeyCtrlSpace}, Cmd: cmd.OpenInExternalEditor, Description: "Open in external editor"}, }, + SidebarGroup: { + Bind{Key: Key{Char: 's'}, Cmd: cmd.UnfocusSidebar, Description: "Focus table"}, + Bind{Key: Key{Char: 'S'}, Cmd: cmd.ToggleSidebar, Description: "Toggle sidebar"}, + Bind{Key: Key{Char: 'j'}, Cmd: cmd.MoveDown, Description: "Focus next field"}, + Bind{Key: Key{Char: 'k'}, Cmd: cmd.MoveUp, Description: "Focus previous field"}, + Bind{Key: Key{Char: 'g'}, Cmd: cmd.GotoStart, Description: "Focus first field"}, + Bind{Key: Key{Char: 'G'}, Cmd: cmd.GotoEnd, Description: "Focus last field"}, + Bind{Key: Key{Char: 'c'}, Cmd: cmd.Edit, Description: "Edit field"}, + Bind{Key: Key{Code: tcell.KeyEnter}, Cmd: cmd.CommitEdit, Description: "Add edit to pending changes"}, + Bind{Key: Key{Code: tcell.KeyEscape}, Cmd: cmd.DiscardEdit, Description: "Discard edit"}, + }, }, } diff --git a/commands/commands.go b/commands/commands.go index 26e7834..be4ec88 100644 --- a/commands/commands.go +++ b/commands/commands.go @@ -45,6 +45,8 @@ const ( UnfocusEditor Copy Edit + CommitEdit + DiscardEdit Save Delete Search @@ -61,6 +63,9 @@ const ( TreeCollapseAll ExpandAll SetValue + FocusSidebar + UnfocusSidebar + ToggleSidebar // Connection NewConnection @@ -182,6 +187,16 @@ func (c Command) String() string { return "ExpandAll" case SetValue: return "SetValue" + case FocusSidebar: + return "FocusSidebar" + case ToggleSidebar: + return "ToggleSidebar" + case UnfocusSidebar: + return "UnfocusSidebar" + case CommitEdit: + return "CommitEdit" + case DiscardEdit: + return "DiscardEdit" } return "Unknown" diff --git a/components/ConfirmationModal.go b/components/ConfirmationModal.go index c9723b9..d0f6e79 100644 --- a/components/ConfirmationModal.go +++ b/components/ConfirmationModal.go @@ -2,6 +2,8 @@ package components import ( "github.com/rivo/tview" + + "github.com/jorgerojas26/lazysql/app" ) type ConfirmationModal struct { @@ -16,8 +18,8 @@ func NewConfirmationModal(confirmationText string) *ConfirmationModal { modal.SetText("Are you sure?") } modal.AddButtons([]string{"Yes", "No"}) - modal.SetBackgroundColor(tview.Styles.PrimitiveBackgroundColor) - modal.SetTextColor(tview.Styles.PrimaryTextColor) + modal.SetBackgroundColor(app.Styles.PrimitiveBackgroundColor) + modal.SetTextColor(app.Styles.PrimaryTextColor) return &ConfirmationModal{ Modal: modal, diff --git a/components/ConnectionForm.go b/components/ConnectionForm.go index 36d6c26..30145cb 100644 --- a/components/ConnectionForm.go +++ b/components/ConnectionForm.go @@ -6,6 +6,7 @@ import ( "github.com/gdamore/tcell/v2" "github.com/rivo/tview" + "github.com/jorgerojas26/lazysql/app" "github.com/jorgerojas26/lazysql/drivers" "github.com/jorgerojas26/lazysql/helpers" "github.com/jorgerojas26/lazysql/models" @@ -23,35 +24,35 @@ func NewConnectionForm(connectionPages *models.ConnectionPages) *ConnectionForm wrapper.SetDirection(tview.FlexColumnCSS) - addForm := tview.NewForm().SetFieldBackgroundColor(tview.Styles.InverseTextColor).SetButtonBackgroundColor(tview.Styles.InverseTextColor).SetLabelColor(tview.Styles.PrimaryTextColor).SetFieldTextColor(tview.Styles.ContrastSecondaryTextColor) + addForm := tview.NewForm().SetFieldBackgroundColor(app.Styles.InverseTextColor).SetButtonBackgroundColor(tview.Styles.InverseTextColor).SetLabelColor(tview.Styles.PrimaryTextColor).SetFieldTextColor(tview.Styles.ContrastSecondaryTextColor) addForm.AddInputField("Name", "", 0, nil, nil) addForm.AddInputField("URL", "", 0, nil, nil) buttonsWrapper := tview.NewFlex().SetDirection(tview.FlexColumn) saveButton := tview.NewButton("[yellow]F1 [dark]Save") - saveButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimaryTextColor)) + saveButton.SetStyle(tcell.StyleDefault.Background(app.Styles.PrimaryTextColor)) saveButton.SetBorder(true) buttonsWrapper.AddItem(saveButton, 0, 1, false) buttonsWrapper.AddItem(nil, 1, 0, false) testButton := tview.NewButton("[yellow]F2 [dark]Test") - testButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimaryTextColor)) + testButton.SetStyle(tcell.StyleDefault.Background(app.Styles.PrimaryTextColor)) testButton.SetBorder(true) buttonsWrapper.AddItem(testButton, 0, 1, false) buttonsWrapper.AddItem(nil, 1, 0, false) connectButton := tview.NewButton("[yellow]F3 [dark]Connect") - connectButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimaryTextColor)) + connectButton.SetStyle(tcell.StyleDefault.Background(app.Styles.PrimaryTextColor)) connectButton.SetBorder(true) buttonsWrapper.AddItem(connectButton, 0, 1, false) buttonsWrapper.AddItem(nil, 1, 0, false) cancelButton := tview.NewButton("[yellow]Esc [dark]Cancel") - cancelButton.SetStyle(tcell.StyleDefault.Background(tcell.Color(tview.Styles.PrimaryTextColor))) + cancelButton.SetStyle(tcell.StyleDefault.Background(tcell.Color(app.Styles.PrimaryTextColor))) cancelButton.SetBorder(true) buttonsWrapper.AddItem(cancelButton, 0, 1, false) @@ -77,7 +78,7 @@ func NewConnectionForm(connectionPages *models.ConnectionPages) *ConnectionForm func (form *ConnectionForm) inputCapture(connectionPages *models.ConnectionPages) func(event *tcell.EventKey) *tcell.EventKey { return func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEsc { - connectionPages.SwitchToPage("Connections") + connectionPages.SwitchToPage(pageNameConnections) } else if event.Key() == tcell.KeyF1 || event.Key() == tcell.KeyEnter { connectionName := form.GetFormItem(0).(*tview.InputField).GetText() @@ -111,7 +112,7 @@ func (form *ConnectionForm) inputCapture(connectionPages *models.ConnectionPages } switch form.Action { - case "create": + case actionNewConnection: newDatabases = append(databases, parsedDatabaseData) err := helpers.SaveConnectionConfig(newDatabases) @@ -120,7 +121,7 @@ func (form *ConnectionForm) inputCapture(connectionPages *models.ConnectionPages return event } - case "edit": + case actionEditConnection: newDatabases = make([]models.Connection, len(databases)) row, _ := ConnectionListTable.GetSelection() @@ -151,7 +152,7 @@ func (form *ConnectionForm) inputCapture(connectionPages *models.ConnectionPages } ConnectionListTable.SetConnections(newDatabases) - connectionPages.SwitchToPage("Connections") + connectionPages.SwitchToPage(pageNameConnections) } else if event.Key() == tcell.KeyF2 { connectionString := form.GetFormItem(1).(*tview.InputField).GetText() @@ -168,16 +169,16 @@ func (form *ConnectionForm) testConnection(connectionString string) { return } - form.StatusText.SetText("Connecting...").SetTextColor(tview.Styles.TertiaryTextColor) + form.StatusText.SetText("Connecting...").SetTextColor(app.Styles.TertiaryTextColor) var db drivers.Driver switch parsed.Driver { - case "mysql": + case drivers.DriverMySQL: db = &drivers.MySQL{} - case "postgres": + case drivers.DriverPostgres: db = &drivers.Postgres{} - case "sqlite3": + case drivers.DriverSqlite: db = &drivers.SQLite{} } @@ -186,7 +187,7 @@ func (form *ConnectionForm) testConnection(connectionString string) { if err != nil { form.StatusText.SetText(err.Error()).SetTextStyle(tcell.StyleDefault.Foreground(tcell.ColorRed)) } else { - form.StatusText.SetText("Connection success").SetTextColor(tview.Styles.TertiaryTextColor) + form.StatusText.SetText("Connection success").SetTextColor(app.Styles.TertiaryTextColor) } App.ForceDraw() } diff --git a/components/ConnectionPage.go b/components/ConnectionPage.go index 1e13431..8fc7ad9 100644 --- a/components/ConnectionPage.go +++ b/components/ConnectionPage.go @@ -32,8 +32,8 @@ func NewConnectionPages() *models.ConnectionPages { connectionForm := NewConnectionForm(cp) connectionSelection := NewConnectionSelection(connectionForm, cp) - cp.AddPage("Connections", connectionSelection.Flex, true, true) - cp.AddPage("ConnectionForm", connectionForm.Flex, true, false) + cp.AddPage(pageNameConnectionSelection, connectionSelection.Flex, true, true) + cp.AddPage(pageNameConnectionForm, connectionForm.Flex, true, false) return cp } diff --git a/components/ConnectionSelection.go b/components/ConnectionSelection.go index 53f34d0..c5f50e4 100644 --- a/components/ConnectionSelection.go +++ b/components/ConnectionSelection.go @@ -29,35 +29,35 @@ func NewConnectionSelection(connectionForm *ConnectionForm, connectionPages *mod buttonsWrapper := tview.NewFlex().SetDirection(tview.FlexRowCSS) newButton := tview.NewButton("[yellow]N[dark]ew") - newButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimitiveBackgroundColor)) + newButton.SetStyle(tcell.StyleDefault.Background(app.Styles.PrimitiveBackgroundColor)) newButton.SetBorder(true) buttonsWrapper.AddItem(newButton, 0, 1, false) buttonsWrapper.AddItem(nil, 1, 0, false) connectButton := tview.NewButton("[yellow]C[dark]onnect") - connectButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimitiveBackgroundColor)) + connectButton.SetStyle(tcell.StyleDefault.Background(app.Styles.PrimitiveBackgroundColor)) connectButton.SetBorder(true) buttonsWrapper.AddItem(connectButton, 0, 1, false) buttonsWrapper.AddItem(nil, 1, 0, false) editButton := tview.NewButton("[yellow]E[dark]dit") - editButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimitiveBackgroundColor)) + editButton.SetStyle(tcell.StyleDefault.Background(app.Styles.PrimitiveBackgroundColor)) editButton.SetBorder(true) buttonsWrapper.AddItem(editButton, 0, 1, false) buttonsWrapper.AddItem(nil, 1, 0, false) deleteButton := tview.NewButton("[yellow]D[dark]elete") - deleteButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimitiveBackgroundColor)) + deleteButton.SetStyle(tcell.StyleDefault.Background(app.Styles.PrimitiveBackgroundColor)) deleteButton.SetBorder(true) buttonsWrapper.AddItem(deleteButton, 0, 1, false) buttonsWrapper.AddItem(nil, 1, 0, false) quitButton := tview.NewButton("[yellow]Q[dark]uit") - quitButton.SetStyle(tcell.StyleDefault.Background(tview.Styles.PrimitiveBackgroundColor)) + quitButton.SetStyle(tcell.StyleDefault.Background(app.Styles.PrimitiveBackgroundColor)) quitButton.SetBorder(true) buttonsWrapper.AddItem(quitButton, 0, 1, false) @@ -88,18 +88,18 @@ func NewConnectionSelection(connectionForm *ConnectionForm, connectionPages *mod case commands.Connect: go cs.Connect(selectedConnection) case commands.EditConnection: - connectionPages.SwitchToPage("ConnectionForm") + connectionPages.SwitchToPage(pageNameConnectionForm) connectionForm.GetFormItemByLabel("Name").(*tview.InputField).SetText(selectedConnection.Name) connectionForm.GetFormItemByLabel("URL").(*tview.InputField).SetText(selectedConnection.URL) connectionForm.StatusText.SetText("") - connectionForm.SetAction("edit") + connectionForm.SetAction(actionEditConnection) return nil case commands.DeleteConnection: confirmationModal := NewConfirmationModal("") confirmationModal.SetDoneFunc(func(_ int, buttonLabel string) { - MainPages.RemovePage("Confirmation") + MainPages.RemovePage(pageNameConfirmation) confirmationModal = nil if buttonLabel == "Yes" { @@ -115,7 +115,7 @@ func NewConnectionSelection(connectionForm *ConnectionForm, connectionPages *mod } }) - MainPages.AddPage("Confirmation", confirmationModal, true, true) + MainPages.AddPage(pageNameConfirmation, confirmationModal, true, true) return nil } @@ -123,11 +123,11 @@ func NewConnectionSelection(connectionForm *ConnectionForm, connectionPages *mod switch command { case commands.NewConnection: - connectionForm.SetAction("create") + connectionForm.SetAction(actionNewConnection) connectionForm.GetFormItemByLabel("Name").(*tview.InputField).SetText("") connectionForm.GetFormItemByLabel("URL").(*tview.InputField).SetText("") connectionForm.StatusText.SetText("") - connectionPages.SwitchToPage("ConnectionForm") + connectionPages.SwitchToPage(pageNameConnectionForm) case commands.Quit: if wrapper.HasFocus() { app.App.Stop() @@ -145,17 +145,17 @@ func (cs *ConnectionSelection) Connect(connection models.Connection) { MainPages.SwitchToPage(connection.URL) App.Draw() } else { - cs.StatusText.SetText("Connecting...").SetTextColor(tview.Styles.TertiaryTextColor) + cs.StatusText.SetText("Connecting...").SetTextColor(app.Styles.TertiaryTextColor) App.Draw() var newDbDriver drivers.Driver switch connection.Provider { - case "mysql": + case drivers.DriverMySQL: newDbDriver = &drivers.MySQL{} - case "postgres": + case drivers.DriverPostgres: newDbDriver = &drivers.Postgres{} - case "sqlite3": + case drivers.DriverSqlite: newDbDriver = &drivers.SQLite{} } diff --git a/components/ConnectionsTable.go b/components/ConnectionsTable.go index e4ca4ab..411609d 100644 --- a/components/ConnectionsTable.go +++ b/components/ConnectionsTable.go @@ -4,6 +4,7 @@ import ( "github.com/gdamore/tcell/v2" "github.com/rivo/tview" + "github.com/jorgerojas26/lazysql/app" "github.com/jorgerojas26/lazysql/helpers" "github.com/jorgerojas26/lazysql/models" ) @@ -29,7 +30,7 @@ func NewConnectionsTable() *ConnectionsTable { } table.SetOffset(5, 0) - table.SetSelectedStyle(tcell.StyleDefault.Foreground(tview.Styles.SecondaryTextColor).Background(tview.Styles.PrimitiveBackgroundColor)) + table.SetSelectedStyle(tcell.StyleDefault.Foreground(app.Styles.SecondaryTextColor).Background(tview.Styles.PrimitiveBackgroundColor)) wrapper.AddItem(table, 0, 1, true) diff --git a/components/HelpModal.go b/components/HelpModal.go index f6e905f..27c965b 100644 --- a/components/HelpModal.go +++ b/components/HelpModal.go @@ -31,10 +31,10 @@ func NewHelpModal() *HelpModal { // table.SetBorders(true) table.SetBorder(true) - table.SetBorderColor(tview.Styles.PrimaryTextColor) + table.SetBorderColor(app.Styles.PrimaryTextColor) table.SetTitle(" Keybindings ") table.SetSelectable(true, false) - table.SetSelectedStyle(tcell.StyleDefault.Background(tview.Styles.SecondaryTextColor).Foreground(tview.Styles.ContrastSecondaryTextColor)) + table.SetSelectedStyle(tcell.StyleDefault.Background(app.Styles.SecondaryTextColor).Foreground(tview.Styles.ContrastSecondaryTextColor)) keymapGroups := app.Keymaps.Groups @@ -51,7 +51,7 @@ func NewHelpModal() *HelpModal { for groupName, keys := range keymapGroups { rowCount := table.GetRowCount() groupNameCell := tview.NewTableCell(strings.ToUpper(groupName)) - groupNameCell.SetTextColor(tview.Styles.TertiaryTextColor) + groupNameCell.SetTextColor(app.Styles.TertiaryTextColor) groupNameCell.SetSelectable(rowCount == 0) table.SetCell(rowCount, 0, tview.NewTableCell("").SetSelectable(false)) @@ -64,7 +64,7 @@ func NewHelpModal() *HelpModal { if len(keyText) < len(mostLengthyKey) { keyText = strings.Repeat(" ", len(mostLengthyKey)-len(keyText)) + keyText } - table.SetCell(rowCount+3+i, 0, tview.NewTableCell(keyText).SetAlign(tview.AlignRight).SetTextColor(tview.Styles.SecondaryTextColor)) + table.SetCell(rowCount+3+i, 0, tview.NewTableCell(keyText).SetAlign(tview.AlignRight).SetTextColor(app.Styles.SecondaryTextColor)) table.SetCell(rowCount+3+i, 1, tview.NewTableCell(key.Description).SetAlign(tview.AlignLeft).SetExpansion(1)) } @@ -75,7 +75,7 @@ func NewHelpModal() *HelpModal { table.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { command := app.Keymaps.Group(app.HomeGroup).Resolve(event) if command == commands.Quit || command == commands.HelpPopup { - MainPages.RemovePage(HelpPageName) + MainPages.RemovePage(pageNameHelp) } return event }) diff --git a/components/HelpStatus.go b/components/HelpStatus.go index dfa03cb..32dec87 100644 --- a/components/HelpStatus.go +++ b/components/HelpStatus.go @@ -12,7 +12,7 @@ type HelpStatus struct { } func NewHelpStatus() HelpStatus { - status := HelpStatus{tview.NewTextView().SetTextColor(tview.Styles.TertiaryTextColor)} + status := HelpStatus{tview.NewTextView().SetTextColor(app.Styles.TertiaryTextColor)} status.SetStatusOnTree() diff --git a/components/Home.go b/components/Home.go index cc8f07c..664ca9c 100644 --- a/components/Home.go +++ b/components/Home.go @@ -47,10 +47,10 @@ func NewHomePage(connection models.Connection, dbdriver drivers.Driver) *Home { go home.subscribeToTreeChanges() - leftWrapper.SetBorderColor(tview.Styles.InverseTextColor) + leftWrapper.SetBorderColor(app.Styles.InverseTextColor) leftWrapper.AddItem(tree.Wrapper, 0, 1, true) - rightWrapper.SetBorderColor(tview.Styles.InverseTextColor) + rightWrapper.SetBorderColor(app.Styles.InverseTextColor) rightWrapper.SetBorder(true) rightWrapper.SetDirection(tview.FlexColumnCSS) rightWrapper.SetInputCapture(home.rightWrapperInputCapture) @@ -66,7 +66,7 @@ func NewHomePage(connection models.Connection, dbdriver drivers.Driver) *Home { home.SetInputCapture(home.homeInputCapture) home.SetFocusFunc(func() { - if home.FocusedWrapper == "left" || home.FocusedWrapper == "" { + if home.FocusedWrapper == focusedWrapperLeft || home.FocusedWrapper == "" { home.focusLeftWrapper() } else { home.focusRightWrapper() @@ -82,7 +82,7 @@ func (home *Home) subscribeToTreeChanges() { for stateChange := range ch { switch stateChange.Key { - case "SelectedTable": + case eventTreeSelectedTable: databaseName := home.Tree.GetSelectedDatabase() tableName := stateChange.Value.(string) @@ -104,16 +104,20 @@ func (home *Home) subscribeToTreeChanges() { } - table.FetchRecords(func() { + results := table.FetchRecords(func() { home.focusLeftWrapper() }) + if len(results) > 1 && !table.GetShowSidebar() { // 1 because the row 0 is the column names + table.ShowSidebar(true) + } + if table.state.error == "" { home.focusRightWrapper() } app.App.ForceDraw() - case "IsFiltering": + case eventTreeIsFiltering: isFiltering := stateChange.Value.(bool) if isFiltering { home.SetInputCapture(nil) @@ -127,8 +131,8 @@ func (home *Home) subscribeToTreeChanges() { func (home *Home) focusRightWrapper() { home.Tree.RemoveHighlight() - home.RightWrapper.SetBorderColor(tview.Styles.PrimaryTextColor) - home.LeftWrapper.SetBorderColor(tview.Styles.InverseTextColor) + home.RightWrapper.SetBorderColor(app.Styles.PrimaryTextColor) + home.LeftWrapper.SetBorderColor(app.Styles.InverseTextColor) home.TabbedPane.Highlight() tab := home.TabbedPane.GetCurrentTab() @@ -136,7 +140,7 @@ func (home *Home) focusRightWrapper() { home.focusTab(tab) } - home.FocusedWrapper = "right" + home.FocusedWrapper = focusedWrapperRight } func (home *Home) focusTab(tab *Tab) { @@ -162,7 +166,7 @@ func (home *Home) focusTab(tab *Tab) { App.SetFocus(table) } - if tab.Name == EditorTabName { + if tab.Name == tabNameEditor { home.HelpStatus.SetStatusOnEditorView() } else { home.HelpStatus.SetStatusOnTableView() @@ -173,8 +177,8 @@ func (home *Home) focusTab(tab *Tab) { func (home *Home) focusLeftWrapper() { home.Tree.Highlight() - home.RightWrapper.SetBorderColor(tview.Styles.InverseTextColor) - home.LeftWrapper.SetBorderColor(tview.Styles.PrimaryTextColor) + home.RightWrapper.SetBorderColor(app.Styles.InverseTextColor) + home.LeftWrapper.SetBorderColor(app.Styles.PrimaryTextColor) tab := home.TabbedPane.GetCurrentTab() @@ -189,7 +193,7 @@ func (home *Home) focusLeftWrapper() { App.SetFocus(home.Tree) - home.FocusedWrapper = "left" + home.FocusedWrapper = focusedWrapperLeft } func (home *Home) rightWrapperInputCapture(event *tcell.EventKey) *tcell.EventKey { @@ -268,22 +272,22 @@ func (home *Home) homeInputCapture(event *tcell.EventKey) *tcell.EventKey { switch command { case commands.MoveLeft: - if table != nil && !table.GetIsEditing() && !table.GetIsFiltering() && home.FocusedWrapper == "right" { + if table != nil && !table.GetIsEditing() && !table.GetIsFiltering() && home.FocusedWrapper == focusedWrapperRight { home.focusLeftWrapper() } case commands.MoveRight: - if table != nil && !table.GetIsEditing() && !table.GetIsFiltering() && home.FocusedWrapper == "left" { + if table != nil && !table.GetIsEditing() && !table.GetIsFiltering() && home.FocusedWrapper == focusedWrapperLeft { home.focusRightWrapper() } case commands.SwitchToEditorView: - tab := home.TabbedPane.GetTabByName(EditorTabName) + tab := home.TabbedPane.GetTabByName(tabNameEditor) if tab != nil { - home.TabbedPane.SwitchToTabByName(EditorTabName) + home.TabbedPane.SwitchToTabByName(tabNameEditor) tab.Content.SetIsFiltering(true) } else { tableWithEditor := NewResultsTable(&home.ListOfDbChanges, home.Tree, home.DBDriver).WithEditor() - home.TabbedPane.AppendTab(EditorTabName, tableWithEditor, EditorTabName) + home.TabbedPane.AppendTab(tabNameEditor, tableWithEditor, tabNameEditor) tableWithEditor.SetIsFiltering(true) } home.HelpStatus.SetStatusOnEditorView() @@ -291,7 +295,7 @@ func (home *Home) homeInputCapture(event *tcell.EventKey) *tcell.EventKey { App.ForceDraw() case commands.SwitchToConnectionsView: if (table != nil && !table.GetIsEditing() && !table.GetIsFiltering() && !table.GetIsLoading()) || table == nil { - MainPages.SwitchToPage("Connections") + MainPages.SwitchToPage(pageNameConnections) } case commands.Quit: if tab != nil { @@ -308,7 +312,7 @@ func (home *Home) homeInputCapture(event *tcell.EventKey) *tcell.EventKey { confirmationModal := NewConfirmationModal("") confirmationModal.SetDoneFunc(func(_ int, buttonLabel string) { - MainPages.RemovePage("Confirmation") + MainPages.RemovePage(pageNameConfirmation) confirmationModal = nil if buttonLabel == "Yes" { @@ -328,7 +332,7 @@ func (home *Home) homeInputCapture(event *tcell.EventKey) *tcell.EventKey { } }) - MainPages.AddPage("Confirmation", confirmationModal, true, true) + MainPages.AddPage(pageNameConfirmation, confirmationModal, true, true) } case commands.HelpPopup: if table == nil || !table.GetIsEditing() { @@ -341,7 +345,7 @@ func (home *Home) homeInputCapture(event *tcell.EventKey) *tcell.EventKey { // } // return event // }) - MainPages.AddPage(HelpPageName, home.HelpModal, true, true) + MainPages.AddPage(pageNameHelp, home.HelpModal, true, true) } } diff --git a/components/Pages.go b/components/Pages.go index 257cf76..7bea031 100644 --- a/components/Pages.go +++ b/components/Pages.go @@ -2,11 +2,13 @@ package components import ( "github.com/rivo/tview" + + "github.com/jorgerojas26/lazysql/app" ) var MainPages = tview.NewPages() func init() { - MainPages.SetBackgroundColor(tview.Styles.PrimitiveBackgroundColor) - MainPages.AddPage("Connections", NewConnectionPages().Flex, true, true) + MainPages.SetBackgroundColor(app.Styles.PrimitiveBackgroundColor) + MainPages.AddPage(pageNameConnections, NewConnectionPages().Flex, true, true) } diff --git a/components/ResultTableFilter.go b/components/ResultTableFilter.go index bf526d5..002582d 100644 --- a/components/ResultTableFilter.go +++ b/components/ResultTableFilter.go @@ -4,6 +4,7 @@ import ( "github.com/gdamore/tcell/v2" "github.com/rivo/tview" + "github.com/jorgerojas26/lazysql/app" "github.com/jorgerojas26/lazysql/models" ) @@ -27,14 +28,14 @@ func NewResultsFilter() *ResultsTableFilter { recordsFilter.SetTitleAlign(tview.AlignCenter) recordsFilter.SetBorderPadding(0, 0, 1, 1) - recordsFilter.Label.SetTextColor(tview.Styles.TertiaryTextColor) + recordsFilter.Label.SetTextColor(app.Styles.TertiaryTextColor) recordsFilter.Label.SetText("WHERE") recordsFilter.Label.SetBorderPadding(0, 0, 0, 1) recordsFilter.Input.SetPlaceholder("Enter a WHERE clause to filter the results") - recordsFilter.Input.SetPlaceholderStyle(tcell.StyleDefault.Foreground(tview.Styles.PrimaryTextColor).Background(tview.Styles.PrimitiveBackgroundColor)) - recordsFilter.Input.SetFieldBackgroundColor(tview.Styles.PrimitiveBackgroundColor) - recordsFilter.Input.SetFieldTextColor(tview.Styles.PrimaryTextColor) + recordsFilter.Input.SetPlaceholderStyle(tcell.StyleDefault.Foreground(app.Styles.PrimaryTextColor).Background(tview.Styles.PrimitiveBackgroundColor)) + recordsFilter.Input.SetFieldBackgroundColor(app.Styles.PrimitiveBackgroundColor) + recordsFilter.Input.SetFieldTextColor(app.Styles.PrimaryTextColor) recordsFilter.Input.SetDoneFunc(func(key tcell.Key) { switch key { case tcell.KeyEnter: @@ -50,7 +51,7 @@ func NewResultsFilter() *ResultsTableFilter { } }) - recordsFilter.Input.SetAutocompleteStyles(tview.Styles.PrimitiveBackgroundColor, tcell.StyleDefault.Foreground(tview.Styles.PrimaryTextColor).Background(tview.Styles.PrimitiveBackgroundColor), tcell.StyleDefault.Foreground(tview.Styles.SecondaryTextColor).Background(tview.Styles.PrimitiveBackgroundColor)) + recordsFilter.Input.SetAutocompleteStyles(app.Styles.PrimitiveBackgroundColor, tcell.StyleDefault.Foreground(tview.Styles.PrimaryTextColor).Background(tview.Styles.PrimitiveBackgroundColor), tcell.StyleDefault.Foreground(tview.Styles.SecondaryTextColor).Background(tview.Styles.PrimitiveBackgroundColor)) recordsFilter.AddItem(recordsFilter.Label, 6, 0, false) recordsFilter.AddItem(recordsFilter.Input, 0, 1, false) @@ -67,7 +68,7 @@ func (filter *ResultsTableFilter) Subscribe() chan models.StateChange { func (filter *ResultsTableFilter) Publish(message string) { for _, sub := range filter.subscribers { sub <- models.StateChange{ - Key: "Filter", + Key: eventResultsTableFiltering, Value: message, } } @@ -87,29 +88,29 @@ func (filter *ResultsTableFilter) SetIsFiltering(filtering bool) { // Function to blur func (filter *ResultsTableFilter) RemoveHighlight() { - filter.SetBorderColor(tview.Styles.InverseTextColor) - filter.Label.SetTextColor(tview.Styles.InverseTextColor) - filter.Input.SetPlaceholderTextColor(tview.Styles.InverseTextColor) - filter.Input.SetFieldTextColor(tview.Styles.InverseTextColor) + filter.SetBorderColor(app.Styles.InverseTextColor) + filter.Label.SetTextColor(app.Styles.InverseTextColor) + filter.Input.SetPlaceholderTextColor(app.Styles.InverseTextColor) + filter.Input.SetFieldTextColor(app.Styles.InverseTextColor) } func (filter *ResultsTableFilter) RemoveLocalHighlight() { filter.SetBorderColor(tcell.ColorWhite) - filter.Label.SetTextColor(tview.Styles.TertiaryTextColor) - filter.Input.SetPlaceholderTextColor(tview.Styles.InverseTextColor) - filter.Input.SetFieldTextColor(tview.Styles.InverseTextColor) + filter.Label.SetTextColor(app.Styles.TertiaryTextColor) + filter.Input.SetPlaceholderTextColor(app.Styles.InverseTextColor) + filter.Input.SetFieldTextColor(app.Styles.InverseTextColor) } func (filter *ResultsTableFilter) Highlight() { filter.SetBorderColor(tcell.ColorWhite) - filter.Label.SetTextColor(tview.Styles.TertiaryTextColor) + filter.Label.SetTextColor(app.Styles.TertiaryTextColor) filter.Input.SetPlaceholderTextColor(tcell.ColorWhite) - filter.Input.SetFieldTextColor(tview.Styles.PrimaryTextColor) + filter.Input.SetFieldTextColor(app.Styles.PrimaryTextColor) } func (filter *ResultsTableFilter) HighlightLocal() { - filter.SetBorderColor(tview.Styles.PrimaryTextColor) - filter.Label.SetTextColor(tview.Styles.TertiaryTextColor) + filter.SetBorderColor(app.Styles.PrimaryTextColor) + filter.Label.SetTextColor(app.Styles.TertiaryTextColor) filter.Input.SetPlaceholderTextColor(tcell.ColorWhite) - filter.Input.SetFieldTextColor(tview.Styles.PrimaryTextColor) + filter.Input.SetFieldTextColor(app.Styles.PrimaryTextColor) } diff --git a/components/ResultsTable.go b/components/ResultsTable.go index f10e856..7651d2a 100644 --- a/components/ResultsTable.go +++ b/components/ResultsTable.go @@ -31,6 +31,7 @@ type ResultsTableState struct { isEditing bool isFiltering bool isLoading bool + showSidebar bool } type ResultsTable struct { @@ -47,16 +48,10 @@ type ResultsTable struct { EditorPages *tview.Pages ResultsInfo *tview.TextView Tree *Tree + Sidebar *Sidebar DBDriver drivers.Driver } -var ( - ErrorModal = tview.NewModal() - ChangeColor = tcell.ColorDarkOrange - InsertColor = tcell.ColorDarkGreen - DeleteColor = tcell.ColorRed -) - func NewResultsTable(listOfDbChanges *[]models.DbDmlChange, tree *Tree, dbdriver drivers.Driver) *ResultsTable { state := &ResultsTableState{ records: [][]string{}, @@ -67,6 +62,7 @@ func NewResultsTable(listOfDbChanges *[]models.DbDmlChange, tree *Tree, dbdriver isEditing: false, isLoading: false, listOfDbChanges: listOfDbChanges, + showSidebar: false, } wrapper := tview.NewFlex() @@ -76,22 +72,24 @@ func NewResultsTable(listOfDbChanges *[]models.DbDmlChange, tree *Tree, dbdriver errorModal.AddButtons([]string{"Ok"}) errorModal.SetText("An error occurred") errorModal.SetBackgroundColor(tcell.ColorRed) - errorModal.SetTextColor(tview.Styles.PrimaryTextColor) - errorModal.SetButtonStyle(tcell.StyleDefault.Foreground(tview.Styles.PrimaryTextColor)) + errorModal.SetTextColor(app.Styles.PrimaryTextColor) + errorModal.SetButtonStyle(tcell.StyleDefault.Foreground(app.Styles.PrimaryTextColor)) errorModal.SetFocus(0) loadingModal := tview.NewModal() loadingModal.SetText("Loading...") - loadingModal.SetBackgroundColor(tview.Styles.PrimitiveBackgroundColor) - loadingModal.SetTextColor(tview.Styles.SecondaryTextColor) + loadingModal.SetBackgroundColor(app.Styles.PrimitiveBackgroundColor) + loadingModal.SetTextColor(app.Styles.SecondaryTextColor) pages := tview.NewPages() - pages.AddPage("table", wrapper, true, true) - pages.AddPage("error", errorModal, true, false) - pages.AddPage("loading", loadingModal, false, false) + pages.AddPage(pageNameTable, wrapper, true, true) + pages.AddPage(pageNameTableError, errorModal, true, false) + pages.AddPage(pageNameTableLoading, loadingModal, false, false) pagination := NewPagination() + sidebar := NewSidebar() + table := &ResultsTable{ Table: tview.NewTable(), state: state, @@ -103,15 +101,25 @@ func NewResultsTable(listOfDbChanges *[]models.DbDmlChange, tree *Tree, dbdriver Editor: nil, Tree: tree, DBDriver: dbdriver, + Sidebar: sidebar, } table.SetSelectable(true, true) table.SetBorders(true) table.SetFixed(1, 0) table.SetInputCapture(table.tableInputCapture) - table.SetSelectedStyle(tcell.StyleDefault.Background(tview.Styles.SecondaryTextColor).Foreground(tview.Styles.ContrastSecondaryTextColor)) + table.SetSelectedStyle(tcell.StyleDefault.Background(app.Styles.SecondaryTextColor).Foreground(tview.Styles.ContrastSecondaryTextColor)) + table.Page.AddPage(pageNameSidebar, table.Sidebar, false, false) + + table.SetSelectionChangedFunc(func(row, col int) { + if table.GetShowSidebar() { + logger.Info("table.SetSelectionChangedFunc", map[string]any{"row": row, "col": col}) + go table.UpdateSidebar() + } + }) go table.subscribeToTreeChanges() + go table.subscribeToSidebarChanges() return table } @@ -159,12 +167,12 @@ func (table *ResultsTable) WithEditor() *ResultsTable { resultsInfoWrapper := tview.NewFlex().SetDirection(tview.FlexColumnCSS) resultsInfoText := tview.NewTextView() resultsInfoText.SetBorder(true) - resultsInfoText.SetBorderColor(tview.Styles.PrimaryTextColor) - resultsInfoText.SetTextColor(tview.Styles.PrimaryTextColor) + resultsInfoText.SetBorderColor(app.Styles.PrimaryTextColor) + resultsInfoText.SetTextColor(app.Styles.PrimaryTextColor) resultsInfoWrapper.AddItem(resultsInfoText, 3, 0, false) - editorPages.AddPage("Table", tableWrapper, true, false) - editorPages.AddPage("ResultsInfo", resultsInfoWrapper, true, true) + editorPages.AddPage(pageNameTableEditorTable, tableWrapper, true, false) + editorPages.AddPage(pageNameTableEditorResultsInfo, resultsInfoWrapper, true, true) table.EditorPages = editorPages table.ResultsInfo = resultsInfoText @@ -180,17 +188,59 @@ func (table *ResultsTable) subscribeToTreeChanges() { ch := table.Tree.Subscribe() for stateChange := range ch { - if stateChange.Key == "SelectedDatabase" { + if stateChange.Key == eventTreeSelectedDatabase { table.SetDatabaseName(stateChange.Value.(string)) } } } +func (table *ResultsTable) subscribeToSidebarChanges() { + ch := table.Sidebar.Subscribe() + + for stateChange := range ch { + switch stateChange.Key { + case eventSidebarEditing: + editing := stateChange.Value.(bool) + table.SetIsEditing(editing) + case eventSidebarUnfocusing: + App.SetFocus(table) + App.ForceDraw() + case eventSidebarToggling: + table.ShowSidebar(false) + App.ForceDraw() + case eventSidebarCommitEditing: + params := stateChange.Value.(models.SidebarEditingCommitParams) + + table.SetInputCapture(table.tableInputCapture) + table.SetIsEditing(false) + + row, _ := table.GetSelection() + changedColumnIndex := table.GetColumnIndexByName(params.ColumnName) + tableCell := table.GetCell(row, changedColumnIndex) + + tableCell.SetText(params.NewValue) + + cellValue := models.CellValue{ + Type: models.String, + Column: params.ColumnName, + Value: params.NewValue, + TableColumnIndex: changedColumnIndex, + TableRowIndex: row, + } + + logger.Info("eventSidebarCommitEditing", map[string]any{"cellValue": cellValue, "params": params, "rowIndex": row, "changedColumnIndex": changedColumnIndex}) + table.AppendNewChange(models.DmlUpdateType, row, changedColumnIndex, cellValue) + + App.ForceDraw() + } + } +} + func (table *ResultsTable) AddRows(rows [][]string) { for i, row := range rows { for j, cell := range row { tableCell := tview.NewTableCell(cell) - tableCell.SetTextColor(tview.Styles.PrimaryTextColor) + tableCell.SetTextColor(app.Styles.PrimaryTextColor) if cell == "EMPTY&" || cell == "NULL&" || cell == "DEFAULT&" { tableCell.SetText(strings.Replace(cell, "&", "", 1)) @@ -234,8 +284,8 @@ func (table *ResultsTable) AddInsertedRows() { tableCell.SetExpansion(1) tableCell.SetReference(inserts[i].PrimaryKeyValue) - tableCell.SetTextColor(tview.Styles.PrimaryTextColor) - tableCell.SetBackgroundColor(InsertColor) + tableCell.SetTextColor(app.Styles.PrimaryTextColor) + tableCell.SetBackgroundColor(colorTableInsert) table.SetCell(rowIndex, j, tableCell) } @@ -247,17 +297,18 @@ func (table *ResultsTable) AppendNewRow(cells []models.CellValue, index int, UUI tableCell := tview.NewTableCell(cell.Value.(string)) tableCell.SetExpansion(1) tableCell.SetReference(UUID) - tableCell.SetTextColor(tview.Styles.PrimaryTextColor) + tableCell.SetTextColor(app.Styles.PrimaryTextColor) + tableCell.SetBackgroundColor(tcell.ColorDarkGreen) switch cell.Type { case models.Null, models.Empty, models.Default: tableCell.SetText(strings.Replace(cell.Value.(string), "&", "", 1)) tableCell.SetStyle(table.GetItalicStyle()) // tableCell.SetText("") - tableCell.SetTextColor(tview.Styles.InverseTextColor) + tableCell.SetTextColor(app.Styles.InverseTextColor) } - tableCell.SetBackgroundColor(InsertColor) + tableCell.SetBackgroundColor(colorTableInsert) table.SetCell(index, i, tableCell) } @@ -315,7 +366,11 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event } if command == commands.Edit { - table.StartEditingCell(selectedRowIndex, selectedColumnIndex, nil) + table.StartEditingCell(selectedRowIndex, selectedColumnIndex, func(_ string, _, _ int) { + if table.GetShowSidebar() { + table.UpdateSidebar() + } + }) } else if command == commands.GotoNext { if selectedColumnIndex+1 < colCount { table.Select(selectedRowIndex, selectedColumnIndex+1) @@ -369,7 +424,7 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event } } } else { - table.AppendNewChange(models.DmlDeleteType, table.GetDatabaseName(), table.GetTableName(), selectedRowIndex, -1, models.CellValue{}) + table.AppendNewChange(models.DmlDeleteType, selectedRowIndex, -1, models.CellValue{}) } } @@ -386,11 +441,17 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event table.FinishSettingValue() if selection >= 0 { - table.AppendNewChange(models.DmlUpdateType, table.Tree.GetSelectedDatabase(), table.Tree.GetSelectedTable(), selectedRowIndex, selectedColumnIndex, models.CellValue{Type: selection, Value: value, Column: table.GetColumnNameByIndex(selectedColumnIndex)}) + table.AppendNewChange(models.DmlUpdateType, selectedRowIndex, selectedColumnIndex, models.CellValue{Type: selection, Value: value, Column: table.GetColumnNameByIndex(selectedColumnIndex)}) } }) list.Show(x, y, 30) + } else if command == commands.ToggleSidebar { + table.ShowSidebar(!table.GetShowSidebar()) + } else if command == commands.FocusSidebar { + if table.GetShowSidebar() { + App.SetFocus(table.Sidebar) + } } if len(table.GetRecords()) > 0 { @@ -437,7 +498,7 @@ func (table *ResultsTable) UpdateRowsColor(headerColor tcell.Color, rowColor tce } else { cellReference := cell.GetReference() - if cellReference != nil && (cellReference == "EMPTY&" || cellReference == "NULL&" || cellReference == "DEFAULT&") && (cell.BackgroundColor != DeleteColor && cell.BackgroundColor != ChangeColor && cell.BackgroundColor != InsertColor) { + if cellReference != nil && (cellReference == "EMPTY&" || cellReference == "NULL&" || cellReference == "DEFAULT&") && (cell.BackgroundColor != colorTableDelete && cell.BackgroundColor != colorTableChange && cell.BackgroundColor != colorTableInsert) { cell.SetStyle(table.GetItalicStyle()) } else { cell.SetTextColor(rowColor) @@ -448,10 +509,10 @@ func (table *ResultsTable) UpdateRowsColor(headerColor tcell.Color, rowColor tce } func (table *ResultsTable) RemoveHighlightTable() { - table.SetBorderColor(tview.Styles.InverseTextColor) - table.SetBordersColor(tview.Styles.InverseTextColor) - table.SetTitleColor(tview.Styles.InverseTextColor) - table.UpdateRowsColor(tview.Styles.InverseTextColor, tview.Styles.InverseTextColor) + table.SetBorderColor(app.Styles.InverseTextColor) + table.SetBordersColor(app.Styles.InverseTextColor) + table.SetTitleColor(app.Styles.InverseTextColor) + table.UpdateRowsColor(app.Styles.InverseTextColor, tview.Styles.InverseTextColor) } func (table *ResultsTable) RemoveHighlightAll() { @@ -465,10 +526,10 @@ func (table *ResultsTable) RemoveHighlightAll() { } func (table *ResultsTable) HighlightTable() { - table.SetBorderColor(tview.Styles.PrimaryTextColor) - table.SetBordersColor(tview.Styles.PrimaryTextColor) - table.SetTitleColor(tview.Styles.PrimaryTextColor) - table.UpdateRowsColor(tview.Styles.PrimaryTextColor, tview.Styles.PrimaryTextColor) + table.SetBorderColor(app.Styles.PrimaryTextColor) + table.SetBordersColor(app.Styles.PrimaryTextColor) + table.SetTitleColor(app.Styles.PrimaryTextColor) + table.UpdateRowsColor(app.Styles.PrimaryTextColor, tview.Styles.PrimaryTextColor) } func (table *ResultsTable) HighlightAll() { @@ -486,7 +547,7 @@ func (table *ResultsTable) subscribeToFilterChanges() { for stateChange := range ch { switch stateChange.Key { - case "Filter": + case eventResultsTableFiltering: if stateChange.Value != "" { rows := table.FetchRecords(nil) @@ -525,7 +586,7 @@ func (table *ResultsTable) subscribeToEditorChanges() { for stateChange := range ch { switch stateChange.Key { - case "Query": + case eventSQLEditorQuery: query := stateChange.Value.(string) if query != "" { queryLower := strings.ToLower(query) @@ -562,7 +623,7 @@ func (table *ResultsTable) subscribeToEditorChanges() { } table.SetLoading(false) } - table.EditorPages.SwitchToPage("Table") + table.EditorPages.SwitchToPage(pageNameTable) App.Draw() } else { table.SetRecords([][]string{}) @@ -578,13 +639,13 @@ func (table *ResultsTable) subscribeToEditorChanges() { } else { table.SetResultsInfo(result) table.SetLoading(false) - table.EditorPages.SwitchToPage("ResultsInfo") + table.EditorPages.SwitchToPage(pageNameTableEditorResultsInfo) App.SetFocus(table.Editor) App.Draw() } } } - case "Escape": + case eventSQLEditorEscape: table.SetIsFiltering(false) App.SetFocus(table) table.HighlightTable() @@ -649,6 +710,17 @@ func (table *ResultsTable) GetColumnNameByIndex(index int) string { return "" } +func (table *ResultsTable) GetColumnIndexByName(columnName string) int { + for i := 0; i < table.GetColumnCount(); i++ { + cell := table.GetCell(0, i) + if cell.Text == columnName { + return i + } + } + + return -1 +} + func (table *ResultsTable) GetIsLoading() bool { return table.state.isLoading } @@ -657,6 +729,10 @@ func (table *ResultsTable) GetIsFiltering() bool { return table.state.isFiltering } +func (table *ResultsTable) GetShowSidebar() bool { + return table.state.showSidebar +} + // Setters func (table *ResultsTable) SetRecords(rows [][]string) { @@ -694,7 +770,7 @@ func (table *ResultsTable) SetError(err string, done func()) { table.Error.SetText(err) table.Error.SetDoneFunc(func(_ int, _ string) { table.state.error = "" - table.Page.HidePage("error") + table.Page.HidePage(pageNameTableError) if table.GetIsFiltering() { if table.Editor != nil { App.SetFocus(table.Editor) @@ -708,7 +784,7 @@ func (table *ResultsTable) SetError(err string, done func()) { done() } }) - table.Page.ShowPage("error") + table.Page.ShowPage(pageNameTableError) App.SetFocus(table.Error) App.ForceDraw() } @@ -721,7 +797,7 @@ func (table *ResultsTable) SetLoading(show bool) { defer func() { if r := recover(); r != nil { logger.Error("ResultsTable.go:800 => Recovered from panic", map[string]any{"error": r}) - _ = table.Page.HidePage("loading") + _ = table.Page.HidePage(pageNameTableLoading) if table.state.error != "" { App.SetFocus(table.Error) } else { @@ -732,11 +808,11 @@ func (table *ResultsTable) SetLoading(show bool) { table.state.isLoading = show if show { - table.Page.ShowPage("loading") + table.Page.ShowPage(pageNameTableLoading) App.SetFocus(table.Loading) App.ForceDraw() } else { - table.Page.HidePage("loading") + table.Page.HidePage(pageNameTableLoading) if table.state.error != "" { App.SetFocus(table.Error) } else { @@ -791,7 +867,7 @@ func (table *ResultsTable) SetSortedBy(column string, direction string) { tableCell := tview.NewTableCell(col[0]) tableCell.SetSelectable(false) tableCell.SetExpansion(1) - tableCell.SetTextColor(tview.Styles.PrimaryTextColor) + tableCell.SetTextColor(app.Styles.PrimaryTextColor) if col[0] == column { tableCell.SetText(fmt.Sprintf("%s %s", col[0], iconDirection)) @@ -858,8 +934,8 @@ func (table *ResultsTable) StartEditingCell(row int, col int, callback func(newV cell := table.GetCell(row, col) inputField := tview.NewInputField() inputField.SetText(cell.Text) - inputField.SetFieldBackgroundColor(tview.Styles.PrimaryTextColor) - inputField.SetFieldTextColor(tview.Styles.PrimitiveBackgroundColor) + inputField.SetFieldBackgroundColor(app.Styles.PrimaryTextColor) + inputField.SetFieldTextColor(app.Styles.PrimitiveBackgroundColor) inputField.SetDoneFunc(func(key tcell.Key) { table.SetIsEditing(false) @@ -871,7 +947,7 @@ func (table *ResultsTable) StartEditingCell(row int, col int, callback func(newV cell.SetText(newValue) if currentValue != newValue { - table.AppendNewChange(models.DmlUpdateType, table.GetDatabaseName(), table.GetTableName(), row, col, models.CellValue{Type: models.String, Value: newValue, Column: columnName}) + table.AppendNewChange(models.DmlUpdateType, row, col, models.CellValue{Type: models.String, Value: newValue, Column: columnName, TableColumnIndex: col, TableRowIndex: row}) } switch key { @@ -899,7 +975,7 @@ func (table *ResultsTable) StartEditingCell(row int, col int, callback func(newV if key == tcell.KeyEnter || key == tcell.KeyEscape { table.SetInputCapture(table.tableInputCapture) - table.Page.RemovePage("edit") + table.Page.RemovePage(pageNameTableEditCell) App.SetFocus(table) } @@ -910,7 +986,7 @@ func (table *ResultsTable) StartEditingCell(row int, col int, callback func(newV x, y, width := cell.GetLastPosition() inputField.SetRect(x, y, width+1, 1) - table.Page.AddPage("edit", inputField, false, true) + table.Page.AddPage(pageNameTableEditCell, inputField, false, true) App.SetFocus(inputField) } @@ -937,7 +1013,10 @@ func (table *ResultsTable) MutateInsertedRowCell(rowID string, newValue models.C } } -func (table *ResultsTable) AppendNewChange(changeType models.DmlType, databaseName, tableName string, rowIndex int, colIndex int, value models.CellValue) { +func (table *ResultsTable) AppendNewChange(changeType models.DmlType, rowIndex int, colIndex int, value models.CellValue) { + databaseName := table.GetDatabaseName() + tableName := table.GetTableName() + dmlChangeAlreadyExists := false // If the column has a reference, it means it's an inserted rowIndex @@ -989,18 +1068,18 @@ func (table *ResultsTable) AppendNewChange(changeType models.DmlType, databaseNa } else { (*table.state.listOfDbChanges)[i].Values = append((*table.state.listOfDbChanges)[i].Values[:valueIndex], (*table.state.listOfDbChanges)[i].Values[valueIndex+1:]...) } - table.SetCellColor(rowIndex, colIndex, tview.Styles.PrimitiveBackgroundColor) + table.SetCellColor(rowIndex, colIndex, app.Styles.PrimitiveBackgroundColor) } else { (*table.state.listOfDbChanges)[i].Values[valueIndex] = value } } else { (*table.state.listOfDbChanges)[i].Values = append((*table.state.listOfDbChanges)[i].Values, value) - table.SetCellColor(rowIndex, colIndex, ChangeColor) + table.SetCellColor(rowIndex, colIndex, colorTableChange) } case models.DmlDeleteType: *table.state.listOfDbChanges = append((*table.state.listOfDbChanges)[:i], (*table.state.listOfDbChanges)[i+1:]...) - table.SetRowColor(rowIndex, tview.Styles.PrimitiveBackgroundColor) + table.SetRowColor(rowIndex, app.Styles.PrimitiveBackgroundColor) } } } @@ -1009,10 +1088,10 @@ func (table *ResultsTable) AppendNewChange(changeType models.DmlType, databaseNa switch changeType { case models.DmlDeleteType: - table.SetRowColor(rowIndex, DeleteColor) + table.SetRowColor(rowIndex, colorTableDelete) case models.DmlUpdateType: - tableCell.SetStyle(tcell.StyleDefault.Background(ChangeColor)) - table.SetCellColor(rowIndex, colIndex, ChangeColor) + tableCell.SetStyle(tcell.StyleDefault.Background(colorTableChange)) + table.SetCellColor(rowIndex, colIndex, colorTableChange) } newDmlChange := models.DbDmlChange{ @@ -1027,6 +1106,8 @@ func (table *ResultsTable) AppendNewChange(changeType models.DmlType, databaseNa *table.state.listOfDbChanges = append(*table.state.listOfDbChanges, newDmlChange) } + + logger.Info("AppendNewChange", map[string]any{"listOfDbChanges": *table.state.listOfDbChanges}) } func (table *ResultsTable) GetPrimaryKeyValue(rowIndex int) (string, string) { @@ -1038,7 +1119,7 @@ func (table *ResultsTable) GetPrimaryKeyValue(rowIndex int) (string, string) { primaryKeyValue := "" switch provider { - case "mysql": + case drivers.DriverMySQL: keyColumnIndex := -1 primaryKeyColumnIndex := -1 @@ -1059,7 +1140,7 @@ func (table *ResultsTable) GetPrimaryKeyValue(rowIndex int) (string, string) { primaryKeyValue = table.GetRecords()[rowIndex][primaryKeyColumnIndex] } - case "postgres": + case drivers.DriverPostgres: keyColumnIndex := -1 constraintTypeColumnIndex := -1 constraintNameColumnIndex := -1 @@ -1101,7 +1182,7 @@ func (table *ResultsTable) GetPrimaryKeyValue(rowIndex int) (string, string) { primaryKeyValue = table.GetRecords()[rowIndex][primaryKeyColumnIndex] } - case "sqlite3": + case drivers.DriverSqlite: keyColumnIndex := -1 primaryKeyColumnIndex := -1 @@ -1144,7 +1225,7 @@ func (table *ResultsTable) appendNewRow() { for i, column := range dbColumns { if i != 0 { // Skip the first row because they are the column names (e.x "Field", "Type", "Null", "Key", "Default", "Extra") - newRow[i-1] = models.CellValue{Type: models.Default, Column: column[0], Value: "DEFAULT"} + newRow[i-1] = models.CellValue{Type: models.Default, Column: column[0], Value: "DEFAULT", TableRowIndex: newRowTableIndex, TableColumnIndex: i} } } @@ -1275,3 +1356,68 @@ func (table *ResultsTable) FinishSettingValue() { func (table *ResultsTable) GetItalicStyle() tcell.Style { return tcell.StyleDefault.Foreground(tview.Styles.InverseTextColor).Italic(true) } + +func (table *ResultsTable) ShowSidebar(show bool) { + table.state.showSidebar = show + + if show { + table.UpdateSidebar() + table.Page.SendToFront(pageNameSidebar) + table.Page.ShowPage(pageNameSidebar) + } else { + table.Page.HidePage(pageNameSidebar) + App.SetFocus(table) + } +} + +func (table *ResultsTable) UpdateSidebar() { + columns := table.GetColumns() + selectedRow, _ := table.GetSelection() + + if selectedRow > 0 { + tableX, _, _, tableHeight := table.GetRect() + _, _, tableInnerWidth, _ := table.GetInnerRect() + _, tableMenuY, _, tableMenuHeight := table.Menu.GetRect() + _, _, _, tableFilterHeight := table.Filter.GetRect() + _, _, _, tablePaginationHeight := table.Pagination.GetRect() + + sidebarWidth := (tableInnerWidth / 4) + sidebarHeight := tableHeight + tableMenuHeight + tableFilterHeight + tablePaginationHeight + 1 + + table.Sidebar.SetRect(tableX+tableInnerWidth-sidebarWidth, tableMenuY, sidebarWidth, sidebarHeight) + table.Sidebar.Clear() + + for i := 1; i < len(columns); i++ { + name := columns[i][0] + colType := columns[i][1] + + text := table.GetCell(selectedRow, i-1).Text + title := name + + repeatCount := sidebarWidth - len(name) - len(colType) - 4 // idk why 4 is needed, but it works. + + if repeatCount <= 0 { + repeatCount = 1 + } + + title += fmt.Sprintf("[%s]", app.Styles.SidebarTitleBorderColor) + strings.Repeat("-", repeatCount) + title += colType + + pendingEditExist := false + + for _, dmlChange := range *table.state.listOfDbChanges { + if dmlChange.Type == models.DmlUpdateType { + for _, v := range dmlChange.Values { + if v.Column == name && v.TableRowIndex == selectedRow && v.TableColumnIndex == i-1 { + pendingEditExist = true + break + } + } + } + } + + table.Sidebar.AddField(title, text, sidebarWidth, pendingEditExist) + } + + } +} diff --git a/components/ResultsTableMenu.go b/components/ResultsTableMenu.go index c907fca..e37ba9b 100644 --- a/components/ResultsTableMenu.go +++ b/components/ResultsTableMenu.go @@ -4,6 +4,8 @@ import ( "fmt" "github.com/rivo/tview" + + "github.com/jorgerojas26/lazysql/app" ) type ResultsTableMenuState struct { @@ -17,11 +19,11 @@ type ResultsTableMenu struct { } var menuItems = []string{ - "Records", - "Columns", - "Constraints", - "Foreign Keys", - "Indexes", + menuRecords, + menuColumns, + menuConstraints, + menuForeignKeys, + menuIndexes, } func NewResultsTableMenu() *ResultsTableMenu { @@ -46,17 +48,17 @@ func NewResultsTableMenu() *ResultsTableMenu { textview := tview.NewTextView().SetText(text) if i == 0 { - textview.SetTextColor(tview.Styles.PrimaryTextColor) + textview.SetTextColor(app.Styles.PrimaryTextColor) } size := 15 switch item { - case "Constraints": + case menuConstraints: size = 19 - case "Foreign Keys": + case menuForeignKeys: size = 20 - case "Indexes": + case menuIndexes: size = 16 } @@ -80,29 +82,29 @@ func (menu *ResultsTableMenu) SetSelectedOption(option int) { itemCount := menu.GetItemCount() for i := 0; i < itemCount; i++ { - menu.GetItem(i).(*tview.TextView).SetTextColor(tview.Styles.PrimaryTextColor) + menu.GetItem(i).(*tview.TextView).SetTextColor(app.Styles.PrimaryTextColor) } - menu.GetItem(option - 1).(*tview.TextView).SetTextColor(tview.Styles.SecondaryTextColor) + menu.GetItem(option - 1).(*tview.TextView).SetTextColor(app.Styles.SecondaryTextColor) } } func (menu *ResultsTableMenu) SetBlur() { - menu.SetBorderColor(tview.Styles.InverseTextColor) + menu.SetBorderColor(app.Styles.InverseTextColor) for _, item := range menu.MenuItems { - item.SetTextColor(tview.Styles.InverseTextColor) + item.SetTextColor(app.Styles.InverseTextColor) } } func (menu *ResultsTableMenu) SetFocus() { - menu.SetBorderColor(tview.Styles.PrimaryTextColor) + menu.SetBorderColor(app.Styles.PrimaryTextColor) for i, item := range menu.MenuItems { if i+1 == menu.GetSelectedOption() { - item.SetTextColor(tview.Styles.SecondaryTextColor) + item.SetTextColor(app.Styles.SecondaryTextColor) } else { - item.SetTextColor(tview.Styles.PrimaryTextColor) + item.SetTextColor(app.Styles.PrimaryTextColor) } } } diff --git a/components/SQLEditor.go b/components/SQLEditor.go index b3e4389..3302259 100644 --- a/components/SQLEditor.go +++ b/components/SQLEditor.go @@ -38,10 +38,10 @@ func NewSQLEditor() *SQLEditor { command := app.Keymaps.Group(app.EditorGroup).Resolve(event) if command == commands.Execute { - sqlEditor.Publish("Query", sqlEditor.GetText()) + sqlEditor.Publish(eventSQLEditorQuery, sqlEditor.GetText()) return nil } else if command == commands.UnfocusEditor { - sqlEditor.Publish("Escape", "") + sqlEditor.Publish(eventSQLEditorEscape, "") } else if command == commands.OpenInExternalEditor && runtime.GOOS == "linux" { // ----- THIS IS A LINUX-ONLY FEATURE, for now @@ -79,13 +79,13 @@ func (s *SQLEditor) SetIsFocused(isFocused bool) { } func (s *SQLEditor) Highlight() { - s.SetBorderColor(tview.Styles.PrimaryTextColor) - s.SetTextStyle(tcell.StyleDefault.Foreground(tview.Styles.PrimaryTextColor)) + s.SetBorderColor(app.Styles.PrimaryTextColor) + s.SetTextStyle(tcell.StyleDefault.Foreground(app.Styles.PrimaryTextColor)) } func (s *SQLEditor) SetBlur() { - s.SetBorderColor(tview.Styles.InverseTextColor) - s.SetTextStyle(tcell.StyleDefault.Foreground(tview.Styles.InverseTextColor)) + s.SetBorderColor(app.Styles.InverseTextColor) + s.SetTextStyle(tcell.StyleDefault.Foreground(app.Styles.InverseTextColor)) } /* diff --git a/components/Sidebar.go b/components/Sidebar.go new file mode 100644 index 0000000..40211df --- /dev/null +++ b/components/Sidebar.go @@ -0,0 +1,319 @@ +package components + +import ( + "strings" + + "github.com/gdamore/tcell/v2" + "github.com/rivo/tview" + + "github.com/jorgerojas26/lazysql/app" + "github.com/jorgerojas26/lazysql/commands" + "github.com/jorgerojas26/lazysql/models" +) + +type SidebarState struct { + currentFieldIndex int +} + +type SidebarFieldParameters struct { + OriginalValue string + Height int +} + +type Sidebar struct { + *tview.Frame + Flex *tview.Flex + state *SidebarState + FieldParameters []*SidebarFieldParameters + subscribers []chan models.StateChange +} + +func NewSidebar() *Sidebar { + flex := tview.NewFlex().SetDirection(tview.FlexColumnCSS) + frame := tview.NewFrame(flex) + frame.SetBackgroundColor(app.Styles.PrimitiveBackgroundColor) + frame.SetBorder(true) + frame.SetBorders(0, 0, 0, 0, 0, 0) + + sidebarState := &SidebarState{ + currentFieldIndex: 0, + } + + newSidebar := &Sidebar{ + Frame: frame, + Flex: flex, + state: sidebarState, + subscribers: []chan models.StateChange{}, + } + + newSidebar.SetInputCapture(newSidebar.inputCapture) + + newSidebar.SetBlurFunc(func() { + newSidebar.SetCurrentFieldIndex(0) + }) + + return newSidebar +} + +func (sidebar *Sidebar) AddField(title, text string, fieldWidth int, pendingEdit bool) { + field := tview.NewTextArea() + field.SetWrap(true) + field.SetDisabled(true) + + field.SetBorder(true) + field.SetTitle(title) + field.SetTitleAlign(tview.AlignLeft) + field.SetTitleColor(app.Styles.PrimaryTextColor) + field.SetText(text, true) + field.SetTextStyle(tcell.StyleDefault.Background(app.Styles.PrimitiveBackgroundColor).Foreground(tview.Styles.SecondaryTextColor)) + + if pendingEdit { + sidebar.SetEditedStyles(field) + } + + textLength := len(field.GetText()) + + itemFixedSize := 3 + + if textLength >= fieldWidth*3 { + itemFixedSize = 5 + } else if textLength >= fieldWidth { + itemFixedSize = 4 + } else { + field.SetWrap(false) + } + + field.SetFocusFunc(func() { + _, y, _, h := field.GetRect() + _, _, _, mph := sidebar.GetRect() + + if y >= mph { + hidingFieldIndex := 0 + fieldCount := sidebar.Flex.GetItemCount() + + for i := 0; i < fieldCount; i++ { + f := sidebar.Flex.GetItem(i) + _, _, _, h := f.GetRect() + if h != 0 { + hidingFieldIndex = i + break + } + } + + sidebar.Flex.ResizeItem(sidebar.Flex.GetItem(hidingFieldIndex), 0, 0) + } else if h == 0 { + sidebar.Flex.ResizeItem(field, itemFixedSize, 0) + } + }) + + fieldParameters := &SidebarFieldParameters{ + Height: itemFixedSize, + OriginalValue: text, + } + + sidebar.FieldParameters = append(sidebar.FieldParameters, fieldParameters) + sidebar.Flex.AddItem(field, itemFixedSize, 0, true) +} + +func (sidebar *Sidebar) FocusNextField() { + newIndex := sidebar.GetCurrentFieldIndex() + 1 + + if newIndex >= sidebar.Flex.GetItemCount() { + return + } + + item := sidebar.Flex.GetItem(newIndex) + + if item == nil { + return + } + + sidebar.SetCurrentFieldIndex(newIndex) + App.SetFocus(item) + App.ForceDraw() +} + +func (sidebar *Sidebar) FocusPreviousField() { + newIndex := sidebar.GetCurrentFieldIndex() - 1 + + if newIndex < 0 { + return + } + + item := sidebar.Flex.GetItem(newIndex) + + if item == nil { + return + } + + sidebar.SetCurrentFieldIndex(newIndex) + App.SetFocus(item) + App.ForceDraw() +} + +func (sidebar *Sidebar) FocusFirstField() { + sidebar.SetCurrentFieldIndex(0) + App.SetFocus(sidebar.Flex.GetItem(0)) + + fieldCount := sidebar.Flex.GetItemCount() + + for i := 0; i < fieldCount; i++ { + field := sidebar.Flex.GetItem(i) + height := sidebar.FieldParameters[i].Height + sidebar.Flex.ResizeItem(field, height, 0) + } +} + +func (sidebar *Sidebar) FocusLastField() { + newIndex := sidebar.Flex.GetItemCount() - 1 + sidebar.SetCurrentFieldIndex(newIndex) + App.SetFocus(sidebar.Flex.GetItem(newIndex)) + + _, _, _, ph := sidebar.GetRect() + + hSum := 0 + + for i := sidebar.Flex.GetItemCount() - 1; i >= 0; i-- { + field := sidebar.Flex.GetItem(i).(*tview.TextArea) + _, _, _, h := field.GetRect() + + hSum += h + + if hSum >= ph { + sidebar.Flex.ResizeItem(field, 0, 0) + } + } +} + +func (sidebar *Sidebar) FocusField(index int) { + sidebar.SetCurrentFieldIndex(index) + App.SetFocus(sidebar.Flex.GetItem(index)) +} + +func (sidebar *Sidebar) Clear() { + sidebar.FieldParameters = make([]*SidebarFieldParameters, 0) + sidebar.Flex.Clear() +} + +func (sidebar *Sidebar) EditTextCurrentField() { + index := sidebar.GetCurrentFieldIndex() + item := sidebar.Flex.GetItem(index).(*tview.TextArea) + + sidebar.SetEditingStyles(item) +} + +func (sidebar *Sidebar) inputCapture(event *tcell.EventKey) *tcell.EventKey { + command := app.Keymaps.Group(app.SidebarGroup).Resolve(event) + + switch command { + case commands.UnfocusSidebar: + sidebar.Publish(models.StateChange{Key: eventSidebarUnfocusing, Value: nil}) + case commands.ToggleSidebar: + sidebar.Publish(models.StateChange{Key: eventSidebarToggling, Value: nil}) + case commands.MoveDown: + sidebar.FocusNextField() + case commands.MoveUp: + sidebar.FocusPreviousField() + case commands.GotoStart: + sidebar.FocusFirstField() + case commands.GotoEnd: + sidebar.FocusLastField() + case commands.Edit: + sidebar.Publish(models.StateChange{Key: eventSidebarEditing, Value: true}) + + currentItemIndex := sidebar.GetCurrentFieldIndex() + item := sidebar.Flex.GetItem(currentItemIndex).(*tview.TextArea) + text := item.GetText() + + columnName := item.GetTitle() + columnNameSplit := strings.Split(columnName, "[") + columnName = columnNameSplit[0] + + sidebar.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + command := app.Keymaps.Group(app.SidebarGroup).Resolve(event) + + switch command { + case commands.CommitEdit: + sidebar.SetInputCapture(sidebar.inputCapture) + originalValue := sidebar.FieldParameters[currentItemIndex].OriginalValue + newText := item.GetText() + + if originalValue == newText { + sidebar.SetDisabledStyles(item) + } else { + sidebar.SetEditedStyles(item) + sidebar.Publish(models.StateChange{Key: eventSidebarCommitEditing, Value: models.SidebarEditingCommitParams{ColumnName: columnName, NewValue: newText}}) + } + + return nil + case commands.DiscardEdit: + sidebar.SetInputCapture(sidebar.inputCapture) + sidebar.SetDisabledStyles(item) + item.SetText(text, true) + sidebar.Publish(models.StateChange{Key: eventSidebarEditing, Value: false}) + return nil + } + + return event + }) + + sidebar.EditTextCurrentField() + + return nil + } + return event +} + +func (sidebar *Sidebar) SetEditingStyles(item *tview.TextArea) { + item.SetBackgroundColor(app.Styles.SecondaryTextColor) + item.SetTextStyle(tcell.StyleDefault.Background(app.Styles.SecondaryTextColor).Foreground(tview.Styles.ContrastSecondaryTextColor)) + item.SetTitleColor(app.Styles.ContrastSecondaryTextColor) + item.SetBorderColor(app.Styles.SecondaryTextColor) + + item.SetWrap(true) + item.SetDisabled(false) +} + +func (sidebar *Sidebar) SetDisabledStyles(item *tview.TextArea) { + item.SetBackgroundColor(app.Styles.PrimitiveBackgroundColor) + item.SetTextStyle(tcell.StyleDefault.Background(app.Styles.PrimitiveBackgroundColor).Foreground(tview.Styles.SecondaryTextColor)) + item.SetTitleColor(app.Styles.PrimaryTextColor) + item.SetBorderColor(app.Styles.BorderColor) + + item.SetWrap(true) + item.SetDisabled(true) +} + +func (sidebar *Sidebar) SetEditedStyles(item *tview.TextArea) { + item.SetBackgroundColor(colorTableChange) + item.SetTextStyle(tcell.StyleDefault.Background(colorTableChange).Foreground(tview.Styles.ContrastSecondaryTextColor)) + item.SetTitleColor(app.Styles.ContrastSecondaryTextColor) + item.SetBorderColor(app.Styles.ContrastSecondaryTextColor) + + item.SetWrap(true) + item.SetDisabled(true) +} + +// Getters +func (sidebar *Sidebar) GetCurrentFieldIndex() int { + return sidebar.state.currentFieldIndex +} + +// Setters +func (sidebar *Sidebar) SetCurrentFieldIndex(index int) { + sidebar.state.currentFieldIndex = index +} + +// Subscribe to changes in the sidebar state +func (sidebar *Sidebar) Subscribe() chan models.StateChange { + subscriber := make(chan models.StateChange) + sidebar.subscribers = append(sidebar.subscribers, subscriber) + return subscriber +} + +// Publish subscribers of changes in the sidebar state +func (sidebar *Sidebar) Publish(change models.StateChange) { + for _, subscriber := range sidebar.subscribers { + subscriber <- change + } +} diff --git a/components/TabbedMenu.go b/components/TabbedMenu.go index 75dfd0f..2ee7bf7 100644 --- a/components/TabbedMenu.go +++ b/components/TabbedMenu.go @@ -241,9 +241,9 @@ func (t *TabbedPane) HighlightTabHeader(tab *Tab) { for i := 0; tabToHighlight != nil && i < t.state.Length; i++ { if tabToHighlight.Header == tab.Header { - tabToHighlight.Header.SetTextColor(tview.Styles.SecondaryTextColor) + tabToHighlight.Header.SetTextColor(app.Styles.SecondaryTextColor) } else { - tabToHighlight.Header.SetTextColor(tview.Styles.PrimaryTextColor) + tabToHighlight.Header.SetTextColor(app.Styles.PrimaryTextColor) } tabToHighlight = tabToHighlight.NextTab } @@ -254,9 +254,9 @@ func (t *TabbedPane) Highlight() { for i := 0; tab != nil && i < t.state.Length; i++ { if tab == t.state.CurrentTab { - tab.Header.SetTextColor(tview.Styles.SecondaryTextColor) + tab.Header.SetTextColor(app.Styles.SecondaryTextColor) } else { - tab.Header.SetTextColor(tview.Styles.PrimaryTextColor) + tab.Header.SetTextColor(app.Styles.PrimaryTextColor) } tab = tab.NextTab } @@ -266,7 +266,7 @@ func (t *TabbedPane) SetBlur() { tab := t.state.FirstTab for i := 0; tab != nil && i < t.state.Length; i++ { - tab.Header.SetTextColor(tview.Styles.InverseTextColor) + tab.Header.SetTextColor(app.Styles.InverseTextColor) tab = tab.NextTab } } diff --git a/components/Tree.go b/components/Tree.go index 8d1e743..17ac075 100644 --- a/components/Tree.go +++ b/components/Tree.go @@ -49,7 +49,7 @@ func NewTree(dbName string, dbdriver drivers.Driver) *Tree { } tree.SetTopLevel(1) - tree.SetGraphicsColor(tview.Styles.PrimaryTextColor) + tree.SetGraphicsColor(app.Styles.PrimaryTextColor) // tree.SetBorder(true) tree.SetTitle("Databases") tree.SetTitleAlign(tview.AlignLeft) @@ -77,7 +77,7 @@ func NewTree(dbName string, dbdriver drivers.Driver) *Tree { childNode := tview.NewTreeNode(database) childNode.SetExpanded(false) childNode.SetReference(database) - childNode.SetColor(tview.Styles.PrimaryTextColor) + childNode.SetColor(app.Styles.PrimaryTextColor) rootNode.AddChild(childNode) go func(database string, node *tview.TreeNode) { @@ -95,7 +95,7 @@ func NewTree(dbName string, dbdriver drivers.Driver) *Tree { tree.SetFocusFunc(nil) }) - selectedNodeTextColor := fmt.Sprintf("[black:%s]", tview.Styles.SecondaryTextColor.Name()) + selectedNodeTextColor := fmt.Sprintf("[black:%s]", app.Styles.SecondaryTextColor.Name()) previouslyFocusedNode := tree.GetCurrentNode() previouslyFocusedNode.SetText(selectedNodeTextColor + previouslyFocusedNode.GetText()) @@ -253,33 +253,33 @@ func NewTree(dbName string, dbdriver drivers.Driver) *Tree { return event }) - tree.Filter.SetFieldStyle(tcell.StyleDefault.Background(tview.Styles.PrimitiveBackgroundColor).Foreground(tview.Styles.PrimaryTextColor)) - tree.Filter.SetPlaceholderStyle(tcell.StyleDefault.Background(tview.Styles.PrimitiveBackgroundColor).Foreground(tview.Styles.InverseTextColor)) + tree.Filter.SetFieldStyle(tcell.StyleDefault.Background(app.Styles.PrimitiveBackgroundColor).Foreground(tview.Styles.PrimaryTextColor)) + tree.Filter.SetPlaceholderStyle(tcell.StyleDefault.Background(app.Styles.PrimitiveBackgroundColor).Foreground(tview.Styles.InverseTextColor)) tree.Filter.SetBorderPadding(0, 0, 0, 0) - tree.Filter.SetBorderColor(tview.Styles.PrimaryTextColor) + tree.Filter.SetBorderColor(app.Styles.PrimaryTextColor) tree.Filter.SetLabel("Search: ") - tree.Filter.SetLabelColor(tview.Styles.InverseTextColor) + tree.Filter.SetLabelColor(app.Styles.InverseTextColor) tree.Filter.SetFocusFunc(func() { - tree.Filter.SetLabelColor(tview.Styles.TertiaryTextColor) - tree.Filter.SetFieldTextColor(tview.Styles.PrimaryTextColor) + tree.Filter.SetLabelColor(app.Styles.TertiaryTextColor) + tree.Filter.SetFieldTextColor(app.Styles.PrimaryTextColor) }) tree.Filter.SetBlurFunc(func() { if tree.Filter.GetText() == "" { - tree.Filter.SetLabelColor(tview.Styles.InverseTextColor) + tree.Filter.SetLabelColor(app.Styles.InverseTextColor) } else { - tree.Filter.SetLabelColor(tview.Styles.TertiaryTextColor) + tree.Filter.SetLabelColor(app.Styles.TertiaryTextColor) } - tree.Filter.SetFieldTextColor(tview.Styles.InverseTextColor) + tree.Filter.SetFieldTextColor(app.Styles.InverseTextColor) }) - tree.FoundNodeCountInput.SetFieldStyle(tcell.StyleDefault.Background(tview.Styles.PrimitiveBackgroundColor).Foreground(tview.Styles.PrimaryTextColor)) + tree.FoundNodeCountInput.SetFieldStyle(tcell.StyleDefault.Background(app.Styles.PrimitiveBackgroundColor).Foreground(tview.Styles.PrimaryTextColor)) tree.Wrapper.SetDirection(tview.FlexRow) tree.Wrapper.SetBorder(true) tree.Wrapper.SetBorderPadding(0, 0, 1, 1) - tree.Wrapper.SetTitleColor(tview.Styles.PrimaryTextColor) + tree.Wrapper.SetTitleColor(app.Styles.PrimaryTextColor) tree.Wrapper.AddItem(tree.Filter, 1, 0, false) tree.Wrapper.AddItem(tree.FoundNodeCountInput, 1, 0, false) @@ -300,14 +300,14 @@ func (tree *Tree) databasesToNodes(children map[string][]string, node *tview.Tre rootNode = tview.NewTreeNode(key) rootNode.SetExpanded(false) rootNode.SetReference(key) - rootNode.SetColor(tview.Styles.PrimaryTextColor) + rootNode.SetColor(app.Styles.PrimaryTextColor) node.AddChild(rootNode) } for _, child := range values { childNode := tview.NewTreeNode(child) childNode.SetExpanded(defaultExpanded) - childNode.SetColor(tview.Styles.PrimaryTextColor) + childNode.SetColor(app.Styles.PrimaryTextColor) if tree.DBDriver.GetProvider() == "sqlite3" { childNode.SetReference(child) } else if tree.DBDriver.GetProvider() == "postgres" { @@ -388,7 +388,7 @@ func (tree *Tree) GetIsFiltering() bool { func (tree *Tree) SetSelectedDatabase(database string) { tree.state.selectedDatabase = database tree.Publish(models.StateChange{ - Key: "SelectedDatabase", + Key: eventTreeSelectedDatabase, Value: database, }) } @@ -396,7 +396,7 @@ func (tree *Tree) SetSelectedDatabase(database string) { func (tree *Tree) SetSelectedTable(table string) { tree.state.selectedTable = table tree.Publish(models.StateChange{ - Key: "SelectedTable", + Key: eventTreeSelectedTable, Value: table, }) } @@ -404,17 +404,17 @@ func (tree *Tree) SetSelectedTable(table string) { func (tree *Tree) SetIsFiltering(isFiltering bool) { tree.state.isFiltering = isFiltering tree.Publish(models.StateChange{ - Key: "IsFiltering", + Key: eventTreeIsFiltering, Value: isFiltering, }) } // Blur func func (tree *Tree) RemoveHighlight() { - tree.SetBorderColor(tview.Styles.InverseTextColor) - tree.SetGraphicsColor(tview.Styles.InverseTextColor) - tree.SetTitleColor(tview.Styles.InverseTextColor) - // tree.GetRoot().SetColor(tview.Styles.InverseTextColor) + tree.SetBorderColor(app.Styles.InverseTextColor) + tree.SetGraphicsColor(app.Styles.InverseTextColor) + tree.SetTitleColor(app.Styles.InverseTextColor) + // tree.GetRoot().SetColor(app.Styles.InverseTextColor) childrens := tree.GetRoot().GetChildren() @@ -423,8 +423,8 @@ func (tree *Tree) RemoveHighlight() { childrenIsCurrentNode := children.GetReference() == tree.GetCurrentNode().GetReference() - if !childrenIsCurrentNode && currentColor == tview.Styles.PrimaryTextColor { - children.SetColor(tview.Styles.InverseTextColor) + if !childrenIsCurrentNode && currentColor == app.Styles.PrimaryTextColor { + children.SetColor(app.Styles.InverseTextColor) } childrenOfChildren := children.GetChildren() @@ -434,8 +434,8 @@ func (tree *Tree) RemoveHighlight() { childrenIsCurrentNode := children.GetReference() == tree.GetCurrentNode().GetReference() - if !childrenIsCurrentNode && currentColor == tview.Styles.PrimaryTextColor { - children.SetColor(tview.Styles.InverseTextColor) + if !childrenIsCurrentNode && currentColor == app.Styles.PrimaryTextColor { + children.SetColor(app.Styles.InverseTextColor) } } @@ -444,21 +444,21 @@ func (tree *Tree) RemoveHighlight() { } func (tree *Tree) ForceRemoveHighlight() { - tree.SetBorderColor(tview.Styles.InverseTextColor) - tree.SetGraphicsColor(tview.Styles.InverseTextColor) - tree.SetTitleColor(tview.Styles.InverseTextColor) - tree.GetRoot().SetColor(tview.Styles.InverseTextColor) + tree.SetBorderColor(app.Styles.InverseTextColor) + tree.SetGraphicsColor(app.Styles.InverseTextColor) + tree.SetTitleColor(app.Styles.InverseTextColor) + tree.GetRoot().SetColor(app.Styles.InverseTextColor) childrens := tree.GetRoot().GetChildren() for _, children := range childrens { - children.SetColor(tview.Styles.InverseTextColor) + children.SetColor(app.Styles.InverseTextColor) childrenOfChildren := children.GetChildren() for _, children := range childrenOfChildren { - children.SetColor(tview.Styles.InverseTextColor) + children.SetColor(app.Styles.InverseTextColor) } } @@ -466,26 +466,26 @@ func (tree *Tree) ForceRemoveHighlight() { // Focus func func (tree *Tree) Highlight() { - tree.SetBorderColor(tview.Styles.PrimaryTextColor) - tree.SetGraphicsColor(tview.Styles.PrimaryTextColor) - tree.SetTitleColor(tview.Styles.PrimaryTextColor) - tree.GetRoot().SetColor(tview.Styles.PrimaryTextColor) + tree.SetBorderColor(app.Styles.PrimaryTextColor) + tree.SetGraphicsColor(app.Styles.PrimaryTextColor) + tree.SetTitleColor(app.Styles.PrimaryTextColor) + tree.GetRoot().SetColor(app.Styles.PrimaryTextColor) childrens := tree.GetRoot().GetChildren() for _, children := range childrens { currentColor := children.GetColor() - if currentColor == tview.Styles.InverseTextColor { - children.SetColor(tview.Styles.PrimaryTextColor) + if currentColor == app.Styles.InverseTextColor { + children.SetColor(app.Styles.PrimaryTextColor) childrenOfChildren := children.GetChildren() for _, children := range childrenOfChildren { currentColor := children.GetColor() - if currentColor == tview.Styles.InverseTextColor { - children.SetColor(tview.Styles.PrimaryTextColor) + if currentColor == app.Styles.InverseTextColor { + children.SetColor(app.Styles.PrimaryTextColor) } } diff --git a/components/constants.go b/components/constants.go index 6422737..2bd13b5 100644 --- a/components/constants.go +++ b/components/constants.go @@ -1,10 +1,79 @@ package components -import "github.com/jorgerojas26/lazysql/app" +import ( + "github.com/gdamore/tcell/v2" + + "github.com/jorgerojas26/lazysql/app" +) var App = app.App +// Pages +const ( + // General + pageNameHelp string = "Help" + pageNameConfirmation string = "Confirmation" + pageNameConnections string = "Connections" + + // Results table + pageNameTable string = "Table" + pageNameTableError string = "TableError" + pageNameTableLoading string = "TableLoading" + pageNameTableEditorTable string = "TableEditorTable" + pageNameTableEditorResultsInfo string = "TableEditorResultsInfo" + pageNameTableEditCell string = "TableEditCell" + + // Sidebar + pageNameSidebar string = "Sidebar" + + // Connections + pageNameConnectionSelection string = "ConnectionSelection" + pageNameConnectionForm string = "ConnectionForm" +) + +// Tabs +const ( + tabNameEditor string = "Editor" +) + +// Events +const ( + eventSidebarEditing string = "EditingSidebar" + eventSidebarUnfocusing string = "UnfocusingSidebar" + eventSidebarToggling string = "TogglingSidebar" + eventSidebarCommitEditing string = "CommitEditingSidebar" + + eventSQLEditorQuery string = "Query" + eventSQLEditorEscape string = "Escape" + + eventResultsTableFiltering string = "FilteringResultsTable" + + eventTreeSelectedDatabase string = "SelectedDatabase" + eventTreeSelectedTable string = "SelectedTable" + eventTreeIsFiltering string = "IsFiltering" +) + +// Results table menu items +const ( + menuRecords string = "Records" + menuColumns string = "Columns" + menuConstraints string = "Constraints" + menuForeignKeys string = "Foreign Keys" + menuIndexes string = "Indexes" +) + +// Actions +const ( + actionNewConnection string = "NewConnection" + actionEditConnection string = "EditConnection" +) + +// Misc (until i find a better name) const ( - EditorTabName string = "Editor" - HelpPageName string = "Help" + focusedWrapperLeft string = "left" + focusedWrapperRight string = "right" + + colorTableChange = tcell.ColorOrange + colorTableInsert = tcell.ColorDarkGreen + colorTableDelete = tcell.ColorRed ) diff --git a/drivers/constants.go b/drivers/constants.go index 5f46aaa..31d8d6f 100644 --- a/drivers/constants.go +++ b/drivers/constants.go @@ -3,3 +3,10 @@ package drivers const ( DefaultRowLimit = 300 ) + +// Drivers +const ( + DriverMySQL string = "mysql" + DriverPostgres string = "postgres" + DriverSqlite string = "sqlite3" +) diff --git a/drivers/mysql.go b/drivers/mysql.go index 26ada57..ebd301b 100644 --- a/drivers/mysql.go +++ b/drivers/mysql.go @@ -8,7 +8,6 @@ import ( "github.com/xo/dburl" - "github.com/jorgerojas26/lazysql/helpers/logger" "github.com/jorgerojas26/lazysql/models" ) @@ -22,7 +21,7 @@ func (db *MySQL) TestConnection(urlstr string) (err error) { } func (db *MySQL) Connect(urlstr string) (err error) { - db.SetProvider("mysql") + db.SetProvider(DriverMySQL) db.Connection, err = dburl.Open(urlstr) if err != nil { @@ -44,12 +43,6 @@ func (db *MySQL) GetDatabases() ([]string, error) { if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() for rows.Next() { @@ -62,6 +55,9 @@ func (db *MySQL) GetDatabases() ([]string, error) { databases = append(databases, database) } } + if err := rows.Err(); err != nil { + return nil, err + } return databases, nil } @@ -72,20 +68,12 @@ func (db *MySQL) GetTables(database string) (map[string][]string, error) { } rows, err := db.Connection.Query(fmt.Sprintf("SHOW TABLES FROM `%s`", database)) - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - - defer rows.Close() - - tables := make(map[string][]string) - if err != nil { return nil, err } + defer rows.Close() + tables := make(map[string][]string) for rows.Next() { var table string err = rows.Scan(&table) @@ -95,6 +83,9 @@ func (db *MySQL) GetTables(database string) (map[string][]string, error) { tables[database] = append(tables[database], table) } + if err := rows.Err(); err != nil { + return nil, err + } return tables, nil } @@ -115,12 +106,6 @@ func (db *MySQL) GetTableColumns(database, table string) (results [][]string, er if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -149,6 +134,9 @@ func (db *MySQL) GetTableColumns(database, table string) (results [][]string, er results = append(results, row) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -168,12 +156,6 @@ func (db *MySQL) GetConstraints(database, table string) (results [][]string, err if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -201,6 +183,9 @@ func (db *MySQL) GetConstraints(database, table string) (results [][]string, err results = append(results, row) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -220,12 +205,6 @@ func (db *MySQL) GetForeignKeys(database, table string) (results [][]string, err if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -253,6 +232,9 @@ func (db *MySQL) GetForeignKeys(database, table string) (results [][]string, err results = append(results, row) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -273,12 +255,6 @@ func (db *MySQL) GetIndexes(database, table string) (results [][]string, err err if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -306,6 +282,9 @@ func (db *MySQL) GetIndexes(database, table string) (results [][]string, err err results = append(results, row) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -340,30 +319,8 @@ func (db *MySQL) GetRecords(database, table, where, sort string, offset, limit i if err != nil { return nil, 0, err } - - rowsErr := paginatedRows.Err() - - if rowsErr != nil { - return nil, 0, rowsErr - } - defer paginatedRows.Close() - countQuery := "SELECT COUNT(*) FROM " - countQuery += fmt.Sprintf("`%s`.", database) - countQuery += fmt.Sprintf("`%s`", table) - - rows := db.Connection.QueryRow(countQuery) - - if err != nil { - return nil, 0, err - } - - err = rows.Scan(&totalRecords) - if err != nil { - return nil, 0, err - } - columns, err := paginatedRows.Columns() if err != nil { return nil, 0, err @@ -398,7 +355,20 @@ func (db *MySQL) GetRecords(database, table, where, sort string, offset, limit i } paginatedResults = append(paginatedResults, row) - + } + if err := paginatedRows.Err(); err != nil { + return nil, 0, err + } + // close to release the connection + if err := paginatedRows.Close(); err != nil { + return nil, 0, err + } + countQuery := "SELECT COUNT(*) FROM " + countQuery += fmt.Sprintf("`%s`.", database) + countQuery += fmt.Sprintf("`%s`", table) + row := db.Connection.QueryRow(countQuery) + if err := row.Scan(&totalRecords); err != nil { + return nil, 0, err } return @@ -409,12 +379,6 @@ func (db *MySQL) ExecuteQuery(query string) (results [][]string, err error) { if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -441,7 +405,9 @@ func (db *MySQL) ExecuteQuery(query string) (results [][]string, err error) { } results = append(results, row) - + } + if err := rows.Err(); err != nil { + return nil, err } return @@ -481,7 +447,7 @@ func (db *MySQL) ExecuteDMLStatement(query string) (result string, err error) { } func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) { - var query []models.Query + var queries []models.Query for _, change := range changes { columnNames := []string{} @@ -521,7 +487,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) Args: values, } - query = append(query, newQuery) + queries = append(queries, newQuery) case models.DmlUpdateType: queryStr := "UPDATE " queryStr += db.formatTableName(change.Database, change.Table) @@ -546,7 +512,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) Args: args, } - query = append(query, newQuery) + queries = append(queries, newQuery) case models.DmlDeleteType: queryStr := "DELETE FROM " queryStr += db.formatTableName(change.Database, change.Table) @@ -557,29 +523,10 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error) Args: []interface{}{change.PrimaryKeyValue}, } - query = append(query, newQuery) - } - } - - trx, err := db.Connection.Begin() - if err != nil { - return err - } - - for _, query := range query { - logger.Info(query.Query, map[string]any{"args": query.Args}) - _, err := trx.Exec(query.Query, query.Args...) - if err != nil { - return err + queries = append(queries, newQuery) } } - - err = trx.Commit() - if err != nil { - return err - } - - return nil + return queriesInTransaction(db.Connection, queries) } func (db *MySQL) SetProvider(provider string) { diff --git a/drivers/postgres.go b/drivers/postgres.go index a636b62..b55bd76 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -31,7 +31,7 @@ func (db *Postgres) TestConnection(urlstr string) error { } func (db *Postgres) Connect(urlstr string) (err error) { - db.SetProvider("postgres") + db.SetProvider(DriverPostgres) db.Connection, err = dburl.Open(urlstr) if err != nil { @@ -67,16 +67,8 @@ func (db *Postgres) GetDatabases() (databases []string, err error) { if err != nil { return nil, err } - defer rows.Close() - rowsErr := rows.Err() - - if rowsErr != nil { - err = rowsErr - return nil, err - } - for rows.Next() { var database string err := rows.Scan(&database) @@ -85,6 +77,9 @@ func (db *Postgres) GetDatabases() (databases []string, err error) { } databases = append(databases, database) } + if err := rows.Err(); err != nil { + return nil, err + } return databases, nil } @@ -113,29 +108,23 @@ func (db *Postgres) GetTables(database string) (tables map[string][]string, err query := "SELECT table_name, table_schema FROM information_schema.tables WHERE table_catalog = $1" rows, err := db.Connection.Query(query, database) + if err != nil { + return nil, err + } + defer rows.Close() - if rows != nil { - rowsErr := rows.Err() - - if rowsErr != nil { - err = rowsErr - } - - defer rows.Close() - - for rows.Next() { - var tableName string - var tableSchema string - - err = rows.Scan(&tableName, &tableSchema) - - tables[tableSchema] = append(tables[tableSchema], tableName) - + for rows.Next() { + var ( + tableName string + tableSchema string + ) + if err := rows.Scan(&tableName, &tableSchema); err != nil { + return nil, err } + tables[tableSchema] = append(tables[tableSchema], tableName) } - - if err != nil { + if err := rows.Err(); err != nil { return nil, err } @@ -176,45 +165,38 @@ func (db *Postgres) GetTableColumns(database, table string) (results [][]string, query := "SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_catalog = $1 AND table_schema = $2 AND table_name = $3 ORDER by ordinal_position" rows, err := db.Connection.Query(query, database, tableSchema, tableName) + if err != nil { + return nil, err + } + defer rows.Close() - if rows != nil { - - rowsErr := rows.Err() - - if rowsErr != nil { - err = rowsErr - } + columns, columnsError := rows.Columns() + if columnsError != nil { + err = columnsError + } - defer rows.Close() + results = append(results, columns) - columns, columnsError := rows.Columns() + for rows.Next() { + rowValues := make([]interface{}, len(columns)) - if columnsError != nil { - err = columnsError + for i := range columns { + rowValues[i] = new(sql.RawBytes) } - results = append(results, columns) - - for rows.Next() { - rowValues := make([]interface{}, len(columns)) - - for i := range columns { - rowValues[i] = new(sql.RawBytes) - } - - err = rows.Scan(rowValues...) - - var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) - } + if err := rows.Scan(rowValues...); err != nil { + return nil, err + } - results = append(results, row) + var row []string + for _, col := range rowValues { + row = append(row, string(*col.(*sql.RawBytes))) } + results = append(results, row) } - if err != nil { + if err := rows.Err(); err != nil { return nil, err } @@ -268,43 +250,36 @@ func (db *Postgres) GetConstraints(database, table string) (constraints [][]stri AND tc.table_schema = '%s' AND tc.table_name = '%s' `, tableSchema, tableName)) + if err != nil { + return nil, err + } + defer rows.Close() - if rows != nil { + columns, columnsError := rows.Columns() + if columnsError != nil { + err = columnsError + } - rowsErr := rows.Err() + constraints = append(constraints, columns) - if rowsErr != nil { - err = rowsErr + for rows.Next() { + rowValues := make([]interface{}, len(columns)) + for i := range columns { + rowValues[i] = new(sql.RawBytes) } - defer rows.Close() - - columns, columnsError := rows.Columns() - - if columnsError != nil { - err = columnsError + if err := rows.Scan(rowValues...); err != nil { + return nil, err } - constraints = append(constraints, columns) - - for rows.Next() { - rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) - } - - err = rows.Scan(rowValues...) - - var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) - } - - constraints = append(constraints, row) + var row []string + for _, col := range rowValues { + row = append(row, string(*col.(*sql.RawBytes))) } - } - if err != nil { + constraints = append(constraints, row) + } + if err := rows.Err(); err != nil { return nil, err } @@ -359,43 +334,36 @@ func (db *Postgres) GetForeignKeys(database, table string) (foreignKeys [][]stri AND tc.table_schema = '%s' AND tc.table_name = '%s' `, tableSchema, tableName)) + if err != nil { + return nil, err + } + defer rows.Close() - if rows != nil { + columns, columnsError := rows.Columns() + if columnsError != nil { + err = columnsError + } - rowsErr := rows.Err() + foreignKeys = append(foreignKeys, columns) - if rowsErr != nil { - err = rowsErr + for rows.Next() { + rowValues := make([]interface{}, len(columns)) + for i := range columns { + rowValues[i] = new(sql.RawBytes) } - defer rows.Close() - - columns, columnsError := rows.Columns() - - if columnsError != nil { - err = columnsError + if err := rows.Scan(rowValues...); err != nil { + return nil, err } - foreignKeys = append(foreignKeys, columns) - - for rows.Next() { - rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) - } - - err = rows.Scan(rowValues...) - - var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) - } - - foreignKeys = append(foreignKeys, row) + var row []string + for _, col := range rowValues { + row = append(row, string(*col.(*sql.RawBytes))) } - } - if err != nil { + foreignKeys = append(foreignKeys, row) + } + if err := rows.Err(); err != nil { return nil, err } @@ -459,43 +427,36 @@ func (db *Postgres) GetIndexes(database, table string) (indexes [][]string, err t.relname, i.relname `, tableSchema, tableName)) + if err != nil { + return nil, err + } + defer rows.Close() - if rows != nil { + columns, columnsError := rows.Columns() + if columnsError != nil { + err = columnsError + } - rowsErr := rows.Err() + indexes = append(indexes, columns) - if rowsErr != nil { - err = rowsErr + for rows.Next() { + rowValues := make([]interface{}, len(columns)) + for i := range columns { + rowValues[i] = new(sql.RawBytes) } - defer rows.Close() - - columns, columnsError := rows.Columns() - - if columnsError != nil { - err = columnsError + if err := rows.Scan(rowValues...); err != nil { + return nil, err } - indexes = append(indexes, columns) - - for rows.Next() { - rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) - } - - err = rows.Scan(rowValues...) - - var row []string - for _, col := range rowValues { - row = append(row, string(*col.(*sql.RawBytes))) - } - - indexes = append(indexes, row) + var row []string + for _, col := range rowValues { + row = append(row, string(*col.(*sql.RawBytes))) } - } - if err != nil { + indexes = append(indexes, row) + } + if err := rows.Err(); err != nil { return nil, err } @@ -555,6 +516,10 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi query += " LIMIT $1 OFFSET $2" paginatedRows, err := db.Connection.Query(query, limit, offset) + if err != nil { + return nil, 0, err + } + defer paginatedRows.Close() if paginatedRows != nil { @@ -568,17 +533,11 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi countQuery := "SELECT COUNT(*) FROM " countQuery += formattedTableName - - rows := db.Connection.QueryRow(countQuery) - - rowsErr = rows.Err() - - if rowsErr != nil { - err = rowsErr + row := db.Connection.QueryRow(countQuery) + if err := row.Scan(&totalRecords); err != nil { + return nil, 0, err } - err = rows.Scan(&totalRecords) - columns, columnsError := paginatedRows.Columns() if columnsError != nil { @@ -595,7 +554,9 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi rowValues[i] = &nullStringSlice[i] } - err = paginatedRows.Scan(rowValues...) + if err := paginatedRows.Scan(rowValues...); err != nil { + return nil, 0, err + } var row []string for _, col := range nullStringSlice { @@ -615,7 +576,11 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi } } - if err != nil { + if err := paginatedRows.Err(); err != nil { + return nil, 0, err + } + // close to release the connection + if err := paginatedRows.Close(); err != nil { return nil, 0, err } @@ -750,15 +715,8 @@ func (db *Postgres) ExecuteQuery(query string) (results [][]string, err error) { if err != nil { return nil, err } - defer rows.Close() - rowsErr := rows.Err() - - if rowsErr != nil { - err = rowsErr - } - columns, err := rows.Columns() if err != nil { return nil, err @@ -783,14 +741,16 @@ func (db *Postgres) ExecuteQuery(query string) (results [][]string, err error) { } results = append(results, row) - + } + if err := rows.Err(); err != nil { + return nil, err } return } func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err error) { - var query []models.Query + var queries []models.Query for _, change := range changes { columnNames := []string{} @@ -840,7 +800,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err Args: values, } - query = append(query, newQuery) + queries = append(queries, newQuery) case models.DmlUpdateType: queryStr := "UPDATE " + formattedTableName @@ -872,7 +832,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err Args: args, } - query = append(query, newQuery) + queries = append(queries, newQuery) case models.DmlDeleteType: queryStr := "DELETE FROM " + formattedTableName queryStr += fmt.Sprintf(" WHERE %s = $1", change.PrimaryKeyColumnName) @@ -882,29 +842,10 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err Args: []interface{}{change.PrimaryKeyValue}, } - query = append(query, newQuery) - } - } - - trx, err := db.Connection.Begin() - if err != nil { - return err - } - - for _, query := range query { - logger.Info(query.Query, map[string]any{"args": query.Args}) - _, err := trx.Exec(query.Query, query.Args...) - if err != nil { - return err + queries = append(queries, newQuery) } } - - err = trx.Commit() - if err != nil { - return err - } - - return nil + return queriesInTransaction(db.Connection, queries) } func (db *Postgres) SetProvider(provider string) { diff --git a/drivers/sqlite.go b/drivers/sqlite.go index f2f47e2..8e368d9 100644 --- a/drivers/sqlite.go +++ b/drivers/sqlite.go @@ -9,7 +9,6 @@ import ( // import sqlite driver _ "modernc.org/sqlite" - "github.com/jorgerojas26/lazysql/helpers/logger" "github.com/jorgerojas26/lazysql/models" ) @@ -23,7 +22,7 @@ func (db *SQLite) TestConnection(urlstr string) (err error) { } func (db *SQLite) Connect(urlstr string) (err error) { - db.SetProvider("sqlite3") + db.SetProvider(DriverSqlite) db.Connection, err = sql.Open("sqlite", urlstr) if err != nil { @@ -45,12 +44,6 @@ func (db *SQLite) GetDatabases() ([]string, error) { if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() for rows.Next() { @@ -65,6 +58,9 @@ func (db *SQLite) GetDatabases() ([]string, error) { databases = append(databases, dbName) } + if err := rows.Err(); err != nil { + return nil, err + } return databases, nil } @@ -78,12 +74,6 @@ func (db *SQLite) GetTables(database string) (map[string][]string, error) { if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() tables := make(map[string][]string) @@ -97,6 +87,9 @@ func (db *SQLite) GetTables(database string) (map[string][]string, error) { tables[database] = append(tables[database], table) } + if err := rows.Err(); err != nil { + return nil, err + } return tables, nil } @@ -110,12 +103,6 @@ func (db *SQLite) GetTableColumns(_, table string) (results [][]string, err erro if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -148,6 +135,9 @@ func (db *SQLite) GetTableColumns(_, table string) (results [][]string, err erro results = append(results, row[1:]) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -164,12 +154,6 @@ func (db *SQLite) GetConstraints(_, table string) (results [][]string, err error if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -201,6 +185,9 @@ func (db *SQLite) GetConstraints(_, table string) (results [][]string, err error results = append(results, row) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -214,12 +201,6 @@ func (db *SQLite) GetForeignKeys(_, table string) (results [][]string, err error if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -251,6 +232,9 @@ func (db *SQLite) GetForeignKeys(_, table string) (results [][]string, err error results = append(results, row) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -264,12 +248,6 @@ func (db *SQLite) GetIndexes(_, table string) (results [][]string, err error) { if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -301,6 +279,9 @@ func (db *SQLite) GetIndexes(_, table string) (results [][]string, err error) { results = append(results, row) } + if err := rows.Err(); err != nil { + return nil, err + } return } @@ -331,29 +312,8 @@ func (db *SQLite) GetRecords(_, table, where, sort string, offset, limit int) (p if err != nil { return nil, 0, err } - - rowsErr := paginatedRows.Err() - - if rowsErr != nil { - return nil, 0, rowsErr - } - defer paginatedRows.Close() - countQuery := "SELECT COUNT(*) FROM " - countQuery += db.formatTableName(table) - - rows := db.Connection.QueryRow(countQuery) - - if err != nil { - return nil, 0, err - } - - err = rows.Scan(&totalRecords) - if err != nil { - return nil, 0, err - } - columns, err := paginatedRows.Columns() if err != nil { return nil, 0, err @@ -389,7 +349,20 @@ func (db *SQLite) GetRecords(_, table, where, sort string, offset, limit int) (p } paginatedResults = append(paginatedResults, row) + } + if err := paginatedRows.Err(); err != nil { + return nil, 0, err + } + // close to release the connection + if err := paginatedRows.Close(); err != nil { + return nil, 0, err + } + countQuery := "SELECT COUNT(*) FROM " + countQuery += db.formatTableName(table) + row := db.Connection.QueryRow(countQuery) + if err := row.Scan(&totalRecords); err != nil { + return nil, 0, err } return @@ -400,12 +373,6 @@ func (db *SQLite) ExecuteQuery(query string) (results [][]string, err error) { if err != nil { return nil, err } - - rowsErr := rows.Err() - if rowsErr != nil { - return nil, rowsErr - } - defer rows.Close() columns, err := rows.Columns() @@ -436,7 +403,9 @@ func (db *SQLite) ExecuteQuery(query string) (results [][]string, err error) { } results = append(results, row) - + } + if err := rows.Err(); err != nil { + return nil, err } return @@ -509,7 +478,7 @@ func (db *SQLite) ExecuteDMLStatement(query string) (result string, err error) { } func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error) { - var query []models.Query + var queries []models.Query for _, change := range changes { columnNames := []string{} @@ -547,7 +516,7 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error Args: values, } - query = append(query, newQuery) + queries = append(queries, newQuery) case models.DmlUpdateType: for _, cell := range change.Values { @@ -594,7 +563,7 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error Args: args, } - query = append(query, newQuery) + queries = append(queries, newQuery) case models.DmlDeleteType: queryStr := "DELETE FROM " queryStr += db.formatTableName(change.Table) @@ -605,29 +574,10 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error Args: []interface{}{change.PrimaryKeyValue}, } - query = append(query, newQuery) - } - } - - trx, err := db.Connection.Begin() - if err != nil { - return err - } - - for _, query := range query { - logger.Info(query.Query, map[string]any{"args": query.Args}) - _, err := trx.Exec(query.Query, query.Args...) - if err != nil { - return err + queries = append(queries, newQuery) } } - - err = trx.Commit() - if err != nil { - return err - } - - return nil + return queriesInTransaction(db.Connection, queries) } func (db *SQLite) SetProvider(provider string) { diff --git a/drivers/utils.go b/drivers/utils.go new file mode 100644 index 0000000..d880917 --- /dev/null +++ b/drivers/utils.go @@ -0,0 +1,34 @@ +package drivers + +import ( + "database/sql" + "errors" + + "github.com/jorgerojas26/lazysql/helpers/logger" + "github.com/jorgerojas26/lazysql/models" +) + +func queriesInTransaction(db *sql.DB, queries []models.Query) (err error) { + trx, err := db.Begin() + if err != nil { + return err + } + defer func() { + rErr := trx.Rollback() + // sql.ErrTxDone is returned when trx.Commit was already called + if !errors.Is(rErr, sql.ErrTxDone) { + err = errors.Join(err, rErr) + } + }() + + for _, query := range queries { + logger.Info(query.Query, map[string]any{"args": query.Args}) + if _, err := trx.Exec(query.Query, query.Args...); err != nil { + return err + } + } + if err := trx.Commit(); err != nil { + return err + } + return nil +} diff --git a/drivers/utils_test.go b/drivers/utils_test.go new file mode 100644 index 0000000..847f486 --- /dev/null +++ b/drivers/utils_test.go @@ -0,0 +1,121 @@ +package drivers + +import ( + "errors" + "strings" + "testing" + + gomock "github.com/DATA-DOG/go-sqlmock" + + "github.com/jorgerojas26/lazysql/models" +) + +func Test_queriesInTransaction(t *testing.T) { + tests := []struct { + setMockExpectations func(mock gomock.Sqlmock) + assertErr func(t *testing.T, err error) + name string + queries []models.Query + }{ + { + name: "successful transaction", + queries: []models.Query{ + {Query: "SELECT * FROM table"}, + }, + setMockExpectations: func(mock gomock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectExec("SELECT \\* FROM table").WillReturnResult(gomock.NewResult(0, 0)) + mock.ExpectCommit() + }, + }, + { + name: "unsuccessful commit", + queries: []models.Query{ + {Query: "SELECT * FROM table"}, + }, + setMockExpectations: func(mock gomock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectExec("SELECT \\* FROM table").WillReturnResult(gomock.NewResult(0, 0)) + mock.ExpectCommit().WillReturnError(errors.New("commit error")) + }, + assertErr: func(t *testing.T, err error) { + t.Helper() + if !strings.Contains(err.Error(), "commit error") { + t.Errorf("expected error to contain 'commit error', got %v", err) + } + }, + }, + { + name: "failed query", + queries: []models.Query{ + {Query: "SELECT * FROM table"}, + }, + setMockExpectations: func(mock gomock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectExec("SELECT \\* FROM table").WillReturnError(errors.New("query error")) + mock.ExpectRollback() + }, + assertErr: func(t *testing.T, err error) { + t.Helper() + if !strings.Contains(err.Error(), "query error") { + t.Errorf("expected error to contain 'commit error', got %v", err) + } + }, + }, + { + name: "failed 2nd query of three", + queries: []models.Query{ + {Query: "SELECT * FROM table"}, + {Query: "SELECT * FROM table"}, + {Query: "SELECT * FROM table"}, + }, + setMockExpectations: func(mock gomock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectExec("SELECT \\* FROM table").WillReturnResult(gomock.NewResult(0, 0)) + mock.ExpectExec("SELECT \\* FROM table").WillReturnError(errors.New("query error")) + mock.ExpectRollback() + }, + assertErr: func(t *testing.T, err error) { + t.Helper() + if !strings.Contains(err.Error(), "query error") { + t.Errorf("expected error to contain 'commit error', got %v", err) + } + }, + }, + { + name: "failed query and rollback", + queries: []models.Query{ + {Query: "SELECT * FROM table"}, + }, + setMockExpectations: func(mock gomock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectExec("SELECT \\* FROM table").WillReturnError(errors.New("query error")) + mock.ExpectRollback().WillReturnError(errors.New("rollback error")) + }, + assertErr: func(t *testing.T, err error) { + t.Helper() + errMsg := err.Error() + if !strings.Contains(errMsg, "query error") { + t.Errorf("expected error to contain 'commit error', got %v", err) + } + if !strings.Contains(errMsg, "rollback error") { + t.Errorf("expected error to contain 'rollback error', got %v", err) + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, mock, err := gomock.New() + if err != nil { + t.Fatal(err) + } + defer db.Close() + tt.setMockExpectations(mock) + queryErr := queriesInTransaction(db, tt.queries) + if tt.assertErr != nil { + tt.assertErr(t, queryErr) + } + }) + } +} diff --git a/go.mod b/go.mod index 41c4dc5..140d36a 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,12 @@ module github.com/jorgerojas26/lazysql go 1.20 require ( + github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/atotto/clipboard v0.1.4 github.com/gdamore/tcell/v2 v2.7.0 github.com/go-sql-driver/mysql v1.7.1 github.com/google/uuid v1.6.0 + github.com/lib/pq v1.10.9 github.com/pelletier/go-toml/v2 v2.1.1 github.com/rivo/tview v0.0.0-20240101144852-b3bd1aa5e9f2 github.com/xo/dburl v0.20.2 @@ -15,10 +17,17 @@ require ( require ( github.com/dustin/go-humanize v1.0.1 // indirect + github.com/gdamore/encoding v1.0.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect + github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.15 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/rivo/uniseg v0.4.3 // indirect + golang.org/x/sys v0.22.0 // indirect + golang.org/x/term v0.15.0 // indirect + golang.org/x/text v0.14.0 // indirect modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect modernc.org/libc v1.55.3 // indirect modernc.org/mathutil v1.6.0 // indirect @@ -26,14 +35,3 @@ require ( modernc.org/strutil v1.2.0 // indirect modernc.org/token v1.1.0 // indirect ) - -require ( - github.com/gdamore/encoding v1.0.0 // indirect - github.com/lib/pq v1.10.9 - github.com/lucasb-eyer/go-colorful v1.2.0 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect - github.com/rivo/uniseg v0.4.3 // indirect - golang.org/x/sys v0.22.0 // indirect - golang.org/x/term v0.15.0 // indirect - golang.org/x/text v0.14.0 // indirect -) diff --git a/go.sum b/go.sum index 03a9305..17c9db8 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -16,6 +18,7 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= diff --git a/models/models.go b/models/models.go index e8eb5ea..e0af427 100644 --- a/models/models.go +++ b/models/models.go @@ -37,9 +37,11 @@ const ( ) type CellValue struct { - Type CellValueType - Column string - Value interface{} + Value interface{} + Column string + TableColumnIndex int + TableRowIndex int + Type CellValueType } const ( @@ -49,12 +51,12 @@ const ( ) type DbDmlChange struct { - Type DmlType Database string Table string - Values []CellValue PrimaryKeyColumnName string PrimaryKeyValue string + Values []CellValue + Type DmlType } type DatabaseTableColumn struct { @@ -70,3 +72,8 @@ type Query struct { Query string Args []interface{} } + +type SidebarEditingCommitParams struct { + ColumnName string + NewValue string +} From cabef158ec0d523d729ceac7fbe0e8d260197ba6 Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Sat, 12 Oct 2024 00:52:08 -0400 Subject: [PATCH 07/10] feat: add set value menu to the sidebar --- app/Keymap.go | 1 + components/ResultsTable.go | 4 ++-- components/Sidebar.go | 33 +++++++++++++++++++++++++++++++-- models/models.go | 1 + 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/app/Keymap.go b/app/Keymap.go index aa846df..493e956 100644 --- a/app/Keymap.go +++ b/app/Keymap.go @@ -131,6 +131,7 @@ var Keymaps = KeymapSystem{ Bind{Key: Key{Char: 'c'}, Cmd: cmd.Edit, Description: "Edit field"}, Bind{Key: Key{Code: tcell.KeyEnter}, Cmd: cmd.CommitEdit, Description: "Add edit to pending changes"}, Bind{Key: Key{Code: tcell.KeyEscape}, Cmd: cmd.DiscardEdit, Description: "Discard edit"}, + Bind{Key: Key{Char: 'C'}, Cmd: cmd.SetValue, Description: "Toggle value menu to put values like NULL, EMPTY or DEFAULT"}, }, }, } diff --git a/components/ResultsTable.go b/components/ResultsTable.go index 7651d2a..e038d91 100644 --- a/components/ResultsTable.go +++ b/components/ResultsTable.go @@ -88,7 +88,7 @@ func NewResultsTable(listOfDbChanges *[]models.DbDmlChange, tree *Tree, dbdriver pagination := NewPagination() - sidebar := NewSidebar() + sidebar := NewSidebar(dbdriver.GetProvider()) table := &ResultsTable{ Table: tview.NewTable(), @@ -221,7 +221,7 @@ func (table *ResultsTable) subscribeToSidebarChanges() { tableCell.SetText(params.NewValue) cellValue := models.CellValue{ - Type: models.String, + Type: params.Type, Column: params.ColumnName, Value: params.NewValue, TableColumnIndex: changedColumnIndex, diff --git a/components/Sidebar.go b/components/Sidebar.go index 40211df..062f912 100644 --- a/components/Sidebar.go +++ b/components/Sidebar.go @@ -12,6 +12,7 @@ import ( ) type SidebarState struct { + dbProvider string currentFieldIndex int } @@ -28,7 +29,7 @@ type Sidebar struct { subscribers []chan models.StateChange } -func NewSidebar() *Sidebar { +func NewSidebar(dbProvider string) *Sidebar { flex := tview.NewFlex().SetDirection(tview.FlexColumnCSS) frame := tview.NewFrame(flex) frame.SetBackgroundColor(app.Styles.PrimitiveBackgroundColor) @@ -37,6 +38,7 @@ func NewSidebar() *Sidebar { sidebarState := &SidebarState{ currentFieldIndex: 0, + dbProvider: dbProvider, } newSidebar := &Sidebar{ @@ -242,7 +244,7 @@ func (sidebar *Sidebar) inputCapture(event *tcell.EventKey) *tcell.EventKey { sidebar.SetDisabledStyles(item) } else { sidebar.SetEditedStyles(item) - sidebar.Publish(models.StateChange{Key: eventSidebarCommitEditing, Value: models.SidebarEditingCommitParams{ColumnName: columnName, NewValue: newText}}) + sidebar.Publish(models.StateChange{Key: eventSidebarCommitEditing, Value: models.SidebarEditingCommitParams{ColumnName: columnName, Type: models.String, NewValue: newText}}) } return nil @@ -259,6 +261,33 @@ func (sidebar *Sidebar) inputCapture(event *tcell.EventKey) *tcell.EventKey { sidebar.EditTextCurrentField() + return nil + case commands.SetValue: + currentItemIndex := sidebar.GetCurrentFieldIndex() + item := sidebar.Flex.GetItem(currentItemIndex).(*tview.TextArea) + x, y, _, _ := item.GetRect() + + columnName := item.GetTitle() + columnNameSplit := strings.Split(columnName, "[") + columnName = columnNameSplit[0] + + list := NewSetValueList(sidebar.state.dbProvider) + + sidebar.Publish(models.StateChange{Key: eventSidebarEditing, Value: true}) + + list.OnFinish(func(selection models.CellValueType, value string) { + sidebar.Publish(models.StateChange{Key: eventSidebarEditing, Value: false}) + App.SetFocus(item) + + if selection >= 0 { + sidebar.SetEditedStyles(item) + item.SetText(value, true) + sidebar.Publish(models.StateChange{Key: eventSidebarCommitEditing, Value: models.SidebarEditingCommitParams{ColumnName: columnName, Type: selection, NewValue: value}}) + } + }) + + list.Show(x, y, 30) + return nil } return event diff --git a/models/models.go b/models/models.go index e0af427..94d8210 100644 --- a/models/models.go +++ b/models/models.go @@ -76,4 +76,5 @@ type Query struct { type SidebarEditingCommitParams struct { ColumnName string NewValue string + Type CellValueType } From 17f3d843cf74c6080b5b1a5c38bd7170d7c70a88 Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Sat, 12 Oct 2024 01:04:02 -0400 Subject: [PATCH 08/10] chore: fix lint --- components/ResultsTable.go | 1 - 1 file changed, 1 deletion(-) diff --git a/components/ResultsTable.go b/components/ResultsTable.go index e91a4e1..08b3fea 100644 --- a/components/ResultsTable.go +++ b/components/ResultsTable.go @@ -284,7 +284,6 @@ func (table *ResultsTable) AddInsertedRows() { tableCell.SetExpansion(1) tableCell.SetReference(inserts[i].PrimaryKeyValue) - tableCell.SetTextColor(app.Styles.PrimaryTextColor) tableCell.SetBackgroundColor(colorTableInsert) From 88a97da094b16c0cb40d603bf8d68d5f49688553 Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Sat, 12 Oct 2024 01:10:10 -0400 Subject: [PATCH 09/10] chore: removes some magic strings --- components/SetValueList.go | 7 ++++--- components/constants.go | 3 +++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/components/SetValueList.go b/components/SetValueList.go index ecf7e7d..ab24bca 100644 --- a/components/SetValueList.go +++ b/components/SetValueList.go @@ -6,6 +6,7 @@ import ( "github.com/jorgerojas26/lazysql/app" "github.com/jorgerojas26/lazysql/commands" + "github.com/jorgerojas26/lazysql/drivers" "github.com/jorgerojas26/lazysql/models" ) @@ -24,7 +25,7 @@ func NewSetValueList(dbProvider string) *SetValueList { list := tview.NewList() list.SetBorder(true) - if dbProvider == "sqlite3" { + if dbProvider == drivers.DriverSqlite { VALUES = []value{ {value: "NULL", key: 'n'}, {value: "EMPTY", key: 'e'}, @@ -77,13 +78,13 @@ func (list *SetValueList) OnFinish(callback func(selection models.CellValueType, func (list *SetValueList) Show(x, y, width int) { list.SetRect(x, y, width, len(VALUES)*2+1) - MainPages.AddPage("setValue", list, false, true) + MainPages.AddPage(pageNameSetValue, list, false, true) App.SetFocus(list) App.ForceDraw() } func (list *SetValueList) Hide() { - MainPages.RemovePage("setValue") + MainPages.RemovePage(pageNameSetValue) App.SetFocus(list) App.ForceDraw() } diff --git a/components/constants.go b/components/constants.go index 2bd13b5..46625d6 100644 --- a/components/constants.go +++ b/components/constants.go @@ -29,6 +29,9 @@ const ( // Connections pageNameConnectionSelection string = "ConnectionSelection" pageNameConnectionForm string = "ConnectionForm" + + // SetValueList + pageNameSetValue string = "SetValue" ) // Tabs From 9bbda7b4313e66713e6a525289593349368d9ccf Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Sun, 13 Oct 2024 21:49:44 -0400 Subject: [PATCH 10/10] fix: postgres get records --- drivers/mysql.go | 1 + drivers/postgres.go | 59 +++++++++++++++++---------------------------- 2 files changed, 23 insertions(+), 37 deletions(-) diff --git a/drivers/mysql.go b/drivers/mysql.go index ebd301b..b71f91c 100644 --- a/drivers/mysql.go +++ b/drivers/mysql.go @@ -363,6 +363,7 @@ func (db *MySQL) GetRecords(database, table, where, sort string, offset, limit i if err := paginatedRows.Close(); err != nil { return nil, 0, err } + countQuery := "SELECT COUNT(*) FROM " countQuery += fmt.Sprintf("`%s`.", database) countQuery += fmt.Sprintf("`%s`", table) diff --git a/drivers/postgres.go b/drivers/postgres.go index c25b3c9..59a55a0 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -529,54 +529,32 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi records = append(records, columns) for paginatedRows.Next() { + nullStringSlice := make([]sql.NullString, len(columns)) + rowValues := make([]interface{}, len(columns)) - for i := range columns { - rowValues[i] = new(sql.RawBytes) + for i := range nullStringSlice { + rowValues[i] = &nullStringSlice[i] } - countQuery := "SELECT COUNT(*) FROM " - countQuery += formattedTableName - row := db.Connection.QueryRow(countQuery) - if err := row.Scan(&totalRecords); err != nil { + if err := paginatedRows.Scan(rowValues...); err != nil { return nil, 0, err } - columns, columnsError := paginatedRows.Columns() - - if columnsError != nil { - err = columnsError - } - - records = append(records, columns) - - for paginatedRows.Next() { - nullStringSlice := make([]sql.NullString, len(columns)) - - rowValues := make([]interface{}, len(columns)) - for i := range nullStringSlice { - rowValues[i] = &nullStringSlice[i] - } - - if err := paginatedRows.Scan(rowValues...); err != nil { - return nil, 0, err - } - - var row []string - for _, col := range nullStringSlice { - if col.Valid { - if col.String == "" { - row = append(row, "EMPTY&") - } else { - row = append(row, col.String) - } + var row []string + for _, col := range nullStringSlice { + if col.Valid { + if col.String == "" { + row = append(row, "EMPTY&") } else { - row = append(row, "NULL&") + row = append(row, col.String) } + } else { + row = append(row, "NULL&") } + } - records = append(records, row) + records = append(records, row) - } } if err := paginatedRows.Err(); err != nil { @@ -587,6 +565,13 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi return nil, 0, err } + countQuery := "SELECT COUNT(*) FROM " + countQuery += formattedTableName + row := db.Connection.QueryRow(countQuery) + if err := row.Scan(&totalRecords); err != nil { + return nil, 0, err + } + return }