Skip to content

Commit

Permalink
feat(indexer/postgres): add insert/update/delete functionality (#21186)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronc authored Sep 4, 2024
1 parent 4b78f15 commit 292d7b4
Show file tree
Hide file tree
Showing 12 changed files with 951 additions and 1 deletion.
61 changes: 61 additions & 0 deletions indexer/postgres/delete.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package postgres

import (
"context"
"fmt"
"io"
"strings"
)

// delete deletes the row with the provided key from the table.
func (tm *objectIndexer) delete(ctx context.Context, conn dbConn, key interface{}) error {
buf := new(strings.Builder)
var params []interface{}
var err error
if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions {
params, err = tm.retainDeleteSqlAndParams(buf, key)
} else {
params, err = tm.deleteSqlAndParams(buf, key)
}
if err != nil {
return err
}

sqlStr := buf.String()
tm.options.logger.Info("Delete", "sql", sqlStr, "params", params)
_, err = conn.ExecContext(ctx, sqlStr, params...)
return err
}

// deleteSqlAndParams generates a DELETE statement and binding parameters for the provided key.
func (tm *objectIndexer) deleteSqlAndParams(w io.Writer, key interface{}) ([]interface{}, error) {
_, err := fmt.Fprintf(w, "DELETE FROM %q", tm.tableName())
if err != nil {
return nil, err
}

_, keyParams, err := tm.whereSqlAndParams(w, key, 1)
if err != nil {
return nil, err
}

_, err = fmt.Fprintf(w, ";")
return keyParams, err
}

// retainDeleteSqlAndParams generates an UPDATE statement to set the _deleted column to true for the provided key
// which is used when the table is set to retain deletions mode.
func (tm *objectIndexer) retainDeleteSqlAndParams(w io.Writer, key interface{}) ([]interface{}, error) {
_, err := fmt.Fprintf(w, "UPDATE %q SET _deleted = TRUE", tm.tableName())
if err != nil {
return nil, err
}

_, keyParams, err := tm.whereSqlAndParams(w, key, 1)
if err != nil {
return nil, err
}

_, err = fmt.Fprintf(w, ";")
return keyParams, err
}
2 changes: 2 additions & 0 deletions indexer/postgres/indexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func StartIndexer(params indexer.InitParams) (indexer.InitResult, error) {
opts := options{
disableRetainDeletions: config.DisableRetainDeletions,
logger: params.Logger,
addressCodec: params.AddressCodec,
}

idx := &indexerImpl{
Expand All @@ -85,6 +86,7 @@ func StartIndexer(params indexer.InitParams) (indexer.InitResult, error) {

return indexer.InitResult{
Listener: idx.listener(),
View: idx,
}, nil
}

Expand Down
116 changes: 116 additions & 0 deletions indexer/postgres/insert_update.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package postgres

import (
"context"
"fmt"
"io"
"strings"
)

// insertUpdate inserts or updates the row with the provided key and value.
func (tm *objectIndexer) insertUpdate(ctx context.Context, conn dbConn, key, value interface{}) error {
exists, err := tm.exists(ctx, conn, key)
if err != nil {
return err
}

buf := new(strings.Builder)
var params []interface{}
if exists {
if len(tm.typ.ValueFields) == 0 {
// special case where there are no value fields, so we can't update anything
return nil
}

params, err = tm.updateSql(buf, key, value)
} else {
params, err = tm.insertSql(buf, key, value)
}
if err != nil {
return err
}

sqlStr := buf.String()
if tm.options.logger != nil {
tm.options.logger.Debug("Insert or Update", "sql", sqlStr, "params", params)
}
_, err = conn.ExecContext(ctx, sqlStr, params...)
return err
}

// insertSql generates an INSERT statement and binding parameters for the provided key and value.
func (tm *objectIndexer) insertSql(w io.Writer, key, value interface{}) ([]interface{}, error) {
keyParams, keyCols, err := tm.bindKeyParams(key)
if err != nil {
return nil, err
}

valueParams, valueCols, err := tm.bindValueParams(value)
if err != nil {
return nil, err
}

var allParams []interface{}
allParams = append(allParams, keyParams...)
allParams = append(allParams, valueParams...)

allCols := make([]string, 0, len(keyCols)+len(valueCols))
allCols = append(allCols, keyCols...)
allCols = append(allCols, valueCols...)

var paramBindings []string
for i := 1; i <= len(allCols); i++ {
paramBindings = append(paramBindings, fmt.Sprintf("$%d", i))
}

_, err = fmt.Fprintf(w, "INSERT INTO %q (%s) VALUES (%s);", tm.tableName(),
strings.Join(allCols, ", "),
strings.Join(paramBindings, ", "),
)
return allParams, err
}

// updateSql generates an UPDATE statement and binding parameters for the provided key and value.
func (tm *objectIndexer) updateSql(w io.Writer, key, value interface{}) ([]interface{}, error) {
_, err := fmt.Fprintf(w, "UPDATE %q SET ", tm.tableName())
if err != nil {
return nil, err
}

valueParams, valueCols, err := tm.bindValueParams(value)
if err != nil {
return nil, err
}

paramIdx := 1
for i, col := range valueCols {
if i > 0 {
_, err = fmt.Fprintf(w, ", ")
if err != nil {
return nil, err
}
}
_, err = fmt.Fprintf(w, "%s = $%d", col, paramIdx)
if err != nil {
return nil, err
}

paramIdx++
}

if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions {
_, err = fmt.Fprintf(w, ", _deleted = FALSE")
if err != nil {
return nil, err
}
}

_, keyParams, err := tm.whereSqlAndParams(w, key, paramIdx)
if err != nil {
return nil, err
}

allParams := append(valueParams, keyParams...)
_, err = fmt.Fprintf(w, ";")
return allParams, err
}
28 changes: 28 additions & 0 deletions indexer/postgres/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,34 @@ func (i *indexerImpl) listener() appdata.Listener {
_, err := i.tx.Exec("INSERT INTO block (number) VALUES ($1)", data.Height)
return err
},
OnObjectUpdate: func(data appdata.ObjectUpdateData) error {
module := data.ModuleName
mod, ok := i.modules[module]
if !ok {
return fmt.Errorf("module %s not initialized", module)
}

for _, update := range data.Updates {
if i.logger != nil {
i.logger.Debug("OnObjectUpdate", "module", module, "type", update.TypeName, "key", update.Key, "delete", update.Delete, "value", update.Value)
}
tm, ok := mod.tables[update.TypeName]
if !ok {
return fmt.Errorf("object type %s not found in schema for module %s", update.TypeName, module)
}

var err error
if update.Delete {
err = tm.delete(i.ctx, i.tx, update.Key)
} else {
err = tm.insertUpdate(i.ctx, i.tx, update.Key, update.Value)
}
if err != nil {
return err
}
}
return nil
},
Commit: func(data appdata.CommitData) (func() error, error) {
err := i.tx.Commit()
if err != nil {
Expand Down
8 changes: 7 additions & 1 deletion indexer/postgres/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package postgres

import "cosmossdk.io/schema/logutil"
import (
"cosmossdk.io/schema/addressutil"
"cosmossdk.io/schema/logutil"
)

// options are the options for module and object indexers.
type options struct {
Expand All @@ -9,4 +12,7 @@ type options struct {

// logger is the logger for the indexer to use. It may be nil.
logger logutil.Logger

// addressCodec is the codec for encoding and decoding addresses. It is expected to be non-nil.
addressCodec addressutil.AddressCodec
}
116 changes: 116 additions & 0 deletions indexer/postgres/params.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package postgres

import (
"fmt"
"time"

"cosmossdk.io/schema"
)

// bindKeyParams binds the key to the key columns.
func (tm *objectIndexer) bindKeyParams(key interface{}) ([]interface{}, []string, error) {
n := len(tm.typ.KeyFields)
if n == 0 {
// singleton, set _id = 1
return []interface{}{1}, []string{"_id"}, nil
} else if n == 1 {
return tm.bindParams(tm.typ.KeyFields, []interface{}{key})
} else {
key, ok := key.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("expected key to be a slice")
}

return tm.bindParams(tm.typ.KeyFields, key)
}
}

func (tm *objectIndexer) bindValueParams(value interface{}) (params []interface{}, valueCols []string, err error) {
n := len(tm.typ.ValueFields)
if n == 0 {
return nil, nil, nil
} else if valueUpdates, ok := value.(schema.ValueUpdates); ok {
var e error
var fields []schema.Field
var params []interface{}
if err := valueUpdates.Iterate(func(name string, value interface{}) bool {
field, ok := tm.valueFields[name]
if !ok {
e = fmt.Errorf("unknown column %q", name)
return false
}
fields = append(fields, field)
params = append(params, value)
return true
}); err != nil {
return nil, nil, err
}
if e != nil {
return nil, nil, e
}

return tm.bindParams(fields, params)
} else if n == 1 {
return tm.bindParams(tm.typ.ValueFields, []interface{}{value})
} else {
values, ok := value.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("expected values to be a slice")
}

return tm.bindParams(tm.typ.ValueFields, values)
}
}

func (tm *objectIndexer) bindParams(fields []schema.Field, values []interface{}) ([]interface{}, []string, error) {
names := make([]string, 0, len(fields))
params := make([]interface{}, 0, len(fields))
for i, field := range fields {
if i >= len(values) {
return nil, nil, fmt.Errorf("missing value for field %q", field.Name)
}

param, err := tm.bindParam(field, values[i])
if err != nil {
return nil, nil, err
}

name, err := tm.updatableColumnName(field)
if err != nil {
return nil, nil, err
}

names = append(names, name)
params = append(params, param)
}
return params, names, nil
}

func (tm *objectIndexer) bindParam(field schema.Field, value interface{}) (param interface{}, err error) {
param = value
if value == nil {
if !field.Nullable {
return nil, fmt.Errorf("expected non-null value for field %q", field.Name)
}
} else if field.Kind == schema.TimeKind {
t, ok := value.(time.Time)
if !ok {
return nil, fmt.Errorf("expected time.Time value for field %q, got %T", field.Name, value)
}

param = t.UnixNano()
} else if field.Kind == schema.DurationKind {
t, ok := value.(time.Duration)
if !ok {
return nil, fmt.Errorf("expected time.Duration value for field %q, got %T", field.Name, value)
}

param = int64(t)
} else if field.Kind == schema.AddressKind {
param, err = tm.options.addressCodec.BytesToString(value.([]byte))
if err != nil {
return nil, fmt.Errorf("address encoding failed for field %q: %w", field.Name, err)
}
}
return
}
Loading

0 comments on commit 292d7b4

Please sign in to comment.