Skip to content

Commit

Permalink
Change wrapper signature so that the connection string is optional
Browse files Browse the repository at this point in the history
  • Loading branch information
aaron-congo authored and karenc-bq committed Sep 18, 2023
1 parent 3d8b7c1 commit 6907ba0
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 145 deletions.
6 changes: 3 additions & 3 deletions aws_wrapper/pep249.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, Union

if TYPE_CHECKING:
from types import TracebackType
Expand Down Expand Up @@ -94,9 +94,9 @@ class Connection:

@staticmethod
def connect(
target: Union[str, Callable],
conninfo: str = "",
**kwargs
) -> Any:
**kwargs: Any) -> Connection:
...

def close(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions aws_wrapper/resources/messages.properties
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ TargetDriverDialectManager.UseDialect=[TargetDriverDialectManager] Target driver

Testing.CantParse=[Testing] Can't parse {}.
Testing.EnvVarRequired=[Testing] Environment variable {} is required.
Testing.FunctionNotImplementedForDriver=[Testing] Function '{}' has no implementation for the passed in driver: '{}'.
Testing.InstanceNotFound=[Testing] Instance {} not found.
Testing.ProxyNotFound=[Testing] Proxy for {} is not found.
Testing.RequiredTestDriver=[Testing] testDriver is required.
Expand Down
4 changes: 2 additions & 2 deletions aws_wrapper/utils/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional, Union
from typing import Any, Dict, Optional

from aws_wrapper.errors import AwsWrapperError
from aws_wrapper.utils.messages import Messages
Expand Down Expand Up @@ -197,7 +197,7 @@ class WrapperProperties:
class PropertiesUtils:

@staticmethod
def parse_properties(conn_info: str, **kwargs: Union[None, int, str]) -> Properties:
def parse_properties(conn_info: str, **kwargs: Any) -> Properties:
props: Properties
if conn_info == "":
props = Properties()
Expand Down
19 changes: 10 additions & 9 deletions aws_wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from logging import getLogger
from typing import Any, Callable, Iterator, List, Optional, Union

Expand Down Expand Up @@ -76,11 +78,10 @@ def autocommit(self, autocommit: bool):

@staticmethod
def connect(
target: Union[str, Callable],
conninfo: str = "",
target: Union[None, str, Callable] = None,
**kwargs: Union[None, int, str]
) -> "AwsWrapperConnection":
if not target:
**kwargs: Any) -> AwsWrapperConnection:
if target is None or target == "":
raise Error(Messages.get("Wrapper.RequiredTargetDriver"))

# TODO: fix target str parsing functionality
Expand Down Expand Up @@ -122,7 +123,7 @@ def close(self) -> None:
self._plugin_manager.execute(self.target_connection, "Connection.close",
lambda: self.target_connection.close())

def cursor(self, **kwargs: Union[None, int, str]) -> "AwsWrapperCursor":
def cursor(self, **kwargs: Any) -> AwsWrapperCursor:
_cursor = self._plugin_manager.execute(self.target_connection, "Connection.cursor",
lambda: self.target_connection.cursor(**kwargs),
kwargs)
Expand Down Expand Up @@ -164,7 +165,7 @@ def release_resources(self):
def __del__(self):
self.release_resources()

def __enter__(self: "AwsWrapperConnection") -> "AwsWrapperConnection":
def __enter__(self: AwsWrapperConnection) -> AwsWrapperConnection:
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
Expand Down Expand Up @@ -211,7 +212,7 @@ def close(self) -> None:
self._plugin_manager.execute(self.target_cursor, "Cursor.close",
lambda: self.target_cursor.close())

def callproc(self, **kwargs: Union[None, int, str]):
def callproc(self, **kwargs: Any):
return self._plugin_manager.execute(self.target_cursor, "Cursor.callproc",
lambda: self.target_cursor.callproc(**kwargs), kwargs)

