diff --git a/databuilder/extractor/sql_alchemy_extractor.py b/databuilder/extractor/sql_alchemy_extractor.py index 2c90a8fd3..a375781a5 100644 --- a/databuilder/extractor/sql_alchemy_extractor.py +++ b/databuilder/extractor/sql_alchemy_extractor.py @@ -14,6 +14,7 @@ class SQLAlchemyExtractor(Extractor): # Config keys CONN_STRING = 'conn_string' EXTRACT_SQL = 'extract_sql' + CONNECT_ARGS = 'connect_args' """ An Extractor that extracts records via SQLAlchemy. Database that supports SQLAlchemy can use this extractor """ @@ -25,6 +26,7 @@ def init(self, conf: ConfigTree) -> None: """ self.conf = conf self.conn_string = conf.get_string(SQLAlchemyExtractor.CONN_STRING) + self.connection = self._get_connection() self.extract_sql = conf.get_string(SQLAlchemyExtractor.EXTRACT_SQL) @@ -40,7 +42,13 @@ def _get_connection(self) -> Any: """ Create a SQLAlchemy connection to Database """ - engine = create_engine(self.conn_string) + connect_args = { + k: v + for k, v in self.conf.get_config( + self.CONNECT_ARGS, default=ConfigTree() + ).items() + } + engine = create_engine(self.conn_string, connect_args=connect_args) conn = engine.connect() return conn diff --git a/tests/unit/extractor/test_sql_alchemy_extractor.py b/tests/unit/extractor/test_sql_alchemy_extractor.py index b9bd967a6..f6644c086 100644 --- a/tests/unit/extractor/test_sql_alchemy_extractor.py +++ b/tests/unit/extractor/test_sql_alchemy_extractor.py @@ -94,6 +94,35 @@ def test_extraction_with_model_class(self: Any, mock_method: Any) -> None: self.assertIsInstance(result, TableMetadataResult) self.assertEqual(result.name, 'test_table') + @patch('databuilder.extractor.sql_alchemy_extractor.create_engine') + def test_get_connection(self: Any, mock_method: Any) -> None: + """ + Test that configs are passed through correctly to the _get_connection method + """ + extractor = SQLAlchemyExtractor() + config_dict = { + 'extractor.sqlalchemy.conn_string': 'TEST_CONNECTION', + 'extractor.sqlalchemy.extract_sql': 'SELECT 1 FROM TEST_TABLE;' + } + conf = ConfigFactory.from_dict(config_dict) + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + extractor._get_connection() + mock_method.assert_called_with('TEST_CONNECTION', connect_args={}) + + extractor = SQLAlchemyExtractor() + config_dict = { + 'extractor.sqlalchemy.conn_string': 'TEST_CONNECTION', + 'extractor.sqlalchemy.extract_sql': 'SELECT 1 FROM TEST_TABLE;', + 'extractor.sqlalchemy.connect_args': {"protocol": "https"}, + } + conf = ConfigFactory.from_dict(config_dict) + extractor.init(Scoped.get_scoped_conf(conf=conf, + scope=extractor.get_scope())) + extractor._get_connection() + mock_method.assert_called_with('TEST_CONNECTION', connect_args={"protocol": "https"}) + + class TableMetadataResult: """