Skip to content

Commit

Permalink
Changes by code review: Clean up, removed unnecessary code comments a…
Browse files Browse the repository at this point in the history
…nd imports in requirements.txt, change imports in sqlalchemy and a little refactor.
  • Loading branch information
davizucon committed Apr 24, 2017
1 parent 3af0b41 commit 1c01393
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 130 deletions.
7 changes: 1 addition & 6 deletions chatterbot/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,4 @@
from .django_storage import DjangoStorageAdapter
from .jsonfile import JsonFileStorageAdapter
from .mongodb import MongoDatabaseAdapter

# FIXME Better way manage import
try:
from .sqlalchemy_storage import SQLAlchemyDatabaseAdapter
except ImportError:
pass
from .sqlalchemy_storage import SQLAlchemyDatabaseAdapter
154 changes: 31 additions & 123 deletions chatterbot/storage/sqlalchemy_storage.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import json
import random

from sqlalchemy import Column, ForeignKey
from sqlalchemy import PickleType
from sqlalchemy import String
from sqlalchemy import Integer
from sqlalchemy import create_engine
from sqlalchemy.exc import DatabaseError
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import sessionmaker

from chatterbot.storage import StorageAdapter
from chatterbot.conversation import Response
from chatterbot.conversation import Statement

Base = declarative_base()
_base = None
try:
from sqlalchemy.ext.declarative import declarative_base

_base = declarative_base()
except:
pass


class StatementTable(Base):
class StatementTable(_base):
from sqlalchemy import Column, Integer, String, PickleType
from sqlalchemy.orm import relationship

__tablename__ = 'StatementTable'

def get_statement(self):
Expand All @@ -40,7 +39,9 @@ def get_statement_serialized(context):
text_search = Column(String, primary_key=True, default=get_statement_serialized)


class ResponseTable(Base):
class ResponseTable(_base):
from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship
__tablename__ = 'ResponseTable'

def get_reponse_serialized(context):
Expand All @@ -53,8 +54,6 @@ def get_reponse_serialized(context):
occurrence = Column(Integer)
statement_text = Column(String, ForeignKey('StatementTable.text'))

# Old: statement_table = relationship("StatementTable", backref=backref('in_response_to'), cascade="all, delete-orphan", single_parent=True)
# Test relationship:
statement_table = relationship("StatementTable", back_populates="in_response_to", cascade="all", uselist=False)
text_search = Column(String, primary_key=True, default=get_reponse_serialized)

Expand All @@ -63,14 +62,6 @@ def get_response(self):
return Response(text=self.text, **occ)


# # relational table TO REMOVE
# in_response_to = sqlalchemy.Table('in_responses_to',
# Base.metadata,
# sqlalchemy.Column('stmt_id', sqlalchemy.Integer, ForeignKey('StatementTable.id')),
# sqlalchemy.Column('resp_id', sqlalchemy.Integer,
# sqlalchemy.ForeignKey('ResponseTable.id')))


def get_statement_table(statement):
responses = list(map(get_response_table, statement.in_response_to))
return StatementTable(text=statement.text, in_response_to=responses, extra_data=statement.extra_data)
Expand All @@ -87,11 +78,13 @@ class SQLAlchemyDatabaseAdapter(StorageAdapter):
def __init__(self, **kwargs):
super(SQLAlchemyDatabaseAdapter, self).__init__(**kwargs)

from sqlalchemy import create_engine

self.database_name = self.kwargs.get(
"database", "chatterbot-database"
)

# if some annoing blank space wrong...
# if some annoying blank space wrong...
db_name = self.database_name.strip()

# default uses sqlite
Expand All @@ -101,32 +94,6 @@ def __init__(self, **kwargs):

self.engine = create_engine(self.database_uri)

# metadata = MetaData(self.engine)

# Recreate database
# metadata.drop_all()
#
# self.response = Table('response', metadata,
# Column('id', Integer, primary_key=True),
# Column('text', Text),
# Column('occurrence', Text),
# )
#
# self.statement = Table('statement', metadata,
# Column('id', Integer, primary_key=True),
# Column('text', Text),
# Column('in_response_to', Integer, ForeignKey('response.id')),
# Column('extra_data', Text)
# )
#
# # mapper(Response, self.response,
# # # non_primary=True,
# # properties={
# # 'statement': relationship(Statement, backref='response')
# # }, )
# mapper(Statement, self.statement)
# mapper(Response, self.response)

