diff --git a/chatterbot/storage/sql_storage.py b/chatterbot/storage/sql_storage.py index d222bb852..4253a11ae 100644 --- a/chatterbot/storage/sql_storage.py +++ b/chatterbot/storage/sql_storage.py @@ -89,7 +89,7 @@ class SQLStorageAdapter(StorageAdapter): can be especified to choose database driver (database parameter will be igored). :type database_uri: str - :keyword read_only: False by default, makes all operations read only, has priority over all DB operations + :keyword read_only: False by default, makes all operations read only, has priority over all DB operations so, create, update, delete will NOT be executed :type read_only: bool @@ -104,20 +104,17 @@ def __init__(self, **kwargs): from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker - self.database_name = self.kwargs.get("database") - - if self.database_name: - - # Create a sqlite file if a database name is provided - self.database_uri = self.kwargs.get( - "database_uri", "sqlite:///" + self.database_name + ".db" - ) - - # The default uses sqlite in-memory database + # The default uses a sqlite in-memory database self.database_uri = self.kwargs.get( "database_uri", "sqlite://" ) + database_name = self.kwargs.get("database") + + # Create a sqlite file if a database name is provided + if database_name: + self.database_uri = "sqlite:///" + database_name + ".db" + self.engine = create_engine(self.database_uri) self.read_only = self.kwargs.get( diff --git a/tests/storage_adapter_tests/test_sqlalchemy_adapter.py b/tests/storage_adapter_tests/test_sqlalchemy_adapter.py index 15497f981..826da95b7 100644 --- a/tests/storage_adapter_tests/test_sqlalchemy_adapter.py +++ b/tests/storage_adapter_tests/test_sqlalchemy_adapter.py @@ -27,6 +27,18 @@ def tearDown(self): class SQLStorageAdapterTestCase(SQLAlchemyAdapterTestCase): + def test_set_database_name_none(self): + adapter = SQLStorageAdapter(database=None) + self.assertEqual(adapter.database_uri, 'sqlite://') + + def test_set_database_name(self): + adapter = SQLStorageAdapter(database='test') + self.assertEqual(adapter.database_uri, 'sqlite:///test.db') + + def test_set_database_uri(self): + adapter = SQLStorageAdapter(database_uri='sqlite:///db.sqlite3') + self.assertEqual(adapter.database_uri, 'sqlite:///db.sqlite3') + def test_count_returns_zero(self): """ The count method should return a value of 0