Skip to content

Commit

Permalink
first pass at migration to pydantic > 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorrent committed Jun 4, 2024
1 parent d20e810 commit 24aabf8
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 40 deletions.
8 changes: 4 additions & 4 deletions examples/albert/utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
from typing import Dict, List, Tuple

from pydantic import BaseModel, StrictFloat, confloat, conint
from pydantic import StrictFloat, confloat, conint

from hivemind.dht.crypto import RSASignatureValidator
from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, SchemaValidator
from hivemind.dht.validation import RecordValidatorBase
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)


class LocalMetrics(BaseModel):
class LocalMetrics(ExtendedBaseModel):
step: conint(ge=0, strict=True)
samples_per_second: confloat(ge=0.0, strict=True)
samples_accumulated: conint(ge=0, strict=True)
loss: StrictFloat
mini_steps: conint(ge=0, strict=True)


class MetricSchema(BaseModel):
class MetricSchema(ExtendedBaseModel):
metrics: Dict[BytesWithPublicKey, LocalMetrics]


Expand Down
41 changes: 22 additions & 19 deletions hivemind/dht/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, Optional, Type

import pydantic
from pydantic_core import CoreSchema, core_schema

from hivemind.dht.crypto import RSASignatureValidator
from hivemind.dht.protocol import DHTProtocol
Expand All @@ -12,13 +13,18 @@
logger = get_logger(__name__)


class ExtendedBaseModel(pydantic.BaseModel):
class Config:
arbitrary_types_allowed = True


class SchemaValidator(RecordValidatorBase):
"""
Restricts specified DHT keys to match a Pydantic schema.
This allows to enforce types, min/max values, require a subkey to contain a public key, etc.
"""

def __init__(self, schema: Type[pydantic.BaseModel], allow_extra_keys: bool = True, prefix: Optional[str] = None):
def __init__(self, schema: Type[ExtendedBaseModel], allow_extra_keys: bool = True, prefix: Optional[str] = None):
"""
:param schema: The Pydantic model (a subclass of pydantic.BaseModel).
Expand All @@ -37,22 +43,23 @@ def __init__(self, schema: Type[pydantic.BaseModel], allow_extra_keys: bool = Tr
:param prefix: (optional) Add ``prefix + '_'`` to the names of all schema fields.
"""

self._patch_schema(schema)
# self._patch_schema(schema)
self._schemas = [schema]

self._key_id_to_field_name = {}
for field in schema.__fields__.values():
raw_key = f"{prefix}_{field.name}" if prefix is not None else field.name
self._key_id_to_field_name[DHTID.generate(source=raw_key).to_bytes()] = field.name
for field_name in schema.__fields__.keys():
raw_key = f"{prefix}_{field_name}" if prefix is not None else field_name
self._key_id_to_field_name[DHTID.generate(source=raw_key).to_bytes()] = field_name
self._allow_extra_keys = allow_extra_keys

@staticmethod
def _patch_schema(schema: pydantic.BaseModel):
# We set required=False because the validate() interface provides only one key at a time
for field in schema.__fields__.values():
field.required = False
# // required was changed to is_required in 2.0, and made read-only
# @staticmethod
# def _patch_schema(schema: ExtendedBaseModel):
# # We set required=False because the validate() interface provides only one key at a time
# for field in schema.__fields__.values():
# field.required = False

schema.Config.extra = pydantic.Extra.forbid
# schema.Config.extra = pydantic.Extra.forbid

