diff --git a/airflow/operators/bash.py b/airflow/operators/bash.py index b07ec504aaea6..6c2714e8c3e8e 100644 --- a/airflow/operators/bash.py +++ b/airflow/operators/bash.py @@ -31,7 +31,7 @@ class BashOperator(BaseOperator): - r""" + """ Execute a Bash script, command or set of commands. .. seealso:: diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index d61bcc2fdac1f..ef6193db14c28 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -29,7 +29,12 @@ class SnowflakeHook(DbApiHook): """ - A client to interact with Snowflake + A client to interact with Snowflake. + + This hook requires the snowflake_conn_id connection. The snowflake host, login, + and, password field must be setup in the connection. Other inputs can be defined + in the connection or hook instantiation. If used with the S3ToSnowflakeOperator + add 'aws_access_key_id' and 'aws_secret_access_key' to extra field in the connection. :param account: snowflake account name :type account: Optional[str] @@ -72,9 +77,9 @@ class SnowflakeHook(DbApiHook): @staticmethod def get_connection_form_widgets() -> Dict[str, Any]: """Returns connection widgets to add to connection form""" - from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget from flask_babel import lazy_gettext - from wtforms import StringField + from wtforms import PasswordField, StringField return { "extra__snowflake__account": StringField(lazy_gettext('Account'), widget=BS3TextFieldWidget()), @@ -83,6 +88,12 @@ def get_connection_form_widgets() -> Dict[str, Any]: ), "extra__snowflake__database": StringField(lazy_gettext('Database'), widget=BS3TextFieldWidget()), "extra__snowflake__region": StringField(lazy_gettext('Region'), widget=BS3TextFieldWidget()), + "extra__snowflake__aws_access_key_id": StringField( + lazy_gettext('AWS Access Key'), widget=BS3TextFieldWidget() + ), + "extra__snowflake__aws_secret_access_key": PasswordField( + lazy_gettext('AWS Secret Key'), widget=BS3PasswordFieldWidget() + ), } @staticmethod @@ -100,8 +111,6 @@ def get_ui_field_behaviour() -> Dict: "authenticator": "snowflake oauth", "private_key_file": "private key", "session_parameters": "session parameters", - "aws_access_key_id": "aws access key", - "aws_secret_access_key": "aws secret key", }, indent=1, ), @@ -113,6 +122,8 @@ def get_ui_field_behaviour() -> Dict: 'extra__snowflake__warehouse': 'snowflake warehouse name', 'extra__snowflake__database': 'snowflake db name', 'extra__snowflake__region': 'snowflake hosted region', + 'extra__snowflake__aws_access_key_id': 'aws access key id (S3ToSnowflakeOperator)', + 'extra__snowflake__aws_secret_access_key': 'aws secret access key (S3ToSnowflakeOperator)', }, } @@ -135,10 +146,16 @@ def _get_conn_params(self) -> Dict[str, Optional[str]]: conn = self.get_connection( self.snowflake_conn_id # type: ignore[attr-defined] # pylint: disable=no-member ) - account = conn.extra_dejson.get('extra__snowflake__account', '') - warehouse = conn.extra_dejson.get('extra__snowflake__warehouse', '') - database = conn.extra_dejson.get('extra__snowflake__database', '') - region = conn.extra_dejson.get('extra__snowflake__region', '') + account = conn.extra_dejson.get('extra__snowflake__account', '') or conn.extra_dejson.get( + 'account', '' + ) + warehouse = conn.extra_dejson.get('extra__snowflake__warehouse', '') or conn.extra_dejson.get( + 'warehouse', '' + ) + database = conn.extra_dejson.get('extra__snowflake__database', '') or conn.extra_dejson.get( + 'database', '' + ) + region = conn.extra_dejson.get('extra__snowflake__region', '') or conn.extra_dejson.get('region', '') role = conn.extra_dejson.get('role', '') schema = conn.schema or '' authenticator = conn.extra_dejson.get('authenticator', 'snowflake') @@ -211,8 +228,12 @@ def _get_aws_credentials(self) -> Tuple[Optional[Any], Optional[Any]]: self.snowflake_conn_id # type: ignore[attr-defined] # pylint: disable=no-member ) if 'aws_secret_access_key' in connection_object.extra_dejson: - aws_access_key_id = connection_object.extra_dejson.get('aws_access_key_id') - aws_secret_access_key = connection_object.extra_dejson.get('aws_secret_access_key') + aws_access_key_id = connection_object.extra_dejson.get( + 'aws_access_key_id' + ) or connection_object.extra_dejson.get('aws_access_key_id') + aws_secret_access_key = connection_object.extra_dejson.get( + 'aws_secret_access_key' + ) or connection_object.extra_dejson.get('aws_secret_access_key') return aws_access_key_id, aws_secret_access_key def set_autocommit(self, conn, autocommit: Any) -> None: diff --git a/scripts/in_container/run_install_and_test_provider_packages.sh b/scripts/in_container/run_install_and_test_provider_packages.sh index e3eff202a1053..f8a14bc44fddd 100755 --- a/scripts/in_container/run_install_and_test_provider_packages.sh +++ b/scripts/in_container/run_install_and_test_provider_packages.sh @@ -157,7 +157,7 @@ function discover_all_connection_form_widgets() { COLUMNS=180 airflow providers widgets - local expected_number_of_widgets=19 + local expected_number_of_widgets=25 local actual_number_of_widgets actual_number_of_widgets=$(airflow providers widgets --output table | grep -c ^extra) if [[ ${actual_number_of_widgets} != "${expected_number_of_widgets}" ]]; then @@ -176,7 +176,7 @@ function discover_all_field_behaviours() { group_start "Listing connections with custom behaviours via 'airflow providers behaviours'" COLUMNS=180 airflow providers behaviours - local expected_number_of_connections_with_behaviours=11 + local expected_number_of_connections_with_behaviours=12 local actual_number_of_connections_with_behaviours actual_number_of_connections_with_behaviours=$(airflow providers behaviours --output table | grep -v "===" | \ grep -v field_behaviours | grep -cv "^ " | xargs) diff --git a/tests/core/test_providers_manager.py b/tests/core/test_providers_manager.py index e02fbef1a8cb4..20e00b1e356dc 100644 --- a/tests/core/test_providers_manager.py +++ b/tests/core/test_providers_manager.py @@ -78,8 +78,7 @@ 'apache-airflow-providers-sftp', 'apache-airflow-providers-singularity', 'apache-airflow-providers-slack', - # Uncomment when https://github.com/apache/airflow/issues/12881 is fixed - # 'apache-airflow-providers-snowflake', + 'apache-airflow-providers-snowflake', 'apache-airflow-providers-sqlite', 'apache-airflow-providers-ssh', 'apache-airflow-providers-tableau', @@ -139,8 +138,7 @@ 'samba', 'segment', 'sftp', - # Uncomment when https://github.com/apache/airflow/issues/12881 is fixed - # 'snowflake', + 'snowflake', 'spark', 'spark_jdbc', 'spark_sql', @@ -174,6 +172,13 @@ 'extra__yandexcloud__public_ssh_key', 'extra__yandexcloud__service_account_json', 'extra__yandexcloud__service_account_json_path', + 'extra__snowflake__account', + 'extra__snowflake__warehouse', + 'extra__snowflake__database', + 'extra__snowflake__region', + 'extra__snowflake__aws_access_key_id', + 'extra__snowflake__aws_secret_access_key', + ] CONNECTIONS_WITH_FIELD_BEHAVIOURS = [ @@ -188,6 +193,7 @@ 'spark', 'ssh', 'yandexcloud', + 'snowflake', ] EXTRA_LINKS = [ diff --git a/tests/providers/snowflake/hooks/test_snowflake.py b/tests/providers/snowflake/hooks/test_snowflake.py index d8beac4528e20..2ecf73223027f 100644 --- a/tests/providers/snowflake/hooks/test_snowflake.py +++ b/tests/providers/snowflake/hooks/test_snowflake.py @@ -39,10 +39,10 @@ def setUp(self): self.conn.password = 'pw' self.conn.schema = 'public' self.conn.extra_dejson = { - 'extra__snowflake__database': 'db', - 'extra__snowflake__account': 'airflow', - 'extra__snowflake__warehouse': 'af_wh', - 'extra__snowflake__region': 'af_region', + 'database': 'db', + 'account': 'airflow', + 'warehouse': 'af_wh', + 'region': 'af_region', 'role': 'af_role', }