Skip to content

Commit

Permalink
Merge pull request #9 from mscarey/pydantic2
Browse files Browse the repository at this point in the history
Pydantic2
  • Loading branch information
mscarey authored Dec 4, 2023
2 parents aa37239 + ac21b22 commit 23f6ec0
Show file tree
Hide file tree
Showing 17 changed files with 138 additions and 151 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9.7", "3.10.0"]
python-version: ["3.11", "3.12"]

steps:
- uses: actions/checkout@v2
Expand Down
4 changes: 0 additions & 4 deletions docs/api/quantities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ Quantity Ranges
:members:
:special-members:

.. autoclass:: nettlesome.quantities.IntRange
:members:
:special-members:

.. autoclass:: nettlesome.quantities.DecimalRange
:members:
:special-members:
Expand Down
2 changes: 1 addition & 1 deletion nettlesome/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .entities import Entity
from .groups import FactorGroup
from .predicates import Predicate
from .quantities import Comparison, DateRange, IntRange, DecimalRange, UnitRange
from .quantities import Comparison, DateRange, DecimalRange, UnitRange
from .statements import Statement, Assertion

__version__ = "0.6.1"
17 changes: 9 additions & 8 deletions nettlesome/entities.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
""":class:`.Comparable` subclass for things that can be referenced in a Statement."""

from __future__ import annotations
from typing import ClassVar, Dict, Optional, Tuple
from typing import ClassVar, Optional, Tuple

from pydantic import BaseModel, Extra, root_validator
from pydantic import BaseModel, model_validator
from pydantic import ValidationError

from nettlesome.terms import Comparable, ContextRegister, Term


class Entity(Term, BaseModel, extra=Extra.forbid):
class Entity(Term, BaseModel, extra="forbid"):
r"""
Things that can be referenced in a Statement.
Expand Down Expand Up @@ -39,11 +39,12 @@ class Entity(Term, BaseModel, extra=Extra.forbid):
plural: bool = False
context_factor_names: ClassVar[Tuple[str, ...]] = ()

@root_validator(pre=True)
def validate_type(cls, values) -> Dict:
name = values.pop("type", None)
if name and name.lower() != cls.__name__.lower():
raise ValidationError(f"Expected type {cls.__name__}, not {name}")
@model_validator(mode="before")
def validate_type(cls, values) -> dict:
if isinstance(values, dict):
name = values.pop("type", None)
if name and name.lower() != cls.__name__.lower():
raise ValidationError(f"Expected type {cls.__name__}, not {name}")
return values

def __str__(self):
Expand Down
4 changes: 2 additions & 2 deletions nettlesome/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any, Dict, Mapping
from typing import List, Optional, Sequence, Set, Tuple

from pydantic import BaseModel, Extra
from pydantic import BaseModel

from nettlesome.terms import Comparable, TermSequence
from nettlesome.terms import Term
Expand Down Expand Up @@ -404,7 +404,7 @@ def _add_truth_to_content(self, content: str) -> str:
return f"{truth_prefix}{content}"


class Predicate(PhraseABC, BaseModel, extra=Extra.forbid):
class Predicate(PhraseABC, BaseModel, extra="forbid"):
r"""
A statement about real events or about a legal conclusion.
Expand Down
113 changes: 41 additions & 72 deletions nettlesome/quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, ClassVar, Dict, Optional, Union

from pint import UnitRegistry, Quantity
from pydantic import BaseModel, root_validator, validator
from pydantic import BaseModel, field_validator, model_validator, validator
import sympy
from sympy import Eq, Interval, oo, S
from sympy.sets import EmptySet, FiniteSet
Expand Down Expand Up @@ -91,15 +91,15 @@ class QuantityRange(BaseModel):
}
normalized_comparisons: ClassVar[Dict[str, str]] = {"=": "==", "<>": "!="}

@validator("sign")
def _check_sign(cls, sign):
if sign in cls.normalized_comparisons:
sign = cls.normalized_comparisons[sign]
if sign not in cls.opposite_comparisons.keys():
@field_validator("sign", mode="after")
def check_sign(cls, v: str) -> str:
if v in cls.normalized_comparisons:
v = cls.normalized_comparisons[v]
if v not in cls.opposite_comparisons.keys():
raise ValueError(
f'"sign" string parameter must be one of {cls.opposite_comparisons.keys()}, not {sign}.'
f'"sign" string parameter must be one of {cls.opposite_comparisons.keys()}, not {v}.'
)
return sign
return v

