diff --git a/chatterbot/storage/sqlalchemy_storage.py b/chatterbot/storage/sqlalchemy_storage.py index c9469b28d..1c2fcbeb7 100644 --- a/chatterbot/storage/sqlalchemy_storage.py +++ b/chatterbot/storage/sqlalchemy_storage.py @@ -6,60 +6,62 @@ from chatterbot.conversation import Statement _base = None + try: from sqlalchemy.ext.declarative import declarative_base _base = declarative_base() -except: - pass -class StatementTable(_base): - from sqlalchemy import Column, Integer, String, PickleType - from sqlalchemy.orm import relationship + class StatementTable(_base): + from sqlalchemy import Column, Integer, String, PickleType + from sqlalchemy.orm import relationship - __tablename__ = 'StatementTable' + __tablename__ = 'StatementTable' - def get_statement(self): - stmt = Statement(self.text, **self.extra_data) - for resp in self.in_response_to: - stmt.add_response(resp.get_response()) - return stmt + def get_statement(self): + stmt = Statement(self.text, **self.extra_data) + for resp in self.in_response_to: + stmt.add_response(resp.get_response()) + return stmt - def get_statement_serialized(context): - params = context.current_parameters - del (params['text_search']) - return json.dumps(params) + def get_statement_serialized(context): + params = context.current_parameters + del (params['text_search']) + return json.dumps(params) - id = Column(Integer) - text = Column(String, primary_key=True) - extra_data = Column(PickleType) - # relationship: - in_response_to = relationship("ResponseTable", back_populates="statement_table") - text_search = Column(String, primary_key=True, default=get_statement_serialized) + id = Column(Integer) + text = Column(String, primary_key=True) + extra_data = Column(PickleType) + # relationship: + in_response_to = relationship("ResponseTable", back_populates="statement_table") + text_search = Column(String, primary_key=True, default=get_statement_serialized) -class ResponseTable(_base): - from sqlalchemy import Column, Integer, String, ForeignKey - from sqlalchemy.orm import relationship - __tablename__ = 'ResponseTable' + class ResponseTable(_base): + from sqlalchemy import Column, Integer, String, ForeignKey + from sqlalchemy.orm import relationship + __tablename__ = 'ResponseTable' - def get_reponse_serialized(context): - params = context.current_parameters - del (params['text_search']) - return json.dumps(params) + def get_reponse_serialized(context): + params = context.current_parameters + del (params['text_search']) + return json.dumps(params) - id = Column(Integer) - text = Column(String, primary_key=True) - occurrence = Column(Integer) - statement_text = Column(String, ForeignKey('StatementTable.text')) + id = Column(Integer) + text = Column(String, primary_key=True) + occurrence = Column(Integer) + statement_text = Column(String, ForeignKey('StatementTable.text')) - statement_table = relationship("StatementTable", back_populates="in_response_to", cascade="all", uselist=False) - text_search = Column(String, primary_key=True, default=get_reponse_serialized) + statement_table = relationship("StatementTable", back_populates="in_response_to", cascade="all", uselist=False) + text_search = Column(String, primary_key=True, default=get_reponse_serialized) - def get_response(self): - occ = {"occurrence": self.occurrence} - return Response(text=self.text, **occ) + def get_response(self): + occ = {"occurrence": self.occurrence} + return Response(text=self.text, **occ) + +except ImportError: + pass def get_statement_table(statement): @@ -103,8 +105,8 @@ def __init__(self, **kwargs): ) if not self.read_only and self.drop_create: - Base.metadata.drop_all(self.engine) - Base.metadata.create_all(self.engine) + _base.metadata.drop_all(self.engine) + _base.metadata.create_all(self.engine) def count(self): """ @@ -258,7 +260,7 @@ def drop(self): """ Drop the database attached to a given adapter. """ - Base.metadata.drop_all(self.engine) + _base.metadata.drop_all(self.engine) def _session_finish(self, session, statement_text=None): from sqlalchemy.exc import DatabaseError