Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass X-Presto-Client-Info in presto hook #22416

Merged
merged 1 commit into from
Mar 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions airflow/providers/presto/hooks/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import os
import warnings
from typing import Any, Callable, Iterable, Optional, overload
Expand All @@ -27,6 +28,34 @@
from airflow.configuration import conf
from airflow.hooks.dbapi import DbApiHook
from airflow.models import Connection
from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING

try:
from airflow.utils.operator_helpers import DEFAULT_FORMAT_PREFIX
except ImportError:
# This is from airflow.utils.operator_helpers,
# For the sake of provider backward compatibility, this is hardcoded if import fails
# https://github.com/apache/airflow/pull/22416#issuecomment-1075531290
DEFAULT_FORMAT_PREFIX = 'airflow.ctx.'


def generate_presto_client_info() -> str:
"""Return json string with dag_id, task_id, execution_date and try_number"""
context_var = {
format_map['default'].replace(DEFAULT_FORMAT_PREFIX, ''): os.environ.get(
format_map['env_var_format'], ''
)
for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()
}
task_info = {
'dag_id': context_var['dag_id'],
pingzh marked this conversation as resolved.
Show resolved Hide resolved
'task_id': context_var['task_id'],
'execution_date': context_var['execution_date'],
pingzh marked this conversation as resolved.
Show resolved Hide resolved
'try_number': context_var['try_number'],
'dag_run_id': context_var['dag_run_id'],
'dag_owner': context_var['dag_owner'],
}
return json.dumps(task_info, sort_keys=True)


class PrestoException(Exception):
Expand Down Expand Up @@ -83,11 +112,13 @@ def get_conn(self) -> Connection:
ca_bundle=extra.get('kerberos__ca_bundle'),
)

http_headers = {"X-Presto-Client-Info": generate_presto_client_info()}
presto_conn = prestodb.dbapi.connect(
host=db.host,
port=db.port,
user=db.login,
source=db.extra_dejson.get('source', 'airflow'),
http_headers=http_headers,
http_scheme=db.extra_dejson.get('protocol', 'http'),
catalog=db.extra_dejson.get('catalog', 'hive'),
schema=db.schema,
Expand Down
73 changes: 72 additions & 1 deletion tests/providers/presto/hooks/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,31 @@

from airflow import AirflowException
from airflow.models import Connection
from airflow.providers.presto.hooks.presto import PrestoHook
from airflow.providers.presto.hooks.presto import PrestoHook, generate_presto_client_info


def test_generate_airflow_presto_client_info_header():
env_vars = {
'AIRFLOW_CTX_DAG_ID': 'dag_id',
'AIRFLOW_CTX_EXECUTION_DATE': '2022-01-01T00:00:00',
'AIRFLOW_CTX_TASK_ID': 'task_id',
'AIRFLOW_CTX_TRY_NUMBER': '1',
'AIRFLOW_CTX_DAG_RUN_ID': 'dag_run_id',
'AIRFLOW_CTX_DAG_OWNER': 'dag_owner',
}
expected = json.dumps(
{
"dag_id": "dag_id",
"execution_date": "2022-01-01T00:00:00",
"task_id": "task_id",
"try_number": "1",
"dag_run_id": "dag_run_id",
"dag_owner": "dag_owner",
},
sort_keys=True,
)
with patch.dict('os.environ', env_vars):
assert generate_presto_client_info() == expected


class TestPrestoHookConn(unittest.TestCase):
Expand All @@ -45,6 +69,7 @@ def test_get_conn_basic_auth(self, mock_get_connection, mock_connect, mock_basic
catalog='hive',
host='host',
port=None,
http_headers=mock.ANY,
http_scheme='http',
schema='hive',
source='airflow',
Expand Down Expand Up @@ -98,6 +123,7 @@ def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect, mock_au
catalog='hive',
host='host',
port=None,
http_headers=mock.ANY,
http_scheme='http',
schema='hive',
source='airflow',
Expand All @@ -118,6 +144,51 @@ def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect, mock_au
)
assert mock_connect.return_value == conn

@patch('airflow.providers.presto.hooks.presto.generate_presto_client_info')
@patch('airflow.providers.presto.hooks.presto.prestodb.auth.BasicAuthentication')
@patch('airflow.providers.presto.hooks.presto.prestodb.dbapi.connect')
@patch('airflow.providers.presto.hooks.presto.PrestoHook.get_connection')
def test_http_headers(
self,
mock_get_connection,
mock_connect,
mock_basic_auth,
mocked_generate_airflow_presto_client_info_header,
):
mock_get_connection.return_value = Connection(
login='login', password='password', host='host', schema='hive'
)
client = json.dumps(
{
"dag_id": "dag-id",
"execution_date": "2022-01-01T00:00:00",
"task_id": "task-id",
"try_number": "1",
"dag_run_id": "dag-run-id",
"dag_owner": "dag-owner",
},
sort_keys=True,
)
http_headers = {'X-Presto-Client-Info': client}

mocked_generate_airflow_presto_client_info_header.return_value = http_headers['X-Presto-Client-Info']

conn = PrestoHook().get_conn()
mock_connect.assert_called_once_with(
catalog='hive',
host='host',
port=None,
http_headers=http_headers,
http_scheme='http',
schema='hive',
source='airflow',
user='login',
isolation_level=0,
auth=mock_basic_auth.return_value,
)
mock_basic_auth.assert_called_once_with('login', 'password')
assert mock_connect.return_value == conn

@parameterized.expand(
[
('False', False),
Expand Down