diff --git a/sql/expression/function/coalesce.go b/sql/expression/function/coalesce.go new file mode 100644 index 000000000..1f0d75a68 --- /dev/null +++ b/sql/expression/function/coalesce.go @@ -0,0 +1,126 @@ +package function + +import ( + "fmt" + "strings" + + "gopkg.in/src-d/go-mysql-server.v0/sql" +) + +// Coalesce returns the first non-NULL value in the list, or NULL if there are no non-NULL values. +type Coalesce struct { + args []sql.Expression +} + +// NewCoalesce creates a new Coalesce sql.Expression. +func NewCoalesce(args ...sql.Expression) (sql.Expression, error) { + if len(args) == 0 { + return nil, sql.ErrInvalidArgumentNumber.New("1 or more", 0) + } + + return &Coalesce{args}, nil +} + +// Type implements the sql.Expression interface. +// The return type of Type() is the aggregated type of the argument types. +func (c *Coalesce) Type() sql.Type { + for _, arg := range c.args { + if arg == nil { + continue + } + t := arg.Type() + if t == nil { + continue + } + return t + } + + return nil +} + +// IsNullable implements the sql.Expression interface. +// Returns true if all arguments are nil +// or of the first non-nil argument is nullable, otherwise false. +func (c *Coalesce) IsNullable() bool { + for _, arg := range c.args { + if arg == nil { + continue + } + return arg.IsNullable() + } + return true +} + +func (c *Coalesce) String() string { + var args = make([]string, len(c.args)) + for i, arg := range c.args { + args[i] = arg.String() + } + return fmt.Sprintf("coalesce(%s)", strings.Join(args, ", ")) +} + +// TransformUp implements the sql.Expression interface. +func (c *Coalesce) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { + var ( + args = make([]sql.Expression, len(c.args)) + err error + ) + + for i, arg := range c.args { + if arg != nil { + arg, err = arg.TransformUp(fn) + if err != nil { + return nil, err + } + } + args[i] = arg + } + + expr, err := NewCoalesce(args...) + if err != nil { + return nil, err + } + + return fn(expr) +} + +// Resolved implements the sql.Expression interface. +// The function checks if first non-nil argument is resolved. +func (c *Coalesce) Resolved() bool { + for _, arg := range c.args { + if arg == nil { + continue + } + if !arg.Resolved() { + return false + } + } + return true +} + +// Children implements the sql.Expression interface. +func (c *Coalesce) Children() []sql.Expression { return c.args } + +// Eval implements the sql.Expression interface. +// The function evaluates the first non-nil argument. If the value is nil, +// then we keep going, otherwise we return the first non-nil value. +func (c *Coalesce) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + for _, arg := range c.args { + if arg == nil { + continue + } + + val, err := arg.Eval(ctx, row) + if err != nil { + return nil, err + } + + if val == nil { + continue + } + + return val, nil + } + + return nil, nil +} diff --git a/sql/expression/function/coalesce_test.go b/sql/expression/function/coalesce_test.go new file mode 100644 index 000000000..b2c81233f --- /dev/null +++ b/sql/expression/function/coalesce_test.go @@ -0,0 +1,64 @@ +package function + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" +) + +func TestEmptyCoalesce(t *testing.T) { + _, err := NewCoalesce() + require.True(t, sql.ErrInvalidArgumentNumber.Is(err)) +} + +func TestCoalesce(t *testing.T) { + testCases := []struct { + name string + input []sql.Expression + expected interface{} + typ sql.Type + nullable bool + }{ + {"coalesce(1, 2, 3)", []sql.Expression{expression.NewLiteral(1, sql.Int32), expression.NewLiteral(2, sql.Int32), expression.NewLiteral(3, sql.Int32)}, 1, sql.Int32, false}, + {"coalesce(NULL, NULL, 3)", []sql.Expression{nil, nil, expression.NewLiteral(3, sql.Int32)}, 3, sql.Int32, false}, + {"coalesce(NULL, NULL, '3')", []sql.Expression{nil, nil, expression.NewLiteral("3", sql.Text)}, "3", sql.Text, false}, + {"coalesce(NULL, '2', 3)", []sql.Expression{nil, expression.NewLiteral("2", sql.Text), expression.NewLiteral(3, sql.Int32)}, "2", sql.Text, false}, + {"coalesce(NULL, NULL, NULL)", []sql.Expression{nil, nil, nil}, nil, nil, true}, + } + + for _, tt := range testCases { + c, err := NewCoalesce(tt.input...) + require.NoError(t, err) + + require.Equal(t, tt.typ, c.Type()) + require.Equal(t, tt.nullable, c.IsNullable()) + v, err := c.Eval(sql.NewEmptyContext(), nil) + require.NoError(t, err) + require.Equal(t, tt.expected, v) + } +} + +func TestComposeCoalasce(t *testing.T) { + c1, err := NewCoalesce(nil) + require.NoError(t, err) + require.Equal(t, nil, c1.Type()) + v, err := c1.Eval(sql.NewEmptyContext(), nil) + require.NoError(t, err) + require.Equal(t, nil, v) + + c2, err := NewCoalesce(nil, expression.NewLiteral(1, sql.Int32)) + require.NoError(t, err) + require.Equal(t, sql.Int32, c2.Type()) + v, err = c2.Eval(sql.NewEmptyContext(), nil) + require.NoError(t, err) + require.Equal(t, 1, v) + + c, err := NewCoalesce(nil, c1, c2) + require.NoError(t, err) + require.Equal(t, sql.Int32, c.Type()) + v, err = c.Eval(sql.NewEmptyContext(), nil) + require.NoError(t, err) + require.Equal(t, 1, v) +}