Skip to content

Commit

Permalink
fixes #12: pass a default migrations directory in the Migrate constru…
Browse files Browse the repository at this point in the history
…ctor
  • Loading branch information
miguelgrinberg committed Dec 20, 2013
1 parent 3d364a8 commit 189dbb5
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 21 deletions.
56 changes: 36 additions & 20 deletions flask_migrate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,107 @@
import os
from flask import current_app
from flask.ext.script import Manager
from alembic.config import Config as AlembicConfig
from alembic import command

class _MigrateConfig(object):
def __init__(self, db, directory):
self.db = db
self.directory = directory

@property
def metadata(self):
"""Backwards compatibility, in old releases app.extensions['migrate']
was set to db, and env.py accessed app.extensions['migrate'].metadata"""
return self.db.metadata

class Migrate(object):
def __init__(self, app = None, db = None):
def __init__(self, app = None, db = None, directory = 'migrations'):
if app is not None and db is not None:
self.init_app(app, db)
self.init_app(app, db, directory)

def init_app(self, app, db):
def init_app(self, app, db, directory = 'migrations'):
if not hasattr(app, 'extensions'):
app.extensions = {}
app.extensions['migrate'] = db
app.extensions['migrate'] = _MigrateConfig(db, directory)

class Config(AlembicConfig):
def get_template_directory(self):
package_dir = os.path.abspath(os.path.dirname(__file__))
return os.path.join(package_dir, 'templates')

def _get_config(directory):
if directory is None:
directory = current_app.extensions['migrate'].directory
config = Config(os.path.join(directory, 'alembic.ini'))
config.set_main_option('script_location', directory)
return config

MigrateCommand = Manager(usage = 'Perform database migrations')

@MigrateCommand.option('-d', '--directory', dest = 'directory', default = 'migrations', help = "migration script directory (default is 'migrations')")
def init(directory = 'migrations'):
@MigrateCommand.option('-d', '--directory', dest = 'directory', default = None, help = "migration script directory (default is 'migrations')")
def init(directory = None):
"Generates a new migration"
if directory is None:
directory = current_app.extensions['migrate'].directory
config = Config()
config.set_main_option('script_location', directory)
config.config_file_name = os.path.join(directory, 'alembic.ini')
command.init(config, directory, 'flask')

@MigrateCommand.option('-d', '--directory', dest = 'directory', default = 'migrations', help = "Migration script directory (default is 'migrations')")
def current(directory = 'migrations'):
@MigrateCommand.option('-d', '--directory', dest = 'directory', default = None, help = "Migration script directory (default is 'migrations')")
def current(directory = None):
"Display the current revision for each database."
config = _get_config(directory)
command.current(config)

@MigrateCommand.option('-r', '--rev-range', dest = 'rev_range', default = None, help = "Specify a revision range; format is [start]:[end]")
@MigrateCommand.option('-d', '--directory', dest = 'directory', default = 'migrations', help = "Migration script directory (default is 'migrations')")
def history(directory = 'migrations', rev_range = None):
@MigrateCommand.option('-d', '--directory', dest = 'directory', default = None, help = "Migration script directory (default is 'migrations')")
def history(directory = None, rev_range = None):
"List changeset scripts in chronological order."
config = _get_config(directory)
command.history(config, rev_range)

@MigrateCommand.option('--sql', dest = 'sql', action = 'store_true', default = False, help = "Don't emit SQL to database - dump to standard output instead")
@MigrateCommand.option('--autogenerate', dest = 'autogenerate', action = 'store_true', default = False, help = "Populate revision script with andidate migration operatons, based on comparison of database to model")
@MigrateCommand.option('-m', '--message', dest = 'message', default = None)
@MigrateCommand.option('-d', '--directory', dest = 'directory', default = 'migrations', help = "Migration script directory (default is 'migrations')")
def revision(directory = 'migrations', message = None, autogenerate = False, sql = False):
@MigrateCommand.option('-d', '--directory', dest = 'directory', default = None, help = "Migration script directory (default is 'migrations')")
def revision(directory = None, message = None, autogenerate = False, sql = False):
"Create a new revision file."
config = _get_config(directory)
command.revision(config, message, autogenerate = autogenerate, sql = sql)

@MigrateCommand.option('--sql', dest = 'sql', action = 'store_true', default = False, help = "Don't emit SQL to database - dump to standard output instead")
@MigrateCommand.option('-m', '--message', dest = 'message', default = None)
@MigrateCommand.option('-d', '--directory', dest = 'directory', default = 'migrations', help = "Migration script directory (default is 'migrations')")
def migrate(directory = 'migrations', message = None, sql = False):
@MigrateCommand.option('-d', '--directory', dest = 'directory', default = None, help = "Migration script directory (default is 'migrations')")
def migrate(directory = None, message = None, sql = False):
"Alias for 'revision --autogenerate'"
config = _get_config(directory)
command.revision(config, message, autogenerate = True, sql = sql)