Expand All @@ -233,7 +234,7 @@ def execute(
def executemany(
self,
query: str,
**kwargs: Union[None, int, str]
**kwargs: Any
) -> None:
self._plugin_manager.execute(self.target_cursor, "Cursor.executemany",
lambda: self.target_cursor.executemany(query, **kwargs), query, kwargs)
Expand Down Expand Up @@ -265,7 +266,7 @@ def setoutputsize(self, size: Any, column: Optional[int] = None) -> None:
return self._plugin_manager.execute(self.target_cursor, "Cursor.setoutputsize",
lambda: self.target_cursor.setoutputsize(size, column), size, column)

def __enter__(self: "AwsWrapperCursor") -> "AwsWrapperCursor":
def __enter__(self: AwsWrapperCursor) -> AwsWrapperCursor:
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
Expand Down
90 changes: 38 additions & 52 deletions tests/integration/container/test_aurora_failover.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from .utils.test_instance_info import TestInstanceInfo
from .utils.test_driver import TestDriver
from .utils.test_database_info import TestDatabaseInfo
from aws_wrapper.pep249 import Cursor


from logging import getLogger

Expand Down Expand Up @@ -61,14 +63,13 @@ def proxied_props(self, props):
props_copy.update({WrapperProperties.CLUSTER_INSTANCE_HOST_PATTERN.name: f"?.{endpoint_suffix}"})
return props_copy

def test_fail_from_writer_to_new_writer_fail_on_connection_invocation(self, test_environment: TestEnvironment,
test_driver: TestDriver, props,
conn_utils, aurora_utility):
def test_fail_from_writer_to_new_writer_fail_on_connection_invocation(
self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils, aurora_utility):
target_driver_connect = DriverHelper.get_connect_func(test_driver)
initial_writer_id = aurora_utility.get_cluster_writer_instance_id()

with AwsWrapperConnection.connect(self._init_default_props(test_environment), target_driver_connect,
**props) as aws_conn:
with AwsWrapperConnection.connect(
target_driver_connect, **conn_utils.get_connect_params(), **props) as aws_conn:
# Enable autocommit, otherwise each select statement will start a valid transaction.
aws_conn.autocommit = True

Expand All @@ -84,16 +85,13 @@ def test_fail_from_writer_to_new_writer_fail_on_connection_invocation(self, test
assert aurora_utility.is_db_instance_writer(current_connection_id) is True
assert current_connection_id != initial_writer_id

def test_fail_from_writer_to_new_writer_fail_on_connection_bound_object_invocation(self,
test_environment: TestEnvironment,
test_driver: TestDriver,
props, conn_utils,
aurora_utility):
def test_fail_from_writer_to_new_writer_fail_on_connection_bound_object_invocation(
self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils, aurora_utility):
target_driver_connect = DriverHelper.get_connect_func(test_driver)
initial_writer_id = aurora_utility.get_cluster_writer_instance_id()

with AwsWrapperConnection.connect(self._init_default_props(test_environment), target_driver_connect,
**props) as aws_conn:
with AwsWrapperConnection.connect(
target_driver_connect, **conn_utils.get_connect_params(), **props) as aws_conn:
# Enable autocommit, otherwise each select statement will start a valid transaction.
aws_conn.autocommit = True

Expand All @@ -109,16 +107,22 @@ def test_fail_from_writer_to_new_writer_fail_on_connection_bound_object_invocati
assert current_connection_id != initial_writer_id

@pytest.mark.skip
def test_fail_from_reader_to_writer(self, test_environment: TestEnvironment,
test_driver: TestDriver, conn_utils, proxied_props, aurora_utility):
def test_fail_from_reader_to_writer(
self,
test_environment: TestEnvironment,
test_driver: TestDriver,
conn_utils,
proxied_props,
aurora_utility):
target_driver_connect = DriverHelper.get_connect_func(test_driver)
instance: TestInstanceInfo = test_environment.get_proxy_instances()[1]
writer_id: str = test_environment.get_proxy_writer().get_instance_id()

proxied_props["plugins"] = "failover,host_monitoring"
with AwsWrapperConnection.connect(
conn_utils.get_proxy_conn_string(instance.get_host()),
target_driver_connect, **proxied_props) as aws_conn:
target_driver_connect,
**conn_utils.get_proxy_connect_params(instance.get_host()),
**proxied_props) as aws_conn:
# Enable autocommit, otherwise each select statement will start a valid transaction.
aws_conn.autocommit = True

Expand All @@ -132,15 +136,13 @@ def test_fail_from_reader_to_writer(self, test_environment: TestEnvironment,
assert writer_id == current_connection_id
assert aurora_utility.is_db_instance_writer(current_connection_id) is True

def test_writer_fail_within_transaction_set_autocommit_false(self, test_driver: TestDriver,
test_environment: TestEnvironment,
props, conn_utils,
aurora_utility):
def test_writer_fail_within_transaction_set_autocommit_false(
self, test_driver: TestDriver, test_environment: TestEnvironment, props, conn_utils, aurora_utility):
target_driver_connect = DriverHelper.get_connect_func(test_driver)
initial_writer_id = test_environment.get_writer().get_instance_id()

with AwsWrapperConnection.connect(self._init_default_props(test_environment), target_driver_connect,
**props) as conn, \
with AwsWrapperConnection.connect(
target_driver_connect, **conn_utils.get_connect_params(), **props) as conn, \
conn.cursor() as cursor_1:
cursor_1.execute("DROP TABLE IF EXISTS test3_2")
cursor_1.execute("CREATE TABLE test3_2 (id int not null primary key, test3_2_field varchar(255) not null)")
Expand Down Expand Up @@ -175,15 +177,13 @@ def test_writer_fail_within_transaction_set_autocommit_false(self, test_driver:
cursor_3.execute("DROP TABLE IF EXISTS test3_2")
conn.commit()

def test_writer_fail_within_transaction_start_transaction(self, test_driver: TestDriver,
test_environment: TestEnvironment,
props, conn_utils,
aurora_utility):
def test_writer_fail_within_transaction_start_transaction(
self, test_driver: TestDriver, test_environment: TestEnvironment, props, conn_utils, aurora_utility):
target_driver_connect = DriverHelper.get_connect_func(test_driver)
initial_writer_id = test_environment.get_writer().get_instance_id()

with AwsWrapperConnection.connect(self._init_default_props(test_environment), target_driver_connect,
**props) as conn:
with AwsWrapperConnection.connect(
target_driver_connect, **conn_utils.get_connect_params(), **props) as conn:
# Enable autocommit, otherwise each select statement will start a valid transaction.
conn.autocommit = True

Expand Down Expand Up @@ -222,9 +222,8 @@ def test_writer_fail_within_transaction_start_transaction(self, test_driver: Tes
cursor_3.execute("DROP TABLE IF EXISTS test3_3")
conn.commit()

def test_writer_failover_in_idle_connections(self, test_environment: TestEnvironment, test_driver: TestDriver,
props, conn_utils,
aurora_utility):
def test_writer_failover_in_idle_connections(
self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils, aurora_utility):
target_driver_connect = DriverHelper.get_connect_func(test_driver)
current_writer_id = aurora_utility.get_cluster_writer_instance_id()

Expand All @@ -233,11 +232,11 @@ def test_writer_failover_in_idle_connections(self, test_environment: TestEnviron

for i in range(self.IDLE_CONNECTIONS_NUM):
idle_connections.append(
AwsWrapperConnection.connect(self._init_default_props(test_environment), target_driver_connect,
**props))
AwsWrapperConnection.connect(
target_driver_connect, **conn_utils.get_connect_params(), **props))

with AwsWrapperConnection.connect(self._init_default_props(test_environment), target_driver_connect,
**props) as conn:
with AwsWrapperConnection.connect(
target_driver_connect, **conn_utils.get_connect_params(), **props) as conn:

# Enable autocommit, otherwise each select statement will start a valid transaction.
conn.autocommit = True
Expand All @@ -260,18 +259,16 @@ def test_writer_failover_in_idle_connections(self, test_environment: TestEnviron
for idle_connection in idle_connections:
assert idle_connection.is_closed is True

def test_basic_failover_with_efm(self, test_driver: TestDriver,
test_environment: TestEnvironment,
props, conn_utils,
aurora_utility):
def test_basic_failover_with_efm(
self, test_driver: TestDriver, test_environment: TestEnvironment, props, conn_utils, aurora_utility):
target_driver_connect = DriverHelper.get_connect_func(test_driver)
initial_writer_instance_info = test_environment.get_writer()
nominated_writer_instance_info = test_environment.get_instances()[1]
nominated_writer_id = nominated_writer_instance_info.get_instance_id()

props["plugins"] = "failover,host_monitoring"
with AwsWrapperConnection.connect(self._init_default_props(test_environment), target_driver_connect,
**props) as conn:
with AwsWrapperConnection.connect(
target_driver_connect, **conn_utils.get_connect_params(), **props) as conn:
# Enable autocommit, otherwise each select statement will start a valid transaction.
conn.autocommit = True

Expand All @@ -288,14 +285,3 @@ def test_basic_failover_with_efm(self, test_driver: TestDriver,

assert initial_writer_instance_info.get_instance_id() != current_connection_id
assert next_writer_id == current_connection_id

def _init_default_props(self, test_environment: TestEnvironment) -> str:
database_info: TestDatabaseInfo = test_environment.get_info().get_database_info()
instance: TestInstanceInfo = test_environment.get_writer()
db_name: str = database_info.get_default_db_name()
user: str = database_info.get_username()
password: str = database_info.get_password()
connect_params: str = "host={0} port={1} dbname={2} user={3} password={4}".format(
instance.get_host(), instance.get_port(), db_name, user, password)

return connect_params
16 changes: 8 additions & 8 deletions tests/integration/container/test_basic_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def aurora_utils(self):
return AuroraTestUtility(region)

@pytest.fixture(scope='class')
def props(self):
def efm_props(self):
return {"plugins": "host_monitoring", "connect_timeout": 10}

def test_direct_connection(self, test_environment: TestEnvironment, test_driver: TestDriver, conn_utils):
target_driver_connect = DriverHelper.get_connect_func(test_driver)
conn = target_driver_connect(conn_utils.get_conn_string())
conn = target_driver_connect(**conn_utils.get_connect_params())
cursor = conn.cursor()
cursor.execute("SELECT 1")
result = cursor.fetchone()
Expand All @@ -59,7 +59,7 @@ def test_direct_connection(self, test_environment: TestEnvironment, test_driver:

def test_wrapper_connection(self, test_environment: TestEnvironment, test_driver: TestDriver, conn_utils):
target_driver_connect = DriverHelper.get_connect_func(test_driver)
conn = AwsWrapperConnection.connect(conn_utils.get_conn_string(), target_driver_connect)
conn = AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params())
cursor = conn.cursor()
cursor.execute("SELECT 1")
result = cursor.fetchone()
Expand All @@ -70,7 +70,7 @@ def test_wrapper_connection(self, test_environment: TestEnvironment, test_driver
@enable_on_features([TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED])
def test_proxied_direct_connection(self, test_environment: TestEnvironment, test_driver: TestDriver, conn_utils):
target_driver_connect = DriverHelper.get_connect_func(test_driver)
conn = target_driver_connect(conn_utils.get_proxy_conn_string())
conn = target_driver_connect(**conn_utils.get_proxy_connect_params())
cursor = conn.cursor()
cursor.execute("SELECT 1")
result = cursor.fetchone()
Expand All @@ -81,7 +81,7 @@ def test_proxied_direct_connection(self, test_environment: TestEnvironment, test
@enable_on_features([TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED])
def test_proxied_wrapper_connection(self, test_environment: TestEnvironment, test_driver: TestDriver, conn_utils):
target_driver_connect = DriverHelper.get_connect_func(test_driver)
conn = AwsWrapperConnection.connect(conn_utils.get_proxy_conn_string(), target_driver_connect)
conn = AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_proxy_connect_params())
cursor = conn.cursor()
cursor.execute("SELECT 1")
result = cursor.fetchone()
Expand All @@ -98,7 +98,7 @@ def test_proxied_wrapper_connection_failed(
ProxyHelper.disable_connectivity(instance.get_instance_id())

try:
AwsWrapperConnection.connect(conn_utils.get_proxy_conn_string(), target_driver_connect)
AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_proxy_connect_params())

# Should not be here since proxy is blocking db connectivity
assert False
Expand All @@ -111,10 +111,10 @@ def test_proxied_wrapper_connection_failed(
@enable_on_deployment(DatabaseEngineDeployment.AURORA)
@disable_on_features([TestEnvironmentFeatures.PERFORMANCE])
def test_wrapper_connection_reader_cluster_with_efm_enabled(
self, test_driver: TestDriver, props, conn_utils):
self, test_driver: TestDriver, efm_props, conn_utils):
target_driver_connect = DriverHelper.get_connect_func(test_driver)
conn = AwsWrapperConnection.connect(
conn_utils.get_conn_string(conn_utils.reader_cluster_host), target_driver_connect, **props)
target_driver_connect, **conn_utils.get_connect_params(conn_utils.reader_cluster_host), **efm_props)
cursor = conn.cursor()
cursor.execute("SELECT 1")
result = cursor.fetchone()
Expand Down
Loading

0 comments on commit 6907ba0

Please sign in to comment.