Skip to content

Commit

Permalink
Use dialect.name in custom SA types (#33503)
Browse files Browse the repository at this point in the history
* Use `dialect.name` in custom SA types

* Fix removed import
  • Loading branch information
Taragolis authored Aug 18, 2023
1 parent d1e6a5c commit 46aa429
Showing 1 changed file with 15 additions and 24 deletions.
39 changes: 15 additions & 24 deletions airflow/utils/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,22 @@
from sqlalchemy.dialects import mssql, mysql
from sqlalchemy.exc import OperationalError
from sqlalchemy.sql import ColumnElement, Select
from sqlalchemy.sql.expression import ColumnOperators
from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText

from airflow import settings
from airflow.configuration import conf
from airflow.serialization.enums import Encoding
from airflow.utils.timezone import make_naive

if TYPE_CHECKING:
from kubernetes.client.models.v1_pod import V1Pod
from sqlalchemy.orm import Query, Session
from sqlalchemy.sql.expression import ColumnOperators

log = logging.getLogger(__name__)

utc = pendulum.tz.timezone("UTC")

using_mysql = conf.get_mandatory_value("database", "sql_alchemy_conn").lower().startswith("mysql")


class UtcDateTime(TypeDecorator):
"""
Expand All @@ -67,22 +66,18 @@ class UtcDateTime(TypeDecorator):
cache_ok = True

def process_bind_param(self, value, dialect):
if value is not None:
if not isinstance(value, datetime.datetime):
raise TypeError("expected datetime.datetime, not " + repr(value))
elif value.tzinfo is None:
raise ValueError("naive datetime is disallowed")
if not isinstance(value, datetime.datetime):
if value is None:
return None
raise TypeError("expected datetime.datetime, not " + repr(value))
elif value.tzinfo is None:
raise ValueError("naive datetime is disallowed")
elif dialect.name == "mysql":
# For mysql we should store timestamps as naive values
# Timestamp in MYSQL is not timezone aware. In MySQL 5.6
# timezone added at the end is ignored but in MySQL 5.7
# inserting timezone value fails with 'invalid-date'
# In MySQL 5.7 inserting timezone value fails with 'invalid-date'
# See https://issues.apache.org/jira/browse/AIRFLOW-7001
if using_mysql:
from airflow.utils.timezone import make_naive

return make_naive(value, timezone=utc)
return value.astimezone(utc)
return None
return make_naive(value, timezone=utc)
return value.astimezone(utc)

def process_result_value(self, value, dialect):
"""
Expand Down Expand Up @@ -119,12 +114,8 @@ class ExtendedJSON(TypeDecorator):

cache_ok = True

def db_supports_json(self):
"""Check if the database supports JSON (i.e. is NOT MSSQL)."""
return not conf.get("database", "sql_alchemy_conn").startswith("mssql")

def load_dialect_impl(self, dialect) -> TypeEngine:
if self.db_supports_json():
if dialect.name != "mssql":
return dialect.type_descriptor(JSON)
return dialect.type_descriptor(UnicodeText)

Expand All @@ -138,7 +129,7 @@ def process_bind_param(self, value, dialect):
value = BaseSerialization.serialize(value)

# Then, if the database does not have native JSON support, encode it again as a string
if not self.db_supports_json():
if dialect.name == "mssql":
value = json.dumps(value)

return value
Expand All @@ -150,7 +141,7 @@ def process_result_value(self, value, dialect):
return None

# Deserialize from a string first if needed
if not self.db_supports_json():
if dialect.name == "mssql":
value = json.loads(value)

return BaseSerialization.deserialize(value)
Expand Down

0 comments on commit 46aa429

Please sign in to comment.