Skip to content

Commit

Permalink
fix(mysql): escape backslash char in strings
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Nov 11, 2021
1 parent e92035d commit fb32029
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 71 deletions.
46 changes: 0 additions & 46 deletions dialect/append.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package dialect

import (
"encoding/hex"
"math"
"strconv"
"unicode/utf8"

"github.com/uptrace/bun/internal"
)
Expand Down Expand Up @@ -48,50 +46,6 @@ func appendFloat(b []byte, v float64, bitSize int) []byte {
}
}

func AppendString(b []byte, s string) []byte {
b = append(b, '\'')
for _, r := range s {
if r == '\000' {
continue
}

if r == '\'' {
b = append(b, '\'', '\'')
continue
}

if r < utf8.RuneSelf {
b = append(b, byte(r))
continue
}

l := len(b)
if cap(b)-l < utf8.UTFMax {
b = append(b, make([]byte, utf8.UTFMax)...)
}
n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
b = b[:l+n]
}
b = append(b, '\'')
return b
}

func AppendBytes(b, bs []byte) []byte {
if bs == nil {
return AppendNull(b)
}

b = append(b, `'\x`...)

s := len(b)
b = append(b, make([]byte, hex.EncodedLen(len(bs)))...)
hex.Encode(b[s:], bs)

b = append(b, '\'')

return b
}

//------------------------------------------------------------------------------

