Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve Amazon Hook's region_name and config in wrapper #25336

Merged
merged 3 commits into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 66 additions & 45 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

from airflow.compat.functools import cached_property
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.models.connection import Connection
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
Expand All @@ -66,19 +66,24 @@ class BaseSessionFactory(LoggingMixin):
"""

def __init__(
self, conn: Union[Connection, AwsConnectionWrapper], region_name: Optional[str], config: Config
self,
conn: Optional[Union[Connection, AwsConnectionWrapper]],
region_name: Optional[str] = None,
config: Optional[Config] = None,
) -> None:
super().__init__()
self._conn = conn
self._region_name = region_name
self.config = config
self._config = config

@cached_property
def conn(self) -> AwsConnectionWrapper:
"""Cached AWS Connection Wrapper."""
if isinstance(self._conn, AwsConnectionWrapper):
return self._conn
return AwsConnectionWrapper(self._conn)
return AwsConnectionWrapper(
conn=self._conn,
region_name=self._region_name,
botocore_config=self._config,
)

@cached_property
def basic_session(self) -> boto3.session.Session:
Expand All @@ -92,21 +97,29 @@ def extra_config(self) -> Dict[str, Any]:

@property
def region_name(self) -> Optional[str]:
"""Resolve region name.
"""AWS Region Name read-only property."""
return self.conn.region_name

1. SessionFactory region_name
2. Connection region_name
"""
return self._region_name or self.conn.region_name
@property
def config(self) -> Optional[Config]:
"""Configuration for botocore client read-only property."""
return self.conn.botocore_config

@property
def role_arn(self) -> Optional[str]:
"""Assume Role ARN from AWS Connection"""
return self.conn.role_arn

def create_session(self) -> boto3.session.Session:
"""Create AWS session."""
if not self.role_arn:
"""Create boto3 Session from connection config."""
if not self.conn:
self.log.info(
"No connection ID provided. Fallback on boto3 credential strategy (region_name=%r). "
"See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html",
self.region_name,
)
return boto3.session.Session(region_name=self.region_name)
elif not self.role_arn:
return self.basic_session
return self._create_session_with_assume_role(session_kwargs=self.conn.session_kwargs)

Expand Down Expand Up @@ -381,45 +394,50 @@ def __init__(
self.verify = verify
self.client_type = client_type
self.resource_type = resource_type
self.region_name = region_name
self.config = config

def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:

if not self.aws_conn_id:
session = boto3.session.Session(region_name=region_name)
return session, None
self._region_name = region_name
self._config = config

self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)
@cached_property
def conn_config(self) -> AwsConnectionWrapper:
"""Get the Airflow Connection object and wrap it in helper (cached)."""
connection = None
if self.aws_conn_id:
try:
connection = self.get_connection(self.aws_conn_id)
except AirflowNotFoundException:
warnings.warn(
f"Unable to find AWS Connection ID '{self.aws_conn_id}', switching to empty. "
"This behaviour is deprecated and will be removed in a future releases. "
"Please provide existed AWS connection ID or if required boto3 credential strategy "
"explicit set AWS Connection ID to None.",
DeprecationWarning,
stacklevel=2,
)

try:
# Fetch the Airflow connection object and wrap it in helper
connection_object = AwsConnectionWrapper(self.get_connection(self.aws_conn_id))
return AwsConnectionWrapper(
conn=connection or Connection(conn_id=None, conn_type="aws"),
region_name=self._region_name,
botocore_config=self._config,
)

if connection_object.botocore_config:
# For historical reason botocore.config.Config from connection overwrites
# config which explicitly set in Hook.
self.config = connection_object.botocore_config
@property
def region_name(self) -> Optional[str]:
"""AWS Region Name read-only property."""
return self.conn_config.region_name

session = SessionFactory(
conn=connection_object, region_name=region_name, config=self.config
).create_session()
@property
def config(self) -> Optional[Config]:
"""Configuration for botocore client read-only property."""
return self.conn_config.botocore_config

return session, connection_object.endpoint_url
def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:
self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)

except AirflowException:
self.log.warning(
"Unable to use Airflow Connection for credentials. "
"Fallback on boto3 credential strategy. See: "
"https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html"
)
session = SessionFactory(
conn=self.conn_config, region_name=region_name, config=self.config
).create_session()

self.log.debug(
"Creating session using boto3 credential strategy region_name=%s",
region_name,
)
session = boto3.session.Session(region_name=region_name)
return session, None
return session, self.conn_config.endpoint_url

def get_client_type(
self,
Expand Down Expand Up @@ -491,17 +509,20 @@ def conn(self) -> Union[boto3.client, boto3.resource]:

@cached_property
def conn_client_meta(self) -> ClientMeta:
"""Get botocore client metadata from Hook connection (cached)."""
conn = self.conn
if isinstance(conn, botocore.client.BaseClient):
return conn.meta
return conn.meta.client.meta

@property
def conn_region_name(self) -> str:
"""Get actual AWS Region Name from Hook connection (cached)."""
return self.conn_client_meta.region_name

@property
def conn_partition(self) -> str:
"""Get associated AWS Region Partition from Hook connection (cached)."""
return self.conn_client_meta.partition

def get_conn(self) -> BaseAwsConnection:
Expand Down
64 changes: 53 additions & 11 deletions airflow/providers/amazon/aws/utils/connection_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import warnings
from copy import deepcopy
from dataclasses import InitVar, dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from dataclasses import MISSING, InitVar, dataclass, field, fields
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

from botocore.config import Config

Expand All @@ -35,25 +35,43 @@ class AwsConnectionWrapper(LoggingMixin):
"""
AWS Connection Wrapper class helper.
Use for validate and resolve AWS Connection parameters.

