Skip to content

Commit

Permalink
sql/mysql: set mariadb as a drivernmae on creation
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m committed Dec 18, 2024
1 parent 45b905e commit 337a219
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 28 deletions.
19 changes: 14 additions & 5 deletions cmd/atlas/internal/migrate/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"ariga.io/atlas/cmd/atlas/internal/migrate/ent"
"ariga.io/atlas/cmd/atlas/internal/migrate/ent/revision"
"ariga.io/atlas/sql/migrate"
"ariga.io/atlas/sql/mysql"
"ariga.io/atlas/sql/schema"
"ariga.io/atlas/sql/sqlclient"
"ariga.io/atlas/sql/sqltool"
Expand Down Expand Up @@ -50,6 +51,14 @@ type (
Option func(*EntRevisions) error
)

// Dialect returns the "ent dialect" of the Ent client.
func (r *EntRevisions) Dialect() string {
if r.ac.Name == mysql.DriverMaria {
return mysql.DriverName // Ent does not support "mariadb" as dialect.
}
return r.ac.Name
}

// RevisionsForClient creates a new RevisionReadWriter for the given sqlclient.Client.
func RevisionsForClient(ctx context.Context, ac *sqlclient.Client, schema string) (RevisionReadWriter, error) {
// If the driver supports the RevisionReadWriter interface, use it.
Expand Down Expand Up @@ -77,9 +86,9 @@ func NewEntRevisions(ctx context.Context, ac *sqlclient.Client, opts ...Option)
}
}
// Create the connection with the underlying migrate.Driver to have it inside a possible transaction.
entopts := []ent.Option{ent.Driver(sql.NewDriver(r.ac.Name, sql.Conn{ExecQuerier: r.ac.Driver}))}
entopts := []ent.Option{ent.Driver(sql.NewDriver(r.Dialect(), sql.Conn{ExecQuerier: r.ac.Driver}))}
// SQLite does not support multiple schema, therefore schema-config is only needed for other dialects.
if r.ac.Name != dialect.SQLite {
if r.Dialect() != dialect.SQLite {
// Make sure the schema to store the revisions table in does exist.
_, err := r.ac.InspectSchema(ctx, r.schema, &schema.InspectOptions{Mode: schema.InspectSchemas})
if err != nil && !schema.IsNotExistError(err) {
Expand Down Expand Up @@ -189,17 +198,17 @@ func (r *EntRevisions) DeleteRevision(ctx context.Context, v string) error {
// execution in a transaction and assumes the underlying connection is of type *sql.DB, which is not true for actually
// reading and writing revisions.
func (r *EntRevisions) Migrate(ctx context.Context) (err error) {
c := ent.NewClient(ent.Driver(sql.OpenDB(r.ac.Name, r.ac.DB)))
c := ent.NewClient(ent.Driver(sql.OpenDB(r.Dialect(), r.ac.DB)))
// Ensure the ent client is bound to the requested revision schema. Open a new connection, if not.
if r.ac.Name != dialect.SQLite && r.ac.URL.Schema != r.schema {
if r.Dialect() != dialect.SQLite && r.ac.URL.Schema != r.schema {
sc, err := sqlclient.OpenURL(ctx, r.ac.URL.URL, sqlclient.OpenSchema(r.schema))
if err != nil {
return err
}
defer sc.Close()
c = ent.NewClient(ent.Driver(sql.OpenDB(sc.Name, sc.DB)))
}
if r.ac.Name == dialect.SQLite {
if r.Dialect() == dialect.SQLite {
var on sql.NullBool
if err := r.ac.DB.QueryRowContext(ctx, "PRAGMA foreign_keys").Scan(&on); err != nil {
return err
Expand Down
52 changes: 29 additions & 23 deletions sql/mysql/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,24 @@ var _ interface {
schema.TypeParseFormatter
} = (*Driver)(nil)

// DriverName holds the name used for registration.
const DriverName = "mysql"
// DriverName and DriverMaria holds the names used for registration.
const (
DriverName = "mysql"
DriverMaria = "mariadb"
)

func init() {
sqlclient.Register(
DriverName,
sqlclient.OpenerFunc(opener),
opener(DriverName),
sqlclient.RegisterDriverOpener(Open),
sqlclient.RegisterCodec(codec, codec),
sqlclient.RegisterFlavours("mysql+unix"),
sqlclient.RegisterURLParser(parser{}),
)
sqlclient.Register(
"mariadb",
sqlclient.OpenerFunc(opener),
DriverMaria,
opener(DriverMaria),
sqlclient.RegisterDriverOpener(Open),
sqlclient.RegisterCodec(mariaCodec, mariaCodec),
sqlclient.RegisterFlavours("mariadb+unix", "maria", "maria+unix"),
Expand Down Expand Up @@ -97,26 +100,29 @@ func Open(db schema.ExecQuerier) (migrate.Driver, error) {
}, nil
}

func opener(_ context.Context, u *url.URL) (*sqlclient.Client, error) {
ur := parser{}.ParseURL(u)
db, err := sql.Open(DriverName, ur.DSN)
if err != nil {
return nil, err
}
drv, err := Open(db)
if err != nil {
if cerr := db.Close(); cerr != nil {
err = fmt.Errorf("%w: %v", err, cerr)
// opener for the given driver name.
func opener(name string) sqlclient.OpenerFunc {
return func(_ context.Context, u *url.URL) (*sqlclient.Client, error) {
ur := parser{}.ParseURL(u)
db, err := sql.Open(DriverName, ur.DSN)
if err != nil {
return nil, err
}
return nil, err
drv, err := Open(db)
if err != nil {
if cerr := db.Close(); cerr != nil {
err = fmt.Errorf("%w: %v", err, cerr)
}
return nil, err
}
drv.(*Driver).schema = ur.Schema
return &sqlclient.Client{
Name: name,
DB: db,
URL: ur,
Driver: drv,
}, nil
}
drv.(*Driver).schema = ur.Schema
return &sqlclient.Client{
Name: DriverName,
DB: db,
URL: ur,
Driver: drv,
}, nil
}

// NormalizeRealm returns the normal representation of the given database.
Expand Down

0 comments on commit 337a219

Please sign in to comment.