Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: mysql support #115

Merged
merged 4 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
- name: 'Set up Temp AWS Credentials'
run: |
creds=($(aws sts get-session-token \
--duration-seconds 7200 \
--duration-seconds 21600 \
--query 'Credentials.[AccessKeyId, SecretAccessKey, SessionToken]' \
--output text \
| xargs));
Expand All @@ -54,13 +54,14 @@ jobs:

- name: 'Run Integration Tests'
run: |
./gradlew --no-parallel --no-daemon test-all-environments --info
./gradlew --no-parallel --no-daemon test-mysql-aurora --info
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporary change to run integration tests

env:
AURORA_CLUSTER_DOMAIN: ${{ secrets.DB_CONN_SUFFIX }}
AURORA_DB_REGION: ${{ secrets.AWS_DEFAULT_REGION }}
AWS_ACCESS_KEY_ID: ${{ env.TEMP_AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ env.TEMP_AWS_SECRET_ACCESS_KEY }}
AWS_SESSION_TOKEN: ${{ env.TEMP_AWS_SESSION_TOKEN }}
NUM_INSTANCES: 5

- name: 'Archive results'
if: always()
Expand Down
6 changes: 3 additions & 3 deletions aws_wrapper/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from aws_wrapper.utils.properties import (Properties, PropertiesUtils,
WrapperProperties)
from aws_wrapper.utils.rdsutils import RdsUtils
from .exceptions import ExceptionHandler, PgExceptionHandler
from .exceptions import (ExceptionHandler, MySQLExceptionHandler,
PgExceptionHandler)
from .target_driver_dialect import TargetDriverDialectCodes
from .utils.cache_map import CacheMap
from .utils.messages import Messages
Expand Down Expand Up @@ -149,8 +150,7 @@ def server_version_query(self) -> str:

@property
def exception_handler(self) -> Optional[ExceptionHandler]:
# TODO
return None
return MySQLExceptionHandler()

def is_dialect(self, conn: Connection) -> bool:
try:
Expand Down
73 changes: 68 additions & 5 deletions aws_wrapper/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from typing import TYPE_CHECKING

import mysql

from aws_wrapper.errors import QueryTimeoutError

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,10 +63,13 @@ class PgExceptionHandler(ExceptionHandler):
def is_network_exception(self, error: Optional[Exception] = None, sql_state: Optional[str] = None) -> bool:
if isinstance(error, QueryTimeoutError) or isinstance(error, ConnectionTimeout):
return True
if sql_state is None:
error_sql_state = getattr(error, "sqlstate")
if error_sql_state is not None:
sql_state = error_sql_state

if sql_state:
if sql_state in self._NETWORK_ERRORS:
return True
if sql_state is not None and sql_state in self._NETWORK_ERRORS:
return True

