Skip to content

Commit

Permalink
fix: mysql support (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
karenc-bq committed Sep 22, 2023
1 parent 00f52e0 commit 834bd89
Show file tree
Hide file tree
Showing 17 changed files with 221 additions and 115 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
- name: 'Set up Temp AWS Credentials'
run: |
creds=($(aws sts get-session-token \
--duration-seconds 7200 \
--duration-seconds 21600 \
--query 'Credentials.[AccessKeyId, SecretAccessKey, SessionToken]' \
--output text \
| xargs));
Expand All @@ -54,13 +54,14 @@ jobs:
- name: 'Run Integration Tests'
run: |
./gradlew --no-parallel --no-daemon test-all-environments --info
./gradlew --no-parallel --no-daemon test-mysql-aurora --info
env:
AURORA_CLUSTER_DOMAIN: ${{ secrets.DB_CONN_SUFFIX }}
AURORA_DB_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
AWS_ACCESS_KEY_ID: ${{ env.TEMP_AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ env.TEMP_AWS_SECRET_ACCESS_KEY }}
AWS_SESSION_TOKEN: ${{ env.TEMP_AWS_SESSION_TOKEN }}
NUM_INSTANCES: 5

- name: 'Archive results'
if: always()
Expand Down
6 changes: 3 additions & 3 deletions aws_wrapper/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from aws_wrapper.utils.properties import (Properties, PropertiesUtils,
WrapperProperties)
from aws_wrapper.utils.rdsutils import RdsUtils
from .exceptions import ExceptionHandler, PgExceptionHandler
from .exceptions import (ExceptionHandler, MySQLExceptionHandler,
PgExceptionHandler)
from .target_driver_dialect import TargetDriverDialectCodes
from .utils.cache_map import CacheMap
from .utils.messages import Messages
Expand Down Expand Up @@ -149,8 +150,7 @@ def server_version_query(self) -> str:

@property
def exception_handler(self) -> Optional[ExceptionHandler]:
# TODO
return None
return MySQLExceptionHandler()

def is_dialect(self, conn: Connection) -> bool:
try:
Expand Down
73 changes: 68 additions & 5 deletions aws_wrapper/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from typing import TYPE_CHECKING

import mysql

from aws_wrapper.errors import QueryTimeoutError

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,10 +63,13 @@ class PgExceptionHandler(ExceptionHandler):
def is_network_exception(self, error: Optional[Exception] = None, sql_state: Optional[str] = None) -> bool:
if isinstance(error, QueryTimeoutError) or isinstance(error, ConnectionTimeout):
return True
if sql_state is None:
error_sql_state = getattr(error, "sqlstate")
if error_sql_state is not None:
sql_state = error_sql_state

if sql_state:
if sql_state in self._NETWORK_ERRORS:
return True
if sql_state is not None and sql_state in self._NETWORK_ERRORS:
return True

if isinstance(error, OperationalError):
if len(error.args) == 0:
Expand All @@ -81,6 +86,12 @@ def is_login_exception(self, error: Optional[Exception] = None, sql_state: Optio
if isinstance(error, InvalidAuthorizationSpecification) or isinstance(error, InvalidPassword):
return True

if sql_state is None and hasattr(error, "sqlstate") and error.sqlstate is not None:
sql_state = error.sqlstate

if sql_state is not None and sql_state in self._ACCESS_ERRORS:
return True

if isinstance(error, OperationalError):
if len(error.args) == 0:
return False
Expand All @@ -91,10 +102,62 @@ def is_login_exception(self, error: Optional[Exception] = None, sql_state: Optio
or self._PAM_AUTHENTICATION_FAILED_MSG in error_msg:
return True

if sql_state:
if sql_state in self._ACCESS_ERRORS:
return False


class MySQLExceptionHandler(ExceptionHandler):
_PAM_AUTHENTICATION_FAILED_MSG = "PAM authentication failed"
_UNAVAILABLE_CONNECTION = "MySQL Connection not available"

_NETWORK_ERRORS: List[int] = [
2001, # Can't create UNIX socket
2002, # Can't connect to local MySQL server through socket
2003, # Can't connect to MySQL server
2004, # Can't create TCP/IP socket
2006, # MySQL server has gone away
2012, # Error in server handshake
2013, # unexpected error
2026, # SSL connection error
2055, # Lost connection to MySQL server
]

def is_network_exception(self, error: Optional[Exception] = None, sql_state: Optional[str] = None) -> bool:
if isinstance(error, QueryTimeoutError):
return True

if sql_state is None:
if hasattr(error, "sqlstate"):
error_sql_state = getattr(error, "sqlstate")
if error_sql_state is not None:
sql_state = error_sql_state

if sql_state is not None:
if sql_state.startswith("08") or sql_state.startswith("HY"):
# Connection exceptions may also be returned as a generic error
# e.g. 2013 (HY000): Lost connection to MySQL server during query
return True

if isinstance(error, mysql.connector.errors.OperationalError):
if error.errno in self._NETWORK_ERRORS:
return True
if error.msg is not None and self._UNAVAILABLE_CONNECTION in error.msg:
return True

if len(error.args) == 1:
return self._UNAVAILABLE_CONNECTION in error.args[0]

return False

def is_login_exception(self, error: Optional[Exception] = None, sql_state: Optional[str] = None) -> bool:
if sql_state is None:
if hasattr(error, "sqlstate"):
error_sql_state = getattr(error, "sqlstate")
if error_sql_state is not None:
sql_state = error_sql_state

if "28000" == sql_state:
return True

return False


Expand Down
22 changes: 16 additions & 6 deletions aws_wrapper/failover_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from logging import getLogger
from typing import Any, Callable, Dict, List, Optional, Set

from psycopg import OperationalError

from aws_wrapper.errors import (AwsWrapperError, FailoverFailedError,
FailoverSuccessError,
TransactionResolutionUnknownError)
Expand Down Expand Up @@ -55,6 +53,16 @@ class FailoverPlugin(Plugin):
"force_connect",
"notify_host_list_changed"}

_METHODS_REQUIRE_UPDATED_TOPOLOGY: Set[str] = {
"Connection.commit",
"Connection.autocommit",
"Connection.autocommit_setter",
"Connection.rollback",
"Connection.cursor",
"Cursor.callproc",
"Cursor.execute"
}

def __init__(self, plugin_service: PluginService, props: Properties):
self._plugin_service = plugin_service
self._properties = props
Expand Down Expand Up @@ -134,7 +142,8 @@ def execute(self, target: type, method_name: str, execute_func: Callable, *args:
self._invalid_invocation_on_closed_connection()

try:
self._update_topology(False)
if self._requires_update_topology(method_name):
self._update_topology(False)
return execute_func()
except Exception as ex:
msg = Messages.get_formatted("FailoverPlugin.DetectedException", str(ex))
Expand Down Expand Up @@ -375,9 +384,6 @@ def _should_exception_trigger_connection_switch(self, ex: Exception) -> bool:
logger.debug(Messages.get_formatted("FailoverPlugin.FailoverDisabled"))
return False

if isinstance(ex, OperationalError):
return True

return self._plugin_service.is_network_exception(ex)

@staticmethod
Expand All @@ -401,6 +407,7 @@ def _is_node_still_valid(node: str, changes: Dict[str, Set[HostEvent]]):
def _can_direct_execute(method_name):
# TODO: adjust method names to proper python method names
return method_name == "Connection.close" or \
method_name == "Cursor.close" or \
method_name == "Connection.abort" or \
method_name == "Connection.isClosed"

Expand All @@ -412,6 +419,9 @@ def _allowed_on_closed_connection(method_name: str):
method_name == "Connection.getSchema" or \
method_name == "Connection.getTransactionIsolation"

def _requires_update_topology(self, method_name: str):
return method_name in FailoverPlugin._METHODS_REQUIRE_UPDATED_TOPOLOGY


class FailoverPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
Expand Down
10 changes: 9 additions & 1 deletion aws_wrapper/generic_target_driver_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from aws_wrapper.errors import UnsupportedOperationError
from aws_wrapper.target_driver_dialect_codes import TargetDriverDialectCodes
from aws_wrapper.utils.messages import Messages
from aws_wrapper.utils.properties import Properties, PropertiesUtils
from aws_wrapper.utils.properties import (Properties, PropertiesUtils,
WrapperProperties)

if TYPE_CHECKING:
from aws_wrapper.hostinfo import HostInfo
Expand Down Expand Up @@ -67,6 +68,10 @@ def supports_socket_timeout(self) -> bool:
def supports_tcp_keepalive(self) -> bool:
return False

@abstractmethod
def set_password(self, props: Properties, pwd: str):
pass

@abstractmethod
def is_dialect(self, connect_func: Callable) -> bool:
pass
Expand Down Expand Up @@ -120,6 +125,9 @@ def prepare_connect_info(self, host_info: HostInfo, props: Properties) -> Proper
PropertiesUtils.remove_wrapper_props(prop_copy)
return prop_copy

def set_password(self, props: Properties, pwd: str):
WrapperProperties.PASSWORD.set(props, pwd)

def is_closed(self, conn: Connection) -> bool:
raise UnsupportedOperationError(Messages.get_formatted("TargetDriverDialect.UnsupportedOperationError", self._driver_name, "is_closed"))

Expand Down
17 changes: 10 additions & 7 deletions aws_wrapper/iam_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl

token_info = IamAuthPlugin._token_cache.get(cache_key)

if token_info and not token_info.is_expired():
if token_info is not None and not token_info.is_expired():
logger.debug(Messages.get_formatted("IamAuthPlugin.UseCachedIamToken", token_info.token))
WrapperProperties.PASSWORD.set(props, token_info.token)
self._plugin_service.target_driver_dialect.set_password(props, token_info.token)
else:
token: str = self._generate_authentication_token(props, host, port, region)
logger.debug(Messages.get_formatted("IamAuthPlugin.GeneratedNewIamToken", token))
WrapperProperties.PASSWORD.set(props, token)
self._plugin_service.target_driver_dialect.set_password(props, token)
IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, datetime.now() + timedelta(
seconds=token_expiration_sec))

Expand All @@ -115,7 +115,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
except Exception as e:
logger.debug(Messages.get_formatted("IamAuthPlugin.ConnectException", e))

is_cached_token = (token_info and not token_info.is_expired())
is_cached_token = (token_info is not None and not token_info.is_expired())
if not self._plugin_service.is_login_exception(error=e) or not is_cached_token:
raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.ConnectException", e)) from e

Expand All @@ -124,7 +124,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl

token = self._generate_authentication_token(props, host, port, region)
logger.debug(Messages.get_formatted("IamAuthPlugin.GeneratedNewIamToken", token))
WrapperProperties.PASSWORD.set(props, token)
self._plugin_service.target_driver_dialect.set_password(props, token)
IamAuthPlugin._token_cache[token] = TokenInfo(token, datetime.now() + timedelta(
seconds=token_expiration_sec))

Expand Down Expand Up @@ -179,8 +179,11 @@ def _get_port(self, props: Properties, host_info: HostInfo) -> int:

if host_info.is_port_specified():
return host_info.port
else:
return 5432 # TODO: update after implementing the dialect class

if self._plugin_service.dialect is not None:
return self._plugin_service.dialect.default_port

raise AwsWrapperError(Messages.get("IamAuthPlugin.NoValidPorts"))

def _get_rds_region(self, hostname: Optional[str]) -> str:
rds_region = self._rds_utils.get_rds_region(hostname) if hostname else None
Expand Down
12 changes: 11 additions & 1 deletion aws_wrapper/mysql_target_driver_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
class MySQLTargetDriverDialect(GenericTargetDriverDialect):
_driver_name = "MySQL Connector Python"
TARGET_DRIVER_CODE = "MySQL"
AUTH_PLUGIN_PARAM = "auth_plugin"
AUTH_METHOD = "mysql_clear_password"

_dialect_code: str = TargetDriverDialectCodes.MYSQL_CONNECTOR_PYTHON
_network_bound_methods: Set[str] = {
Expand Down Expand Up @@ -78,6 +80,10 @@ def set_autocommit(self, conn: Connection, autocommit: bool):
raise UnsupportedOperationError(
Messages.get_formatted("TargetDriverDialect.UnsupportedOperationError", self._driver_name, "autocommit"))

def set_password(self, props: Properties, pwd: str):
WrapperProperties.PASSWORD.set(props, pwd)
props[MySQLTargetDriverDialect.AUTH_PLUGIN_PARAM] = MySQLTargetDriverDialect.AUTH_METHOD

def abort_connection(self, conn: Connection):
raise UnsupportedOperationError(
Messages.get_formatted(
Expand All @@ -98,7 +104,11 @@ def get_connection_from_obj(self, obj: object) -> Any:
return obj

if isinstance(obj, CMySQLCursor):
return obj._cnx
try:
if isinstance(obj._cnx, CMySQLConnection) or isinstance(obj._cnx, MySQLConnection):
return obj._cnx
except ReferenceError:
return None

if isinstance(obj, MySQLCursor):
return obj._connection
Expand Down
1 change: 0 additions & 1 deletion aws_wrapper/pep249.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def execute(
) -> Cursor:
...

# TODO: check parameters
def executemany(
self,
*args,
Expand Down
9 changes: 6 additions & 3 deletions aws_wrapper/plugin_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,12 @@ def execute(self, target: object, method_name: str, target_driver_func: Callable
conn: Optional[Connection] = target_driver_dialect.get_connection_from_obj(target)
current_conn: Optional[Connection] = target_driver_dialect.unwrap_connection(plugin_service.current_connection)

if conn is not None and conn != current_conn and method_name != "Connection.close" and method_name != "Cursor.close":
msg = Messages.get_formatted("PluginManager.MethodInvokedAgainstOldConnection", target)
raise AwsWrapperError(msg)
if method_name != "Connection.close" and method_name != "Cursor.close" and conn is not None and conn != current_conn:
raise AwsWrapperError(Messages.get_formatted("PluginManager.MethodInvokedAgainstOldConnection", target))

if conn is None and (method_name == "Connection.close" or method_name == "Cursor.close"):
return

return self._execute_with_subscribed_plugins(
method_name,
# next_plugin_func is defined later in make_pipeline
Expand Down
15 changes: 14 additions & 1 deletion aws_wrapper/resources/messages.properties
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ HostSelector.NoHostsMatchingRole=[HostSelector] No hosts were found matching the
IamAuthPlugin.ConnectException=[IamAuthPlugin] Error occurred while opening a connection: {}
IamAuthPlugin.GeneratedNewIamToken=[IamAuthPlugin] Generated new IAM token = {}
IamAuthPlugin.InvalidPort=[IamAuthPlugin] Port number: {} is not valid. Port number should be greater than zero. Falling back to default port.
IamAuthPlugin.NoValidPort=[IamAuthPlugin] Unable to determine a valid port.
IamAuthPlugin.UnhandledException=[IamAuthPlugin] Unhandled exception: {}
IamAuthPlugin.UnsupportedHostname=[IamAuthPlugin] Unsupported AWS hostname {}. Amazon domain name in format *.AWS-Region.rds.amazonaws.com or *.rds.AWS-Region.amazonaws.com.cn is expected.
IamAuthPlugin.UseCachedIamToken=[IamAuthPlugin] Used cached IAM token = {}

IamPlugin.IsNullOrEmpty=[IamPlugin] Property "{}" is null or empty.

Monitor.NullContext=[Monitor] Parameter 'context' should not evaluate to None.
Expand Down Expand Up @@ -134,6 +134,19 @@ PluginManager.InvalidPlugin=[PluginManager] Invalid plugin requested: '{}'.
PluginManager.MethodInvokedAgainstOldConnection = [PluginManager] The internal connection has changed since '{}' was created. This is likely due to failover or read-write splitting functionality. To ensure you are using the updated connection, please re-create Cursor objects after failover and/or setting readonly.
PluginManager.NullPipeline=[PluginManager] A pipeline was requested but the created pipeline evaluated to None.

Failover.ConnectionChangedError=The active SQL connection has changed due to a connection failure. Please re-configure session state if required.
Failover.DetectedException=Detected an exception while executing a command: {}
Failover.EstablishedConnection=Connected to: {}
Failover.FailoverDisabled=Cluster-aware failover is disabled.
Failover.InvalidNode=Node is no longer available in the topology: {}
Failover.NoOperationsAfterConnectionClosed=No operations allowed after connection closed.
Failover.ParameterValue={}={}
Failover.StartReaderFailover=Starting reader failover procedure.
Failover.StartWriterFailover=Starting writer failover procedure.
Failover.TransactionResolutionUnknownError=Transaction resolution unknown. Please re-configure session state if required and try restarting the transaction.
Failover.UnableToConnectToReader=Unable to establish SQL connection to the reader instance.
Failover.UnableToConnectToWriter=Unable to establish SQL connection to the writer instance.

PluginServiceImpl.FailedToRetrieveHostPort=[PluginServiceImpl] Could not retrieve Host:Port for connection. {}
PluginServiceImpl.NonEmptyAliases=[PluginServiceImpl] fill_aliases called when HostInfo already contains the following aliases: {}.
PluginServiceImpl.UnableToUpdateTransactionStatus=[PluginServiceImpl] Unable to update transaction status, current connection is None.
Expand Down
7 changes: 5 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
[mypy]
exclude = (?x)(
debug_integration_.*\.py$
)
debug_integration_.*\.py$
)

[mypy-mysql]
ignore_missing_imports = True

[mypy-parameterized]
ignore_missing_imports = True
Expand Down
Loading

0 comments on commit 834bd89

Please sign in to comment.