@property
def _include_negatives(self) -> bool:
Expand Down Expand Up @@ -202,7 +202,7 @@ def reverse_meaning(self) -> None:
"""
Change self.sign in place to reverse the range of numbers covered.
>>> quantity_range = UnitRange(quantity="100 meters", sign=">")
>>> quantity_range = UnitRange(quantity_magnitude="100", quantity_units="meters", sign=">")
>>> str(quantity_range)
'greater than 100 meter'
>>> quantity_range.reverse_meaning()
Expand All @@ -220,31 +220,24 @@ def _quantity_string(self) -> str:
class UnitRange(QuantityRange, BaseModel):
"""A range defined relative to a pint Quantity."""

quantity: str
quantity_magnitude: Decimal
quantity_units: str
sign: str = "=="
include_negatives: Optional[bool] = None

@validator("quantity", pre=True)
def validate_quantity(cls, quantity: Union[Quantity, str]) -> str:
"""Validate that quantity is a pint Quantity."""
if isinstance(quantity, str):
quantity = Q_(quantity)
if not isinstance(quantity, Quantity):
raise TypeError(
f"quantity must be a pint Quantity or string, not {type(quantity)}"
)
return str(quantity)
@property
def q(self) -> Quantity:
return Quantity(self.quantity_magnitude, self.quantity_units)

@property
def quantity(self) -> Quantity:
return self.q

@property
def domain(self) -> sympy.Set:
"""Get the domain of the quantity range."""
return S.Reals

@property
def pint_quantity(self) -> Quantity:
"""Get the Quantity as a Pint object."""
return Q_(self.quantity)

@property
def magnitude(self) -> Union[int, float]:
"""Get magnitude of pint Quantity."""
Expand Down Expand Up @@ -315,11 +308,7 @@ def means(self, other: Any) -> bool:
return other_interval.is_subset(self.interval)

def _quantity_string(self) -> str:
return super()._quantity_string() + str(self.quantity)

@property
def q(self) -> Quantity:
return Q_(self.quantity)
return super()._quantity_string() + str(self.q)


class DateRange(QuantityRange, BaseModel):
Expand Down Expand Up @@ -347,31 +336,6 @@ def _quantity_string(self) -> str:
return str(self.quantity)


class IntRange(QuantityRange, BaseModel):
"""A range defined relative to an integer or float."""

quantity: int
sign: str = "=="
include_negatives: Optional[bool] = None

@property
def domain(self) -> sympy.Set:
"""Set domain as natural numbers."""
return S.Naturals0

@property
def magnitude(self) -> Union[int, float]:
"""Return quantity attribute."""
return self.quantity

def _quantity_string(self) -> str:
return str(self.quantity)

@property
def q(self) -> int:
return self.quantity


class DecimalRange(QuantityRange, BaseModel):
"""A range defined relative to an integer or float."""

Expand Down Expand Up @@ -459,14 +423,20 @@ class Comparison(BaseModel, PhraseABC):
"""

content: str
quantity_range: Union[DecimalRange, IntRange, UnitRange, DateRange]
quantity_range: Union[DecimalRange, UnitRange, DateRange]
truth: Optional[bool] = True

@root_validator(pre=True)
@model_validator(mode="before")
def set_quantity_range(cls, values):
"""Reverse the sign of a Comparison if necessary."""
if not values.get("quantity_range"):
quantity = cls.expression_to_quantity(values.pop("expression", None))
try:
quantity = cls.expression_to_quantity(values.pop("expression", None))
except AttributeError:
raise ValueError(
"A Comparison must have a quantity_range, "
"a quantity, or an expression."
)
sign = values.pop("sign", "==")
include_negatives = values.pop("include_negatives", None)
if isinstance(quantity, date):
Expand All @@ -476,15 +446,12 @@ def set_quantity_range(cls, values):
include_negatives=include_negatives,
)
elif isinstance(quantity, (str, Quantity)):
if isinstance(quantity, str):
quantity = Q_(quantity)
values["quantity_range"] = UnitRange(
sign=sign,
quantity=str(quantity),
include_negatives=include_negatives,
)
elif isinstance(quantity, int):
values["quantity_range"] = IntRange(
sign=sign,
quantity=quantity,
quantity_magnitude=Decimal(quantity.magnitude),
quantity_units=str(quantity.units),
include_negatives=include_negatives,
)
else:
Expand All @@ -498,7 +465,7 @@ def set_quantity_range(cls, values):
values["quantity_range"].reverse_meaning()
return values

