Skip to content

Commit

Permalink
Merge pull request #12 from Paperchain/develop
Browse files Browse the repository at this point in the history
Adds new insert method with runtime Primary Key population
  • Loading branch information
shashankgroovy authored Sep 8, 2021
2 parents cacf2c4 + 57496f8 commit e68115e
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 94 deletions.
8 changes: 3 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
module github.com/Paperchain/papergres

go 1.12
go 1.16

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jmoiron/sqlx v1.3.4
github.com/lib/pq v1.10.2
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/testify v1.2.2
github.com/lib/pq v1.10.3
github.com/stretchr/testify v1.7.0
)
15 changes: 11 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/jmoiron/sqlx v1.3.4 h1:wv+0IJZfL5z0uZoUjlpKgHkgaFSYD+r9CfrXjEXsO7w=
github.com/jmoiron/sqlx v1.3.4/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8=
github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.10.3 h1:v9QZf2Sn6AmjXtQeFpdoq/eaNtYP6IN+7lcrygsIAtg=
github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
165 changes: 89 additions & 76 deletions papergres.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,20 +407,31 @@ func (db *Database) GenerateInsert(obj interface{}) *Query {
return db.Schema("public").GenerateInsert(obj)
}

// Insert inserts the passed in object
// Insert inserts the passed in object in DB.
//
// NOTE: It does not account for client side generated value for a
// primary key and expects that the logic for populating value
// of primary key should reside in the database as sequence.
//
// DO NOT use Insert() if you wish to populate a client side generated
// value in primary key, use InsertWithPK() instead.
func (s *Schema) Insert(obj interface{}) *Result {
sql := insertSQL(obj, s.Name)
args := insertArgs(obj)
sql := insertSQL(obj, s.Name, false)
args := insertArgs(obj, false)
return s.Database.Query(sql, args...).Exec()
}

// GenerateInsert generates the insert query for the given object
func (s *Schema) GenerateInsert(obj interface{}) *Query {
sql := insertSQL(obj, s.Name)
args := insertArgs(obj)
q := s.Database.Query(sql, args...)
q.insert = true
return q
// InsertWithPK performs inserts on objects and persists the Primary key value
// to DB as well. It will fail to insert duplicate values to DB.
//
// NOTE: Proceed with caution!
// Only use this method you wish to persist a client side generated value in
// Primary key and don't rely on database sequenece to autogenerate
// PrimaryKey values.
func (s *Schema) InsertWithPK(obj interface{}) *Result {
sql := insertSQL(obj, s.Name, true)
args := insertArgs(obj, true)
return s.Database.Query(sql, args...).ExecNonQuery()
}

// InsertAll inserts a slice of objects concurrently.
Expand All @@ -436,89 +447,90 @@ func (s *Schema) InsertAll(objs interface{}) ([]*Result, error) {
return nil, err
}
if len(slice) == 0 {
return nil, errors.New("Empty slice")
return nil, errors.New("empty slice")
}

// sql, args := insertMultipleSQL(objs, s.Name)
// return s.Database.Query(sql, args...).ExecNonQuery(), nil

// now turn the objs into a repeat query and exec
return s.GenerateInsert(slice[0]).Repeat(len(slice),
func(i int) (dest interface{}, args []interface{}) {
args = insertArgs(slice[i])
args = insertArgs(slice[i], false)
return
}).Exec()
}

// GenerateInsert generates the insert query for the given object
func (s *Schema) GenerateInsert(obj interface{}) *Query {
return s.generateInsertQuery(obj, false)
}

// GenerateInsertWithPK generates the insert query for the given object in which
// PrimaryKey value is also supposed to be populated during insert.
func (s *Schema) GenerateInsertWithPK(obj interface{}) *Query {
return s.generateInsertQuery(obj, true)
}

// generateInsertQuery constructs an insert query for the given object
func (s *Schema) generateInsertQuery(obj interface{}, withPK bool) *Query {
sql := insertSQL(obj, s.Name, withPK)
args := insertArgs(obj, withPK)
q := s.Database.Query(sql, args...)
q.insert = true
return q
}

