From 7bf15b060cdefbfbbaafa4a3e597c7c71a6e4772 Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Sun, 16 Jul 2017 21:28:05 -0400 Subject: [PATCH] Add a base class for SQL Alchemy models --- chatterbot/ext/sqlalchemy_app/__init__.py | 0 chatterbot/ext/sqlalchemy_app/models.py | 26 +++++++++++++++++++++++ chatterbot/storage/sql_storage.py | 6 +----- setup.py | 1 + 4 files changed, 28 insertions(+), 5 deletions(-) create mode 100644 chatterbot/ext/sqlalchemy_app/__init__.py create mode 100644 chatterbot/ext/sqlalchemy_app/models.py diff --git a/chatterbot/ext/sqlalchemy_app/__init__.py b/chatterbot/ext/sqlalchemy_app/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/chatterbot/ext/sqlalchemy_app/models.py b/chatterbot/ext/sqlalchemy_app/models.py new file mode 100644 index 000000000..1b8b74c8b --- /dev/null +++ b/chatterbot/ext/sqlalchemy_app/models.py @@ -0,0 +1,26 @@ +from sqlalchemy import Column, Integer +from sqlalchemy.ext.declarative import ( + declared_attr, declarative_base +) + + +class ModelBase(object): + """ + An augmented base class for SqlAlchemy models. + """ + + @declared_attr + def __tablename__(cls): + """ + Return the lowercase class name as the name of the table. + """ + return cls.__name__.lower() + + id = Column( + Integer, + primary_key=True, + autoincrement=True + ) + + +Base = declarative_base(cls=ModelBase) diff --git a/chatterbot/storage/sql_storage.py b/chatterbot/storage/sql_storage.py index fc6747a9b..3421a2b06 100644 --- a/chatterbot/storage/sql_storage.py +++ b/chatterbot/storage/sql_storage.py @@ -8,9 +8,7 @@ Base = None try: - from sqlalchemy.ext.declarative import declarative_base - - Base = declarative_base() + from chatterbot.ext.sqlalchemy_app.models import Base class StatementTable(Base): """ @@ -32,7 +30,6 @@ def get_statement_serialized(context): del params['text_search'] return json.dumps(params) - id = Column(Integer, primary_key=True, autoincrement=True) text = Column(String, unique=True) extra_data = Column(PickleType) @@ -60,7 +57,6 @@ def get_reponse_serialized(context): del params['text_search'] return json.dumps(params) - id = Column(Integer, primary_key=True, autoincrement=True) text = Column(String) occurrence = Column(Integer, default=1) statement_text = Column(String, ForeignKey('StatementTable.text')) diff --git a/setup.py b/setup.py index 31434c398..65995f818 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ 'chatterbot.corpus', 'chatterbot.conversation', 'chatterbot.ext', + 'chatterbot.ext.sqlalchemy_app', 'chatterbot.ext.django_chatterbot', 'chatterbot.ext.django_chatterbot.migrations', 'chatterbot.ext.django_chatterbot.management',