Skip to content

Commit

Permalink
feat: add custom sql support and several json helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangwei2514 committed Jul 6, 2024
1 parent 47b22f0 commit 68c18e9
Show file tree
Hide file tree
Showing 4 changed files with 750 additions and 20 deletions.
19 changes: 18 additions & 1 deletion builder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ var (
errLockModeValueType = errors.New(`[builder] the value of "_lockMode" must be of string type`)
errNotAllowedLockMode = errors.New(`[builder] the value of "_lockMode" is not allowed`)
errLimitType = errors.New(`[builder] the value of "_limit" must be one of int,uint,int64,uint64`)
errCustomValueType = errors.New(`[builder] the value of "_custom_" must impl Comparable`)

errWhereInterfaceSliceType = `[builder] the value of "xxx %s" must be of []interface{} type`
errEmptySliceCondition = `[builder] the value of "%s" must contain at least one element`
Expand Down Expand Up @@ -252,7 +253,15 @@ func getWhereConditions(where map[string]interface{}, ignoreKeys map[string]stru
var comparables []Comparable
var field, operator string
var err error
for key, val := range where {
// to keep the result in certain order
keys := make([]string, 0, len(where))
for key := range where {
keys = append(keys, key)
}
defaultSortAlgorithm(keys)

for _, key := range keys {
val := where[key]
if _, ok := ignoreKeys[key]; ok {
continue
}
Expand All @@ -278,6 +287,14 @@ func getWhereConditions(where map[string]interface{}, ignoreKeys map[string]stru
comparables = append(comparables, OrWhere(orWhereComparable))
continue
}
if strings.HasPrefix(key, "_custom_") {
v, ok := val.(Comparable)
if !ok {
return nil, errCustomValueType
}
comparables = append(comparables, v)
continue
}
field, operator, err = splitKey(key, val)
if nil != err {
return nil, err
Expand Down
49 changes: 30 additions & 19 deletions builder/dao.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ var (
}
)

//the order of a map is unpredicatable so we need a sort algorithm to sort the fields
//and make it predicatable
// the order of a map is unpredicatable so we need a sort algorithm to sort the fields
// and make it predicatable
var (
defaultSortAlgorithm = sort.Strings
)

//Comparable requires type implements the Build method
// Comparable requires type implements the Build method
type Comparable interface {
Build() ([]string, []interface{})
}
Expand Down Expand Up @@ -123,58 +123,58 @@ func (l NotLike) Build() ([]string, []interface{}) {
return cond, vals
}

//Eq means equal(=)
// Eq means equal(=)
type Eq map[string]interface{}

//Build implements the Comparable interface
// Build implements the Comparable interface
func (e Eq) Build() ([]string, []interface{}) {
return build(e, "=")
}

//Ne means Not Equal(!=)
// Ne means Not Equal(!=)
type Ne map[string]interface{}

//Build implements the Comparable interface
// Build implements the Comparable interface
func (n Ne) Build() ([]string, []interface{}) {
return build(n, "!=")
}

//Lt means less than(<)
// Lt means less than(<)
type Lt map[string]interface{}

//Build implements the Comparable interface
// Build implements the Comparable interface
func (l Lt) Build() ([]string, []interface{}) {
return build(l, "<")
}

//Lte means less than or equal(<=)
// Lte means less than or equal(<=)
type Lte map[string]interface{}

//Build implements the Comparable interface
// Build implements the Comparable interface
func (l Lte) Build() ([]string, []interface{}) {
return build(l, "<=")
}

//Gt means greater than(>)
// Gt means greater than(>)
type Gt map[string]interface{}

//Build implements the Comparable interface
// Build implements the Comparable interface
func (g Gt) Build() ([]string, []interface{}) {
return build(g, ">")
}

//Gte means greater than or equal(>=)
// Gte means greater than or equal(>=)
type Gte map[string]interface{}

//Build implements the Comparable interface
// Build implements the Comparable interface
func (g Gte) Build() ([]string, []interface{}) {
return build(g, ">=")
}

//In means in
// In means in
type In map[string][]interface{}

//Build implements the Comparable interface
// Build implements the Comparable interface
func (i In) Build() ([]string, []interface{}) {
if nil == i || 0 == len(i) {
return nil, nil
Expand All @@ -199,10 +199,10 @@ func buildIn(field string, vals []interface{}) (cond string) {
return
}

//NotIn means not in
// NotIn means not in
type NotIn map[string][]interface{}

//Build implements the Comparable interface
// Build implements the Comparable interface
func (i NotIn) Build() ([]string, []interface{}) {
if nil == i || 0 == len(i) {
return nil, nil
Expand Down Expand Up @@ -416,6 +416,17 @@ func resolveUpdate(update map[string]interface{}) (sets string, vals []interface
sb.WriteString(fmt.Sprintf("%s=%s,", k, v))
continue
}
if strings.HasPrefix(k, "_custom_") {
if custom, ok := v.(Comparable); ok {
sql, val := custom.Build()
for _, s := range sql {
sb.WriteString(s)
sb.WriteByte(',')
}
vals = append(vals, val...)
}
continue
}
vals = append(vals, v)
sb.WriteString(fmt.Sprintf("%s=?,", quoteField(k)))
}
Expand Down
182 changes: 182 additions & 0 deletions builder/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"context"
"database/sql"
"reflect"
"sort"
"strconv"
"strings"
)

// AggregateQuery is a helper function to execute the aggregate query and return the result
Expand Down Expand Up @@ -174,3 +176,183 @@ func isZero(v reflect.Value) bool {
}
return true
}

type rawSql struct {
sqlCond string
values []interface{}
}

func (r rawSql) Build() ([]string, []interface{}) {
return []string{r.sqlCond}, r.values
}

func Custom(query string, args ...interface{}) Comparable {
return rawSql{sqlCond: query, values: args}
}

// JsonContains aim to check target json contains all items in given obj;if check certain value just use direct
// where := map[string]interface{}{"your_json_field.'$.path_to_key' =": val}
//
// notice: fullJsonPath should hard code, never from user input;
// jsonLike only support json element like array,map,string,number etc., struct input will result panic!!!
//
// usage where := map[string]interface{}{"_custom_xxx": builder.JsonContains("my_json->'$.my_data.list'", 7)}
//
// usage where := map[string]interface{}{"_custom_xxx": builder.JsonContains("my_json->'$'", []int{1,2})}
//
// usage where := map[string]interface{}{"_custom_xxx": builder.JsonContains("my_json->'$.user_info'", map[string]any{"name": "", "age": 18})}
func JsonContains(fullJsonPath string, jsonLike interface{}) Comparable {
// MEMBER OF cant not deal null in json array
if jsonLike == nil {
return rawSql{
sqlCond: "JSON_CONTAINS(" + fullJsonPath + ",'null')",
values: nil,
}
}

s, v := genJsonObj(jsonLike)
// jsonLike is number, string, bool
_, ok := jsonLike.(string) // this check avoid eg jsonLike "JSONa"
if ok || !strings.HasPrefix(s, "JSON") {
return rawSql{
sqlCond: "(" + s + " MEMBER OF(" + fullJsonPath + "))",
values: v,
}
}
// jsonLike is array or map
return rawSql{
sqlCond: "JSON_CONTAINS(" + fullJsonPath + "," + s + ")",
values: v,
}
}

// JsonSet aim to simply set/update json field operation;
//
// notice: jsonPath should hard code, never from user input;
//
// usage update := map[string]interface{}{"_custom_xxx": builder.JsonSet(field, "$.code", 1, "$.user_info", map[string]any{"name": "", "age": 18})}
func JsonSet(field string, pathAndValuePair ...interface{}) Comparable {
return jsonUpdateCall("JSON_SET", field, pathAndValuePair...)
}

// JsonArrayAppend gen JsonObj and call MySQL JSON_ARRAY_APPEND function;
// usage update := map[string]interface{}{"_custom_xxx": builder.JsonArrayAppend(field, "$", 1, "$[last]", []string{"2","3"}}
func JsonArrayAppend(field string, pathAndValuePair ...interface{}) Comparable {
return jsonUpdateCall("JSON_ARRAY_APPEND", field, pathAndValuePair...)
}

// JsonArrayInsert gen JsonObj and call MySQL JSON_ARRAY_INSERT function; insert at index
// usage update := map[string]interface{}{"_custom_xxx": builder.JsonArrayInsert(field, "$[0]", 1, "$[0]", []string{"2","3"}}
func JsonArrayInsert(field string, pathAndValuePair ...interface{}) Comparable {
return jsonUpdateCall("JSON_ARRAY_INSERT", field, pathAndValuePair...)
}

// JsonRemove call MySQL JSON_REMOVE function; remove element from Array or Map
// path removed in order, prev remove affect the later operation, maybe the array shrink
//
// remove last array element; update := map[string]interface{}{"_custom_xxx":builder.JsonRemove(field,'$.list[last]')}
// remove element; update := map[string]interface{}{"_custom_xxx":builder.JsonRemove(field,'$.key0')}
func JsonRemove(field string, path ...string) Comparable {
if len(path) == 0 {
// do nothing, update xxx set a=a;
return rawSql{
sqlCond: field + "=" + field,
values: nil,
}
}
return rawSql{
sqlCond: field + "=JSON_REMOVE(" + field + ",'" + strings.Join(path, "','") + "')",
values: nil,
}
}

// jsonUpdateCall build args then call fn
func jsonUpdateCall(fn string, field string, pathAndValuePair ...interface{}) Comparable {
if len(pathAndValuePair) == 0 || len(pathAndValuePair)%2 != 0 {
return rawSql{sqlCond: field, values: nil}
}
val := make([]interface{}, 0, len(pathAndValuePair)/2)
var buf strings.Builder
buf.WriteString(field)
buf.WriteByte('=')
buf.WriteString(fn + "(")
buf.WriteString(field)
for i := 0; i < len(pathAndValuePair); i += 2 {
buf.WriteString(",'")
buf.WriteString(pathAndValuePair[i].(string))
buf.WriteString("',")

jsonSql, jsonVals := genJsonObj(pathAndValuePair[i+1])
buf.WriteString(jsonSql)
val = append(val, jsonVals...)
}
buf.WriteByte(')')

return rawSql{
sqlCond: buf.String(),
values: val,
}
}

// genJsonObj build MySQL JSON object using JSON_ARRAY, JSON_OBJECT or ?; return sql string and args
func genJsonObj(obj interface{}) (string, []interface{}) {
if obj == nil {
return "null", nil
}
rValue := reflect.Indirect(reflect.ValueOf(obj))
rType := rValue.Kind()
var s []string
var vals []interface{}
switch rType {
case reflect.Array, reflect.Slice:
s = append(s, "JSON_ARRAY(")
length := rValue.Len()
for i := 0; i < length; i++ {
subS, subVals := genJsonObj(rValue.Index(i).Interface())
s = append(s, subS, ",")
vals = append(vals, subVals...)
}

if s[len(s)-1] == "," {
s[len(s)-1] = ")"
} else { // empty slice
s = append(s, ")")
}
case reflect.Map:
s = append(s, "JSON_OBJECT(")
// sort keys in map to keep generate result same.
keys := rValue.MapKeys()
sort.Slice(keys, func(i, j int) bool {
return keys[i].String() < keys[j].String()
})
length := rValue.Len()
for i := 0; i < length; i++ {
k := keys[i]
v := rValue.MapIndex(k)
subS, subVals := genJsonObj(v.Interface())
s = append(s, "?,", subS, ",")
vals = append(vals, k.String())
vals = append(vals, subVals...)
}

if s[len(s)-1] == "," {
s[len(s)-1] = ")"
} else { // empty map
s = append(s, ")")
}

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
return "?", []interface{}{rValue.Interface()}
case reflect.Bool:
if rValue.Bool() {
return "true", nil
}
return "false", nil
default:
panic("genJsonObj not support type: " + rType.String())
}
return strings.Join(s, ""), vals
}
Loading

0 comments on commit 68c18e9

Please sign in to comment.