diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index 38716d4eb5f55..b4b726ca327ab 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -30,23 +30,22 @@ from sqlalchemy.dialects import mssql, mysql from sqlalchemy.exc import OperationalError from sqlalchemy.sql import ColumnElement, Select -from sqlalchemy.sql.expression import ColumnOperators from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText from airflow import settings from airflow.configuration import conf from airflow.serialization.enums import Encoding +from airflow.utils.timezone import make_naive if TYPE_CHECKING: from kubernetes.client.models.v1_pod import V1Pod from sqlalchemy.orm import Query, Session + from sqlalchemy.sql.expression import ColumnOperators log = logging.getLogger(__name__) utc = pendulum.tz.timezone("UTC") -using_mysql = conf.get_mandatory_value("database", "sql_alchemy_conn").lower().startswith("mysql") - class UtcDateTime(TypeDecorator): """ @@ -67,22 +66,18 @@ class UtcDateTime(TypeDecorator): cache_ok = True def process_bind_param(self, value, dialect): - if value is not None: - if not isinstance(value, datetime.datetime): - raise TypeError("expected datetime.datetime, not " + repr(value)) - elif value.tzinfo is None: - raise ValueError("naive datetime is disallowed") + if not isinstance(value, datetime.datetime): + if value is None: + return None + raise TypeError("expected datetime.datetime, not " + repr(value)) + elif value.tzinfo is None: + raise ValueError("naive datetime is disallowed") + elif dialect.name == "mysql": # For mysql we should store timestamps as naive values - # Timestamp in MYSQL is not timezone aware. In MySQL 5.6 - # timezone added at the end is ignored but in MySQL 5.7 - # inserting timezone value fails with 'invalid-date' + # In MySQL 5.7 inserting timezone value fails with 'invalid-date' # See https://issues.apache.org/jira/browse/AIRFLOW-7001 - if using_mysql: - from airflow.utils.timezone import make_naive - - return make_naive(value, timezone=utc) - return value.astimezone(utc) - return None + return make_naive(value, timezone=utc) + return value.astimezone(utc) def process_result_value(self, value, dialect): """ @@ -119,12 +114,8 @@ class ExtendedJSON(TypeDecorator): cache_ok = True - def db_supports_json(self): - """Check if the database supports JSON (i.e. is NOT MSSQL).""" - return not conf.get("database", "sql_alchemy_conn").startswith("mssql") - def load_dialect_impl(self, dialect) -> TypeEngine: - if self.db_supports_json(): + if dialect.name != "mssql": return dialect.type_descriptor(JSON) return dialect.type_descriptor(UnicodeText) @@ -138,7 +129,7 @@ def process_bind_param(self, value, dialect): value = BaseSerialization.serialize(value) # Then, if the database does not have native JSON support, encode it again as a string - if not self.db_supports_json(): + if dialect.name == "mssql": value = json.dumps(value) return value @@ -150,7 +141,7 @@ def process_result_value(self, value, dialect): return None # Deserialize from a string first if needed - if not self.db_supports_json(): + if dialect.name == "mssql": value = json.loads(value) return BaseSerialization.deserialize(value)