Skip to content

Commit

Permalink
example changes
Browse files Browse the repository at this point in the history
  • Loading branch information
justanr committed May 8, 2015
1 parent fe67c63 commit 03ea88c
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions flask_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def _make_table(*args, **kwargs):
return _make_table


def _set_default_query_class(d):
def _set_default_query_class(d, cls):
if 'query_class' not in d:
d['query_class'] = BaseQuery
d['query_class'] = cls


def _wrap_with_default_query_class(fn):
def _wrap_with_default_query_class(fn, cls):
@functools.wraps(fn)
def newfn(*args, **kwargs):
_set_default_query_class(kwargs)
_set_default_query_class(kwargs, cls)
if "backref" in kwargs:
backref = kwargs['backref']
if isinstance(backref, string_types):
Expand All @@ -86,16 +86,16 @@ def newfn(*args, **kwargs):
return newfn


def _include_sqlalchemy(obj):
def _include_sqlalchemy(obj, cls):
for module in sqlalchemy, sqlalchemy.orm:
for key in module.__all__:
if not hasattr(obj, key):
setattr(obj, key, getattr(module, key))
# Note: obj.Table does not attempt to be a SQLAlchemy Table class.
obj.Table = _make_table(obj)
obj.relationship = _wrap_with_default_query_class(obj.relationship)
obj.relation = _wrap_with_default_query_class(obj.relation)
obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader)
obj.relationship = _wrap_with_default_query_class(obj.relationship, cls)
obj.relation = _wrap_with_default_query_class(obj.relation, cls)
obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader, cls)
obj.event = event


Expand Down Expand Up @@ -725,19 +725,20 @@ class User(db.Model):
naming conventions among other, non-trivial things.
"""

def __init__(self, app=None, use_native_unicode=True, session_options=None, metadata=None):
def __init__(self, app=None, use_native_unicode=True, session_options=None,
metadata=None, query_class=BaseQuery, model_class=Model):

if session_options is None:
session_options = {}

session_options.setdefault('scopefunc', connection_stack.__ident_func__)
self.use_native_unicode = use_native_unicode
self.session = self.create_scoped_session(session_options)
self.Model = self.make_declarative_base(metadata)
self.Query = BaseQuery
self.Model = self.make_declarative_base(metadata, model_class)
self.Query = query_class
self._engine_lock = Lock()
self.app = app
_include_sqlalchemy(self)
_include_sqlalchemy(self, query_class)

if app is not None:
self.init_app(app)
Expand Down Expand Up @@ -765,9 +766,9 @@ def create_session(self, options):
"""
return SignallingSession(self, **options)

def make_declarative_base(self, metadata=None):
def make_declarative_base(self, metadata=None, model_class):
"""Creates the declarative base."""
base = declarative_base(cls=Model, name='Model',
base = declarative_base(cls=model_class, name='Model',
metadata=metadata,
metaclass=_BoundDeclarativeMeta)
base.query = _QueryProperty(self)
Expand Down

0 comments on commit 03ea88c

Please sign in to comment.