Skip to content

Commit

Permalink
feat: Add config key for connect_arg for SqlAlchemyExtractor (amundse…
Browse files Browse the repository at this point in the history
…n-io#434)

* General connect_args for SqlAlchemyExtractor

Signed-off-by: benrifkind <ben.rifkind@gmail.com>

* lint

Signed-off-by: benrifkind <ben.rifkind@gmail.com>

* more lint

Signed-off-by: benrifkind <ben.rifkind@gmail.com>
  • Loading branch information
benrifkind authored and Wonong committed Mar 4, 2021
1 parent f4cf2e6 commit 44347e8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
10 changes: 9 additions & 1 deletion databuilder/extractor/sql_alchemy_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class SQLAlchemyExtractor(Extractor):
# Config keys
CONN_STRING = 'conn_string'
EXTRACT_SQL = 'extract_sql'
CONNECT_ARGS = 'connect_args'
"""
An Extractor that extracts records via SQLAlchemy. Database that supports SQLAlchemy can use this extractor
"""
Expand All @@ -25,6 +26,7 @@ def init(self, conf: ConfigTree) -> None:
"""
self.conf = conf
self.conn_string = conf.get_string(SQLAlchemyExtractor.CONN_STRING)

self.connection = self._get_connection()

self.extract_sql = conf.get_string(SQLAlchemyExtractor.EXTRACT_SQL)
Expand All @@ -40,7 +42,13 @@ def _get_connection(self) -> Any:
"""
Create a SQLAlchemy connection to Database
"""
engine = create_engine(self.conn_string)
connect_args = {
k: v
for k, v in self.conf.get_config(
self.CONNECT_ARGS, default=ConfigTree()
).items()
}
engine = create_engine(self.conn_string, connect_args=connect_args)
conn = engine.connect()
return conn

Expand Down
30 changes: 29 additions & 1 deletion tests/unit/extractor/test_sql_alchemy_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import unittest
from typing import Any
from typing import Any, Dict

from mock import patch
from pyhocon import ConfigFactory
Expand Down Expand Up @@ -94,6 +94,34 @@ def test_extraction_with_model_class(self: Any, mock_method: Any) -> None:
self.assertIsInstance(result, TableMetadataResult)
self.assertEqual(result.name, 'test_table')

@patch('databuilder.extractor.sql_alchemy_extractor.create_engine')
def test_get_connection(self: Any, mock_method: Any) -> None:
"""
Test that configs are passed through correctly to the _get_connection method
"""
extractor = SQLAlchemyExtractor()
config_dict: Dict[str, Any] = {
'extractor.sqlalchemy.conn_string': 'TEST_CONNECTION',
'extractor.sqlalchemy.extract_sql': 'SELECT 1 FROM TEST_TABLE;'
}
conf = ConfigFactory.from_dict(config_dict)
extractor.init(Scoped.get_scoped_conf(conf=conf,
scope=extractor.get_scope()))
extractor._get_connection()
mock_method.assert_called_with('TEST_CONNECTION', connect_args={})

extractor = SQLAlchemyExtractor()
config_dict = {
'extractor.sqlalchemy.conn_string': 'TEST_CONNECTION',
'extractor.sqlalchemy.extract_sql': 'SELECT 1 FROM TEST_TABLE;',
'extractor.sqlalchemy.connect_args': {"protocol": "https"},
}
conf = ConfigFactory.from_dict(config_dict)
extractor.init(Scoped.get_scoped_conf(conf=conf,
scope=extractor.get_scope()))
extractor._get_connection()
mock_method.assert_called_with('TEST_CONNECTION', connect_args={"protocol": "https"})


class TableMetadataResult:
"""
Expand Down

0 comments on commit 44347e8

Please sign in to comment.