diff --git a/examples/albert/utils.py b/examples/albert/utils.py index 4d5d0a9b9..3aa2cccf1 100644 --- a/examples/albert/utils.py +++ b/examples/albert/utils.py @@ -1,16 +1,16 @@ from typing import Dict, List, Tuple -from pydantic import BaseModel, StrictFloat, confloat, conint +from pydantic import StrictFloat, confloat, conint from hivemind.dht.crypto import RSASignatureValidator -from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator +from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, SchemaValidator from hivemind.dht.validation import RecordValidatorBase from hivemind.utils.logging import get_logger logger = get_logger(__name__) -class LocalMetrics(BaseModel): +class LocalMetrics(ExtendedBaseModel): step: conint(ge=0, strict=True) samples_per_second: confloat(ge=0.0, strict=True) samples_accumulated: conint(ge=0, strict=True) @@ -18,7 +18,7 @@ class LocalMetrics(BaseModel): mini_steps: conint(ge=0, strict=True) -class MetricSchema(BaseModel): +class MetricSchema(ExtendedBaseModel): metrics: Dict[BytesWithPublicKey, LocalMetrics] diff --git a/hivemind/dht/schema.py b/hivemind/dht/schema.py index 587afd1fb..4bb7f6c51 100644 --- a/hivemind/dht/schema.py +++ b/hivemind/dht/schema.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Optional, Type import pydantic +from pydantic_core import CoreSchema, core_schema from hivemind.dht.crypto import RSASignatureValidator from hivemind.dht.protocol import DHTProtocol @@ -12,13 +13,18 @@ logger = get_logger(__name__) +class ExtendedBaseModel(pydantic.BaseModel): + class Config: + arbitrary_types_allowed = True + + class SchemaValidator(RecordValidatorBase): """ Restricts specified DHT keys to match a Pydantic schema. This allows to enforce types, min/max values, require a subkey to contain a public key, etc. """ - def __init__(self, schema: Type[pydantic.BaseModel], allow_extra_keys: bool = True, prefix: Optional[str] = None): + def __init__(self, schema: Type[ExtendedBaseModel], allow_extra_keys: bool = True, prefix: Optional[str] = None): """ :param schema: The Pydantic model (a subclass of pydantic.BaseModel). @@ -37,22 +43,23 @@ def __init__(self, schema: Type[pydantic.BaseModel], allow_extra_keys: bool = Tr :param prefix: (optional) Add ``prefix + '_'`` to the names of all schema fields. """ - self._patch_schema(schema) + # self._patch_schema(schema) self._schemas = [schema] self._key_id_to_field_name = {} - for field in schema.__fields__.values(): - raw_key = f"{prefix}_{field.name}" if prefix is not None else field.name - self._key_id_to_field_name[DHTID.generate(source=raw_key).to_bytes()] = field.name + for field_name in schema.__fields__.keys(): + raw_key = f"{prefix}_{field_name}" if prefix is not None else field_name + self._key_id_to_field_name[DHTID.generate(source=raw_key).to_bytes()] = field_name self._allow_extra_keys = allow_extra_keys - @staticmethod - def _patch_schema(schema: pydantic.BaseModel): - # We set required=False because the validate() interface provides only one key at a time - for field in schema.__fields__.values(): - field.required = False + # // required was changed to is_required in 2.0, and made read-only + # @staticmethod + # def _patch_schema(schema: ExtendedBaseModel): + # # We set required=False because the validate() interface provides only one key at a time + # for field in schema.__fields__.values(): + # field.required = False - schema.Config.extra = pydantic.Extra.forbid + # schema.Config.extra = pydantic.Extra.forbid def validate(self, record: DHTRecord) -> bool: """ @@ -151,11 +158,11 @@ def __setstate__(self, state): self.__dict__.update(state) # If unpickling happens in another process, the previous model modifications may be lost - for schema in self._schemas: - self._patch_schema(schema) + # for schema in self._schemas: + # self._patch_schema(schema) -def conbytes(*, regex: bytes = None, **kwargs) -> Type[pydantic.BaseModel]: +def conbytes(*, regex: bytes = None, **kwargs) -> Type[ExtendedBaseModel]: """ Extend pydantic.conbytes() to support ``regex`` constraints (like pydantic.constr() does). """ @@ -164,11 +171,7 @@ def conbytes(*, regex: bytes = None, **kwargs) -> Type[pydantic.BaseModel]: class ConstrainedBytesWithRegex(pydantic.conbytes(**kwargs)): @classmethod - def __get_validators__(cls): - yield from super().__get_validators__() - yield cls.match_regex - - @classmethod + @pydantic.validator("*") def match_regex(cls, value: bytes) -> bytes: if compiled_regex is not None and compiled_regex.match(value) is None: raise ValueError(f"Value `{value}` doesn't match regex `{regex}`") diff --git a/hivemind/optim/progress_tracker.py b/hivemind/optim/progress_tracker.py index 9a6ff66e7..63987a983 100644 --- a/hivemind/optim/progress_tracker.py +++ b/hivemind/optim/progress_tracker.py @@ -6,10 +6,10 @@ from typing import Dict, Optional import numpy as np -from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint +from pydantic import StrictBool, StrictFloat, confloat, conint from hivemind.dht import DHT -from hivemind.dht.schema import BytesWithPublicKey, RSASignatureValidator, SchemaValidator +from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, RSASignatureValidator, SchemaValidator from hivemind.utils import DHTExpiration, ValueWithExpiration, enter_asynchronously, get_dht_time, get_logger from hivemind.utils.crypto import RSAPrivateKey from hivemind.utils.performance_ema import PerformanceEMA @@ -28,7 +28,7 @@ class GlobalTrainingProgress: next_fetch_time: float -class LocalTrainingProgress(BaseModel): +class LocalTrainingProgress(ExtendedBaseModel): peer_id: bytes epoch: conint(ge=0, strict=True) samples_accumulated: conint(ge=0, strict=True) @@ -37,7 +37,7 @@ class LocalTrainingProgress(BaseModel): client_mode: StrictBool -class TrainingProgressSchema(BaseModel): +class TrainingProgressSchema(ExtendedBaseModel): progress: Dict[BytesWithPublicKey, Optional[LocalTrainingProgress]] diff --git a/requirements.txt b/requirements.txt index df60317d8..8420ae128 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,5 +12,5 @@ configargparse>=1.2.3 py-multihash>=0.2.3 multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@e01dbd38f2c0464c0f78b556691d655265018cce cryptography>=3.4.6 -pydantic>=1.8.1,<2.0 +pydantic>=2.7.3 packaging>=20.9 diff --git a/tests/test_dht_schema.py b/tests/test_dht_schema.py index 8c9fd4aeb..74afd365d 100644 --- a/tests/test_dht_schema.py +++ b/tests/test_dht_schema.py @@ -2,16 +2,16 @@ from typing import Dict import pytest -from pydantic import BaseModel, StrictInt, conint +from pydantic import StrictInt, conint import hivemind from hivemind.dht.node import DHTNode -from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator +from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, SchemaValidator from hivemind.dht.validation import DHTRecord, RecordValidatorBase from hivemind.utils.timed_storage import get_dht_time -class SampleSchema(BaseModel): +class SampleSchema(ExtendedBaseModel): experiment_name: bytes n_batches: Dict[bytes, conint(ge=0, strict=True)] signed_data: Dict[BytesWithPublicKey, bytes] @@ -94,10 +94,10 @@ async def test_expecting_public_keys(dht_nodes_with_schema): @pytest.mark.forked @pytest.mark.asyncio async def test_keys_outside_schema(dht_nodes_with_schema): - class Schema(BaseModel): + class Schema(ExtendedBaseModel): some_field: StrictInt - class MergedSchema(BaseModel): + class MergedSchema(ExtendedBaseModel): another_field: StrictInt for allow_extra_keys in [False, True]: @@ -121,7 +121,7 @@ class MergedSchema(BaseModel): @pytest.mark.forked @pytest.mark.asyncio async def test_prefix(): - class Schema(BaseModel): + class Schema(ExtendedBaseModel): field: StrictInt validator = SchemaValidator(Schema, allow_extra_keys=False, prefix="prefix") @@ -153,11 +153,11 @@ def validate(self, record: DHTRecord) -> bool: # Can't merge with the validator of the different type assert not alice.protocol.record_validator.merge_with(second_validator) - class SecondSchema(BaseModel): + class SecondSchema(ExtendedBaseModel): some_field: StrictInt another_field: str - class ThirdSchema(BaseModel): + class ThirdSchema(ExtendedBaseModel): another_field: StrictInt # Allow it to be a StrictInt as well for schema in [SecondSchema, ThirdSchema]: diff --git a/tests/test_dht_validation.py b/tests/test_dht_validation.py index 56420d3d7..a0d77b5c4 100644 --- a/tests/test_dht_validation.py +++ b/tests/test_dht_validation.py @@ -2,21 +2,21 @@ from typing import Dict import pytest -from pydantic import BaseModel, StrictInt +from pydantic import StrictInt import hivemind from hivemind.dht.crypto import RSASignatureValidator from hivemind.dht.protocol import DHTProtocol from hivemind.dht.routing import DHTID -from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator +from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, SchemaValidator from hivemind.dht.validation import CompositeValidator, DHTRecord -class SchemaA(BaseModel): +class SchemaA(ExtendedBaseModel): field_a: bytes -class SchemaB(BaseModel): +class SchemaB(ExtendedBaseModel): field_b: Dict[BytesWithPublicKey, StrictInt]