@MigrateCommand.option('--tag', dest = 'tag', default = None, help = "Arbitrary 'tag' name - can be used by custom env.py scripts")
@MigrateCommand.option('--sql', dest = 'sql', action = 'store_true', default = False, help = "Don't emit SQL to database - dump to standard output instead")
@MigrateCommand.option('revision', default = None, help = "revision identifier")
@MigrateCommand.option('-d', '--directory', dest = 'directory', default = 'migrations', help = "Migration script directory (default is 'migrations')")
def stamp(directory = 'migrations', revision = 'head', sql = False, tag = None):
@MigrateCommand.option('-d', '--directory', dest = 'directory', default = None, help = "Migration script directory (default is 'migrations')")
def stamp(directory = None, revision = 'head', sql = False, tag = None):
"'stamp' the revision table with the given revision; don't run any migrations"
config = _get_config(directory)
command.stamp(config, revision, sql = sql, tag = tag)

@MigrateCommand.option('--tag', dest = 'tag', default = None, help = "Arbitrary 'tag' name - can be used by custom env.py scripts")
@MigrateCommand.option('--sql', dest = 'sql', action = 'store_true', default = False, help = "Don't emit SQL to database - dump to standard output instead")
@MigrateCommand.option('revision', nargs = '?', default = 'head', help = "revision identifier")
@MigrateCommand.option('-d', '--directory', dest = 'directory', default = 'migrations', help = "Migration script directory (default is 'migrations')")
def upgrade(directory = 'migrations', revision = 'head', sql = False, tag = None):
@MigrateCommand.option('-d', '--directory', dest = 'directory', default = None, help = "Migration script directory (default is 'migrations')")
def upgrade(directory = None, revision = 'head', sql = False, tag = None):
"Upgrade to a later version"
config = _get_config(directory)
command.upgrade(config, revision, sql = sql, tag = tag)

@MigrateCommand.option('--tag', dest = 'tag', default = None, help = "Arbitrary 'tag' name - can be used by custom env.py scripts")
@MigrateCommand.option('--sql', dest = 'sql', action = 'store_true', default = False, help = "Don't emit SQL to database - dump to standard output instead")
@MigrateCommand.option('revision', nargs = '?', default = "-1", help = "revision identifier")
@MigrateCommand.option('-d', '--directory', dest = 'directory', default = 'migrations', help = "Migration script directory (default is 'migrations')")
def downgrade(directory = 'migrations', revision = '-1', sql = False, tag = None):
@MigrateCommand.option('-d', '--directory', dest = 'directory', default = None, help = "Migration script directory (default is 'migrations')")
def downgrade(directory = None, revision = '-1', sql = False, tag = None):
"Revert to a previous version"
config = _get_config(directory)
command.downgrade(config, revision, sql = sql, tag = tag)
2 changes: 1 addition & 1 deletion flask_migrate/templates/flask/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# target_metadata = mymodel.Base.metadata
from flask import current_app
config.set_main_option('sqlalchemy.url', current_app.config.get('SQLALCHEMY_DATABASE_URI'))
target_metadata = current_app.extensions['migrate'].metadata
target_metadata = current_app.extensions['migrate'].db.metadata

# other values from the config, defined by the needs of env.py,
# can be acquired:
Expand Down
21 changes: 21 additions & 0 deletions tests/app2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from flask import Flask
from flask.ext.sqlalchemy import SQLAlchemy
from flask.ext.script import Manager
from flask.ext.migrate import Migrate, MigrateCommand

app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///app2.db'

db = SQLAlchemy(app)
migrate = Migrate(app, db, directory = 'temp_folder/temp_migrations')

manager = Manager(app)
manager.add_command('db', MigrateCommand)

class User(db.Model):
id = db.Column(db.Integer, primary_key = True)
name = db.Column(db.String(128))

if __name__ == '__main__':
manager.run()

39 changes: 39 additions & 0 deletions tests/test_migrate_custom_directory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import shutil
from exceptions import OSError
import unittest

class TestMigrate(unittest.TestCase):
def setUp(self):
os.chdir(os.path.split(os.path.abspath(__file__))[0])
try:
os.remove('app2.db')
except OSError:
pass
try:
shutil.rmtree('temp_folder')
except OSError:
pass

os.system('python app2.py db init')
os.system('python app2.py db migrate')
os.system('python app2.py db upgrade')

def tearDown(self):
try:
os.remove('app2.db')
except OSError:
pass
try:
shutil.rmtree('migrations')
except OSError:
pass

def test_migrate_upgrade(self):
from app2 import db, User
db.session.add(User(name = 'test'))
db.session.commit()

if __name__ == '__main__':
unittest.main()

0 comments on commit 189dbb5

Please sign in to comment.