diff --git a/engine_test.go b/engine_test.go index 92c7d2040..c8cd01ec0 100644 --- a/engine_test.go +++ b/engine_test.go @@ -565,6 +565,42 @@ func TestQueries(t *testing.T) { }) } +func TestSessionSelectLimit(t *testing.T) { + ctx := newCtx() + ctx.Session.Set("sql_select_limit", sql.Int64, int64(1)) + + q := []struct { + query string + expected []sql.Row + }{ + { + "SELECT * FROM mytable ORDER BY i", + []sql.Row{{int64(1), "first row"}}, + }, + { + "SELECT * FROM mytable ORDER BY i LIMIT 2", + []sql.Row{ + {int64(1), "first row"}, + {int64(2), "second row"}, + }, + }, + { + "SELECT i FROM (SELECT i FROM mytable LIMIT 2) t ORDER BY i", + []sql.Row{{int64(1)}}, + }, + { + "SELECT i FROM (SELECT i FROM mytable) t ORDER BY i LIMIT 2", + []sql.Row{{int64(1)}}, + }, + } + e := newEngine(t) + t.Run("sql_select_limit", func(t *testing.T) { + for _, tt := range q { + testQueryWithContext(ctx, t, e, tt.query, tt.expected) + } + }) +} + func TestSessionDefaults(t *testing.T) { ctx := newCtx() ctx.Session.Set("auto_increment_increment", sql.Int64, 0) @@ -602,6 +638,7 @@ func TestSessionDefaults(t *testing.T) { require.Equal(defaults["ndbinfo_version"].Value, val) }) } + func TestWarnings(t *testing.T) { ctx := newCtx() ctx.Session.Warn(&sql.Warning{Code: 1}) diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 76474496a..fb6f876f4 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -287,6 +287,9 @@ func convertSelect(ctx *sql.Context, s *sqlparser.Select) (sql.Node, error) { if err != nil { return nil, err } + } else if ok, val := sql.HasDefaultValue(ctx.Session, "sql_select_limit"); !ok { + limit := val.(int64) + node = plan.NewLimit(int64(limit), node) } if s.Limit != nil && s.Limit.Offset != nil { diff --git a/sql/session.go b/sql/session.go index 53ca7d246..f8b5ea26a 100644 --- a/sql/session.go +++ b/sql/session.go @@ -152,6 +152,15 @@ func DefaultSessionConfig() map[string]TypedValue { } } +// HasDefaultValue checks if session variable value is the default one. +func HasDefaultValue(s Session, key string) (bool, interface{}) { + typ, val := s.Get(key) + if cfg, ok := DefaultSessionConfig()[key]; ok { + return (cfg.Typ == typ && cfg.Value == val), val + } + return false, val +} + // NewSession creates a new session with data. func NewSession(address string, user string, id uint32) Session { return &BaseSession{ diff --git a/sql/session_test.go b/sql/session_test.go index 9240944e6..fe6dcb05f 100644 --- a/sql/session_test.go +++ b/sql/session_test.go @@ -12,7 +12,6 @@ func TestSessionConfig(t *testing.T) { require := require.New(t) sess := NewSession("foo", "bar", 1) - typ, v := sess.Get("foo") require.Equal(Null, typ) require.Equal(nil, v) @@ -34,7 +33,20 @@ func TestSessionConfig(t *testing.T) { require.Equal(3, sess.Warnings()[0].Code) require.Equal(2, sess.Warnings()[1].Code) require.Equal(1, sess.Warnings()[2].Code) +} + +func TestHasDefaultValue(t *testing.T) { + require := require.New(t) + sess := NewSession("foo", "bar", 1) + + for key := range DefaultSessionConfig() { + require.True(HasDefaultValue(sess, key)) + } + + sess.Set("auto_increment_increment", Int64, 123) + require.False(HasDefaultValue(sess, "auto_increment_increment")) + require.False(HasDefaultValue(sess, "non_existing_key")) } type testNode struct{}