Skip to content

Commit

Permalink
Add support for pymongo bson ObjectId (#290)
Browse files Browse the repository at this point in the history
* Add support for pymongo bson ObjectId (#133)

* Fix py38 type hint

* Add test for json schema
  • Loading branch information
Ale-Cas authored Jan 30, 2025
1 parent b7ddcfa commit d8272c4
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ sources = pydantic_extra_types tests

.PHONY: install ## Install the package, dependencies, and pre-commit for local development
install: .uv
uv sync --frozen --group all --all-extras
uv sync --frozen --all-groups --all-extras
uv pip install pre-commit
pre-commit install --install-hooks

Expand Down
71 changes: 71 additions & 0 deletions pydantic_extra_types/mongo_object_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Validation for MongoDB ObjectId fields.
Ref: https://github.com/pydantic/pydantic-extra-types/issues/133
"""

from typing import Any

from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

try:
from bson import ObjectId
except ModuleNotFoundError as e: # pragma: no cover
raise RuntimeError(
'The `mongo_object_id` module requires "pymongo" to be installed. You can install it with "pip install '
'pymongo".'
) from e


class MongoObjectId(str):
"""MongoObjectId parses and validates MongoDB bson.ObjectId.
```py
from pydantic import BaseModel
from pydantic_extra_types.mongo_object_id import MongoObjectId
class MongoDocument(BaseModel):
id: MongoObjectId
doc = MongoDocument(id='5f9f2f4b9d3c5a7b4c7e6c1d')
print(doc)
# > id='5f9f2f4b9d3c5a7b4c7e6c1d'
```
Raises:
PydanticCustomError: If the provided value is not a valid MongoDB ObjectId.
"""

OBJECT_ID_LENGTH = 24

@classmethod
def __get_pydantic_core_schema__(cls, _: Any, __: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.json_or_python_schema(
json_schema=core_schema.str_schema(min_length=cls.OBJECT_ID_LENGTH, max_length=cls.OBJECT_ID_LENGTH),
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(ObjectId),
core_schema.chain_schema(
[
core_schema.str_schema(min_length=cls.OBJECT_ID_LENGTH, max_length=cls.OBJECT_ID_LENGTH),
core_schema.no_info_plain_validator_function(cls.validate),
]
),
]
),
serialization=core_schema.plain_serializer_function_ser_schema(lambda x: str(x)),
)

@classmethod
def validate(cls, value: str) -> ObjectId:
"""Validate the MongoObjectId str is a valid ObjectId instance."""
if not ObjectId.is_valid(value):
raise ValueError(
f"Invalid ObjectId {value} has to be 24 characters long and in the format '5f9f2f4b9d3c5a7b4c7e6c1d'."
)

return ObjectId(value)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ all = [
'python-ulid>=1,<2; python_version<"3.9"',
'python-ulid>=1,<4; python_version>="3.9"',
'pendulum>=3.0.0,<4.0.0',
'pymongo>=4.0.0,<5.0.0',
'pytz>=2024.1',
'semver~=3.0.2',
'tzdata>=2024.1',
Expand Down
23 changes: 21 additions & 2 deletions tests/test_json_schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Any, Dict, Union

import pycountry
import pytest
Expand All @@ -15,6 +15,7 @@
from pydantic_extra_types.isbn import ISBN
from pydantic_extra_types.language_code import ISO639_3, ISO639_5, LanguageAlpha2, LanguageName
from pydantic_extra_types.mac_address import MacAddress
from pydantic_extra_types.mongo_object_id import MongoObjectId
from pydantic_extra_types.payment import PaymentCardNumber
from pydantic_extra_types.pendulum_dt import DateTime
from pydantic_extra_types.phone_numbers import PhoneNumber, PhoneNumberValidator
Expand Down Expand Up @@ -494,9 +495,27 @@
],
},
),
(
MongoObjectId,
{
'title': 'Model',
'type': 'object',
'properties': {
'x': {
'maxLength': MongoObjectId.OBJECT_ID_LENGTH,
'minLength': MongoObjectId.OBJECT_ID_LENGTH,
'title': 'X',
'type': 'string',
},
},
'required': ['x'],
},
),
],
)
def test_json_schema(cls, expected):
def test_json_schema(cls: Any, expected: Dict[str, Any]) -> None:
"""Test the model_json_schema implementation for all extra types."""

class Model(BaseModel):
x: cls

Expand Down
71 changes: 71 additions & 0 deletions tests/test_mongo_object_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Tests for the mongo_object_id module."""

import pytest
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError
from pydantic.json_schema import JsonSchemaMode

from pydantic_extra_types.mongo_object_id import MongoObjectId


class MongoDocument(BaseModel):
object_id: MongoObjectId


@pytest.mark.parametrize(
'object_id, result, valid',
[
# Valid ObjectId for str format
('611827f2878b88b49ebb69fc', '611827f2878b88b49ebb69fc', True),
('611827f2878b88b49ebb69fd', '611827f2878b88b49ebb69fd', True),
# Invalid ObjectId for str format
('611827f2878b88b49ebb69f', None, False), # Invalid ObjectId (short length)
('611827f2878b88b49ebb69fca', None, False), # Invalid ObjectId (long length)
# Valid ObjectId for bytes format
],
)
def test_format_for_object_id(object_id: str, result: str, valid: bool) -> None:
"""Test the MongoObjectId validation."""
if valid:
assert str(MongoDocument(object_id=object_id).object_id) == result
else:
with pytest.raises(ValidationError):
MongoDocument(object_id=object_id)
with pytest.raises(
ValueError,
match=f"Invalid ObjectId {object_id} has to be 24 characters long and in the format '5f9f2f4b9d3c5a7b4c7e6c1d'.",
):
MongoObjectId.validate(object_id)


@pytest.mark.parametrize(
'schema_mode',
[
'validation',
'serialization',
],
)
def test_json_schema(schema_mode: JsonSchemaMode) -> None:
"""Test the MongoObjectId model_json_schema implementation."""
expected_json_schema = {
'properties': {
'object_id': {
'maxLength': MongoObjectId.OBJECT_ID_LENGTH,
'minLength': MongoObjectId.OBJECT_ID_LENGTH,
'title': 'Object Id',
'type': 'string',
}
},
'required': ['object_id'],
'title': 'MongoDocument',
'type': 'object',
}
assert MongoDocument.model_json_schema(mode=schema_mode) == expected_json_schema


def test_get_pydantic_core_schema() -> None:
"""Test the __get_pydantic_core_schema__ method override."""
schema = MongoObjectId.__get_pydantic_core_schema__(MongoObjectId, GetCoreSchemaHandler())
assert isinstance(schema, dict)
assert 'json_schema' in schema
assert 'python_schema' in schema
assert schema['json_schema']['type'] == 'str'
Loading

0 comments on commit d8272c4

Please sign in to comment.