Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: PythonVirtualenvOperator crashes if any python_callable function is defined in the same source as DAG #37165

Merged
merged 9 commits into from
Feb 7, 2024
12 changes: 7 additions & 5 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import hashlib
import importlib
import importlib.machinery
import importlib.util
Expand Down Expand Up @@ -48,7 +47,12 @@
from airflow.utils import timezone
from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.docs import get_docs_url
from airflow.utils.file import correct_maybe_zipped, list_py_file_paths, might_contain_dag
from airflow.utils.file import (
correct_maybe_zipped,
get_unique_dag_module_name,
list_py_file_paths,
might_contain_dag,
)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down Expand Up @@ -326,9 +330,7 @@ def _load_modules_from_file(self, filepath, safe_mode):
return []

self.log.debug("Importing %s", filepath)
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
org_mod_name = Path(filepath).stem
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
mod_name = get_unique_dag_module_name(filepath)

if mod_name in sys.modules:
del sys.modules[mod_name]
Expand Down
23 changes: 15 additions & 8 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from airflow.operators.branch import BranchMixIn
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_merge
from airflow.utils.file import get_unique_dag_module_name
from airflow.utils.operator_helpers import KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script
Expand Down Expand Up @@ -437,15 +438,21 @@ def _execute_python_callable_in_subprocess(self, python_path: Path):

self._write_args(input_path)
self._write_string_args(string_args_path)

jinja_context = {
"op_args": self.op_args,
"op_kwargs": op_kwargs,
"expect_airflow": self.expect_airflow,
"pickling_library": self.pickling_library.__name__,
"python_callable": self.python_callable.__name__,
"python_callable_source": self.get_python_source(),
}

if inspect.getfile(self.python_callable) == self.dag.fileloc:
jinja_context["modified_dag_module_name"] = get_unique_dag_module_name(self.dag.fileloc)

write_python_script(
jinja_context={
"op_args": self.op_args,
"op_kwargs": op_kwargs,
"expect_airflow": self.expect_airflow,
"pickling_library": self.pickling_library.__name__,
"python_callable": self.python_callable.__name__,
"python_callable_source": self.get_python_source(),
},
jinja_context=jinja_context,
filename=os.fspath(script_path),
render_template_as_native_obj=self.dag.render_template_as_native_obj,
)
Expand Down
12 changes: 12 additions & 0 deletions airflow/utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import ast
import hashlib
import logging
import os
import zipfile
Expand All @@ -33,6 +34,8 @@

log = logging.getLogger(__name__)

MODIFIED_DAG_MODULE_NAME = "unusual_prefix_{path_hash}_{module_name}"


class _IgnoreRule(Protocol):
"""Interface for ignore rules for structural subtyping."""
Expand Down Expand Up @@ -379,3 +382,12 @@ def iter_airflow_imports(file_path: str) -> Generator[str, None, None]:
for m in _find_imported_modules(parsed):
if m.startswith("airflow."):
yield m


def get_unique_dag_module_name(file_path: str) -> str:
"""Returns a unique module name in the format unusual_prefix_{sha1 of module's file path}_{original module name}."""
if isinstance(file_path, str):
path_hash = hashlib.sha1(file_path.encode("utf-8")).hexdigest()
org_mod_name = Path(file_path).stem
return MODIFIED_DAG_MODULE_NAME.format(path_hash=path_hash, module_name=org_mod_name)
raise ValueError("file_path should be a string to generate unique module name")
rawwar marked this conversation as resolved.
Show resolved Hide resolved
19 changes: 17 additions & 2 deletions airflow/utils/python_virtualenv_script.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ if sys.version_info >= (3,6):
pass
{% endif %}

# Script
{{ python_callable_source }}

# monkey patching for the cases when python_callable is part of the dag module.
{% if modified_dag_module_name is defined %}

import types

{{ modified_dag_module_name }} = types.ModuleType("{{ modified_dag_module_name }}")

{{ modified_dag_module_name }}.{{ python_callable }} = {{ python_callable }}

sys.modules["{{modified_dag_module_name}}"] = {{modified_dag_module_name}}

{% endif%}

{% if op_args or op_kwargs %}
with open(sys.argv[1], "rb") as file:
arg_dict = {{ pickling_library }}.load(file)
Expand All @@ -47,8 +63,7 @@ with open(sys.argv[3], "r") as file:
virtualenv_string_args = list(map(lambda x: x.strip(), list(file)))
{% endif %}

# Script
{{ python_callable_source }}

try:
res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"])
except Exception as e:
Expand Down