From bd236154ab76454a23da48ac3e1d3fcb262a84c0 Mon Sep 17 00:00:00 2001 From: ALMAS Date: Sat, 28 Dec 2024 17:25:37 +0800 Subject: [PATCH] feat: Add artisan db:show command (#787) * feat: Add artisan db:show command * chore: Improve unit test and review issue --- database/console/show_command.go | 194 ++++++++++++++++++++++++++ database/console/show_command_test.go | 175 +++++++++++++++++++++++ database/service_provider.go | 1 + 3 files changed, 370 insertions(+) create mode 100644 database/console/show_command.go create mode 100644 database/console/show_command_test.go diff --git a/database/console/show_command.go b/database/console/show_command.go new file mode 100644 index 000000000..164d43dfa --- /dev/null +++ b/database/console/show_command.go @@ -0,0 +1,194 @@ +package console + +import ( + "fmt" + "strings" + + "github.com/goravel/framework/contracts/config" + "github.com/goravel/framework/contracts/console" + "github.com/goravel/framework/contracts/console/command" + "github.com/goravel/framework/contracts/database" + "github.com/goravel/framework/contracts/database/schema" + "github.com/goravel/framework/support/str" +) + +type ShowCommand struct { + config config.Config + schema schema.Schema +} + +type databaseInfo struct { + Name string + Version string + Database string + Host string + Port string + Username string + OpenConnections string + Tables []schema.Table `gorm:"-"` + Views []schema.View `gorm:"-"` +} + +type queryResult struct{ Value string } + +func NewShowCommand(config config.Config, schema schema.Schema) *ShowCommand { + return &ShowCommand{ + config: config, + schema: schema, + } +} + +// Signature The name and signature of the console command. +func (r *ShowCommand) Signature() string { + return "db:show" +} + +// Description The console command description. +func (r *ShowCommand) Description() string { + return "Display information about the given database" +} + +// Extend The console command extend. +func (r *ShowCommand) Extend() command.Extend { + return command.Extend{ + Category: "db", + Flags: []command.Flag{ + &command.StringFlag{ + Name: "database", + Aliases: []string{"d"}, + Usage: "The database connection", + }, + &command.BoolFlag{ + Name: "views", + Aliases: []string{"v"}, + Usage: "Show the database views", + }, + }, + } +} + +// Handle Execute the console command. +func (r *ShowCommand) Handle(ctx console.Context) error { + if got := ctx.Argument(0); len(got) > 0 { + ctx.Error(fmt.Sprintf("No arguments expected for '%s' command, got '%s'.", r.Signature(), got)) + return nil + } + r.schema = r.schema.Connection(ctx.Option("database")) + connection := r.schema.GetConnection() + getConfigValue := func(k string) string { + return r.config.GetString("database.connections." + connection + "." + k) + } + info := databaseInfo{ + Database: getConfigValue("database"), + Host: getConfigValue("host"), + Port: getConfigValue("port"), + Username: getConfigValue("username"), + } + var err error + info.Name, info.Version, info.OpenConnections, err = r.getDataBaseInfo() + if err != nil { + ctx.Error(err.Error()) + return nil + } + if info.Tables, err = r.schema.GetTables(); err != nil { + ctx.Error(err.Error()) + return nil + } + if ctx.OptionBool("views") { + if info.Views, err = r.schema.GetViews(); err != nil { + ctx.Error(err.Error()) + return nil + } + } + r.display(ctx, info) + return nil +} + +func (r *ShowCommand) getDataBaseInfo() (name, version, openConnections string, err error) { + var ( + drivers = map[database.Driver]struct { + name string + versionQuery string + openConnectionsQuery string + }{ + database.DriverSqlite: { + name: "SQLite", + versionQuery: "SELECT sqlite_version() AS value;", + }, + database.DriverMysql: { + name: "MySQL", + versionQuery: "SELECT VERSION() AS value;", + openConnectionsQuery: "SHOW status WHERE variable_name = 'threads_connected';", + }, + database.DriverPostgres: { + name: "PostgresSQL", + versionQuery: "SELECT current_setting('server_version') AS value;", + openConnectionsQuery: "SELECT COUNT(*) AS value FROM pg_stat_activity;", + }, + database.DriverSqlserver: { + name: "SQL Server", + versionQuery: "SELECT SERVERPROPERTY('productversion') AS value;", + openConnectionsQuery: "SELECT COUNT(*) Value FROM sys.dm_exec_sessions WHERE status = 'running';", + }, + } + ) + name = string(r.schema.Orm().Query().Driver()) + if driver, ok := drivers[r.schema.Orm().Query().Driver()]; ok { + name = driver.name + var versionResult queryResult + if err = r.schema.Orm().Query().Raw(driver.versionQuery).Scan(&versionResult); err == nil { + version = versionResult.Value + if strings.Contains(version, "MariaDB") { + name = "MariaDB" + } + if len(driver.openConnectionsQuery) > 0 { + var openConnectionsResult queryResult + if err = r.schema.Orm().Query().Raw(driver.openConnectionsQuery).Scan(&openConnectionsResult); err == nil { + openConnections = openConnectionsResult.Value + } + } + } + } + return +} + +func (r *ShowCommand) display(ctx console.Context, info databaseInfo) { + ctx.NewLine() + ctx.TwoColumnDetail(fmt.Sprintf("%s", info.Name), info.Version) + ctx.TwoColumnDetail("Database", info.Database) + ctx.TwoColumnDetail("Host", info.Host) + ctx.TwoColumnDetail("Port", info.Port) + ctx.TwoColumnDetail("Username", info.Username) + ctx.TwoColumnDetail("Open Connections", info.OpenConnections) + ctx.TwoColumnDetail("Tables", fmt.Sprintf("%d", len(info.Tables))) + if size := func() (size int) { + for i := range info.Tables { + size += info.Tables[i].Size + } + return + }(); size > 0 { + ctx.TwoColumnDetail("Total Size", fmt.Sprintf("%.3fMiB", float64(size)/1024/1024)) + } + ctx.NewLine() + if len(info.Tables) > 0 { + ctx.TwoColumnDetail("Tables", "Size (MiB)") + for i := range info.Tables { + ctx.TwoColumnDetail(info.Tables[i].Name, fmt.Sprintf("%.3f", float64(info.Tables[i].Size)/1024/1024)) + } + ctx.NewLine() + } + if len(info.Views) > 0 { + ctx.TwoColumnDetail("Views", "Rows") + for i := range info.Views { + if !str.Of(info.Views[i].Name).StartsWith("pg_catalog", "information_schema", "spt_") { + var rows int64 + if err := r.schema.Orm().Query().Table(info.Views[i].Name).Count(&rows); err != nil { + ctx.Error(err.Error()) + return + } + ctx.TwoColumnDetail(info.Views[i].Name, fmt.Sprintf("%d", rows)) + } + } + ctx.NewLine() + } +} diff --git a/database/console/show_command_test.go b/database/console/show_command_test.go new file mode 100644 index 000000000..d05ec011c --- /dev/null +++ b/database/console/show_command_test.go @@ -0,0 +1,175 @@ +package console + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/goravel/framework/contracts/database" + "github.com/goravel/framework/contracts/database/schema" + mocksconfig "github.com/goravel/framework/mocks/config" + mocksconsole "github.com/goravel/framework/mocks/console" + mocksorm "github.com/goravel/framework/mocks/database/orm" + mocksschema "github.com/goravel/framework/mocks/database/schema" + "github.com/goravel/framework/support/color" +) + +func TestShowCommand(t *testing.T) { + var ( + mockContext *mocksconsole.Context + mockConfig *mocksconfig.Config + mockSchema *mocksschema.Schema + mockOrm *mocksorm.Orm + mockQuery *mocksorm.Query + ) + + beforeEach := func() { + mockContext = mocksconsole.NewContext(t) + mockConfig = mocksconfig.NewConfig(t) + mockSchema = mocksschema.NewSchema(t) + mockOrm = mocksorm.NewOrm(t) + mockQuery = mocksorm.NewQuery(t) + } + successCaseExpected := [][2]string{ + {"MariaDB", "MariaDB"}, + {"Database", "db"}, + {"Host", "host"}, + {"Port", "port"}, + {"Username", "username"}, + {"Open Connections", "2"}, + {"Tables", "1"}, + {"Total Size", "0.000MiB"}, + {"Tables", "Size (MiB)"}, + {"test", "0.000"}, + {"Views", "Rows"}, + {"test", "0"}, + } + tests := []struct { + name string + setup func() + expected string + }{ + { + name: "invalid argument", + setup: func() { + mockContext.EXPECT().Argument(0).Return("test").Once() + mockContext.EXPECT().Error("No arguments expected for 'db:show' command, got 'test'.").Run(func(message string) { + color.Errorln(message) + }).Once() + }, + expected: "No arguments expected for 'db:show' command, got 'test'.", + }, + { + name: "get tables failed", + setup: func() { + mockContext.EXPECT().Argument(0).Return("").Once() + mockContext.EXPECT().Option("database").Return("test").Once() + mockSchema.EXPECT().Connection("test").Return(mockSchema).Once() + mockSchema.EXPECT().GetConnection().Return("test").Once() + mockConfig.EXPECT().GetString("database.connections.test.database").Return("db").Once() + mockConfig.EXPECT().GetString("database.connections.test.host").Return("host").Once() + mockConfig.EXPECT().GetString("database.connections.test.port").Return("port").Once() + mockConfig.EXPECT().GetString("database.connections.test.username").Return("username").Once() + mockQuery.EXPECT().Driver().Return(database.DriverMysql).Twice() + mockOrm.EXPECT().Query().Return(mockQuery).Times(4) + mockSchema.EXPECT().Orm().Return(mockOrm).Times(4) + mockQuery.EXPECT().Raw("SELECT VERSION() AS value;").Return(mockQuery).Once() + mockQuery.EXPECT().Raw("SHOW status WHERE variable_name = 'threads_connected';").Return(mockQuery).Once() + mockQuery.EXPECT().Scan(&queryResult{}).Return(nil).Twice() + mockSchema.EXPECT().GetTables().Return(nil, assert.AnError).Once() + mockContext.EXPECT().Error(assert.AnError.Error()).Run(func(message string) { + color.Errorln(message) + }).Once() + }, + expected: assert.AnError.Error(), + }, { + name: "get views failed", + setup: func() { + mockContext.EXPECT().Argument(0).Return("").Once() + mockContext.EXPECT().Option("database").Return("test").Once() + mockSchema.EXPECT().Connection("test").Return(mockSchema).Once() + mockSchema.EXPECT().GetConnection().Return("test").Once() + mockConfig.EXPECT().GetString("database.connections.test.database").Return("db").Once() + mockConfig.EXPECT().GetString("database.connections.test.host").Return("host").Once() + mockConfig.EXPECT().GetString("database.connections.test.port").Return("port").Once() + mockConfig.EXPECT().GetString("database.connections.test.username").Return("username").Once() + mockQuery.EXPECT().Driver().Return(database.DriverMysql).Twice() + mockOrm.EXPECT().Query().Return(mockQuery).Times(4) + mockSchema.EXPECT().Orm().Return(mockOrm).Times(4) + mockQuery.EXPECT().Raw("SELECT VERSION() AS value;").Return(mockQuery).Once() + mockQuery.EXPECT().Raw("SHOW status WHERE variable_name = 'threads_connected';").Return(mockQuery).Once() + mockQuery.EXPECT().Scan(&queryResult{}).Return(nil).Twice() + mockSchema.EXPECT().GetTables().Return(nil, nil).Once() + mockContext.EXPECT().OptionBool("views").Return(true).Once() + mockSchema.EXPECT().GetViews().Return(nil, assert.AnError) + mockContext.EXPECT().Error(assert.AnError.Error()).Run(func(message string) { + color.Errorln(message) + }) + }, + expected: assert.AnError.Error(), + }, { + name: "success", + setup: func() { + mockContext.EXPECT().Argument(0).Return("").Once() + mockContext.EXPECT().Option("database").Return("test").Once() + mockSchema.EXPECT().Connection("test").Return(mockSchema).Once() + mockSchema.EXPECT().GetConnection().Return("test").Once() + mockConfig.EXPECT().GetString("database.connections.test.database").Return("db").Once() + mockConfig.EXPECT().GetString("database.connections.test.host").Return("host").Once() + mockConfig.EXPECT().GetString("database.connections.test.port").Return("port").Once() + mockConfig.EXPECT().GetString("database.connections.test.username").Return("username").Once() + mockQuery.EXPECT().Driver().Return(database.DriverMysql).Twice() + mockOrm.EXPECT().Query().Return(mockQuery).Times(5) + mockSchema.EXPECT().Orm().Return(mockOrm).Times(5) + mockQuery.EXPECT().Raw("SELECT VERSION() AS value;").Return(mockQuery).Once() + mockQuery.EXPECT().Raw("SHOW status WHERE variable_name = 'threads_connected';").Return(mockQuery).Once() + mockQuery.EXPECT().Scan(&queryResult{}).Run(func(dest interface{}) { + if d, ok := dest.(*queryResult); ok { + d.Value = "MariaDB" + } + }).Return(nil).Once() + mockQuery.EXPECT().Scan(&queryResult{}).Run(func(dest interface{}) { + if d, ok := dest.(*queryResult); ok { + d.Value = "2" + } + }).Return(nil).Once() + mockSchema.EXPECT().GetTables().Return([]schema.Table{ + {Name: "test", Size: 100}, + }, nil).Once() + mockContext.EXPECT().OptionBool("views").Return(true).Once() + mockSchema.EXPECT().GetViews().Return([]schema.View{ + {Name: "test"}, + }, nil).Once() + mockQuery.EXPECT().Table("test").Return(mockQuery).Once() + var rows int64 + mockQuery.EXPECT().Count(&rows).Return(nil).Once() + mockContext.EXPECT().NewLine().Times(4) + for i := range successCaseExpected { + mockContext.EXPECT().TwoColumnDetail(successCaseExpected[i][0], successCaseExpected[i][1]).Run(func(first string, second string, filler ...rune) { + color.Default().Printf("%s %s\n", first, second) + }).Once() + } + }, + expected: func() string { + var result string + for i := range successCaseExpected { + result += color.Default().Sprintf("%s %s\n", successCaseExpected[i][0], successCaseExpected[i][1]) + } + return result + }(), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + beforeEach() + test.setup() + command := NewShowCommand(mockConfig, mockSchema) + assert.Contains(t, color.CaptureOutput(func(_ io.Writer) { + assert.NoError(t, command.Handle(mockContext)) + }), test.expected) + }) + } + +} diff --git a/database/service_provider.go b/database/service_provider.go index 1103e6d1e..111943043 100644 --- a/database/service_provider.go +++ b/database/service_provider.go @@ -112,6 +112,7 @@ func (r *ServiceProvider) registerCommands(app foundation.Application) { console.NewSeedCommand(config, seeder), console.NewSeederMakeCommand(), console.NewFactoryMakeCommand(), + console.NewShowCommand(config, schema), console.NewWipeCommand(config, schema), }) }