diff --git a/pytest_factoryboy/fixture.py b/pytest_factoryboy/fixture.py index e94b039..e5f2a9a 100644 --- a/pytest_factoryboy/fixture.py +++ b/pytest_factoryboy/fixture.py @@ -19,7 +19,7 @@ import inflection from typing_extensions import ParamSpec, TypeAlias -from .compat import PostGenerationContext +from .compat import PostGenerationContext, PytestFixtureT from .fixturegen import create_fixture if TYPE_CHECKING: @@ -136,7 +136,7 @@ def generate_fixtures( factory_name: str, overrides: Mapping[str, Any], caller_locals: Box[Mapping[str, Any]], -) -> Iterable[tuple[str, Callable[..., Any]]]: +) -> Iterable[tuple[str, Callable[..., object]]]: """Generate all the FixtureDefs for the given factory class.""" related: list[str] = [] @@ -153,16 +153,18 @@ def generate_fixtures( ), ) + deps = get_deps(factory_class, model_name=model_name) if factory_name not in caller_locals.value: yield ( factory_name, create_fixture_with_related( name=factory_name, function=functools.partial(factory_fixture, factory_class=factory_class), + fixtures=deps, + # TODO: related too? ), ) - deps = get_deps(factory_class, model_name=model_name) yield ( model_name, create_fixture_with_related( @@ -176,10 +178,10 @@ def generate_fixtures( def create_fixture_with_related( name: str, - function: Callable[P, T], + function: Callable[..., object], fixtures: Collection[str] | None = None, related: Collection[str] | None = None, -) -> Callable[P, T]: +) -> PytestFixtureT: if related is None: related = [] fixture, fn = create_fixture(name=name, function=function, fixtures=fixtures) @@ -195,7 +197,7 @@ def make_declaration_fixturedef( value: Any, factory_class: FactoryType, related: list[str], -) -> Callable[..., Any]: +) -> Callable[..., object]: """Create the FixtureDef for a factory declaration.""" if isinstance(value, (factory.SubFactory, factory.RelatedFactory)): subfactory_class = value.get_factory() @@ -345,6 +347,7 @@ def model_fixture(request: SubRequest, factory_name: str) -> Any: fixture_name = request.fixturename prefix = "".join((fixture_name, SEPARATOR)) + # TODO: This should be a dependency of the current fixture (i.e. use `usefixtures`) factory_class: FactoryType = request.getfixturevalue(factory_name) # Create model fixture instance @@ -478,7 +481,20 @@ def deferred_impl(request: SubRequest) -> Any: def factory_fixture(request: SubRequest, factory_class: F) -> F: """Factory fixture implementation.""" - return factory_class + fixture_name = request.fixturename + # TODO: Not good to check the fixture name, we should know what to expect (via args?) + assert fixture_name.endswith("_factory") + fixture_name = fixture_name[: -len("_factory")] + prefix = "".join((fixture_name, SEPARATOR)) + + # TODO: copy-paste from model_fixture; refactor + kwargs = {} + for key in factory_class._meta.pre_declarations: + argname = "".join((prefix, key)) + if argname in request._fixturedef.argnames: + kwargs[key] = evaluate(request, request.getfixturevalue(argname)) + + return type(f"{factory_class.__name__}Fixture", (factory_class,), kwargs) def attr_fixture(request: SubRequest, value: T) -> T: diff --git a/tests/test_factory_fixtures.py b/tests/test_factory_fixtures.py index 202973d..3b7767e 100644 --- a/tests/test_factory_fixtures.py +++ b/tests/test_factory_fixtures.py @@ -111,7 +111,7 @@ class Meta: def test_factory(book_factory) -> None: """Test model factory fixture.""" - assert book_factory == BookFactory + assert issubclass(book_factory, BookFactory) def test_model(book: Book): diff --git a/tests/test_foo.py b/tests/test_foo.py new file mode 100644 index 0000000..657f9d2 --- /dev/null +++ b/tests/test_foo.py @@ -0,0 +1,30 @@ +# TODO: Improve tests +# TODO: Change test module + +from dataclasses import dataclass + +import factory +import pytest + +from pytest_factoryboy import register + + +@dataclass +class Book: + name: str + + +@register +class BookFactory(factory.Factory): + class Meta: + model = Book + + name = "foo" + + +@pytest.mark.parametrize("book__name", ["bar"]) +def test_book_initialise_later(book_factory, book__name, book): + assert book.name == "bar" + + book_f = book_factory() + assert book_f.name == "bar"