Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
Merge pull request #456 from erizocosmico/feature/negative-numbers
Browse files Browse the repository at this point in the history
sql/(parse,expression): implement unary minus
  • Loading branch information
ajnavarro authored Oct 19, 2018
2 parents b1203b4 + 03bb18b commit 4095f3d
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 0 deletions.
4 changes: 4 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,10 @@ var queries = []struct {
"CREATE DATABASE `mydb` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8_bin */",
}},
},
{
`SELECT -1`,
[]sql.Row{{int64(-1)}},
},
}

func TestQueries(t *testing.T) {
Expand Down
79 changes: 79 additions & 0 deletions sql/expression/arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package expression

import (
"fmt"
"reflect"

errors "gopkg.in/src-d/go-errors.v1"
"gopkg.in/src-d/go-vitess.v1/vt/sqlparser"
Expand Down Expand Up @@ -407,3 +408,81 @@ func mod(lval, rval interface{}) (interface{}, error) {

return nil, errUnableToCast.New(lval, rval)
}

// UnaryMinus is an unary minus operator.
type UnaryMinus struct {
UnaryExpression
}

// NewUnaryMinus creates a new UnaryMinus expression node.
func NewUnaryMinus(child sql.Expression) *UnaryMinus {
return &UnaryMinus{UnaryExpression{Child: child}}
}

// Eval implements the sql.Expression interface.
func (e *UnaryMinus) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
child, err := e.Child.Eval(ctx, row)
if err != nil {
return nil, err
}

if child == nil {
return nil, nil
}

if !sql.IsNumber(e.Child.Type()) {
child, err = sql.Float64.Convert(child)
if err != nil {
child = 0.0
}
}

switch n := child.(type) {
case float64:
return -n, nil
case float32:
return -n, nil
case int64:
return -n, nil
case uint64:
return -int64(n), nil
case int32:
return -n, nil
case uint32:
return -int32(n), nil
default:
return nil, sql.ErrInvalidType.New(reflect.TypeOf(n))
}
}

// Type implements the sql.Expression interface.
func (e *UnaryMinus) Type() sql.Type {
typ := e.Child.Type()
if !sql.IsNumber(typ) {
return sql.Float64
}

if typ == sql.Uint32 {
return sql.Int32
}

if typ == sql.Uint64 {
return sql.Int64
}

return e.Child.Type()
}

func (e *UnaryMinus) String() string {
return fmt.Sprintf("-%s", e.Child)
}

// TransformUp implements the sql.Expression interface.
func (e *UnaryMinus) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
c, err := e.Child.TransformUp(f)
if err != nil {
return nil, err
}

return f(NewUnaryMinus(c))
}
28 changes: 28 additions & 0 deletions sql/expression/arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,31 @@ func TestAllInt64(t *testing.T) {
})
}
}

func TestUnaryMinus(t *testing.T) {
testCases := []struct {
name string
input interface{}
typ sql.Type
expected interface{}
}{
{"int32", int32(1), sql.Int32, int32(-1)},
{"uint32", uint32(1), sql.Uint32, int32(-1)},
{"int64", int64(1), sql.Int64, int64(-1)},
{"uint64", uint64(1), sql.Uint64, int64(-1)},
{"float32", float32(1), sql.Float32, float32(-1)},
{"float64", float64(1), sql.Float64, float64(-1)},
{"int text", "1", sql.Text, float64(-1)},
{"float text", "1.2", sql.Text, float64(-1.2)},
{"nil", nil, sql.Text, nil},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
f := NewUnaryMinus(NewLiteral(tt.input, tt.typ))
result, err := f.Eval(sql.NewEmptyContext(), nil)
require.NoError(t, err)
require.Equal(t, tt.expected, result)
})
}
}
17 changes: 17 additions & 0 deletions sql/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,8 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) {

case *sqlparser.BinaryExpr:
return binaryExprToExpression(v)
case *sqlparser.UnaryExpr:
return unaryExprToExpression(v)
}
}

Expand Down Expand Up @@ -893,6 +895,21 @@ func selectExprToExpression(se sqlparser.SelectExpr) (sql.Expression, error) {
}
}

func unaryExprToExpression(e *sqlparser.UnaryExpr) (sql.Expression, error) {
switch e.Operator {
case sqlparser.MinusStr:
expr, err := exprToExpression(e.Expr)
if err != nil {
return nil, err
}

return expression.NewUnaryMinus(expr), nil

default:
return nil, ErrUnsupportedFeature.New("unary operator: " + e.Operator)
}
}

func binaryExprToExpression(be *sqlparser.BinaryExpr) (sql.Expression, error) {
switch be.Operator {
case
Expand Down
8 changes: 8 additions & 0 deletions sql/parse/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,14 @@ var fixtures = map[string]sql.Node{
`SHOW CREATE SCHEMA foo`: plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), false),
`SHOW CREATE DATABASE IF NOT EXISTS foo`: plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), true),
`SHOW CREATE SCHEMA IF NOT EXISTS foo`: plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), true),
`SELECT -i FROM mytable`: plan.NewProject(
[]sql.Expression{
expression.NewUnaryMinus(
expression.NewUnresolvedColumn("i"),
),
},
plan.NewUnresolvedTable("mytable", ""),
),
}

func TestParse(t *testing.T) {
Expand Down

0 comments on commit 4095f3d

Please sign in to comment.