def validate(self, record: DHTRecord) -> bool:
"""
Expand Down Expand Up @@ -151,11 +158,11 @@ def __setstate__(self, state):
self.__dict__.update(state)

# If unpickling happens in another process, the previous model modifications may be lost
for schema in self._schemas:
self._patch_schema(schema)
# for schema in self._schemas:
# self._patch_schema(schema)


def conbytes(*, regex: bytes = None, **kwargs) -> Type[pydantic.BaseModel]:
def conbytes(*, regex: bytes = None, **kwargs) -> Type[ExtendedBaseModel]:
"""
Extend pydantic.conbytes() to support ``regex`` constraints (like pydantic.constr() does).
"""
Expand All @@ -164,11 +171,7 @@ def conbytes(*, regex: bytes = None, **kwargs) -> Type[pydantic.BaseModel]:

class ConstrainedBytesWithRegex(pydantic.conbytes(**kwargs)):
@classmethod
def __get_validators__(cls):
yield from super().__get_validators__()
yield cls.match_regex

@classmethod
@pydantic.validator("*")
def match_regex(cls, value: bytes) -> bytes:
if compiled_regex is not None and compiled_regex.match(value) is None:
raise ValueError(f"Value `{value}` doesn't match regex `{regex}`")
Expand Down
8 changes: 4 additions & 4 deletions hivemind/optim/progress_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from typing import Dict, Optional

import numpy as np
from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
from pydantic import StrictBool, StrictFloat, confloat, conint

from hivemind.dht import DHT
from hivemind.dht.schema import BytesWithPublicKey, RSASignatureValidator, SchemaValidator
from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, RSASignatureValidator, SchemaValidator
from hivemind.utils import DHTExpiration, ValueWithExpiration, enter_asynchronously, get_dht_time, get_logger
from hivemind.utils.crypto import RSAPrivateKey
from hivemind.utils.performance_ema import PerformanceEMA
Expand All @@ -28,7 +28,7 @@ class GlobalTrainingProgress:
next_fetch_time: float


class LocalTrainingProgress(BaseModel):
class LocalTrainingProgress(ExtendedBaseModel):
peer_id: bytes
epoch: conint(ge=0, strict=True)
samples_accumulated: conint(ge=0, strict=True)
Expand All @@ -37,7 +37,7 @@ class LocalTrainingProgress(BaseModel):
client_mode: StrictBool


class TrainingProgressSchema(BaseModel):
class TrainingProgressSchema(ExtendedBaseModel):
progress: Dict[BytesWithPublicKey, Optional[LocalTrainingProgress]]


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ configargparse>=1.2.3
py-multihash>=0.2.3
multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@e01dbd38f2c0464c0f78b556691d655265018cce
cryptography>=3.4.6
pydantic>=1.8.1,<2.0
pydantic>=2.7.3
packaging>=20.9
16 changes: 8 additions & 8 deletions tests/test_dht_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
from typing import Dict

import pytest
from pydantic import BaseModel, StrictInt, conint
from pydantic import StrictInt, conint

import hivemind
from hivemind.dht.node import DHTNode
from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, SchemaValidator
from hivemind.dht.validation import DHTRecord, RecordValidatorBase
from hivemind.utils.timed_storage import get_dht_time


class SampleSchema(BaseModel):
class SampleSchema(ExtendedBaseModel):
experiment_name: bytes
n_batches: Dict[bytes, conint(ge=0, strict=True)]
signed_data: Dict[BytesWithPublicKey, bytes]
Expand Down Expand Up @@ -94,10 +94,10 @@ async def test_expecting_public_keys(dht_nodes_with_schema):
@pytest.mark.forked
@pytest.mark.asyncio
async def test_keys_outside_schema(dht_nodes_with_schema):
class Schema(BaseModel):
class Schema(ExtendedBaseModel):
some_field: StrictInt

class MergedSchema(BaseModel):
class MergedSchema(ExtendedBaseModel):
another_field: StrictInt

for allow_extra_keys in [False, True]:
Expand All @@ -121,7 +121,7 @@ class MergedSchema(BaseModel):
@pytest.mark.forked
@pytest.mark.asyncio
async def test_prefix():
class Schema(BaseModel):
class Schema(ExtendedBaseModel):
field: StrictInt

validator = SchemaValidator(Schema, allow_extra_keys=False, prefix="prefix")
Expand Down Expand Up @@ -153,11 +153,11 @@ def validate(self, record: DHTRecord) -> bool:
# Can't merge with the validator of the different type
assert not alice.protocol.record_validator.merge_with(second_validator)

class SecondSchema(BaseModel):
class SecondSchema(ExtendedBaseModel):
some_field: StrictInt
another_field: str

class ThirdSchema(BaseModel):
class ThirdSchema(ExtendedBaseModel):
another_field: StrictInt # Allow it to be a StrictInt as well

for schema in [SecondSchema, ThirdSchema]:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_dht_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
from typing import Dict

import pytest
from pydantic import BaseModel, StrictInt
from pydantic import StrictInt

import hivemind
from hivemind.dht.crypto import RSASignatureValidator
from hivemind.dht.protocol import DHTProtocol
from hivemind.dht.routing import DHTID
from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
from hivemind.dht.schema import BytesWithPublicKey, ExtendedBaseModel, SchemaValidator
from hivemind.dht.validation import CompositeValidator, DHTRecord


class SchemaA(BaseModel):
class SchemaA(ExtendedBaseModel):
field_a: bytes


class SchemaB(BaseModel):
class SchemaB(ExtendedBaseModel):
field_b: Dict[BytesWithPublicKey, StrictInt]


Expand Down

0 comments on commit 24aabf8

Please sign in to comment.