From 420af04098810cb5cbfdf276ebfb01cb0933920e Mon Sep 17 00:00:00 2001 From: Gunther Cox Date: Sat, 10 Jun 2017 08:47:34 -0400 Subject: [PATCH] Create tables if they do not exist --- chatterbot/storage/sql_storage.py | 280 +++++++++++++++++++++++ chatterbot/storage/sqlalchemy_storage.py | 4 +- 2 files changed, 281 insertions(+), 3 deletions(-) create mode 100644 chatterbot/storage/sql_storage.py diff --git a/chatterbot/storage/sql_storage.py b/chatterbot/storage/sql_storage.py new file mode 100644 index 000000000..0d3cd8b1a --- /dev/null +++ b/chatterbot/storage/sql_storage.py @@ -0,0 +1,280 @@ +import json +import random + +from chatterbot.storage import StorageAdapter +from chatterbot.conversation import Response +from chatterbot.conversation import Statement + +Base = None + +try: + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class StatementTable(Base): + from sqlalchemy import Column, Integer, String, PickleType + from sqlalchemy.orm import relationship + + __tablename__ = 'StatementTable' + + def get_statement(self): + stmt = Statement(self.text, **self.extra_data) + for resp in self.in_response_to: + stmt.add_response(resp.get_response()) + return stmt + + def get_statement_serialized(context): + params = context.current_parameters + del params['text_search'] + return json.dumps(params) + + id = Column(Integer) + text = Column(String, primary_key=True) + extra_data = Column(PickleType) + # relationship: + in_response_to = relationship("ResponseTable", back_populates="statement_table") + text_search = Column(String, primary_key=True, default=get_statement_serialized) + + class ResponseTable(Base): + from sqlalchemy import Column, Integer, String, ForeignKey + from sqlalchemy.orm import relationship + + __tablename__ = 'ResponseTable' + + def get_reponse_serialized(context): + params = context.current_parameters + del params['text_search'] + return json.dumps(params) + + id = Column(Integer) + text = Column(String, primary_key=True) + occurrence = Column(Integer) + statement_text = Column(String, ForeignKey('StatementTable.text')) + + statement_table = relationship("StatementTable", back_populates="in_response_to", cascade="all", uselist=False) + text_search = Column(String, primary_key=True, default=get_reponse_serialized) + + def get_response(self): + occ = {"occurrence": self.occurrence} + return Response(text=self.text, **occ) + +except ImportError: + pass + + +def get_statement_table(statement): + responses = list(map(get_response_table, statement.in_response_to)) + return StatementTable(text=statement.text, in_response_to=responses, extra_data=statement.extra_data) + + +def get_response_table(response): + return ResponseTable(text=response.text, occurrence=response.occurrence) + + +class SQLStorageAdapter(StorageAdapter): + + def __init__(self, **kwargs): + super(SQLStorageAdapter, self).__init__(**kwargs) + + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + + self.database_name = self.kwargs.get("database") + + if self.database_name: + + # Create a sqlite file if a database name is provided + self.database_uri = self.kwargs.get( + "database_uri", "sqlite:///" + self.database_name + ".db" + ) + + # The default uses sqlite in-memory database + self.database_uri = self.kwargs.get( + "database_uri", "sqlite://" + ) + + self.engine = create_engine(self.database_uri) + + self.read_only = self.kwargs.get( + "read_only", False + ) + + if not self.engine.dialect.has_table(self.engine, 'StatementTable'): + self.create() + + self.Session = sessionmaker(bind=self.engine, expire_on_commit=True) + + def count(self): + """ + Return the number of entries in the database. + """ + session = self.Session() + statement_count = session.query(StatementTable).count() + session.close() + return statement_count + + def __statement_filter(self, session, **kwargs): + """ + Apply filter operation on StatementTable + + rtype: query + """ + _query = session.query(StatementTable) + return _query.filter_by(**kwargs) + + def find(self, statement_text): + """ + Returns a statement if it exists otherwise None + """ + session = self.Session() + query = self.__statement_filter(session, **{"text": statement_text}) + record = query.first() + if record: + statement = record.get_statement() + session.close() + return statement + + session.close() + return None + + def remove(self, statement_text): + """ + Removes the statement that matches the input text. + Removes any responses from statements where the response text matches + the input text. + """ + session = self.Session() + query = self.__statement_filter(session, **{"text": statement_text}) + record = query.first() + + session.delete(record) + + self._session_finish(session, statement_text) + + def filter(self, **kwargs): + """ + Returns a list of objects from the database. + The kwargs parameter can contain any number + of attributes. Only objects which contain + all listed attributes and in which all values + match for all listed attributes will be returned. + """ + session = self.Session() + + filter_parameters = kwargs.copy() + + statements = [] + # _response_query = None + _query = None + if len(filter_parameters) == 0: + _response_query = session.query(StatementTable) + statements.extend(_response_query.all()) + else: + for i, fp in enumerate(filter_parameters): + _filter = filter_parameters[fp] + if fp in ['in_response_to', 'in_response_to__contains']: + _response_query = session.query(StatementTable) + if isinstance(_filter, list): + if len(_filter) == 0: + _query = _response_query.filter( + StatementTable.in_response_to == None) # NOQA Here must use == instead of is + else: + for f in _filter: + _query = _response_query.filter( + StatementTable.in_response_to.contains(get_response_table(f))) + else: + if fp == 'in_response_to__contains': + _query = _response_query.join(ResponseTable).filter(ResponseTable.text == _filter) + else: + _query = _response_query.filter(StatementTable.in_response_to == None) # NOQA + else: + if _query: + _query = _query.filter(ResponseTable.text_search.like('%' + _filter + '%')) + else: + _response_query = session.query(ResponseTable) + _query = _response_query.filter(ResponseTable.text_search.like('%' + _filter + '%')) + + if _query is None: + return [] + if len(filter_parameters) == i + 1: + statements.extend(_query.all()) + + results = [] + + for statement in statements: + if isinstance(statement, ResponseTable): + if statement and statement.statement_table: + results.append(statement.statement_table.get_statement()) + else: + if statement: + results.append(statement.get_statement()) + + session.close() + + return results + + def update(self, statement): + """ + Modifies an entry in the database. + Creates an entry if one does not exist. + """ + if statement: + session = self.Session() + query = self.__statement_filter(session, **{"text": statement.text}) + record = query.first() + + if record: + # update + if statement.text: + record.text = statement.text + if statement.extra_data: + record.extra_data = dict[statement.extra_data] + if statement.in_response_to: + record.in_response_to = list(map(get_response_table, statement.in_response_to)) + session.add(record) + else: + session.add(get_statement_table(statement)) + + self._session_finish(session) + + def get_random(self): + """ + Returns a random statement from the database + """ + session = self.Session() + count = self.count() + if count < 1: + raise self.EmptyDatabaseException() + + rand = random.randrange(0, count) + stmt = session.query(StatementTable)[rand] + + statement = stmt.get_statement() + + session.close() + return statement + + def drop(self): + """ + Drop the database attached to a given adapter. + """ + Base.metadata.drop_all(self.engine) + + def create(self): + """ + Populate the database with the tables. + """ + Base.metadata.create_all(self.engine) + + def _session_finish(self, session, statement_text=None): + from sqlalchemy.exc import DatabaseError + try: + if not self.read_only: + session.commit() + else: + session.rollback() + except DatabaseError as e: + self.logger.error(statement_text, str(e.orig)) + finally: + session.close() diff --git a/chatterbot/storage/sqlalchemy_storage.py b/chatterbot/storage/sqlalchemy_storage.py index 53083c75e..6356531dd 100644 --- a/chatterbot/storage/sqlalchemy_storage.py +++ b/chatterbot/storage/sqlalchemy_storage.py @@ -100,9 +100,7 @@ def __init__(self, **kwargs): "read_only", False ) - create = self.kwargs.get("create", False) - - if not self.read_only and create: + if not self.engine.dialect.has_table(self.engine, 'StatementTable'): self.create() self.Session = sessionmaker(bind=self.engine, expire_on_commit=True)