From 9ead91ad5baa0daf5527bd9a13bac2447ec7f872 Mon Sep 17 00:00:00 2001 From: almas1992 Date: Tue, 24 Dec 2024 09:14:38 +0800 Subject: [PATCH] feat: Add artisan db:show command --- database/console/show_command.go | 181 ++++++++++++++++++++++++++ database/console/show_command_test.go | 135 +++++++++++++++++++ database/service_provider.go | 1 + 3 files changed, 317 insertions(+) create mode 100644 database/console/show_command.go create mode 100644 database/console/show_command_test.go diff --git a/database/console/show_command.go b/database/console/show_command.go new file mode 100644 index 000000000..81b0276d6 --- /dev/null +++ b/database/console/show_command.go @@ -0,0 +1,181 @@ +package console + +import ( + "fmt" + "strings" + + "github.com/goravel/framework/contracts/config" + "github.com/goravel/framework/contracts/console" + "github.com/goravel/framework/contracts/console/command" + "github.com/goravel/framework/contracts/database" + "github.com/goravel/framework/contracts/database/schema" + "github.com/goravel/framework/support/str" +) + +type ShowCommand struct { + config config.Config + schema schema.Schema +} + +type databaseInfo struct { + Name string + Version string + Database string + Host string + Port string + Username string + OpenConnections string + Tables []schema.Table `gorm:"-"` + Views []schema.View `gorm:"-"` +} + +type queryResult struct{ Value string } + +func NewShowCommand(config config.Config, schema schema.Schema) *ShowCommand { + return &ShowCommand{ + config: config, + schema: schema, + } +} + +// Signature The name and signature of the console command. +func (r *ShowCommand) Signature() string { + return "db:show" +} + +// Description The console command description. +func (r *ShowCommand) Description() string { + return "Display information about the given database" +} + +// Extend The console command extend. +func (r *ShowCommand) Extend() command.Extend { + return command.Extend{ + Category: "db", + Flags: []command.Flag{ + &command.StringFlag{ + Name: "database", + Usage: "The database connection", + }, + &command.BoolFlag{ + Name: "views", + Usage: "Show the database views", + }, + }, + } +} + +// Handle Execute the console command. +func (r *ShowCommand) Handle(ctx console.Context) error { + if got := ctx.Argument(0); len(got) > 0 { + ctx.Error(fmt.Sprintf("No arguments expected for '%s' command, got '%s'.", r.Signature(), got)) + return nil + } + r.schema = r.schema.Connection(ctx.Option("database")) + getConfigValue := func(k string) string { + return r.config.GetString("database.connections." + r.schema.GetConnection() + "." + k) + } + info := databaseInfo{ + Database: getConfigValue("database"), + Host: getConfigValue("host"), + Port: getConfigValue("port"), + Username: getConfigValue("username"), + } + info.Name, info.Version, info.OpenConnections = r.getDataBaseInfo() + var err error + if info.Tables, err = r.schema.GetTables(); err != nil { + ctx.Error(err.Error()) + return nil + } + if ctx.OptionBool("views") { + if info.Views, err = r.schema.GetViews(); err != nil { + ctx.Error(err.Error()) + return nil + } + } + r.display(ctx, info) + return nil +} + +func (r *ShowCommand) getDataBaseInfo() (name, version, openConnections string) { + var ( + result queryResult + drivers = map[database.Driver]struct { + name string + versionQuery string + openConnectionsQuery string + }{ + database.DriverSqlite: { + name: "SQLite", + versionQuery: "SELECT sqlite_version() AS value;", + }, + database.DriverMysql: { + name: "MySQL", + versionQuery: "SELECT VERSION() AS value;", + openConnectionsQuery: "SHOW status WHERE variable_name = 'threads_connected';", + }, + database.DriverPostgres: { + name: "PostgresSQL", + versionQuery: "SELECT current_setting('server_version') AS value;", + openConnectionsQuery: "SELECT COUNT(*) AS value FROM pg_stat_activity;", + }, + database.DriverSqlserver: { + name: "SQL Server", + versionQuery: "SELECT SERVERPROPERTY('productversion') AS value;", + openConnectionsQuery: "SELECT COUNT(*) Value FROM sys.dm_exec_sessions WHERE status = 'running';", + }, + } + ) + name = string(r.schema.Orm().Query().Driver()) + if driver, ok := drivers[r.schema.Orm().Query().Driver()]; ok { + name = driver.name + _ = r.schema.Orm().Query().Raw(driver.versionQuery).Scan(&result) + version = result.Value + if strings.Contains(version, "MariaDB") { + name = "MariaDB" + } + if len(driver.openConnectionsQuery) > 0 { + _ = r.schema.Orm().Query().Raw(driver.openConnectionsQuery).Scan(&result) + openConnections = result.Value + } + } + return +} + +func (r *ShowCommand) display(ctx console.Context, info databaseInfo) { + ctx.NewLine() + ctx.TwoColumnDetail(fmt.Sprintf("%s", info.Name), info.Version) + ctx.TwoColumnDetail("Database", info.Database) + ctx.TwoColumnDetail("Host", info.Host) + ctx.TwoColumnDetail("Port", info.Port) + ctx.TwoColumnDetail("Username", info.Username) + ctx.TwoColumnDetail("Open Connections", info.OpenConnections) + ctx.TwoColumnDetail("Tables", fmt.Sprintf("%d", len(info.Tables))) + if size := func() (size int) { + for i := range info.Tables { + size += info.Tables[i].Size + } + return + }(); size > 0 { + ctx.TwoColumnDetail("Total Size", fmt.Sprintf("%.3fMiB", float64(size)/1024/1024)) + } + ctx.NewLine() + if len(info.Tables) > 0 { + ctx.TwoColumnDetail("Tables", "Size (MiB)") + for i := range info.Tables { + ctx.TwoColumnDetail(info.Tables[i].Name, fmt.Sprintf("%.3f", float64(info.Tables[i].Size)/1024/1024)) + } + ctx.NewLine() + } + if len(info.Views) > 0 { + ctx.TwoColumnDetail("Views", "Rows") + for i := range info.Views { + if !str.Of(info.Views[i].Name).StartsWith("pg_catalog", "information_schema", "spt_") { + var rows int64 + _ = r.schema.Orm().Query().Table(info.Views[i].Name).Count(&rows) + ctx.TwoColumnDetail(info.Views[i].Name, fmt.Sprintf("%d", rows)) + } + } + ctx.NewLine() + } +} diff --git a/database/console/show_command_test.go b/database/console/show_command_test.go new file mode 100644 index 000000000..8c5cc15c9 --- /dev/null +++ b/database/console/show_command_test.go @@ -0,0 +1,135 @@ +package console + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/goravel/framework/contracts/database" + "github.com/goravel/framework/contracts/database/schema" + mocksconfig "github.com/goravel/framework/mocks/config" + mocksconsole "github.com/goravel/framework/mocks/console" + mocksorm "github.com/goravel/framework/mocks/database/orm" + mocksschema "github.com/goravel/framework/mocks/database/schema" + "github.com/goravel/framework/support/color" +) + +func TestShowCommand(t *testing.T) { + var ( + mockContext *mocksconsole.Context + mockConfig *mocksconfig.Config + mockSchema *mocksschema.Schema + mockOrm *mocksorm.Orm + mockQuery *mocksorm.Query + ) + + beforeEach := func() { + mockContext = mocksconsole.NewContext(t) + mockConfig = mocksconfig.NewConfig(t) + mockSchema = mocksschema.NewSchema(t) + mockOrm = mocksorm.NewOrm(t) + mockQuery = mocksorm.NewQuery(t) + } + + tests := []struct { + name string + setup func() + expected string + }{ + { + name: "invalid argument", + setup: func() { + mockContext.EXPECT().Argument(0).Return("test") + mockContext.EXPECT().Error(mock.Anything).Run(func(message string) { + color.Errorln(message) + }) + }, + expected: "No arguments expected for 'db:show' command, got 'test'.", + }, + { + name: "get tables failed", + setup: func() { + mockContext.EXPECT().Argument(0).Return("") + mockContext.EXPECT().Option("database").Return("") + mockSchema.EXPECT().Connection(mock.Anything).Return(mockSchema) + mockSchema.EXPECT().GetConnection().Return("test") + mockConfig.EXPECT().GetString(mock.Anything).Return("test") + mockQuery.EXPECT().Driver().Return(database.DriverMysql) + mockOrm.EXPECT().Query().Return(mockQuery) + mockSchema.EXPECT().Orm().Return(mockOrm) + mockQuery.EXPECT().Raw(mock.Anything).Return(mockQuery) + mockQuery.EXPECT().Scan(mock.Anything).Return(nil) + mockSchema.EXPECT().GetTables().Return(nil, assert.AnError) + mockContext.EXPECT().Error(mock.Anything).Run(func(message string) { + color.Errorln(message) + }) + }, + expected: assert.AnError.Error(), + }, { + name: "get views failed", + setup: func() { + mockContext.EXPECT().Argument(0).Return("") + mockContext.EXPECT().Option("database").Return("") + mockSchema.EXPECT().Connection(mock.Anything).Return(mockSchema) + mockSchema.EXPECT().GetConnection().Return("test") + mockConfig.EXPECT().GetString(mock.Anything).Return("test") + mockQuery.EXPECT().Driver().Return(database.DriverMysql) + mockOrm.EXPECT().Query().Return(mockQuery) + mockSchema.EXPECT().Orm().Return(mockOrm) + mockQuery.EXPECT().Raw(mock.Anything).Return(mockQuery) + mockQuery.EXPECT().Scan(mock.Anything).Return(nil) + mockSchema.EXPECT().GetTables().Return(nil, nil) + mockContext.EXPECT().OptionBool("views").Return(true) + mockSchema.EXPECT().GetViews().Return(nil, assert.AnError) + mockContext.EXPECT().Error(mock.Anything).Run(func(message string) { + color.Errorln(message) + }) + }, + expected: assert.AnError.Error(), + }, { + name: "success", + setup: func() { + mockContext.EXPECT().Argument(0).Return("") + mockContext.EXPECT().Option("database").Return("") + mockSchema.EXPECT().Connection(mock.Anything).Return(mockSchema) + mockSchema.EXPECT().GetConnection().Return("test") + mockConfig.EXPECT().GetString(mock.Anything).Return("test") + mockQuery.EXPECT().Driver().Return(database.DriverMysql) + mockOrm.EXPECT().Query().Return(mockQuery) + mockSchema.EXPECT().Orm().Return(mockOrm) + mockQuery.EXPECT().Raw(mock.Anything).Return(mockQuery) + mockQuery.EXPECT().Scan(mock.Anything).RunAndReturn(func(dest interface{}) error { + if d, ok := dest.(*queryResult); ok { + d.Value = "MariaDB" + } + return nil + }) + mockSchema.EXPECT().GetTables().Return([]schema.Table{ + {Name: "test", Size: 100}, + }, nil) + mockContext.EXPECT().OptionBool("views").Return(true) + mockSchema.EXPECT().GetViews().Return([]schema.View{ + {Name: "test"}, + }, nil) + mockQuery.EXPECT().Table(mock.Anything).Return(mockQuery) + mockQuery.EXPECT().Count(mock.Anything).Return(nil) + mockContext.EXPECT().NewLine() + mockContext.EXPECT().TwoColumnDetail(mock.Anything, mock.Anything) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + beforeEach() + test.setup() + command := NewShowCommand(mockConfig, mockSchema) + assert.Contains(t, color.CaptureOutput(func(_ io.Writer) { + assert.NoError(t, command.Handle(mockContext)) + }), test.expected) + }) + } + +} diff --git a/database/service_provider.go b/database/service_provider.go index 1103e6d1e..111943043 100644 --- a/database/service_provider.go +++ b/database/service_provider.go @@ -112,6 +112,7 @@ func (r *ServiceProvider) registerCommands(app foundation.Application) { console.NewSeedCommand(config, seeder), console.NewSeederMakeCommand(), console.NewFactoryMakeCommand(), + console.NewShowCommand(config, schema), console.NewWipeCommand(config, schema), }) }