// insertSQL generates insert SQL string for a given object and schema
func insertSQL(obj interface{}, schema string) string {
func insertSQL(obj interface{}, schema string, withPK bool) string {
// Construct the table name prefixed with schema name
tname := goToSQLName(getTypeName(obj))
tname = fmt.Sprintf("%s.%s", schema, tname)

// Construct the first component of insert statement
sql := fmt.Sprintf("INSERT INTO %s (", tname)

fields, primary := prepareFields(obj)
// NOTE: An object is represented as a slice of Fields, where each Field
// represents a column.
// Get list of columns to populate.
fields, primary := prepareFields(obj, withPK)

// Based on the number of columns, create value placeholders
var values string
for i, f := range fields {
sql += fmt.Sprintf("\n\t%s,", getColumnName(f))
values += fmt.Sprintf("\n\t$%v,", i+1)
}

// Add source data placeholders
// something like this: `VALUES ($1, $2, $3, $4, $5, $6)`
sql = strings.TrimRight(sql, ",")
sql += "\n)\nVALUES ("
sql += values
sql = strings.TrimRight(sql, ",")
sql += "\n)\n"

// Add last line to capture primary key
sql += fmt.Sprintf("RETURNING %s as LastInsertId;", getColumnName(primary))

return sql
}

func insertMultipleSQL(entities interface{}, schema string) (string, []interface{}) {
slice, err := convertToSlice(entities)
if err != nil {
return "", nil
}

tname := goToSQLName(getTypeName(slice[0]))
tname = fmt.Sprintf("%s.%s", schema, tname)
sql := fmt.Sprintf("INSERT INTO %s (", tname)

fields, _ := prepareFields(slice[0])
for _, f := range fields {
sql += fmt.Sprintf("\n\t%s,", getColumnName(f))
}
sql = strings.TrimRight(sql, ",")
sql += "\n)\nVALUES \n"

args := make([]interface{}, len(slice)*len(fields))

counter := 0
for _, e := range slice {
fs, _ := prepareFields(e)
sql += fmt.Sprintf("(")
for _, f := range fs {
sql += fmt.Sprintf("$%v, ", counter+1)
args[counter] = f.Value
counter++
}
sql = strings.TrimRight(sql, ", ")
sql += fmt.Sprintf("),\n")
}
sql = strings.TrimRight(sql, ",\n")

return sql, args
}

// getColumnName returns a Field's associated Tag name if it is supplied.
// Else, it constructs a snake_case value from Field.Name value and returns it.
// Example:
// For a Field with `Name` as 'TransactionSource' if `db: transaction_source` is
// present in Tag then it'll be used else it'll be constructed.
func getColumnName(f *Field) string {
var columnName string
if f.Tag != "" {
columnName = f.Tag
} else {
columnName = goToSQLName(f.Name)
return columnName
}

columnName = goToSQLName(f.Name)
return columnName
}

// goToSQLName converts a string from camel case to snake case
// e.g. TransactionSource to transaction_source
func goToSQLName(name string) string {
var s string
for _, c := range name {
Expand All @@ -533,22 +545,27 @@ func goToSQLName(name string) string {
}

// insertArgs creates the insert arg slice for an object
func insertArgs(obj interface{}) []interface{} {
final, _ := prepareFields(obj)
func insertArgs(obj interface{}, withPK bool) []interface{} {
final, _ := prepareFields(obj, withPK)
args := make([]interface{}, len(final))
for i, f := range final {
args[i] = f.Value
}
return args
}

// prepareFields performs necessary transformations for the insert statement
func prepareFields(obj interface{}) (nfields []*Field, primary *Field) {
// prepareFields performs necessary transformations for the insert statement.
// If `withPK` is false: It does not account a primary key Field to list of
// fields to append, else, primary key is also considered.
func prepareFields(obj interface{}, withPK bool) (nfields []*Field, primary *Field) {
fields := fields(obj)
for i, f := range fields {
if i == 0 {
for _, f := range fields {

if f.IsPrimary {
primary = f
continue
if !withPK {
continue
}
}
nfields = append(nfields, f)
}
Expand Down Expand Up @@ -688,7 +705,14 @@ func exec(q *Query, nonQuery bool) *Result {
if err != nil {
return err
}
meta.RowsAffected, _ = res.RowsAffected()
meta.LastInsertId, err = res.LastInsertId()
if err != nil {
return err
}
meta.RowsAffected, err = res.RowsAffected()
if err != nil {
return err
}
r.setMeta(meta)
return nil
}
Expand Down Expand Up @@ -724,10 +748,10 @@ func (r *Result) setMeta(m meta) {
r.LastInsertId.ID = m.LastInsertId
r.RowsAffected.Count = m.RowsAffected
if m.LastInsertId == 0 {
r.LastInsertId.Err = errors.New("No LastInsertId returned")
r.LastInsertId.Err = errors.New("no LastInsertId returned")
}
if m.RowsAffected == -1 {
r.RowsAffected.Err = errors.New("No RowsAffected returned")
r.RowsAffected.Err = errors.New("no RowsAffected returned")
}
}

Expand Down Expand Up @@ -801,17 +825,6 @@ func getLen(i interface{}) int {
return 0
}

// cutFirstIndex cuts the string on the first occurrence of the sep.
// cutFirstIndex("hey.o", ".") => ("hey", "o")
// if index not found, returns (s, "")
func cutFirstIndex(s, sep string) (first, rest string) {
idx := strings.Index(s, sep)
if idx == -1 {
return s, ""
}
return s[:idx], s[idx+1:]
}

/* --- Start Connection String Functionality --- */

// SSLMode defines all possible SSL options
Expand Down
27 changes: 18 additions & 9 deletions reflect_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package papergres
import (
"errors"
"reflect"
"strconv"
)

// GetTypeName gets the type name of an object
Expand Down Expand Up @@ -45,12 +46,14 @@ func convertToSlice(v interface{}) ([]interface{}, error) {
return s, nil
}

// Field is a struct field
// Field is a struct field that represents a single entity of an object.
// To set a field as primary add `db_pk:true` to tag.
type Field struct {
Typeof string
Name string
Tag string
Value interface{}
Value interface{}
Typeof string
Name string
Tag string
IsPrimary bool
}

// Fields returns a struct's fields and their values
Expand All @@ -64,13 +67,19 @@ func fields(v interface{}) []*Field {

fields := make([]*Field, val.NumField())
vtype := val.Type()

for i := 0; i < val.NumField(); i++ {
f := val.Field(i)

// Get primary key value from db_pk tag
isPrimary, _ := strconv.ParseBool(vtype.Field(i).Tag.Get("db_pk"))

field := &Field{
Typeof: getTypeName(f.Interface()),
Name: vtype.Field(i).Name,
Value: f.Interface(),
Tag: vtype.Field(i).Tag.Get("db"),
Value: f.Interface(),
Typeof: getTypeName(f.Interface()),
Name: vtype.Field(i).Name,
Tag: vtype.Field(i).Tag.Get("db"),
IsPrimary: isPrimary,
}
fields[i] = field
}
Expand Down

0 comments on commit e68115e

Please sign in to comment.