if isinstance(error, OperationalError):
if len(error.args) == 0:
Expand All @@ -81,6 +86,12 @@ def is_login_exception(self, error: Optional[Exception] = None, sql_state: Optio
if isinstance(error, InvalidAuthorizationSpecification) or isinstance(error, InvalidPassword):
return True

if sql_state is None and hasattr(error, "sqlstate") and error.sqlstate is not None:
sql_state = error.sqlstate

if sql_state is not None and sql_state in self._ACCESS_ERRORS:
return True

if isinstance(error, OperationalError):
if len(error.args) == 0:
return False
Expand All @@ -91,10 +102,62 @@ def is_login_exception(self, error: Optional[Exception] = None, sql_state: Optio
or self._PAM_AUTHENTICATION_FAILED_MSG in error_msg:
return True

if sql_state:
if sql_state in self._ACCESS_ERRORS:
return False


class MySQLExceptionHandler(ExceptionHandler):
_PAM_AUTHENTICATION_FAILED_MSG = "PAM authentication failed"
_UNAVAILABLE_CONNECTION = "MySQL Connection not available"

_NETWORK_ERRORS: List[int] = [
2001, # Can't create UNIX socket
2002, # Can't connect to local MySQL server through socket
2003, # Can't connect to MySQL server
2004, # Can't create TCP/IP socket
2006, # MySQL server has gone away
2012, # Error in server handshake
2013, # unexpected error
2026, # SSL connection error
2055, # Lost connection to MySQL server
]

def is_network_exception(self, error: Optional[Exception] = None, sql_state: Optional[str] = None) -> bool:
if isinstance(error, QueryTimeoutError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does MySQL have a ConnectionTimeoutError that we should check for like PG? Or is the connection timeout caught in the logic below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return True

if sql_state is None:
if hasattr(error, "sqlstate"):
error_sql_state = getattr(error, "sqlstate")
if error_sql_state is not None:
sql_state = error_sql_state

if sql_state is not None:
if sql_state.startswith("08") or sql_state.startswith("HY"):
# Connection exceptions may also be returned as a generic error
# e.g. 2013 (HY000): Lost connection to MySQL server during query
return True

if isinstance(error, mysql.connector.errors.OperationalError):
if error.errno in self._NETWORK_ERRORS:
return True
if error.msg is not None and self._UNAVAILABLE_CONNECTION in error.msg:
return True

if len(error.args) == 1:
return self._UNAVAILABLE_CONNECTION in error.args[0]

return False

def is_login_exception(self, error: Optional[Exception] = None, sql_state: Optional[str] = None) -> bool:
if sql_state is None:
if hasattr(error, "sqlstate"):
error_sql_state = getattr(error, "sqlstate")
if error_sql_state is not None:
sql_state = error_sql_state

if "28000" == sql_state:
return True

return False


Expand Down
22 changes: 16 additions & 6 deletions aws_wrapper/failover_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from logging import getLogger
from typing import Any, Callable, Dict, List, Optional, Set

from psycopg import OperationalError

from aws_wrapper.errors import (AwsWrapperError, FailoverFailedError,
FailoverSuccessError,
TransactionResolutionUnknownError)
Expand Down Expand Up @@ -55,6 +53,16 @@ class FailoverPlugin(Plugin):
"force_connect",
"notify_host_list_changed"}

_METHODS_REQUIRE_UPDATED_TOPOLOGY: Set[str] = {
"Connection.commit",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how did we come up with this method list, and why is it necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mysql specific thing awslabs/aws-mysql-jdbc#364

"Connection.autocommit",
"Connection.autocommit_setter",
"Connection.rollback",
"Connection.cursor",
"Cursor.callproc",
"Cursor.execute"
}

def __init__(self, plugin_service: PluginService, props: Properties):
self._plugin_service = plugin_service
self._properties = props
Expand Down Expand Up @@ -134,7 +142,8 @@ def execute(self, target: type, method_name: str, execute_func: Callable, *args:
self._invalid_invocation_on_closed_connection()

try:
self._update_topology(False)
if self._requires_update_topology(method_name):
self._update_topology(False)
return execute_func()
except Exception as ex:
msg = Messages.get_formatted("FailoverPlugin.DetectedException", str(ex))
Expand Down Expand Up @@ -375,9 +384,6 @@ def _should_exception_trigger_connection_switch(self, ex: Exception) -> bool:
logger.debug(Messages.get_formatted("FailoverPlugin.FailoverDisabled"))
return False

if isinstance(ex, OperationalError):
return True

return self._plugin_service.is_network_exception(ex)

@staticmethod
Expand All @@ -401,6 +407,7 @@ def _is_node_still_valid(node: str, changes: Dict[str, Set[HostEvent]]):
def _can_direct_execute(method_name):
# TODO: adjust method names to proper python method names
return method_name == "Connection.close" or \
method_name == "Cursor.close" or \
method_name == "Connection.abort" or \
method_name == "Connection.isClosed"

Expand All @@ -412,6 +419,9 @@ def _allowed_on_closed_connection(method_name: str):
method_name == "Connection.getSchema" or \
method_name == "Connection.getTransactionIsolation"

def _requires_update_topology(self, method_name: str):
return method_name in FailoverPlugin._METHODS_REQUIRE_UPDATED_TOPOLOGY


class FailoverPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
Expand Down
10 changes: 9 additions & 1 deletion aws_wrapper/generic_target_driver_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from aws_wrapper.errors import UnsupportedOperationError
from aws_wrapper.target_driver_dialect_codes import TargetDriverDialectCodes
from aws_wrapper.utils.messages import Messages
from aws_wrapper.utils.properties import Properties, PropertiesUtils
from aws_wrapper.utils.properties import (Properties, PropertiesUtils,
WrapperProperties)

if TYPE_CHECKING:
from aws_wrapper.hostinfo import HostInfo
Expand Down Expand Up @@ -67,6 +68,10 @@ def supports_socket_timeout(self) -> bool:
def supports_tcp_keepalive(self) -> bool:
return False

@abstractmethod
def set_password(self, props: Properties, pwd: str):
pass

@abstractmethod
def is_dialect(self, connect_func: Callable) -> bool:
pass
Expand Down Expand Up @@ -120,6 +125,9 @@ def prepare_connect_info(self, host_info: HostInfo, props: Properties) -> Proper
PropertiesUtils.remove_wrapper_props(prop_copy)
return prop_copy

def set_password(self, props: Properties, pwd: str):
WrapperProperties.PASSWORD.set(props, pwd)

def is_closed(self, conn: Connection) -> bool:
raise UnsupportedOperationError(Messages.get_formatted("TargetDriverDialect.UnsupportedOperationError", self._driver_name, "is_closed"))

Expand Down
17 changes: 10 additions & 7 deletions aws_wrapper/iam_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl

token_info = IamAuthPlugin._token_cache.get(cache_key)

if token_info and not token_info.is_expired():
if token_info is not None and not token_info.is_expired():
logger.debug(Messages.get_formatted("IamAuthPlugin.UseCachedIamToken", token_info.token))
WrapperProperties.PASSWORD.set(props, token_info.token)
self._plugin_service.target_driver_dialect.set_password(props, token_info.token)
else:
token: str = self._generate_authentication_token(props, host, port, region)
logger.debug(Messages.get_formatted("IamAuthPlugin.GeneratedNewIamToken", token))
WrapperProperties.PASSWORD.set(props, token)
self._plugin_service.target_driver_dialect.set_password(props, token)
IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, datetime.now() + timedelta(
seconds=token_expiration_sec))

Expand All @@ -115,7 +115,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
except Exception as e:
logger.debug(Messages.get_formatted("IamAuthPlugin.ConnectException", e))

is_cached_token = (token_info and not token_info.is_expired())
is_cached_token = (token_info is not None and not token_info.is_expired())
if not self._plugin_service.is_login_exception(error=e) or not is_cached_token:
raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.ConnectException", e)) from e

Expand All @@ -124,7 +124,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl

token = self._generate_authentication_token(props, host, port, region)
logger.debug(Messages.get_formatted("IamAuthPlugin.GeneratedNewIamToken", token))
WrapperProperties.PASSWORD.set(props, token)
self._plugin_service.target_driver_dialect.set_password(props, token)
IamAuthPlugin._token_cache[token] = TokenInfo(token, datetime.now() + timedelta(
seconds=token_expiration_sec))

Expand Down Expand Up @@ -179,8 +179,11 @@ def _get_port(self, props: Properties, host_info: HostInfo) -> int:

if host_info.is_port_specified():
return host_info.port
else:
return 5432 # TODO: update after implementing the dialect class

if self._plugin_service.dialect is not None:
return self._plugin_service.dialect.default_port

raise AwsWrapperError(Messages.get("IamAuthPlugin.NoValidPorts"))

def _get_rds_region(self, hostname: Optional[str]) -> str:
rds_region = self._rds_utils.get_rds_region(hostname) if hostname else None
Expand Down
12 changes: 11 additions & 1 deletion aws_wrapper/mysql_target_driver_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
class MySQLTargetDriverDialect(GenericTargetDriverDialect):
_driver_name = "MySQL Connector Python"
TARGET_DRIVER_CODE = "MySQL"
AUTH_PLUGIN_PARAM = "auth_plugin"
AUTH_METHOD = "mysql_clear_password"

_dialect_code: str = TargetDriverDialectCodes.MYSQL_CONNECTOR_PYTHON
_network_bound_methods: Set[str] = {
Expand Down Expand Up @@ -78,6 +80,10 @@ def set_autocommit(self, conn: Connection, autocommit: bool):
raise UnsupportedOperationError(
Messages.get_formatted("TargetDriverDialect.UnsupportedOperationError", self._driver_name, "autocommit"))

def set_password(self, props: Properties, pwd: str):
WrapperProperties.PASSWORD.set(props, pwd)
props[MySQLTargetDriverDialect.AUTH_PLUGIN_PARAM] = MySQLTargetDriverDialect.AUTH_METHOD

def abort_connection(self, conn: Connection):
raise UnsupportedOperationError(
Messages.get_formatted(
Expand All @@ -98,7 +104,11 @@ def get_connection_from_obj(self, obj: object) -> Any:
return obj

if isinstance(obj, CMySQLCursor):
return obj._cnx
try:
if isinstance(obj._cnx, CMySQLConnection) or isinstance(obj._cnx, MySQLConnection):
return obj._cnx
except ReferenceError:
return None

if isinstance(obj, MySQLCursor):
return obj._connection
Expand Down
3 changes: 0 additions & 3 deletions aws_wrapper/pep249.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def connect(
def close(self) -> None:
...

# TODO: check parameters
def cursor(self, **kwargs) -> Cursor:
...

Expand Down Expand Up @@ -161,15 +160,13 @@ def close(self) -> None:
def callproc(self, **kwargs):
...

# TODO: check parameters
def execute(
self,
query: str,
**kwargs
) -> Cursor:
...

# TODO: check parameters
def executemany(
self,
query: str,
Expand Down
8 changes: 5 additions & 3 deletions aws_wrapper/plugin_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,11 @@ def execute(self, target: object, method_name: str, target_driver_func: Callable
conn: Optional[Connection] = target_driver_dialect.get_connection_from_obj(target)
current_conn: Optional[Connection] = target_driver_dialect.unwrap_connection(plugin_service.current_connection)

if conn is not None and conn != current_conn and method_name != "Connection.close" and method_name != "Cursor.close":
msg = Messages.get_formatted("PluginManager.MethodInvokedAgainstOldConnection", target)
raise AwsWrapperError(msg)
if method_name != "Connection.close" and method_name != "Cursor.close" and conn is not None and conn != current_conn:
raise AwsWrapperError(Messages.get_formatted("PluginManager.MethodInvokedAgainstOldConnection", target))

if conn is None and (method_name == "Connection.close" or method_name == "Cursor.close"):
return

return self._execute_with_subscribed_plugins(
method_name,
Expand Down
15 changes: 14 additions & 1 deletion aws_wrapper/resources/messages.properties
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ HostSelector.NoHostsMatchingRole=[HostSelector] No hosts were found matching the
IamAuthPlugin.ConnectException=[IamAuthPlugin] Error occurred while opening a connection: {}
IamAuthPlugin.GeneratedNewIamToken=[IamAuthPlugin] Generated new IAM token = {}
IamAuthPlugin.InvalidPort=[IamAuthPlugin] Port number: {} is not valid. Port number should be greater than zero. Falling back to default port.
IamAuthPlugin.NoValidPort=[IamAuthPlugin] Unable to determine a valid port.
IamAuthPlugin.UnhandledException=[IamAuthPlugin] Unhandled exception: {}
IamAuthPlugin.UnsupportedHostname=[IamAuthPlugin] Unsupported AWS hostname {}. Amazon domain name in format *.AWS-Region.rds.amazonaws.com or *.rds.AWS-Region.amazonaws.com.cn is expected.
IamAuthPlugin.UseCachedIamToken=[IamAuthPlugin] Used cached IAM token = {}

IamPlugin.IsNullOrEmpty=[IamPlugin] Property "{}" is null or empty.

Monitor.NullContext=[Monitor] Parameter 'context' should not evaluate to None.
Expand Down Expand Up @@ -134,6 +134,19 @@ PluginManager.InvalidPlugin=[PluginManager] Invalid plugin requested: '{}'.
PluginManager.MethodInvokedAgainstOldConnection = [PluginManager] The internal connection has changed since '{}' was created. This is likely due to failover or read-write splitting functionality. To ensure you are using the updated connection, please re-create Cursor objects after failover and/or setting readonly.
PluginManager.NullPipeline=[PluginManager] A pipeline was requested but the created pipeline evaluated to None.

Failover.ConnectionChangedError=The active SQL connection has changed due to a connection failure. Please re-configure session state if required.
Failover.DetectedException=Detected an exception while executing a command: {}
Failover.EstablishedConnection=Connected to: {}
Failover.FailoverDisabled=Cluster-aware failover is disabled.
Failover.InvalidNode=Node is no longer available in the topology: {}
Failover.NoOperationsAfterConnectionClosed=No operations allowed after connection closed.
Failover.ParameterValue={}={}
Failover.StartReaderFailover=Starting reader failover procedure.
Failover.StartWriterFailover=Starting writer failover procedure.
Failover.TransactionResolutionUnknownError=Transaction resolution unknown. Please re-configure session state if required and try restarting the transaction.
Failover.UnableToConnectToReader=Unable to establish SQL connection to the reader instance.
Failover.UnableToConnectToWriter=Unable to establish SQL connection to the writer instance.

PluginServiceImpl.FailedToRetrieveHostPort=[PluginServiceImpl] Could not retrieve Host:Port for connection. {}
PluginServiceImpl.NonEmptyAliases=[PluginServiceImpl] fill_aliases called when HostInfo already contains the following aliases: {}.
PluginServiceImpl.UnableToUpdateTransactionStatus=[PluginServiceImpl] Unable to update transaction status, current connection is None.
Expand Down
Loading