diff --git a/engine_test.go b/engine_test.go index bff0954d4..b44c8806b 100644 --- a/engine_test.go +++ b/engine_test.go @@ -506,6 +506,13 @@ var queries = []struct { `, []sql.Row{}, }, + { + `SHOW CREATE DATABASE mydb`, + []sql.Row{{ + "mydb", + "CREATE DATABASE `mydb` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8_bin */", + }}, + }, } func TestQueries(t *testing.T) { diff --git a/sql/analyzer/resolve_database.go b/sql/analyzer/resolve_database.go index 172208471..3ec98aadc 100644 --- a/sql/analyzer/resolve_database.go +++ b/sql/analyzer/resolve_database.go @@ -45,6 +45,15 @@ func resolveDatabase(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error return nil, err } + nc := *v + nc.Database = db + return &nc, nil + case *plan.ShowCreateDatabase: + db, err := a.Catalog.Database(v.Database.Name()) + if err != nil { + return nil, err + } + nc := *v nc.Database = db return &nc, nil diff --git a/sql/parse/parse.go b/sql/parse/parse.go index b592faebf..49e377303 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -9,7 +9,7 @@ import ( "strconv" "strings" - "github.com/opentracing/opentracing-go" + opentracing "github.com/opentracing/opentracing-go" "github.com/sirupsen/logrus" "gopkg.in/src-d/go-errors.v1" "gopkg.in/src-d/go-mysql-server.v0/sql" @@ -34,14 +34,15 @@ var ( ) var ( - describeTablesRegex = regexp.MustCompile(`^describe\s+table\s+(.*)`) - createIndexRegex = regexp.MustCompile(`^create\s+index\s+`) - dropIndexRegex = regexp.MustCompile(`^drop\s+index\s+`) - showIndexRegex = regexp.MustCompile(`^show\s+(index|indexes|keys)\s+(from|in)\s+\S+\s*`) - showCreateRegex = regexp.MustCompile(`^show create\s+\S+\s*`) - showVariablesRegex = regexp.MustCompile(`^show\s+(.*)?variables\s*`) - describeRegex = regexp.MustCompile(`^(describe|desc|explain)\s+(.*)\s+`) - fullProcessListRegex = regexp.MustCompile(`^show\s+(full\s+)?processlist$`) + describeTablesRegex = regexp.MustCompile(`^describe\s+table\s+(.*)`) + createIndexRegex = regexp.MustCompile(`^create\s+index\s+`) + dropIndexRegex = regexp.MustCompile(`^drop\s+index\s+`) + showIndexRegex = regexp.MustCompile(`^show\s+(index|indexes|keys)\s+(from|in)\s+\S+\s*`) + showVariablesRegex = regexp.MustCompile(`^show\s+(.*)?variables\s*`) + describeRegex = regexp.MustCompile(`^(describe|desc|explain)\s+(.*)\s+`) + fullProcessListRegex = regexp.MustCompile(`^show\s+(full\s+)?processlist$`) + showCreateRegex = regexp.MustCompile(`^show create\s+\S+\s*`) + showCreateDatabaseRegex = regexp.MustCompile(`^show\s+create\s+(database|schema)`) ) // Parse parses the given SQL sentence and returns the corresponding node. @@ -78,6 +79,8 @@ func Parse(ctx *sql.Context, query string) (sql.Node, error) { return parseDescribeQuery(ctx, s) case fullProcessListRegex.MatchString(lowerQuery): return plan.NewShowProcessList(), nil + case showCreateDatabaseRegex.MatchString(lowerQuery): + return parseShowCreateDatabase(s) } stmt, err := sqlparser.Parse(s) @@ -1081,3 +1084,61 @@ func parseShowTableStatus(query string) (sql.Node, error) { return nil, errUnexpectedSyntax.New("one of: FROM, IN, LIKE or WHERE", clause) } } + +func parseShowCreateDatabase(query string) (sql.Node, error) { + buf := bufio.NewReader(strings.NewReader(query)) + steps := []parseFunc{ + expect("show"), + skipSpaces, + expect("create"), + skipSpaces, + oneOf("database", "schema"), + skipSpaces, + } + + for _, s := range steps { + if err := s(buf); err != nil { + return nil, err + } + } + + var next string + if err := readIdent(&next)(buf); err != nil { + return nil, err + } + + var ifNotExists bool + if next == "if" { + ifNotExists = true + steps := []parseFunc{ + skipSpaces, + expect("not"), + skipSpaces, + expect("exists"), + skipSpaces, + readIdent(&next), + } + + for _, s := range steps { + if err := s(buf); err != nil { + return nil, err + } + } + } + + steps = []parseFunc{ + skipSpaces, + checkEOF, + } + + for _, s := range steps { + if err := s(buf); err != nil { + return nil, err + } + } + + return plan.NewShowCreateDatabase( + sql.UnresolvedDatabase(next), + ifNotExists, + ), nil +} diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 74f835c59..c74379389 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -793,6 +793,10 @@ var fixtures = map[string]sql.Node{ `SHOW SESSION VARIABLES`: plan.NewShowVariables(sql.NewEmptyContext().GetAll(), ""), `SHOW VARIABLES LIKE 'gtid_mode'`: plan.NewShowVariables(sql.NewEmptyContext().GetAll(), "gtid_mode"), `SHOW SESSION VARIABLES LIKE 'autocommit'`: plan.NewShowVariables(sql.NewEmptyContext().GetAll(), "autocommit"), + `SHOW CREATE DATABASE foo`: plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), false), + `SHOW CREATE SCHEMA foo`: plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), false), + `SHOW CREATE DATABASE IF NOT EXISTS foo`: plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), true), + `SHOW CREATE SCHEMA IF NOT EXISTS foo`: plan.NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), true), } func TestParse(t *testing.T) { diff --git a/sql/plan/show_create_database.go b/sql/plan/show_create_database.go new file mode 100644 index 000000000..313e0e369 --- /dev/null +++ b/sql/plan/show_create_database.go @@ -0,0 +1,79 @@ +package plan + +import ( + "bytes" + "fmt" + + "gopkg.in/src-d/go-mysql-server.v0/sql" +) + +// ShowCreateDatabase returns the SQL for creating a database. +type ShowCreateDatabase struct { + Database sql.Database + IfNotExists bool +} + +const defaultCharacterSet = "utf8mb4" + +var showCreateDatabaseSchema = sql.Schema{ + {Name: "Database", Type: sql.Text}, + {Name: "Create Database", Type: sql.Text}, +} + +// NewShowCreateDatabase creates a new ShowCreateDatabase node. +func NewShowCreateDatabase(db sql.Database, ifNotExists bool) *ShowCreateDatabase { + return &ShowCreateDatabase{db, ifNotExists} +} + +// RowIter implements the sql.Node interface. +func (s *ShowCreateDatabase) RowIter(ctx *sql.Context) (sql.RowIter, error) { + var name = s.Database.Name() + + var buf bytes.Buffer + + buf.WriteString("CREATE DATABASE ") + if s.IfNotExists { + buf.WriteString("/*!32312 IF NOT EXISTS*/ ") + } + + buf.WriteRune('`') + buf.WriteString(name) + buf.WriteRune('`') + buf.WriteString(fmt.Sprintf( + " /*!40100 DEFAULT CHARACTER SET %s COLLATE %s */", + defaultCharacterSet, + defaultCollation, + )) + + return sql.RowsToRowIter( + sql.NewRow(name, buf.String()), + ), nil +} + +// Schema implements the sql.Node interface. +func (s *ShowCreateDatabase) Schema() sql.Schema { + return showCreateDatabaseSchema +} + +func (s *ShowCreateDatabase) String() string { + return fmt.Sprintf("SHOW CREATE DATABASE %s", s.Database.Name()) +} + +// Children implements the sql.Node interface. +func (s *ShowCreateDatabase) Children() []sql.Node { return nil } + +// Resolved implements the sql.Node interface. +func (s *ShowCreateDatabase) Resolved() bool { + _, ok := s.Database.(sql.UnresolvedDatabase) + return !ok +} + +// TransformExpressionsUp implements the sql.Node interface. +func (s *ShowCreateDatabase) TransformExpressionsUp(f sql.TransformExprFunc) (sql.Node, error) { + return s, nil +} + +// TransformUp implements the sql.Node interface. +func (s *ShowCreateDatabase) TransformUp(f sql.TransformNodeFunc) (sql.Node, error) { + return f(s) +} diff --git a/sql/plan/show_create_database_test.go b/sql/plan/show_create_database_test.go new file mode 100644 index 000000000..4f0da72f6 --- /dev/null +++ b/sql/plan/show_create_database_test.go @@ -0,0 +1,34 @@ +package plan + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-mysql-server.v0/sql" +) + +func TestShowCreateDatabase(t *testing.T) { + require := require.New(t) + + node := NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), true) + iter, err := node.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal([]sql.Row{ + {"foo", "CREATE DATABASE /*!32312 IF NOT EXISTS*/ `foo` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8_bin */"}, + }, rows) + + node = NewShowCreateDatabase(sql.UnresolvedDatabase("foo"), false) + iter, err = node.RowIter(sql.NewEmptyContext()) + require.NoError(err) + + rows, err = sql.RowIterToRows(iter) + require.NoError(err) + + require.Equal([]sql.Row{ + {"foo", "CREATE DATABASE `foo` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8_bin */"}, + }, rows) +}