diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py index 419b571c9f92b..97f6ec566eea2 100644 --- a/airflow/providers/presto/hooks/presto.py +++ b/airflow/providers/presto/hooks/presto.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import json import os import warnings from typing import Any, Callable, Iterable, Optional, overload @@ -27,6 +28,34 @@ from airflow.configuration import conf from airflow.hooks.dbapi import DbApiHook from airflow.models import Connection +from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING + +try: + from airflow.utils.operator_helpers import DEFAULT_FORMAT_PREFIX +except ImportError: + # This is from airflow.utils.operator_helpers, + # For the sake of provider backward compatibility, this is hardcoded if import fails + # https://github.com/apache/airflow/pull/22416#issuecomment-1075531290 + DEFAULT_FORMAT_PREFIX = 'airflow.ctx.' + + +def generate_presto_client_info() -> str: + """Return json string with dag_id, task_id, execution_date and try_number""" + context_var = { + format_map['default'].replace(DEFAULT_FORMAT_PREFIX, ''): os.environ.get( + format_map['env_var_format'], '' + ) + for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values() + } + task_info = { + 'dag_id': context_var['dag_id'], + 'task_id': context_var['task_id'], + 'execution_date': context_var['execution_date'], + 'try_number': context_var['try_number'], + 'dag_run_id': context_var['dag_run_id'], + 'dag_owner': context_var['dag_owner'], + } + return json.dumps(task_info, sort_keys=True) class PrestoException(Exception): @@ -83,11 +112,13 @@ def get_conn(self) -> Connection: ca_bundle=extra.get('kerberos__ca_bundle'), ) + http_headers = {"X-Presto-Client-Info": generate_presto_client_info()} presto_conn = prestodb.dbapi.connect( host=db.host, port=db.port, user=db.login, source=db.extra_dejson.get('source', 'airflow'), + http_headers=http_headers, http_scheme=db.extra_dejson.get('protocol', 'http'), catalog=db.extra_dejson.get('catalog', 'hive'), schema=db.schema, diff --git a/tests/providers/presto/hooks/test_presto.py b/tests/providers/presto/hooks/test_presto.py index 08278dd3b3ca3..e6fd7c5ed01b1 100644 --- a/tests/providers/presto/hooks/test_presto.py +++ b/tests/providers/presto/hooks/test_presto.py @@ -28,7 +28,31 @@ from airflow import AirflowException from airflow.models import Connection -from airflow.providers.presto.hooks.presto import PrestoHook +from airflow.providers.presto.hooks.presto import PrestoHook, generate_presto_client_info + + +def test_generate_airflow_presto_client_info_header(): + env_vars = { + 'AIRFLOW_CTX_DAG_ID': 'dag_id', + 'AIRFLOW_CTX_EXECUTION_DATE': '2022-01-01T00:00:00', + 'AIRFLOW_CTX_TASK_ID': 'task_id', + 'AIRFLOW_CTX_TRY_NUMBER': '1', + 'AIRFLOW_CTX_DAG_RUN_ID': 'dag_run_id', + 'AIRFLOW_CTX_DAG_OWNER': 'dag_owner', + } + expected = json.dumps( + { + "dag_id": "dag_id", + "execution_date": "2022-01-01T00:00:00", + "task_id": "task_id", + "try_number": "1", + "dag_run_id": "dag_run_id", + "dag_owner": "dag_owner", + }, + sort_keys=True, + ) + with patch.dict('os.environ', env_vars): + assert generate_presto_client_info() == expected class TestPrestoHookConn(unittest.TestCase): @@ -45,6 +69,7 @@ def test_get_conn_basic_auth(self, mock_get_connection, mock_connect, mock_basic catalog='hive', host='host', port=None, + http_headers=mock.ANY, http_scheme='http', schema='hive', source='airflow', @@ -98,6 +123,7 @@ def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect, mock_au catalog='hive', host='host', port=None, + http_headers=mock.ANY, http_scheme='http', schema='hive', source='airflow', @@ -118,6 +144,51 @@ def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect, mock_au ) assert mock_connect.return_value == conn + @patch('airflow.providers.presto.hooks.presto.generate_presto_client_info') + @patch('airflow.providers.presto.hooks.presto.prestodb.auth.BasicAuthentication') + @patch('airflow.providers.presto.hooks.presto.prestodb.dbapi.connect') + @patch('airflow.providers.presto.hooks.presto.PrestoHook.get_connection') + def test_http_headers( + self, + mock_get_connection, + mock_connect, + mock_basic_auth, + mocked_generate_airflow_presto_client_info_header, + ): + mock_get_connection.return_value = Connection( + login='login', password='password', host='host', schema='hive' + ) + client = json.dumps( + { + "dag_id": "dag-id", + "execution_date": "2022-01-01T00:00:00", + "task_id": "task-id", + "try_number": "1", + "dag_run_id": "dag-run-id", + "dag_owner": "dag-owner", + }, + sort_keys=True, + ) + http_headers = {'X-Presto-Client-Info': client} + + mocked_generate_airflow_presto_client_info_header.return_value = http_headers['X-Presto-Client-Info'] + + conn = PrestoHook().get_conn() + mock_connect.assert_called_once_with( + catalog='hive', + host='host', + port=None, + http_headers=http_headers, + http_scheme='http', + schema='hive', + source='airflow', + user='login', + isolation_level=0, + auth=mock_basic_auth.return_value, + ) + mock_basic_auth.assert_called_once_with('login', 'password') + assert mock_connect.return_value == conn + @parameterized.expand( [ ('False', False),