func AppendIdent(b []byte, field string, quote byte) []byte {
Expand Down
14 changes: 0 additions & 14 deletions dialect/append_test.go

This file was deleted.

38 changes: 35 additions & 3 deletions dialect/mysqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"strings"
"time"
"unicode/utf8"

"golang.org/x/mod/semver"

Expand Down Expand Up @@ -82,14 +83,45 @@ func (d *Dialect) IdentQuote() byte {
return '`'
}

func (d *Dialect) AppendTime(b []byte, tm time.Time) []byte {
func (*Dialect) AppendTime(b []byte, tm time.Time) []byte {
b = append(b, '\'')
b = tm.AppendFormat(b, "2006-01-02 15:04:05.999999")
b = append(b, '\'')
return b
}

func (d *Dialect) AppendBytes(b []byte, bs []byte) []byte {
func (*Dialect) AppendString(b []byte, s string) []byte {
b = append(b, '\'')
loop:
for _, r := range s {
switch r {
case '\000':
continue loop
case '\'':
b = append(b, "''"...)
continue loop
case '\\':
b = append(b, '\\', '\\')
continue loop
}

if r < utf8.RuneSelf {
b = append(b, byte(r))
continue
}

l := len(b)
if cap(b)-l < utf8.UTFMax {
b = append(b, make([]byte, utf8.UTFMax)...)
}
n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
b = b[:l+n]
}
b = append(b, '\'')
return b
}

func (*Dialect) AppendBytes(b []byte, bs []byte) []byte {
if bs == nil {
return dialect.AppendNull(b)
}
Expand All @@ -105,7 +137,7 @@ func (d *Dialect) AppendBytes(b []byte, bs []byte) []byte {
return b
}

func (d *Dialect) AppendJSON(b, jsonb []byte) []byte {
func (*Dialect) AppendJSON(b, jsonb []byte) []byte {
b = append(b, '\'')

for _, c := range jsonb {
Expand Down
5 changes: 5 additions & 0 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ func testSelectScan(t *testing.T, db *bun.DB) {
err = db.NewSelect().TableExpr("(SELECT 10) AS t").Where("FALSE").Scan(ctx, &num)
require.Equal(t, sql.ErrNoRows, err)

var str string
err = db.NewSelect().ColumnExpr("?", "\\\"'hello\n%_").Scan(ctx, &str)
require.NoError(t, err)
require.Equal(t, "\\\"'hello\n%_", str)

var flag bool
err = db.NewSelect().
ColumnExpr("EXISTS (?)", db.NewSelect().ColumnExpr("1")).
Expand Down
3 changes: 1 addition & 2 deletions query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"strings"
"sync"

"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
Expand Down Expand Up @@ -542,7 +541,7 @@ func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte,
if len(q.table.Fields) > 10 && fmter.IsNop() {
b = append(b, q.table.SQLAlias...)
b = append(b, '.')
b = dialect.AppendString(b, fmt.Sprintf("%d columns", len(q.table.Fields)))
b = fmter.Dialect().AppendString(b, fmt.Sprintf("%d columns", len(q.table.Fields)))
} else {
b = appendColumns(b, q.table.SQLAlias, q.table.Fields)
}
Expand Down
2 changes: 1 addition & 1 deletion schema/append.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func Append(fmter Formatter, b []byte, v interface{}) []byte {
case float64:
return dialect.AppendFloat64(b, v)
case string:
return dialect.AppendString(b, v)
return fmter.Dialect().AppendString(b, v)
case time.Time:
return fmter.Dialect().AppendTime(b, v)
case []byte:
Expand Down
8 changes: 4 additions & 4 deletions schema/append_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func appendArrayBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
}

func AppendStringValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return dialect.AppendString(b, v.String())
return fmter.Dialect().AppendString(b, v.String())
}

func AppendJSONValue(fmter Formatter, b []byte, v reflect.Value) []byte {
Expand All @@ -217,20 +217,20 @@ func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte {

func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte {
ip := v.Interface().(net.IP)
return dialect.AppendString(b, ip.String())
return fmter.Dialect().AppendString(b, ip.String())
}

func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte {
ipnet := v.Interface().(net.IPNet)
return dialect.AppendString(b, ipnet.String())
return fmter.Dialect().AppendString(b, ipnet.String())
}

func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte {
bytes := v.Bytes()
if bytes == nil {
return dialect.AppendNull(b)
}
return dialect.AppendString(b, internal.String(bytes))
return fmter.Dialect().AppendString(b, internal.String(bytes))
}

func appendQueryAppenderValue(fmter Formatter, b []byte, v reflect.Value) []byte {
Expand Down
45 changes: 44 additions & 1 deletion schema/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package schema

import (
"database/sql"
"encoding/hex"
"strconv"
"time"
"unicode/utf8"

"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
Expand All @@ -24,6 +26,7 @@ type Dialect interface {
AppendUint32(b []byte, n uint32) []byte
AppendUint64(b []byte, n uint64) []byte
AppendTime(b []byte, tm time.Time) []byte
AppendString(b []byte, s string) []byte
AppendBytes(b []byte, bs []byte) []byte
AppendJSON(b, jsonb []byte) []byte
}
Expand All @@ -47,8 +50,48 @@ func (BaseDialect) AppendTime(b []byte, tm time.Time) []byte {
return b
}

func (BaseDialect) AppendString(b []byte, s string) []byte {
b = append(b, '\'')
for _, r := range s {
if r == '\000' {
continue
}

if r == '\'' {
b = append(b, '\'', '\'')
continue
}

if r < utf8.RuneSelf {
b = append(b, byte(r))
continue
}

l := len(b)
if cap(b)-l < utf8.UTFMax {
b = append(b, make([]byte, utf8.UTFMax)...)
}
n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
b = b[:l+n]
}
b = append(b, '\'')
return b
}

func (BaseDialect) AppendBytes(b, bs []byte) []byte {
return dialect.AppendBytes(b, bs)
if bs == nil {
return dialect.AppendNull(b)
}

b = append(b, `'\x`...)

s := len(b)
b = append(b, make([]byte, hex.EncodedLen(len(bs)))...)
hex.Encode(b[s:], bs)

b = append(b, '\'')

return b
}

func (BaseDialect) AppendJSON(b, jsonb []byte) []byte {
Expand Down

0 comments on commit fb32029

Please sign in to comment.