Skip to content

Commit

Permalink
feat: support NULL, DEFAULT and EMPTY values
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgerojas26 committed Oct 12, 2024
1 parent 2088995 commit e46fa6b
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 52 deletions.
14 changes: 7 additions & 7 deletions components/ResultsTable.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func (table *ResultsTable) AddRows(rows [][]string) {
tableCell := tview.NewTableCell(cell)
tableCell.SetTextColor(tview.Styles.PrimaryTextColor)

if cell == "EMPTY&" || cell == "NULL&" {
if cell == "EMPTY&" || cell == "NULL&" || cell == "DEFAULT&" {
tableCell.SetText(strings.Replace(cell, "&", "", 1))
tableCell.SetStyle(table.GetItalicStyle())
tableCell.SetReference(cell)
Expand Down Expand Up @@ -248,16 +248,16 @@ func (table *ResultsTable) AppendNewRow(cells []models.CellValue, index int, UUI
tableCell.SetExpansion(1)
tableCell.SetReference(UUID)
tableCell.SetTextColor(tview.Styles.PrimaryTextColor)
tableCell.SetBackgroundColor(InsertColor)

switch cell.Type {
case models.Null:
case models.Default:
case models.String:
tableCell.SetText("")
case models.Null, models.Empty, models.Default:
tableCell.SetText(strings.Replace(cell.Value.(string), "&", "", 1))
tableCell.SetStyle(table.GetItalicStyle())
// tableCell.SetText("")
tableCell.SetTextColor(tview.Styles.InverseTextColor)
}

tableCell.SetBackgroundColor(InsertColor)
table.SetCell(index, i, tableCell)
}

Expand Down Expand Up @@ -380,7 +380,7 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event
cell := table.GetCell(selectedRowIndex, selectedColumnIndex)
x, y, _ := cell.GetLastPosition()

list := NewSetValueList()
list := NewSetValueList(table.DBDriver.GetProvider())

list.OnFinish(func(selection models.CellValueType, value string) {
table.FinishSettingValue()
Expand Down
31 changes: 21 additions & 10 deletions components/SetValueList.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package components

import (
"github.com/gdamore/tcell/v2"
"github.com/rivo/tview"

"github.com/jorgerojas26/lazysql/app"
"github.com/jorgerojas26/lazysql/commands"
"github.com/jorgerojas26/lazysql/models"
"github.com/rivo/tview"
)

type SetValueList struct {
Expand All @@ -17,16 +18,25 @@ type value struct {
key rune
}

var VALUES = []value{
{value: "NULL", key: 'n'},
{value: "EMPTY", key: 'e'},
{value: "DEFAULT", key: 'd'},
}
var VALUES = []value{}

func NewSetValueList() *SetValueList {
func NewSetValueList(dbProvider string) *SetValueList {
list := tview.NewList()
list.SetBorder(true)

if dbProvider == "sqlite3" {
VALUES = []value{
{value: "NULL", key: 'n'},
{value: "EMPTY", key: 'e'},
}
} else {
VALUES = []value{
{value: "NULL", key: 'n'},
{value: "EMPTY", key: 'e'},
{value: "DEFAULT", key: 'd'},
}
}

for _, value := range VALUES {
list.AddItem(value.value, "", value.key, nil)
}
Expand All @@ -42,11 +52,12 @@ func (list *SetValueList) OnFinish(callback func(selection models.CellValueType,

list.SetSelectedFunc(func(_ int, _ string, _ string, shortcut rune) {
list.Hide()
if shortcut == 'n' {
switch shortcut {
case 'n':
callback(models.Null, "NULL")
} else if shortcut == 'e' {
case 'e':
callback(models.Empty, "EMPTY")
} else if shortcut == 'd' {
case 'd':
callback(models.Default, "DEFAULT")
}
})
Expand Down
15 changes: 9 additions & 6 deletions drivers/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,9 +489,14 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error)
valuesPlaceholder := []string{}

for _, cell := range change.Values {
columnNames = append(columnNames, cell.Column)

switch cell.Type {
case models.Empty, models.Null, models.String:
columnNames = append(columnNames, cell.Column)
case models.Default:
valuesPlaceholder = append(valuesPlaceholder, "DEFAULT")
case models.Null:
valuesPlaceholder = append(valuesPlaceholder, "NULL")
default:
valuesPlaceholder = append(valuesPlaceholder, "?")
}
}
Expand All @@ -500,8 +505,6 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error)
switch cell.Type {
case models.Empty:
values = append(values, "")
case models.Null:
values = append(values, sql.NullString{})
case models.String:
values = append(values, cell.Value)
}
Expand All @@ -525,9 +528,9 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange) (err error)

for i, column := range columnNames {
if i == 0 {
queryStr += fmt.Sprintf(" SET `%s` = ?", column)
queryStr += fmt.Sprintf(" SET `%s` = %s", column, valuesPlaceholder[i])
} else {
queryStr += fmt.Sprintf(", `%s` = ?", column)
queryStr += fmt.Sprintf(", `%s` = %s", column, valuesPlaceholder[i])
}
}

Expand Down
26 changes: 18 additions & 8 deletions drivers/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,6 @@ func (db *Postgres) GetRecords(database, table, where, sort string, offset, limi
for paginatedRows.Next() {
nullStringSlice := make([]sql.NullString, len(columns))

// Create a slice of interface{} to hold pointers to the sql.NullString slice
rowValues := make([]interface{}, len(columns))
for i := range nullStringSlice {
rowValues[i] = &nullStringSlice[i]
Expand Down Expand Up @@ -800,9 +799,14 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err
placeholderIndex := 1

for _, cell := range change.Values {
columnNames = append(columnNames, cell.Column)

switch cell.Type {
case models.Empty, models.Null, models.String:
columnNames = append(columnNames, cell.Column)
case models.Default:
valuesPlaceholder = append(valuesPlaceholder, "DEFAULT")
case models.Null:
valuesPlaceholder = append(valuesPlaceholder, "NULL")
default:
valuesPlaceholder = append(valuesPlaceholder, fmt.Sprintf("$%d", placeholderIndex))
placeholderIndex++
}
Expand All @@ -812,8 +816,6 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err
switch cell.Type {
case models.Empty:
values = append(values, "")
case models.Null:
values = append(values, sql.NullString{})
case models.String:
values = append(values, cell.Value)
}
Expand Down Expand Up @@ -844,17 +846,25 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange) (err err

for i, column := range columnNames {
if i == 0 {
queryStr += fmt.Sprintf(" SET \"%s\" = $1", column)
queryStr += fmt.Sprintf(" SET \"%s\" = %s", column, valuesPlaceholder[i])
} else {
queryStr += fmt.Sprintf(", \"%s\" = $%d", column, i+1)
queryStr += fmt.Sprintf(", \"%s\" = %s", column, valuesPlaceholder[i])
}
}

args := make([]interface{}, len(values))

copy(args, values)

queryStr += fmt.Sprintf(" WHERE \"%s\" = $%d", change.PrimaryKeyColumnName, len(columnNames)+1)
wherePlaceholder := 0

for _, placeholder := range valuesPlaceholder {
if strings.Contains(placeholder, "$") {
wherePlaceholder++
}
}

queryStr += fmt.Sprintf(" WHERE \"%s\" = $%d", change.PrimaryKeyColumnName, wherePlaceholder+1)
args = append(args, change.PrimaryKeyValue)

newQuery := models.Query{
Expand Down
64 changes: 43 additions & 21 deletions drivers/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,28 +516,28 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error
values := []interface{}{}
valuesPlaceholder := []string{}

for _, cell := range change.Values {
switch cell.Type {
case models.Empty, models.Null, models.String:
columnNames = append(columnNames, cell.Column)
valuesPlaceholder = append(valuesPlaceholder, "?")
switch change.Type {
case models.DmlInsertType:
for _, cell := range change.Values {
switch cell.Type {
case models.Null:
columnNames = append(columnNames, cell.Column)
valuesPlaceholder = append(valuesPlaceholder, "NULL")
case models.Empty, models.String:
columnNames = append(columnNames, cell.Column)
valuesPlaceholder = append(valuesPlaceholder, "?")
}
}
}
logger.Info("Column names", map[string]any{"columnNames": columnNames})

for _, cell := range change.Values {
switch cell.Type {
case models.Empty:
values = append(values, "")
case models.Null:
values = append(values, sql.NullString{})
case models.String:
values = append(values, cell.Value)

for _, cell := range change.Values {
switch cell.Type {
case models.Empty:
values = append(values, "")
case models.String:
values = append(values, cell.Value)
}
}
}

switch change.Type {
case models.DmlInsertType:
queryStr := "INSERT INTO "
queryStr += db.formatTableName(change.Table)
queryStr += fmt.Sprintf(" (%s) VALUES (%s)", strings.Join(columnNames, ", "), strings.Join(valuesPlaceholder, ", "))
Expand All @@ -549,14 +549,36 @@ func (db *SQLite) ExecutePendingChanges(changes []models.DbDmlChange) (err error

query = append(query, newQuery)
case models.DmlUpdateType:

for _, cell := range change.Values {
switch cell.Type {
case models.Null:
columnNames = append(columnNames, cell.Column)
valuesPlaceholder = append(valuesPlaceholder, "NULL")
case models.Empty, models.String:
columnNames = append(columnNames, cell.Column)
valuesPlaceholder = append(valuesPlaceholder, "?")
/// Leaves "DEFAULT" type out because it's not supported by sqlite
}
}

for _, cell := range change.Values {
switch cell.Type {
case models.Empty:
values = append(values, "")
case models.String:
values = append(values, cell.Value)
}
}

queryStr := "UPDATE "
queryStr += db.formatTableName(change.Table)

for i, column := range columnNames {
if i == 0 {
queryStr += fmt.Sprintf(" SET `%s` = ?", column)
queryStr += fmt.Sprintf(" SET `%s` = %s", column, valuesPlaceholder[i])
} else {
queryStr += fmt.Sprintf(", `%s` = ?", column)
queryStr += fmt.Sprintf(", `%s` = %s", column, valuesPlaceholder[i])
}
}

Expand Down

0 comments on commit e46fa6b

Please sign in to comment.