Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sql alchemy orm compatability #54

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
13 changes: 0 additions & 13 deletions setup.cfg

This file was deleted.

64 changes: 41 additions & 23 deletions sqlacodegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_flask_prepend = 'db.'

_dataclass = False
_sqla_orm = False


class _DummyInflectEngine(object):
Expand Down Expand Up @@ -168,7 +169,11 @@ def _render_column(column, show_name):
server_default = 'server_default=' + _flask_prepend + 'FetchedValue()'

comment = getattr(column, 'comment', None)
return _flask_prepend + 'Column({0})'.format(', '.join(
if _sqla_orm:
column_string = 'mapped_column'
else:
column_string = 'Column'
return _flask_prepend + '{0}({1})'.format(column_string, ', '.join(
([repr(column.name)] if show_name else []) +
([_render_column_type(column.type)] if render_coltype else []) +
[_render_constraint(x) for x in dedicated_fks] +
Expand Down Expand Up @@ -227,7 +232,8 @@ def _render_index(index):
class ImportCollector(OrderedDict):
def add_import(self, obj):
type_ = type(obj) if not isinstance(obj, type) else obj
pkgname = 'sqlalchemy' if type_.__name__ in sqlalchemy.__all__ else type_.__module__ # @UndefinedVariable
pkgname = 'sqlalchemy' if hasattr(sqlalchemy, type_.__name__) else type_.__module__ # @UndefinedVariable
# pkgname = 'sqlalchemy' if type_.__name__ in sqlalchemy.__all__ else type_.__module__ # @UndefinedVariable
self.add_literal_import(pkgname, type_.__name__)

def add_literal_import(self, pkgname, name):
Expand Down Expand Up @@ -326,7 +332,7 @@ def __init__(self, table, association_tables, inflect_engine, detect_joined, col
if _dataclass:
if column.type.python_type.__module__ != 'builtins':
collector.add_literal_import(column.type.python_type.__module__, column.type.python_type.__name__)


# Add many-to-one relationships
pk_column_names = set(col.name for col in table.primary_key.columns)
Expand Down Expand Up @@ -373,13 +379,12 @@ def add_imports(self, collector):
child.add_imports(collector)

def render(self):
global _dataclass

global _dataclass
text = 'class {0}({1}):\n'.format(self.name, self.parent_name)

if _dataclass:
text = '@dataclass\n' + text

text += ' __tablename__ = {0!r}\n'.format(self.table.name)

# Render constraints and indexes as __table_args__
Expand All @@ -397,7 +402,6 @@ def render(self):
table_kwargs = {}
if self.schema:
table_kwargs['schema'] = self.schema

kwargs_items = ', '.join('{0!r}: {1!r}'.format(key, table_kwargs[key]) for key in table_kwargs)
kwargs_items = '{{{0}}}'.format(kwargs_items) if kwargs_items else None
if table_kwargs and not table_args:
Expand All @@ -414,9 +418,12 @@ def render(self):
for attr, column in self.attributes.items():
if isinstance(column, Column):
show_name = attr != column.name
if _dataclass:
text += ' ' + attr + ' : ' + column.type.python_type.__name__ + '\n'

if _dataclass:
if _sqla_orm:
text += ' ' + attr + ' : ' + 'Mapped[{0}]\n'.format(column.type.python_type.__name__)
else:
text += ' ' + attr + ' : ' + column.type.python_type.__name__ + '\n'

text += ' {0} = {1}\n'.format(attr, _render_column(column, show_name))

# Render relationships
Expand Down Expand Up @@ -452,7 +459,7 @@ def render(self):
delimiter, end = ', ', ')'

args.extend([key + '=' + value for key, value in self.kwargs.items()])

return _re_invalid_relationship.sub('_', text + delimiter.join(args) + end)

def make_backref(self, relationships, classes):
Expand Down Expand Up @@ -509,7 +516,7 @@ def __init__(self, source_cls, target_cls, constraint, inflect_engine):
# common_fk_constraints = _get_common_fk_constraints(constraint.table, constraint.elements[0].column.table)
# if len(common_fk_constraints) > 1:
# self.kwargs['primaryjoin'] = "'{0}.{1} == {2}.{3}'".format(source_cls, constraint.columns[0], target_cls, constraint.elements[0].column.name)
if len(constraint.elements) > 1: # or
if len(constraint.elements) > 1: # or
self.kwargs['primaryjoin'] = "'and_({0})'".format(', '.join(['{0}.{1} == {2}.{3}'.format(source_cls, k.parent.name, target_cls, k.column.name)
for k in constraint.elements]))
else:
Expand Down Expand Up @@ -550,9 +557,8 @@ class CodeGenerator(object):

def __init__(self, metadata, noindexes=False, noconstraints=False,
nojoined=False, noinflect=False, nobackrefs=False,
flask=False, ignore_cols=None, noclasses=False, nocomments=False, notables=False, dataclass=False):
flask=False, ignore_cols=None, noclasses=False, nocomments=False, notables=False, dataclass=False, sqla_orm=False):
super(CodeGenerator, self).__init__()

if noinflect:
inflect_engine = _DummyInflectEngine()
else:
Expand All @@ -561,19 +567,23 @@ def __init__(self, metadata, noindexes=False, noconstraints=False,

# exclude these column names from consideration when generating association tables
_ignore_columns = ignore_cols or []

self.flask = flask
if not self.flask:
global _flask_prepend
_flask_prepend = ''

self.nocomments = nocomments

self.dataclass = dataclass
if self.dataclass:
global _dataclass
_dataclass = True

self.sqla_orm = sqla_orm
global _sqla_orm
_sqla_orm = sqla_orm

# Pick association tables from the metadata into their own set, don't process them normally
links = defaultdict(lambda: [])
association_tables = set()
Expand Down Expand Up @@ -671,15 +681,20 @@ def __init__(self, metadata, noindexes=False, noconstraints=False,
if model.parent_name == 'Base':
model.parent_name = parent_name
else:
self.collector.add_literal_import('sqlalchemy.ext.declarative', 'declarative_base')
self.collector.add_literal_import('sqlalchemy', 'MetaData')


if self.sqla_orm:
self.collector.add_literal_import('sqlalchemy.orm', 'DeclarativeBase')
self.collector.add_literal_import('sqlalchemy.orm', 'Mapped')
self.collector.add_literal_import('sqlalchemy.orm', 'mapped_column')
else:
self.collector.add_literal_import('sqlalchemy.ext.declarative', 'declarative_base')
self.collector.add_literal_import('sqlalchemy', 'MetaData')


if self.dataclass:
self.collector.add_literal_import('dataclasses', 'dataclass')

def render(self, outfile=sys.stdout):

print(self.header, file=outfile)

# Render the collected imports
Expand All @@ -689,7 +704,10 @@ def render(self, outfile=sys.stdout):
print('db = SQLAlchemy()', file=outfile)
else:
if any(isinstance(model, ModelClass) for model in self.models):
print('Base = declarative_base()\nmetadata = Base.metadata', file=outfile)
if self.sqla_orm:
print('class Base(DeclarativeBase):\n pass', file=outfile)
else:
print('Base = declarative_base()\nmetadata = Base.metadata', file=outfile)
else:
print('metadata = MetaData()', file=outfile)

Expand Down
20 changes: 13 additions & 7 deletions sqlacodegen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from sqlalchemy.engine import create_engine
from sqlalchemy.schema import MetaData

from sqlacodegen.codegen import CodeGenerator
# from sqlacodegen.codegen import CodeGenerator
from codegen import CodeGenerator
import sqlacodegen
import sqlacodegen.dialects

Expand All @@ -25,7 +26,7 @@ def main():
parser = argparse.ArgumentParser(description='Generates SQLAlchemy model code from an existing database.')
parser.add_argument('url', nargs='?', help='SQLAlchemy url to the database')
parser.add_argument('--version', action='store_true', help="print the version number and exit")
parser.add_argument('--schema', help='load tables from an alternate schema')
parser.add_argument('--schema', help='alternate schemas to load in addition to local schema (comma-separated)')
parser.add_argument('--default-schema', help='default schema name for local schema object')
parser.add_argument('--tables', help='tables to process (comma-separated, default: all)')
parser.add_argument('--noviews', action='store_true', help="ignore views")
Expand All @@ -41,6 +42,8 @@ def main():
parser.add_argument('--ignore-cols', help="Don't check foreign key constraints on specified columns (comma-separated)")
parser.add_argument('--nocomments', action='store_true', help="don't render column comments")
parser.add_argument('--dataclass', action='store_true', help="add dataclass decorators for JSON serialization")
parser.add_argument('--sqlalchemyorm', action='store_true', help="use SQLAlchemy.orm module")

args = parser.parse_args()

if args.version:
Expand All @@ -52,20 +55,23 @@ def main():
return
default_schema = args.default_schema
if not default_schema:
default_schema = None
default_schema = None

engine = create_engine(args.url)
import_dialect_specificities(engine)
metadata = MetaData()
metadata = MetaData(schema=default_schema)
tables = args.tables.split(',') if args.tables else None
ignore_cols = args.ignore_cols.split(',') if args.ignore_cols else None
metadata.reflect(engine, args.schema, not args.noviews, tables)
metadata.reflect(engine, views=not args.noviews, only=tables)
for schema in args.schema.split(','):
metadata.reflect(engine, schema, not args.noviews, tables)

outfile = codecs.open(args.outfile, 'w', encoding='utf-8') if args.outfile else sys.stdout
generator = CodeGenerator(metadata, args.noindexes, args.noconstraints,
args.nojoined, args.noinflect, args.nobackrefs,
args.flask, ignore_cols, args.noclasses, args.nocomments, args.notables, args.dataclass)
args.flask, ignore_cols, args.noclasses, args.nocomments, args.notables, args.dataclass, args.sqlalchemyorm)
generator.render(outfile)


if __name__ == '__main__':
main()
main()