diff --git a/README.rst b/README.rst index af0c212..03baff3 100644 --- a/README.rst +++ b/README.rst @@ -15,7 +15,7 @@ Features U+0056, to keep the output as small as possible. * Uses the shortest escape sequence for each escaped character. * Encodes the JSON as UTF-8. -* Can encode ``frozendict`` immutable dictionaries. +* Can be configured to encode custom types unknown to the stdlib JSON encoder. Supports Python versions 3.7 and newer. @@ -59,3 +59,20 @@ The underlying JSON implementation can be chosen with the following: which uses the standard library json module). .. _simplejson: https://simplejson.readthedocs.io/ + +A preserialisation hook allows you to encode objects which aren't encodable by the +standard library ``JSONEncoder``. + +.. code:: python + + import canonicaljson + from typing import Dict + + class CustomType: + pass + + def callback(c: CustomType) -> Dict[str, str]: + return {"Hello": "world!"} + + canonicaljson.register_preserialisation_callback(CustomType, callback) + assert canonicaljson.encode_canonical_json(CustomType()) == b'{"Hello":"world!"}' diff --git a/setup.cfg b/setup.cfg index 4b707de..60417f4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,12 +34,6 @@ install_requires = typing_extensions>=4.0.0; python_version < '3.8' -[options.extras_require] -# frozendict support can be enabled using the `canonicaljson[frozendict]` syntax -frozendict = - frozendict>=1.0 - - [options.package_data] canonicaljson = py.typed diff --git a/src/canonicaljson/__init__.py b/src/canonicaljson/__init__.py index 2e33b66..24ed332 100644 --- a/src/canonicaljson/__init__.py +++ b/src/canonicaljson/__init__.py @@ -13,33 +13,56 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import functools import platform -from typing import Any, Generator, Iterator, Optional, Type +from typing import Any, Callable, Generator, Iterator, Type, TypeVar try: from typing import Protocol except ImportError: # pragma: no cover from typing_extensions import Protocol # type: ignore[assignment] -frozendict_type: Optional[Type[Any]] -try: - from frozendict import frozendict as frozendict_type -except ImportError: - frozendict_type = None # pragma: no cover __version__ = "1.6.5" -def _default(obj: object) -> object: # pragma: no cover - if type(obj) is frozendict_type: - # If frozendict is available and used, cast `obj` into a dict - return dict(obj) # type: ignore[call-overload] +@functools.singledispatch +def _preprocess_for_serialisation(obj: object) -> object: # pragma: no cover + """Transform an `obj` into something the JSON library knows how to encode. + + This is only called for types that the JSON library does not recognise. + """ raise TypeError( "Object of type %s is not JSON serializable" % obj.__class__.__name__ ) +T = TypeVar("T") + + +def register_preserialisation_callback( + data_type: Type[T], callback: Callable[[T], object] +) -> None: + """ + Register a `callback` to preprocess `data_type` objects unknown to the JSON encoder. + + When canonicaljson encodes an object `x` at runtime that its JSON library does not + know how to encode, it will + - select a `callback`, + - compute `y = callback(x)`, then + - JSON-encode `y` and return the result. + + The `callback` should return an object that is JSON-serialisable by the stdlib + json module. + + If this is called multiple times with the same `data_type`, the most recently + registered callback is used when serialising that `data_type`. + """ + if data_type is object: + raise ValueError("Cannot register callback for the `object` type") + _preprocess_for_serialisation.register(data_type, callback) + + class Encoder(Protocol): # pragma: no cover def encode(self, data: object) -> str: pass @@ -77,7 +100,7 @@ def set_json_library(json_lib: JsonLibrary) -> None: allow_nan=False, separators=(",", ":"), sort_keys=True, - default=_default, + default=_preprocess_for_serialisation, ) global _pretty_encoder @@ -86,7 +109,7 @@ def set_json_library(json_lib: JsonLibrary) -> None: allow_nan=False, indent=4, sort_keys=True, - default=_default, + default=_preprocess_for_serialisation, ) diff --git a/tests/test_canonicaljson.py b/tests/test_canonicaljson.py index f1fac9a..44422d5 100644 --- a/tests/test_canonicaljson.py +++ b/tests/test_canonicaljson.py @@ -13,16 +13,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import Mock from math import inf, nan from canonicaljson import ( encode_canonical_json, encode_pretty_printed_json, - frozendict_type, iterencode_canonical_json, iterencode_pretty_printed_json, set_json_library, + register_preserialisation_callback, ) import unittest @@ -107,22 +108,6 @@ def test_encode_pretty_printed(self) -> None: b'{\n "la merde amus\xc3\xa9e": "\xF0\x9F\x92\xA9"\n}', ) - @unittest.skipIf( - frozendict_type is None, - "If `frozendict` is not available, skip test", - ) - def test_frozen_dict(self) -> None: - # For mypy's benefit: - assert frozendict_type is not None - self.assertEqual( - encode_canonical_json(frozendict_type({"a": 1})), - b'{"a":1}', - ) - self.assertEqual( - encode_pretty_printed_json(frozendict_type({"a": 1})), - b'{\n "a": 1\n}', - ) - def test_unknown_type(self) -> None: class Unknown(object): pass @@ -167,3 +152,46 @@ def test_set_json(self) -> None: from canonicaljson import json # type: ignore[attr-defined] set_json_library(json) + + def test_encode_unknown_class_raises(self) -> None: + class C: + pass + + with self.assertRaises(Exception): + encode_canonical_json(C()) + + def test_preserialisation_callback(self) -> None: + class C: + pass + + # Naughty: this alters the global state of the module. However this + # `C` class is limited to this test only, so this shouldn't affect + # other types and other tests. + register_preserialisation_callback(C, lambda c: "I am a C instance") + + result = encode_canonical_json(C()) + self.assertEqual(result, b'"I am a C instance"') + + def test_cannot_register_preserialisation_callback_for_object(self) -> None: + with self.assertRaises(Exception): + register_preserialisation_callback( + object, lambda c: "shouldn't be able to do this" + ) + + def test_most_recent_preserialisation_callback_called(self) -> None: + class C: + pass + + callback1 = Mock(return_value="callback 1 was called") + callback2 = Mock(return_value="callback 2 was called") + + # Naughty: this alters the global state of the module. However this + # `C` class is limited to this test only, so this shouldn't affect + # other types and other tests. + register_preserialisation_callback(C, callback1) + register_preserialisation_callback(C, callback2) + + encode_canonical_json(C()) + + callback1.assert_not_called() + callback2.assert_called_once() diff --git a/tox.ini b/tox.ini index a893107..63b9d58 100644 --- a/tox.ini +++ b/tox.ini @@ -33,7 +33,6 @@ commands = python -m black --check --diff src tests [testenv:mypy] deps = mypy==1.0 - types-frozendict==2.0.8 types-simplejson==3.17.5 types-setuptools==57.4.14 commands = mypy src tests