From 1c01393b6c7bef184fd7a3fefd9fecbc0146b820 Mon Sep 17 00:00:00 2001 From: "Davi S. Zucon" Date: Mon, 24 Apr 2017 11:04:50 -0300 Subject: [PATCH] Changes by code review: Clean up, removed unnecessary code comments and imports in requirements.txt, change imports in sqlalchemy and a little refactor. --- chatterbot/storage/__init__.py | 7 +- chatterbot/storage/sqlalchemy_storage.py | 154 +++++------------------ requirements.txt | 1 - 3 files changed, 32 insertions(+), 130 deletions(-) diff --git a/chatterbot/storage/__init__.py b/chatterbot/storage/__init__.py index 644421cfc..2efe59951 100644 --- a/chatterbot/storage/__init__.py +++ b/chatterbot/storage/__init__.py @@ -2,9 +2,4 @@ from .django_storage import DjangoStorageAdapter from .jsonfile import JsonFileStorageAdapter from .mongodb import MongoDatabaseAdapter - -# FIXME Better way manage import -try: - from .sqlalchemy_storage import SQLAlchemyDatabaseAdapter -except ImportError: - pass +from .sqlalchemy_storage import SQLAlchemyDatabaseAdapter diff --git a/chatterbot/storage/sqlalchemy_storage.py b/chatterbot/storage/sqlalchemy_storage.py index 608dff71f..c9469b28d 100644 --- a/chatterbot/storage/sqlalchemy_storage.py +++ b/chatterbot/storage/sqlalchemy_storage.py @@ -1,24 +1,23 @@ import json import random -from sqlalchemy import Column, ForeignKey -from sqlalchemy import PickleType -from sqlalchemy import String -from sqlalchemy import Integer -from sqlalchemy import create_engine -from sqlalchemy.exc import DatabaseError -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import relationship -from sqlalchemy.orm import sessionmaker - from chatterbot.storage import StorageAdapter from chatterbot.conversation import Response from chatterbot.conversation import Statement -Base = declarative_base() +_base = None +try: + from sqlalchemy.ext.declarative import declarative_base + + _base = declarative_base() +except: + pass -class StatementTable(Base): +class StatementTable(_base): + from sqlalchemy import Column, Integer, String, PickleType + from sqlalchemy.orm import relationship + __tablename__ = 'StatementTable' def get_statement(self): @@ -40,7 +39,9 @@ def get_statement_serialized(context): text_search = Column(String, primary_key=True, default=get_statement_serialized) -class ResponseTable(Base): +class ResponseTable(_base): + from sqlalchemy import Column, Integer, String, ForeignKey + from sqlalchemy.orm import relationship __tablename__ = 'ResponseTable' def get_reponse_serialized(context): @@ -53,8 +54,6 @@ def get_reponse_serialized(context): occurrence = Column(Integer) statement_text = Column(String, ForeignKey('StatementTable.text')) - # Old: statement_table = relationship("StatementTable", backref=backref('in_response_to'), cascade="all, delete-orphan", single_parent=True) - # Test relationship: statement_table = relationship("StatementTable", back_populates="in_response_to", cascade="all", uselist=False) text_search = Column(String, primary_key=True, default=get_reponse_serialized) @@ -63,14 +62,6 @@ def get_response(self): return Response(text=self.text, **occ) -# # relational table TO REMOVE -# in_response_to = sqlalchemy.Table('in_responses_to', -# Base.metadata, -# sqlalchemy.Column('stmt_id', sqlalchemy.Integer, ForeignKey('StatementTable.id')), -# sqlalchemy.Column('resp_id', sqlalchemy.Integer, -# sqlalchemy.ForeignKey('ResponseTable.id'))) - - def get_statement_table(statement): responses = list(map(get_response_table, statement.in_response_to)) return StatementTable(text=statement.text, in_response_to=responses, extra_data=statement.extra_data) @@ -87,11 +78,13 @@ class SQLAlchemyDatabaseAdapter(StorageAdapter): def __init__(self, **kwargs): super(SQLAlchemyDatabaseAdapter, self).__init__(**kwargs) + from sqlalchemy import create_engine + self.database_name = self.kwargs.get( "database", "chatterbot-database" ) - # if some annoing blank space wrong... + # if some annoying blank space wrong... db_name = self.database_name.strip() # default uses sqlite @@ -101,32 +94,6 @@ def __init__(self, **kwargs): self.engine = create_engine(self.database_uri) - # metadata = MetaData(self.engine) - - # Recreate database - # metadata.drop_all() - # - # self.response = Table('response', metadata, - # Column('id', Integer, primary_key=True), - # Column('text', Text), - # Column('occurrence', Text), - # ) - # - # self.statement = Table('statement', metadata, - # Column('id', Integer, primary_key=True), - # Column('text', Text), - # Column('in_response_to', Integer, ForeignKey('response.id')), - # Column('extra_data', Text) - # ) - # - # # mapper(Response, self.response, - # # # non_primary=True, - # # properties={ - # # 'statement': relationship(Statement, backref='response') - # # }, ) - # mapper(Statement, self.statement) - # mapper(Response, self.response) - self.read_only = self.kwargs.get( "read_only", False ) @@ -150,6 +117,8 @@ def __get_session(self): """ :rtype: Session """ + from sqlalchemy.orm import sessionmaker + Session = sessionmaker(bind=self.engine) session = Session() return session @@ -185,13 +154,7 @@ def remove(self, statement_text): record = query.first() session.delete(record) - try: - if not self.read_only: - session.commit() - else: - session.rollback() - except DatabaseError as e: - self.logger.error(statement_text, str(e.orig)) + self._session_finish(session, statement_text) def filter(self, **kwargs): """ @@ -253,69 +216,11 @@ def filter(self, **kwargs): return results - # def filter(self, **kwargs): - # """ - # Returns a list of objects from the database. - # The kwargs parameter can contain any number - # of attributes. Only objects which contain - # all listed attributes and in which all values - # match for all listed attributes will be returned. - # """ - # - # filter_parameters = kwargs.copy() - # - # session = self.__get_session() - # statements = [] - # _response_query = None - # - # if len(filter_parameters) == 0: - # _response_query = session.query(StatementTable) - # statements.extend(_response_query.all()) - # else: - # for fp in filter_parameters: - # _filter = filter_parameters[fp] - # if fp == 'in_response_to' or fp == 'in_response_to__contains': - # if _response_query: - # _response_query.join(StatementTable) - # else: - # _response_query = session.query(StatementTable) - # if isinstance(_filter, list): - # if len(_filter) == 0: - # query = _response_query.filter(StatementTable.in_response_to is None) - # else: - # # _in_response_tables = [] - # # for respnse_table in _like: - # # _in_response_tables.append(get_response_table(respnse_table)) - # query = _response_query.filter(StatementTable.in_response_to.contains(_filter)) - # else: - # query = _response_query.filter(StatementTable.in_response_to.like('%' + _filter + '%')) - # else: - # if fp == 'text' or fp == 'text__contains': # Text always use like - # _response_query = session.query(ResponseTable) - # # if fp == 'text__contains': - # query = _response_query.filter(ResponseTable.text.like('%' + _filter + '%')) - # # if fp == 'text': - # # query = _response_query.filter(ResponseTable.text == _filter) - # - # statements.extend(query.all()) - # - # results = [] - # for statement in statements: - # if isinstance(statement, ResponseTable): - # if statement and statement.statement_table: - # results.append(statement.statement_table.get_statement()) - # else: - # if statement: - # results.append(statement.get_statement()) - # - # return results - def update(self, statement): """ Modifies an entry in the database. Creates an entry if one does not exist. """ - session = self.__get_session() if statement: query = self.__statement_filter(session, **{"text": statement.text}) @@ -333,14 +238,7 @@ def update(self, statement): else: session.add(get_statement_table(statement)) - try: - if not self.read_only: - session.commit() - else: - session.rollback() - except DatabaseError as e: - pass - # self.logger.error(statement, str(e.orig)) + self._session_finish(session) def get_random(self): """ @@ -361,3 +259,13 @@ def drop(self): Drop the database attached to a given adapter. """ Base.metadata.drop_all(self.engine) + + def _session_finish(self, session, statement_text=None): + from sqlalchemy.exc import DatabaseError + try: + if not self.read_only: + session.commit() + else: + session.rollback() + except DatabaseError as e: + self.logger.error(statement_text, str(e.orig)) diff --git a/requirements.txt b/requirements.txt index 5588f71bc..8f28c03b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,4 @@ jsondatabase>=0.1.7,<1.0.0 nltk>=3.2.0,<4.0.0 pymongo>=3.3.0,<4.0.0 python-twitter>=3.0.0,<4.0.0 -textblob>=0.11.0,<0.12.0 SQLAlchemy==1.1.7 \ No newline at end of file