From 8838a11f584886d30e900f6e5bc8f6a7c9e5fea9 Mon Sep 17 00:00:00 2001 From: Bhargav Dodla Date: Wed, 7 Aug 2024 13:04:48 -0700 Subject: [PATCH] fix: Using get_type_hints instead of inspect for udf return type --- sdk/python/feast/on_demand_feature_view.py | 4 +- .../infra/scaffolding/test_repo_operations.py | 56 ++++++++++++++++++- 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/sdk/python/feast/on_demand_feature_view.py b/sdk/python/feast/on_demand_feature_view.py index 586f5d1bac9..aeb1cc207a1 100644 --- a/sdk/python/feast/on_demand_feature_view.py +++ b/sdk/python/feast/on_demand_feature_view.py @@ -3,7 +3,7 @@ import inspect import warnings from types import FunctionType -from typing import Any, Optional, Union +from typing import Any, Optional, Union, get_type_hints import dill import pandas as pd @@ -631,7 +631,7 @@ def mainify(obj) -> None: obj.__module__ = "__main__" def decorator(user_function): - return_annotation = inspect.signature(user_function).return_annotation + return_annotation = get_type_hints(user_function).get("return", inspect._empty) udf_string = dill.source.getsource(user_function) mainify(user_function) if mode == "pandas": diff --git a/sdk/python/tests/unit/infra/scaffolding/test_repo_operations.py b/sdk/python/tests/unit/infra/scaffolding/test_repo_operations.py index aa4ff1c40f7..2d4972080aa 100644 --- a/sdk/python/tests/unit/infra/scaffolding/test_repo_operations.py +++ b/sdk/python/tests/unit/infra/scaffolding/test_repo_operations.py @@ -1,3 +1,5 @@ +import os +import tempfile from contextlib import contextmanager from pathlib import Path from tempfile import TemporaryDirectory @@ -6,7 +8,13 @@ import assertpy -from feast.repo_operations import get_ignore_files, get_repo_files, read_feastignore +from feast.repo_operations import ( + get_ignore_files, + get_repo_files, + parse_repo, + read_feastignore, +) +from tests.utils.cli_repo_creator import CliRunner @contextmanager @@ -140,3 +148,49 @@ def test_feastignore_with_stars2(): (repo_root / "foo1/c.py").resolve(), ] ) + + +def test_parse_repo(): + "Test to ensure that the repo is parsed correctly" + runner = CliRunner() + with tempfile.TemporaryDirectory(dir=os.getcwd()) as temp_dir: + # Make sure the path is absolute by resolving any symlinks + temp_path = Path(temp_dir).resolve() + result = runner.run(["init", "my_project"], cwd=temp_path) + repo_path = Path(temp_path / "my_project" / "feature_repo") + assert result.returncode == 0 + + repo_contents = parse_repo(repo_path) + + assert len(repo_contents.data_sources) == 3 + assert len(repo_contents.feature_views) == 2 + assert len(repo_contents.on_demand_feature_views) == 2 + assert len(repo_contents.stream_feature_views) == 0 + assert len(repo_contents.entities) == 2 + assert len(repo_contents.feature_services) == 3 + + +def test_parse_repo_with_future_annotations(): + "Test to ensure that the repo is parsed correctly when using future annotations" + runner = CliRunner() + with tempfile.TemporaryDirectory(dir=os.getcwd()) as temp_dir: + # Make sure the path is absolute by resolving any symlinks + temp_path = Path(temp_dir).resolve() + result = runner.run(["init", "my_project"], cwd=temp_path) + repo_path = Path(temp_path / "my_project" / "feature_repo") + assert result.returncode == 0 + + with open(repo_path / "example_repo.py", "r") as f: + existing_content = f.read() + + with open(repo_path / "example_repo.py", "w") as f: + f.write("from __future__ import annotations" + "\n" + existing_content) + + repo_contents = parse_repo(repo_path) + + assert len(repo_contents.data_sources) == 3 + assert len(repo_contents.feature_views) == 2 + assert len(repo_contents.on_demand_feature_views) == 2 + assert len(repo_contents.stream_feature_views) == 0 + assert len(repo_contents.entities) == 2 + assert len(repo_contents.feature_services) == 3