@validator("content")
@field_validator("content")
def content_ends_with_was(cls, content: str) -> str:
"""Ensure content ends with 'was'."""
if content.endswith("were"):
Expand All @@ -514,7 +481,7 @@ def content_ends_with_was(cls, content: str) -> str:
@classmethod
def expression_to_quantity(
cls, value: Union[date, float, int, str]
) -> Union[date, float, int, Quantity]:
) -> Union[date, Decimal, str]:
r"""
Create numeric expression from text for Comparison class.
Expand All @@ -531,8 +498,10 @@ def expression_to_quantity(
"""
if isinstance(value, Quantity):
return str(value)
if isinstance(value, (int, float, date)):
if isinstance(value, date):
return value
if isinstance(value, (int, float)):
return Decimal(value)
quantity = value.strip()

try:
Expand All @@ -541,12 +510,12 @@ def expression_to_quantity(
pass

if quantity.isdigit():
return int(quantity)
return Decimal(quantity)
float_parts = quantity.split(".")
if len(float_parts) == 2 and all(
substring.isnumeric() for substring in float_parts
):
return float(quantity)
return Decimal(quantity)
return str(Q_(quantity))

@property
Expand All @@ -559,7 +528,7 @@ def interval(self) -> Union[FiniteSet, Interval, sympy.Union]:
... sign=">=",
... expression="10 grams")
>>> weight.interval
Interval(10, oo)
Interval(10.0000000000000, oo)
"""
return self.quantity_range.interval
Expand Down Expand Up @@ -696,4 +665,4 @@ def __str__(self):
return self._add_truth_to_content(self.content)


Comparison.update_forward_refs()
Comparison.model_rebuild()
38 changes: 24 additions & 14 deletions nettlesome/statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
import operator

from typing import ClassVar, Dict, Iterator, List, Mapping
from typing import Optional, Sequence, Tuple, Union

from pydantic import BaseModel, validator, root_validator
from typing import Optional, Self, Sequence, Tuple, Union

from pydantic import (
BaseModel,
validator,
field_validator,
model_validator,
root_validator,
)
from slugify import slugify

from nettlesome.terms import (
Expand Down Expand Up @@ -63,9 +69,9 @@ class Statement(Factor, BaseModel):
absent: bool = False
generic: bool = False

@root_validator(pre=True)
@model_validator(mode="before")
def move_truth_to_predicate(cls, values):
if isinstance(values["predicate"], str):
if isinstance(values.get("predicate"), str):
values["predicate"] = Predicate(content=values["predicate"])
if "truth" in values:
values["predicate"].truth = values["truth"]
Expand All @@ -76,25 +82,29 @@ def move_truth_to_predicate(cls, values):
].template.get_term_sequence_from_mapping(values["terms"])
if not values.get("terms"):
values["terms"] = []
elif isinstance(values["terms"], Term):
elif isinstance(values.get("terms"), Term):
values["terms"] = [values["terms"]]
return values

@validator("terms")
def _validate_terms(cls, v, values, **kwargs):
@field_validator("terms")
@classmethod
def validate_terms(cls, v):
"""Normalize ``terms`` to initialize Statement."""

# make TermSequence for validation, then ignore it
TermSequence.validate_terms(v)
return v

if len(v) != len(values["predicate"]):
@model_validator(mode="after")
def validate_terms_for_predicate(self) -> Self:
if self.predicate and len(self.terms) != len(self.predicate):
message = (
"The number of items in 'terms' must be "
+ f"{len(values['predicate'])}, not {len(v)}, "
+ f"to match predicate.context_slots for '{values['predicate']}'"
+ f"{len(self.predicate)}, not {len(self.terms)}, "
+ f"to match predicate.context_slots for '{self.predicate}'"
)
raise ValueError(message)
return v
return self

@property
def term_sequence(self) -> TermSequence:
Expand Down Expand Up @@ -285,5 +295,5 @@ def __str__(self):
return formatted


Statement.update_forward_refs()
Assertion.update_forward_refs()
Statement.model_rebuild()
Assertion.model_rebuild()
Loading

0 comments on commit 23f6ec0

Please sign in to comment.