``conn`` reference to Airflow Connection object or AwsConnectionWrapper
if it set to ``None`` than default values would use.

The precedence rules for ``region_name``
1. Explicit set (in Hook) ``region_name``.
2. Airflow Connection Extra 'region_name'.

The precedence rules for ``botocore_config``
1. Explicit set (in Hook) ``botocore_config``.
2. Construct from Airflow Connection Extra 'botocore_kwargs'.
3. The wrapper's default value
"""

conn: InitVar[Optional["Connection"]]
conn: InitVar[Optional[Union["Connection", "AwsConnectionWrapper"]]]
region_name: Optional[str] = field(default=None)
botocore_config: Optional[Config] = field(default=None)

# Reference to Airflow Connection attributes
# ``extra_config`` contains original Airflow Connection Extra.
conn_id: Optional[str] = field(init=False, default=None)
conn_type: Optional[str] = field(init=False, default=None)
login: Optional[str] = field(init=False, repr=False, default=None)
password: Optional[str] = field(init=False, repr=False, default=None)
extra_config: Dict[str, Any] = field(init=False, repr=False, default_factory=dict)

aws_access_key_id: Optional[str] = field(init=False)
aws_secret_access_key: Optional[str] = field(init=False)
aws_session_token: Optional[str] = field(init=False)
# AWS Credentials from connection.
aws_access_key_id: Optional[str] = field(init=False, default=None)
aws_secret_access_key: Optional[str] = field(init=False, default=None)
aws_session_token: Optional[str] = field(init=False, default=None)

region_name: Optional[str] = field(init=False, default=None)
# Additional boto3.session.Session keyword arguments.
session_kwargs: Dict[str, Any] = field(init=False, default_factory=dict)
botocore_config: Optional[Config] = field(init=False, default=None)
# Custom endpoint_url for boto3.client and boto3.resource
endpoint_url: Optional[str] = field(init=False, default=None)

# Assume Role Configurations
role_arn: Optional[str] = field(init=False, default=None)
assume_role_method: Optional[str] = field(init=False, default=None)
assume_role_kwargs: Dict[str, Any] = field(init=False, default_factory=dict)
Expand All @@ -63,7 +81,30 @@ def conn_repr(self):
return f"AWS Connection (conn_id={self.conn_id!r}, conn_type={self.conn_type!r})"

def __post_init__(self, conn: "Connection"):
if not conn:
if isinstance(conn, type(self)):
# For every field with init=False we copy reference value from original wrapper
# For every field with init=True we use init values if it not equal default
# We can't use ``dataclasses.replace`` in classmethod because
# we limited by InitVar arguments since it not stored in object,
# and also we do not want to run __post_init__ method again which print all logs/warnings again.
for fl in fields(conn):
value = getattr(conn, fl.name)
if not fl.init:
setattr(self, fl.name, value)
else:
if fl.default is not MISSING:
default = fl.default
elif fl.default_factory is not MISSING:
default = fl.default_factory() # zero-argument callable
else:
continue # Value mandatory, skip

orig_value = getattr(self, fl.name)
if orig_value == default:
# Only replace value if it not equal default value
setattr(self, fl.name, value)
return
elif not conn:
return

extra = deepcopy(conn.extra_dejson)
Expand All @@ -86,7 +127,7 @@ def __post_init__(self, conn: "Connection"):
init_credentials = self._get_credentials(**extra)
self.aws_access_key_id, self.aws_secret_access_key, self.aws_session_token = init_credentials

if "region_name" in extra:
if not self.region_name and "region_name" in extra:
self.region_name = extra["region_name"]
self.log.info("Retrieving region_name=%s from %s extra.", self.region_name, self.conn_repr)

Expand All @@ -106,7 +147,7 @@ def __post_init__(self, conn: "Connection"):
)

config_kwargs = extra.get("config_kwargs")
if config_kwargs:
if not self.botocore_config and config_kwargs:
# https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
self.log.info("Retrieving botocore config=%s from %s extra.", config_kwargs, self.conn_repr)
self.botocore_config = Config(**config_kwargs)
Expand All @@ -119,6 +160,7 @@ def __post_init__(self, conn: "Connection"):

@property
def extra_dejson(self):
"""Compatibility with `airflow.models.Connection.extra_dejson` property."""
return self.extra_config

def __bool__(self):
Expand Down
Loading