self.read_only = self.kwargs.get(
"read_only", False
)
Expand All @@ -150,6 +117,8 @@ def __get_session(self):
"""
:rtype: Session
"""
from sqlalchemy.orm import sessionmaker

Session = sessionmaker(bind=self.engine)
session = Session()
return session
Expand Down Expand Up @@ -185,13 +154,7 @@ def remove(self, statement_text):
record = query.first()
session.delete(record)

try:
if not self.read_only:
session.commit()
else:
session.rollback()
except DatabaseError as e:
self.logger.error(statement_text, str(e.orig))
self._session_finish(session, statement_text)

def filter(self, **kwargs):
"""
Expand Down Expand Up @@ -253,69 +216,11 @@ def filter(self, **kwargs):

return results

# def filter(self, **kwargs):
# """
# Returns a list of objects from the database.
# The kwargs parameter can contain any number
# of attributes. Only objects which contain
# all listed attributes and in which all values
# match for all listed attributes will be returned.
# """
#
# filter_parameters = kwargs.copy()
#
# session = self.__get_session()
# statements = []
# _response_query = None
#
# if len(filter_parameters) == 0:
# _response_query = session.query(StatementTable)
# statements.extend(_response_query.all())
# else:
# for fp in filter_parameters:
# _filter = filter_parameters[fp]
# if fp == 'in_response_to' or fp == 'in_response_to__contains':
# if _response_query:
# _response_query.join(StatementTable)
# else:
# _response_query = session.query(StatementTable)
# if isinstance(_filter, list):
# if len(_filter) == 0:
# query = _response_query.filter(StatementTable.in_response_to is None)
# else:
# # _in_response_tables = []
# # for respnse_table in _like:
# # _in_response_tables.append(get_response_table(respnse_table))
# query = _response_query.filter(StatementTable.in_response_to.contains(_filter))
# else:
# query = _response_query.filter(StatementTable.in_response_to.like('%' + _filter + '%'))
# else:
# if fp == 'text' or fp == 'text__contains': # Text always use like
# _response_query = session.query(ResponseTable)
# # if fp == 'text__contains':
# query = _response_query.filter(ResponseTable.text.like('%' + _filter + '%'))
# # if fp == 'text':
# # query = _response_query.filter(ResponseTable.text == _filter)
#
# statements.extend(query.all())
#
# results = []
# for statement in statements:
# if isinstance(statement, ResponseTable):
# if statement and statement.statement_table:
# results.append(statement.statement_table.get_statement())
# else:
# if statement:
# results.append(statement.get_statement())
#
# return results

def update(self, statement):
"""
Modifies an entry in the database.
Creates an entry if one does not exist.
"""

session = self.__get_session()
if statement:
query = self.__statement_filter(session, **{"text": statement.text})
Expand All @@ -333,14 +238,7 @@ def update(self, statement):
else:
session.add(get_statement_table(statement))

try:
if not self.read_only:
session.commit()
else:
session.rollback()
except DatabaseError as e:
pass
# self.logger.error(statement, str(e.orig))
self._session_finish(session)

def get_random(self):
"""
Expand All @@ -361,3 +259,13 @@ def drop(self):
Drop the database attached to a given adapter.
"""
Base.metadata.drop_all(self.engine)

def _session_finish(self, session, statement_text=None):
from sqlalchemy.exc import DatabaseError
try:
if not self.read_only:
session.commit()
else:
session.rollback()
except DatabaseError as e:
self.logger.error(statement_text, str(e.orig))
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ jsondatabase>=0.1.7,<1.0.0
nltk>=3.2.0,<4.0.0
pymongo>=3.3.0,<4.0.0
python-twitter>=3.0.0,<4.0.0
textblob>=0.11.0,<0.12.0
SQLAlchemy==1.1.7

0 comments on commit 1c01393

Please sign in to comment.