From 041fcf952c0721804e47c0a8ae4999de5eb683cc Mon Sep 17 00:00:00 2001 From: Anton Agestam Date: Sun, 18 Aug 2024 13:34:24 +0200 Subject: [PATCH] fix: Support extracting generic bound from TypeVar (#88) Adds support for Pydantic models that are generic with a currency type. Fixes #87. --- goose.yaml | 14 +-- src/immoney/_pydantic.py | 11 ++- tests/test_pydantic.py | 196 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 213 insertions(+), 8 deletions(-) diff --git a/goose.yaml b/goose.yaml index 265b71c..beed38c 100644 --- a/goose.yaml +++ b/goose.yaml @@ -28,12 +28,14 @@ environments: - prettier hooks: - - id: check-manifest - environment: check-manifest - command: check-manifest - parameterize: false - read_only: true - args: [--no-build-isolation] + # Commented out until there's a fix for pinning setuptools. + # https://github.com/antonagestam/goose/issues/30 + # - id: check-manifest + # environment: check-manifest + # command: check-manifest + # parameterize: false + # read_only: true + # args: [--no-build-isolation] - id: prettier environment: node diff --git a/src/immoney/_pydantic.py b/src/immoney/_pydantic.py index 05d65f4..41808b1 100644 --- a/src/immoney/_pydantic.py +++ b/src/immoney/_pydantic.py @@ -39,9 +39,16 @@ class OverdraftDict(TypedDict): def extract_currency_type_arg(source_type: type) -> type[Currency]: match get_args(source_type): - case (type() as currency_type,): - assert issubclass(currency_type, Currency) + case (type() as currency_type,) if issubclass(currency_type, Currency): return currency_type + # TypeVar with Currency bound. + case (TypeVar(__bound__=type() as currency_type),) if issubclass( + currency_type, Currency + ): + return currency_type + # TypeVar without bound. + case (TypeVar(__bound__=None),): + return Currency case invalid: # pragma: no cover raise TypeError(f"Invalid type args: {invalid!r}.") diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py index 086d69f..a98e1c4 100644 --- a/tests/test_pydantic.py +++ b/tests/test_pydantic.py @@ -1,5 +1,7 @@ import json from fractions import Fraction +from typing import Generic +from typing import TypeVar import pytest from pydantic import BaseModel @@ -263,6 +265,200 @@ def test_can_generate_schema(self) -> None: } +C_bounded = TypeVar("C_bounded", bound=Currency) + + +class BoundedGenericMoneyModel(BaseModel, Generic[C_bounded]): + money: Money[C_bounded] + + +class TestBoundedGenericMoneyModel: + @pytest.mark.parametrize( + ("subunits", "currency_code", "expected"), + ( + (4990, "USD", USD("49.90")), + (4990, "EUR", EUR("49.90")), + (0, "NOK", NOK(0)), + ), + ) + def test_can_roundtrip_valid_data( + self, + subunits: int, + currency_code: str, + expected: Money[C_bounded], + ) -> None: + data = { + "money": { + "subunits": subunits, + "currency": currency_code, + } + } + + instance = BoundedGenericMoneyModel[C_bounded].model_validate(data) + assert instance.money == expected + assert json.loads(instance.model_dump_json()) == data + + def test_parsing_raises_validation_error_for_negative_value(self) -> None: + with pytest.raises( + ValidationError, + match=r"Input should be greater than or equal to 0", + ): + BoundedGenericMoneyModel.model_validate( + { + "money": { + "currency": "EUR", + "subunits": -1, + }, + } + ) + + def test_parsing_raises_validation_error_for_invalid_currency(self) -> None: + with pytest.raises( + ValidationError, + match=r"Input should be.*\[type=literal_error", + ): + BoundedGenericMoneyModel.model_validate( + { + "money": { + "currency": "JCN", + "subunits": 4990, + }, + } + ) + + def test_can_instantiate_valid_value(self) -> None: + instance = BoundedGenericMoneyModel(money=USD("49.90")) + assert instance.money == USD("49.90") + + def test_instantiation_raises_validation_error_for_invalid_currency(self) -> None: + with pytest.raises(ValidationError, match=r"Currency is not registered"): + BoundedGenericMoneyModel(money=JCN(1)) + + def test_can_generate_schema(self) -> None: + assert BoundedGenericMoneyModel.model_json_schema() == { + "properties": { + "money": { + "properties": { + "currency": { + "enum": sorted_items_equal(default_registry.keys()), + "title": "Currency", + "type": "string", + }, + "subunits": { + "minimum": 0, + "title": "Subunits", + "type": "integer", + }, + }, + "required": sorted_items_equal(["subunits", "currency"]), + "title": "Money", + "type": "object", + }, + }, + "required": ["money"], + "title": BoundedGenericMoneyModel.__name__, + "type": "object", + } + + +C_unbound = TypeVar("C_unbound") + + +class UnboundGenericMoneyModel(BaseModel, Generic[C_unbound]): + # mypy rightfully errors here, demanding that the type var is bounded to + # Currency, but we still want to test this case. + money: Money[C_unbound] # type: ignore[type-var] + + +class TestUnboundGenericMoneyModel: + @pytest.mark.parametrize( + ("subunits", "currency_code", "expected"), + ( + (4990, "USD", USD("49.90")), + (4990, "EUR", EUR("49.90")), + (0, "NOK", NOK(0)), + ), + ) + def test_can_roundtrip_valid_data( + self, + subunits: int, + currency_code: str, + expected: Money[C_unbound], # type: ignore[type-var] + ) -> None: + data = { + "money": { + "subunits": subunits, + "currency": currency_code, + } + } + + instance = UnboundGenericMoneyModel[C_unbound].model_validate(data) + assert instance.money == expected + assert json.loads(instance.model_dump_json()) == data + + def test_parsing_raises_validation_error_for_negative_value(self) -> None: + with pytest.raises( + ValidationError, + match=r"Input should be greater than or equal to 0", + ): + UnboundGenericMoneyModel.model_validate( + { + "money": { + "currency": "EUR", + "subunits": -1, + }, + } + ) + + def test_parsing_raises_validation_error_for_invalid_currency(self) -> None: + with pytest.raises( + ValidationError, + match=r"Input should be.*\[type=literal_error", + ): + UnboundGenericMoneyModel.model_validate( + { + "money": { + "currency": "JCN", + "subunits": 4990, + }, + } + ) + + def test_can_instantiate_valid_value(self) -> None: + instance = UnboundGenericMoneyModel(money=USD("49.90")) + assert instance.money == USD("49.90") + + def test_instantiation_raises_validation_error_for_invalid_currency(self) -> None: + with pytest.raises(ValidationError, match=r"Currency is not registered"): + UnboundGenericMoneyModel(money=JCN(1)) + + def test_can_generate_schema(self) -> None: + assert UnboundGenericMoneyModel.model_json_schema() == { + "properties": { + "money": { + "properties": { + "currency": { + "enum": sorted_items_equal(default_registry.keys()), + "title": "Currency", + "type": "string", + }, + "subunits": { + "minimum": 0, + "title": "Subunits", + "type": "integer", + }, + }, + "required": sorted_items_equal(["subunits", "currency"]), + "title": "Money", + "type": "object", + }, + }, + "required": ["money"], + "title": UnboundGenericMoneyModel.__name__, + "type": "object", + } + + class SpecializedMoneyModel(BaseModel): money: Money[USDType]