diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e080290..6f6989a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,23 +1,23 @@ repos: - repo: https://github.com/psf/black - rev: 22.6.0 + rev: 24.8.0 hooks: - id: black args: ["--target-version=py38", "--line-length=88"] - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.13.2 hooks: - id: isort args: ["--profile=black"] - - repo: https://gitlab.com/pycqa/flake8 - rev: 3.9.2 + - repo: https://github.com/pycqa/flake8 + rev: 7.1.1 hooks: - id: flake8 - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.1.0 + rev: v4.6.0 hooks: - id: check-merge-conflict - id: check-toml @@ -28,14 +28,14 @@ repos: - id: trailing-whitespace - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.931 + rev: v1.11.2 hooks: - id: mypy exclude: ^tests/ additional_dependencies: [ types-requests ] - repo: https://github.com/codespell-project/codespell - rev: v2.1.0 + rev: v2.3.0 hooks: - id: codespell name: Run codespell to check for common misspellings in files diff --git a/CHANGELOG.md b/CHANGELOG.md index 53db658..cbf59bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,17 @@ - ### Fixed -- + +## [0.1.10] - 2024-09-09 + +### Added +- Enhanced retry mechanism for polling project status +- New `max_poll_retries` and `poll_retry_delay` parameters for `HexRunProjectOperator` +- New `run_status_with_retries` method in `HexHook` +- New `poll_project_status` method in `HexHook` with improved error handling + +### Changed +- Improved error handling for API calls and status checks ## [0.1.9] - 2023-05-16 diff --git a/Makefile b/Makefile index c6d986e..a739488 100644 --- a/Makefile +++ b/Makefile @@ -19,4 +19,4 @@ clean: docker-compose -f dev/docker-compose.yaml down --volumes --remove-orphans init: - docker-compose up -f dev/docker-compose.yaml airflow-init + docker-compose -f dev/docker-compose.yaml up airflow-init diff --git a/README.md b/README.md index 514fd0d..ee2d39b 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,36 @@ -# Hex Airflow Provider +# Airflow Provider for Hex -Provides an Airflow Operator and Hook to trigger Hex project runs. +[![PyPI version](https://badge.fury.io/py/airflow-provider-hex.svg)](https://badge.fury.io/py/airflow-provider-hex) -This [Airflow Provider Package](https://airflow.apache.org/docs/apache-airflow-providers/) -provides Hooks and Operators for interacting with the Hex API. +This [Airflow Provider Package](https://airflow.apache.org/docs/apache-airflow-providers/) provides Hooks and Operators for interacting with the Hex API, allowing you to trigger and manage Hex project runs in your Apache Airflow DAGs. + +## Table of Contents +- [Requirements](#requirements) +- [Installation](#installation) +- [Initial Setup](#initial-setup) +- [Operators](#operators) +- [Hooks](#hooks) +- [Examples](#examples) +- [Development](#development) +- [Changelog](#changelog) ## Requirements -* Airflow >=2.2 +* Apache Airflow >= 2.2.0 +* Python >= 3.7 * Hex API Token -## Initial Setup +## Installation -Install the package. +Install the package using pip: -``` +```bash pip install airflow-provider-hex ``` -After creating a Hex API token, set up your Airflow Connection Credentials in the Airflow -UI. +## Initial Setup + +After creating a Hex API token, set up your Airflow Connection Credentials in the Airflow UI: ![Connection Setup](https://raw.githubusercontent.com/hex-inc/airflow-provider-hex/main/docs/hex-connection-setup.png) @@ -30,41 +41,38 @@ UI. ## Operators -The [`airflow_provider_hex.operators.hex.HexRunProjectOperator`](/airflow_provider_hex/operators/hex.py) -Operator runs Hex Projects, either synchronously or asynchronously. - -In the synchronous mode, the Operator will start a Hex Project run and then -poll the run until either an error or success status is returned, or until -the poll timeout. If the timeout occurs, the default behaviour is to attempt to -cancel the run. +The [`HexRunProjectOperator`](/airflow_provider_hex/operators/hex.py) runs Hex Projects either synchronously or asynchronously. -In the asynchronous mode, the Operator will request that a Hex Project is run, -but will not poll for completion. This can be useful for long-running projects. +- In synchronous mode, the Operator starts a Hex Project run and polls until completion or timeout. +- In asynchronous mode, the Operator requests a Hex Project run without waiting for completion. -The operator accepts inputs in the form of a dictionary. These can be used to -override existing input elements in your Hex project. +The operator accepts inputs as a dictionary to override existing input elements in your Hex project. You can also include optional notifications for a run. -You may also optionally include notifications for a particular run. See -the [Hex API documentation](https://learn.hex.tech/docs/develop-logic/hex-api/api-reference#operation/RunProject) for details. +For more details, see the [Hex API documentation](https://learn.hex.tech/docs/develop-logic/hex-api/api-reference#operation/RunProject). ## Hooks -The [`airflow_provider_hex.hooks.hex.HexHook`](/airflow_provider_hex/hooks/hex.py) -provides a low-level interface to the Hex API. - -These can be useful for testing and development, as they provide both a generic -`run` method which sends an authenticated request to the Hex API, as well as -implementations of the `run` method that provide access to specific endpoints. - +The [`HexHook`](/airflow_provider_hex/hooks/hex.py) provides a low-level interface to the Hex API. It's useful for testing and development, offering both a generic `run` method for authenticated requests and specific endpoint implementations. ## Examples -A simplified example DAG demonstrates how to use the [Airflow Operator](/example_dags/example_hex.py) +Here's a simplified example DAG demonstrating how to use the HexRunProjectOperator: ```python +from airflow import DAG +from airflow.utils.dates import days_ago from airflow_provider_hex.operators.hex import HexRunProjectOperator +from airflow_provider_hex.types import NotificationDetails PROJ_ID = 'abcdef-ghijkl-mnopq' + +default_args = { + 'owner': 'airflow', + 'start_date': days_ago(1), +} + +dag = DAG('hex_example', default_args=default_args, schedule_interval=None) + notifications: list[NotificationDetails] = [ { "type": "SUCCESS", @@ -74,7 +82,7 @@ notifications: list[NotificationDetails] = [ "groupIds": [], } ] -... + sync_run = HexRunProjectOperator( task_id="run", hex_conn_id="hex_default", @@ -83,3 +91,29 @@ sync_run = HexRunProjectOperator( notifications=notifications ) ``` + +For more examples, check the [example_dags](/example_dags) directory. + +## Development + +To set up the development environment: + +1. Clone the repository +2. Install development dependencies: `pip install -e .[dev]` +3. Install pre-commit hooks: `pre-commit install` + +To run tests: + +```bash +make tests +``` + +To run linters: + +```bash +make lint +``` + +## Changelog + +See the [CHANGELOG.md](CHANGELOG.md) file for details on all changes and past releases. diff --git a/VERSION.txt b/VERSION.txt index 1a03094..9767cc9 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -0.1.9 +0.1.10 diff --git a/airflow_provider_hex/__init__.py b/airflow_provider_hex/__init__.py index c12a94e..c68e1ff 100644 --- a/airflow_provider_hex/__init__.py +++ b/airflow_provider_hex/__init__.py @@ -1,4 +1,5 @@ """Version information for the package.""" + import os import sys diff --git a/airflow_provider_hex/hooks/hex.py b/airflow_provider_hex/hooks/hex.py index 6ecaf87..8ca3008 100644 --- a/airflow_provider_hex/hooks/hex.py +++ b/airflow_provider_hex/hooks/hex.py @@ -7,6 +7,8 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from importlib_metadata import PackageNotFoundError, version +from requests.exceptions import RequestException +from tenacity import retry, stop_after_attempt, wait_fixed from airflow_provider_hex.types import NotificationDetails, RunResponse, StatusResponse @@ -151,52 +153,74 @@ def run_project( ), ) - def run_status(self, project_id, run_id) -> StatusResponse: + @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) + def run_status(self, project_id: str, run_id: str) -> StatusResponse: endpoint = f"api/v1/project/{project_id}/run/{run_id}" method = "GET" + try: + response = self.run(method=method, endpoint=endpoint, data=None) + return cast(StatusResponse, response) + except RequestException as e: + self.log.error(f"API call failed: {str(e)}") + raise - return cast( - StatusResponse, self.run(method=method, endpoint=endpoint, data=None) - ) - - def cancel_run(self, project_id, run_id) -> str: + def cancel_run(self, project_id: str, run_id: str) -> str: endpoint = f"api/v1/project/{project_id}/run/{run_id}" method = "DELETE" self.run(method=method, endpoint=endpoint) return run_id - def run_and_poll( + def run_status_with_retries( + self, project_id: str, run_id: str, max_retries: int = 3, retry_delay: int = 1 + ) -> StatusResponse: + @retry(stop=stop_after_attempt(max_retries), wait=wait_fixed(retry_delay)) + def _run_status(): + return self.run_status(project_id, run_id) + + return _run_status() + + def poll_project_status( self, project_id: str, - inputs: Optional[dict], - update_cache: bool = False, + run_id: str, poll_interval: int = 3, poll_timeout: int = 600, kill_on_timeout: bool = True, - notifications: List[NotificationDetails] = [], - ): - run_response = self.run_project(project_id, inputs, update_cache, notifications) - run_id = run_response["runId"] - + max_poll_retries: int = 3, + poll_retry_delay: int = 5, + ) -> StatusResponse: poll_start = datetime.datetime.now() while True: - run_status = self.run_status(project_id, run_id) + try: + run_status = self.run_status_with_retries( + project_id, run_id, max_poll_retries, poll_retry_delay + ) + except Exception as e: + self.log.error( + f"Failed to get run status after {max_poll_retries} " + f"attempts: {str(e)}" + ) + if kill_on_timeout: + self.cancel_run(project_id, run_id) + raise AirflowException( + "Failed to get run status for project " + f"{project_id} with run: {run_id}" + ) + project_status = run_status["status"] self.log.info( f"Polling Hex Project {project_id}. Status: {project_status}." ) - if project_status not in VALID_STATUSES: - raise AirflowException(f"Unhandled status: {project_status}") if project_status == COMPLETE: - break + return run_status if project_status in TERMINAL_STATUSES: raise AirflowException( f"Project Run failed with status {project_status}. " - f"See Run URL for more info {run_response['runUrl']}" + f"See Run URL for more info {run_status['runUrl']}" ) if ( @@ -217,4 +241,28 @@ def run_and_poll( ) time.sleep(poll_interval) - return run_status + + def run_and_poll( + self, + project_id: str, + inputs: Optional[dict], + update_cache: bool = False, + poll_interval: int = 3, + poll_timeout: int = 600, + kill_on_timeout: bool = True, + notifications: List[NotificationDetails] = [], + max_poll_retries: int = 3, + poll_retry_delay: int = 5, + ): + run_response = self.run_project(project_id, inputs, update_cache, notifications) + run_id = run_response["runId"] + + return self.poll_project_status( + project_id, + run_id, + poll_interval, + poll_timeout, + kill_on_timeout, + max_poll_retries, + poll_retry_delay, + ) diff --git a/airflow_provider_hex/operators/hex.py b/airflow_provider_hex/operators/hex.py index c0e439e..e14cca5 100644 --- a/airflow_provider_hex/operators/hex.py +++ b/airflow_provider_hex/operators/hex.py @@ -2,7 +2,6 @@ from airflow.models import BaseOperator from airflow.models.dag import Context -from airflow.utils.decorators import apply_defaults from airflow_provider_hex.hooks.hex import HexHook from airflow_provider_hex.types import NotificationDetails @@ -41,7 +40,6 @@ class HexRunProjectOperator(BaseOperator): template_fields = ["project_id", "input_parameters"] ui_color = "#F5C0C0" - @apply_defaults def __init__( self, project_id: str, @@ -53,6 +51,8 @@ def __init__( input_parameters: Optional[Dict[str, Any]] = None, update_cache: bool = False, notifications: List[NotificationDetails] = [], + max_poll_retries: int = 3, + poll_retry_delay: int = 5, # Change this to 5 **kwargs, ) -> None: super().__init__(**kwargs) @@ -65,6 +65,8 @@ def __init__( self.input_parameters = input_parameters self.update_cache = update_cache self.notifications = notifications + self.max_poll_retries = max_poll_retries + self.poll_retry_delay = poll_retry_delay def execute(self, context: Context) -> Any: hook = HexHook(self.hex_conn_id) @@ -79,6 +81,8 @@ def execute(self, context: Context) -> Any: poll_timeout=self.timeout, kill_on_timeout=self.kill_on_timeout, notifications=self.notifications, + max_poll_retries=self.max_poll_retries, + poll_retry_delay=self.poll_retry_delay, ) self.log.info("Hex Project completed successfully") diff --git a/setup.cfg b/setup.cfg index ea06af0..9ca07b1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,7 @@ install_requires = apache-airflow>=2.2.0 requests>=2 importlib-metadata>=4.8.1 - typing-extensions>=3.10.0.2 + typing-extensions>=4 zip_safe = false [options.extras_require] diff --git a/tests/conftest.py b/tests/conftest.py index fcc8b13..507372d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,10 @@ import datetime +import os import pendulum import pytest from airflow import DAG +from airflow.utils.db import initdb from airflow_provider_hex.operators.hex import HexRunProjectOperator @@ -20,11 +22,17 @@ def sample_conn(mocker): ) +@pytest.fixture(scope="session", autouse=True) +def init_airflow_db(): + os.environ["AIRFLOW__CORE__UNIT_TEST_MODE"] = "True" + initdb() + + @pytest.fixture() def dag(): with DAG( dag_id=TEST_DAG_ID, - schedule_interval="@daily", + schedule="@daily", start_date=DATA_INTERVAL_START, ) as dag: HexRunProjectOperator( @@ -32,6 +40,8 @@ def dag(): hex_conn_id="hex_conn", project_id="ABC-123", input_parameters={"input_date": "{{ ds }}"}, + max_poll_retries=3, + poll_retry_delay=1, ) return dag @@ -48,5 +58,7 @@ def fake_dag(): hex_conn_id="hex_conn", project_id="ABC-123", input_parameters={"input_date": "{{ ds }}"}, + max_poll_retries=3, + poll_retry_delay=1, ) return dag diff --git a/tests/hooks/test_hex_hook.py b/tests/hooks/test_hex_hook.py index d4af454..fa0c0fe 100644 --- a/tests/hooks/test_hex_hook.py +++ b/tests/hooks/test_hex_hook.py @@ -1,13 +1,12 @@ import logging import pytest -from airflow import AirflowException +from airflow.exceptions import AirflowException from airflow_provider_hex.hooks.hex import HexHook log = logging.getLogger(__name__) - mock_run = { "projectId": "abc-123", "runId": "1", @@ -118,12 +117,15 @@ def test_run_poll_pending_and_error(self, requests_mock): requests_mock.post( "https://www.httpbin.org/api/v1/project/abc-123/run", headers={"Content-Type": "application/json"}, - json=mock_run, + json={"projectId": "abc-123", "runId": "1"}, ) mock_status = {"projectId": "abc-123", "status": "PENDING"} - - mock_status_2 = {"projectId": "abc-123", "status": "UNABLE_TO_ALLOCATE_KERNEL"} + mock_status_2 = { + "projectId": "abc-123", + "status": "UNABLE_TO_ALLOCATE_KERNEL", + "runUrl": "https://example.com/run/1", + } header = {"Content-Type": "application/json"} requests_mock.register_uri( @@ -138,4 +140,96 @@ def test_run_poll_pending_and_error(self, requests_mock): hook = HexHook(hex_conn_id="hex_conn") with pytest.raises(AirflowException, match=r"Project Run failed with status.*"): - hook.run_and_poll("abc-123", inputs=None, poll_interval=1) + hook.run_and_poll( + "abc-123", + inputs=None, + poll_interval=1, + max_poll_retries=3, + poll_retry_delay=1, + ) + + # Check if the status endpoint was called multiple times + assert requests_mock.call_count == 3 # 1 POST + 2 GET requests + + def test_run_status_with_retries(self, requests_mock): + mock_status_error = {"projectId": "abc-123", "error": "Internal Server Error"} + mock_status_success = {"projectId": "abc-123", "status": "RUNNING"} + + header = {"Content-Type": "application/json"} + requests_mock.register_uri( + "GET", + "https://www.httpbin.org/api/v1/project/abc-123/run/1", + [ + {"headers": header, "json": mock_status_error, "status_code": 500}, + {"headers": header, "json": mock_status_error, "status_code": 500}, + {"headers": header, "json": mock_status_success, "status_code": 200}, + ], + ) + + hook = HexHook(hex_conn_id="hex_conn") + response = hook.run_status_with_retries( + "abc-123", "1", max_retries=3, retry_delay=1 + ) + + assert response["status"] == "RUNNING" + assert requests_mock.call_count == 3 + + def test_poll_project_status(self, requests_mock): + mock_status_pending = {"projectId": "abc-123", "status": "PENDING"} + mock_status_running = {"projectId": "abc-123", "status": "RUNNING"} + mock_status_completed = { + "projectId": "abc-123", + "status": "COMPLETED", + "runUrl": "https://example.com/run/1", + } + + header = {"Content-Type": "application/json"} + requests_mock.register_uri( + "GET", + "https://www.httpbin.org/api/v1/project/abc-123/run/1", + [ + {"headers": header, "json": mock_status_pending}, + {"headers": header, "json": mock_status_running}, + {"headers": header, "json": mock_status_completed}, + ], + ) + + hook = HexHook(hex_conn_id="hex_conn") + response = hook.poll_project_status( + "abc-123", + "1", + poll_interval=1, + poll_timeout=10, + max_poll_retries=3, + poll_retry_delay=1, + ) + + assert response["status"] == "COMPLETED" + assert requests_mock.call_count == 3 + + def test_poll_project_status_error(self, requests_mock): + requests_mock.get( + "https://www.httpbin.org/api/v1/project/abc-123/run/1", + [{"status_code": 500}] * 9, # 3 retries * 3 attempts + ) + + requests_mock.delete( + "https://www.httpbin.org/api/v1/project/abc-123/run/1", status_code=200 + ) + + hook = HexHook(hex_conn_id="hex_conn") + + with pytest.raises( + AirflowException, match="Failed to get run status for project" + ): + hook.poll_project_status( + "abc-123", + "1", + poll_interval=1, + poll_timeout=10, + kill_on_timeout=True, + max_poll_retries=3, + poll_retry_delay=1, + ) + + assert requests_mock.call_count == 10 # 9 GET requests + 1 DELETE request diff --git a/tests/operators/test_hex_operator.py b/tests/operators/test_hex_operator.py index 6b4b74d..6c83330 100644 --- a/tests/operators/test_hex_operator.py +++ b/tests/operators/test_hex_operator.py @@ -17,12 +17,17 @@ def test_my_custom_operator_execute_no_trigger(dag, requests_mock): json={"projectId": "ABC-123", "runId": "1"}, ) - mock_status = {"projectId": "abc-123", "status": "COMPLETED"} + mock_status = { + "projectId": "ABC-123", + "status": "COMPLETED", + "runUrl": "https://example.com/run/1", + } requests_mock.get( "https://www.httpbin.org/api/v1/project/abc-123/run/1", headers={"Content-Type": "application/json"}, json=mock_status, ) + dagrun = dag.create_dagrun( state=DagRunState.RUNNING, execution_date=DATA_INTERVAL_START, @@ -33,6 +38,7 @@ def test_my_custom_operator_execute_no_trigger(dag, requests_mock): ti = dagrun.get_task_instance(task_id=TEST_TASK_ID) ti.task = dag.get_task(task_id=TEST_TASK_ID) ti.run(ignore_ti_state=True) + assert ti.state == TaskInstanceState.SUCCESS json = requests_mock.request_history[0].json() assert json["inputParams"]["input_date"][0:4] == str(DATA_INTERVAL_START.year)