diff --git a/engine_test.go b/engine_test.go index 0e0873246..a5dfb8c17 100644 --- a/engine_test.go +++ b/engine_test.go @@ -418,6 +418,10 @@ var queries = []struct { {"c", int32(0)}, }, }, + { + `SELECT -1`, + []sql.Row{{int64(-1)}}, + }, } func TestQueries(t *testing.T) { diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 21ef18009..005bd6260 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -2,6 +2,7 @@ package expression import ( "fmt" + "reflect" errors "gopkg.in/src-d/go-errors.v1" "gopkg.in/src-d/go-vitess.v1/vt/sqlparser" @@ -407,3 +408,81 @@ func mod(lval, rval interface{}) (interface{}, error) { return nil, errUnableToCast.New(lval, rval) } + +// UnaryMinus is an unary minus operator. +type UnaryMinus struct { + UnaryExpression +} + +// NewUnaryMinus creates a new UnaryMinus expression node. +func NewUnaryMinus(child sql.Expression) *UnaryMinus { + return &UnaryMinus{UnaryExpression{Child: child}} +} + +// Eval implements the sql.Expression interface. +func (e *UnaryMinus) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + child, err := e.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + if child == nil { + return nil, nil + } + + if !sql.IsNumber(e.Child.Type()) { + child, err = sql.Float64.Convert(child) + if err != nil { + child = 0.0 + } + } + + switch n := child.(type) { + case float64: + return -n, nil + case float32: + return -n, nil + case int64: + return -n, nil + case uint64: + return -int64(n), nil + case int32: + return -n, nil + case uint32: + return -int32(n), nil + default: + return nil, sql.ErrInvalidType.New(reflect.TypeOf(n)) + } +} + +// Type implements the sql.Expression interface. +func (e *UnaryMinus) Type() sql.Type { + typ := e.Child.Type() + if !sql.IsNumber(typ) { + return sql.Float64 + } + + if typ == sql.Uint32 { + return sql.Int32 + } + + if typ == sql.Uint64 { + return sql.Int64 + } + + return e.Child.Type() +} + +func (e *UnaryMinus) String() string { + return fmt.Sprintf("-%s", e.Child) +} + +// TransformUp implements the sql.Expression interface. +func (e *UnaryMinus) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + c, err := e.Child.TransformUp(f) + if err != nil { + return nil, err + } + + return f(NewUnaryMinus(c)) +} diff --git a/sql/expression/arithmetic_test.go b/sql/expression/arithmetic_test.go index 02809687d..d56047504 100644 --- a/sql/expression/arithmetic_test.go +++ b/sql/expression/arithmetic_test.go @@ -360,3 +360,31 @@ func TestAllInt64(t *testing.T) { }) } } + +func TestUnaryMinus(t *testing.T) { + testCases := []struct { + name string + input interface{} + typ sql.Type + expected interface{} + }{ + {"int32", int32(1), sql.Int32, int32(-1)}, + {"uint32", uint32(1), sql.Uint32, int32(-1)}, + {"int64", int64(1), sql.Int64, int64(-1)}, + {"uint64", uint64(1), sql.Uint64, int64(-1)}, + {"float32", float32(1), sql.Float32, float32(-1)}, + {"float64", float64(1), sql.Float64, float64(-1)}, + {"int text", "1", sql.Text, float64(-1)}, + {"float text", "1.2", sql.Text, float64(-1.2)}, + {"nil", nil, sql.Text, nil}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + f := NewUnaryMinus(NewLiteral(tt.input, tt.typ)) + result, err := f.Eval(sql.NewEmptyContext(), nil) + require.NoError(t, err) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 3d38c3d90..738ab59fb 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -737,6 +737,8 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) { case *sqlparser.BinaryExpr: return binaryExprToExpression(v) + case *sqlparser.UnaryExpr: + return unaryExprToExpression(v) } } @@ -882,6 +884,21 @@ func selectExprToExpression(se sqlparser.SelectExpr) (sql.Expression, error) { } } +func unaryExprToExpression(e *sqlparser.UnaryExpr) (sql.Expression, error) { + switch e.Operator { + case sqlparser.MinusStr: + expr, err := exprToExpression(e.Expr) + if err != nil { + return nil, err + } + + return expression.NewUnaryMinus(expr), nil + + default: + return nil, ErrUnsupportedFeature.New("unary operator: " + e.Operator) + } +} + func binaryExprToExpression(be *sqlparser.BinaryExpr) (sql.Expression, error) { switch be.Operator { case