Skip to content

Commit

Permalink
feat: allow users to register custom encoders (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostming authored Jun 28, 2023
1 parent 5f0e954 commit a3cb8a2
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 1 deletion.
18 changes: 18 additions & 0 deletions tests/test_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,3 +946,21 @@ def test_copy_copy():
)
def test_escape_key(key_str, escaped):
assert api.key(key_str).as_string() == escaped


def test_custom_encoders():
import decimal

@api.register_encoder
def encode_decimal(obj):
if isinstance(obj, decimal.Decimal):
return api.float_(str(obj))
raise TypeError

assert api.item(decimal.Decimal("1.23")).as_string() == "1.23"

with pytest.raises(TypeError):
api.item(object())

assert api.dumps({"foo": decimal.Decimal("1.23")}) == "foo = 1.23\n"
api.unregister_encoder(encode_decimal)
4 changes: 4 additions & 0 deletions tomlkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from tomlkit.api import loads
from tomlkit.api import nl
from tomlkit.api import parse
from tomlkit.api import register_encoder
from tomlkit.api import string
from tomlkit.api import table
from tomlkit.api import time
from tomlkit.api import unregister_encoder
from tomlkit.api import value
from tomlkit.api import ws

Expand Down Expand Up @@ -52,4 +54,6 @@
"TOMLDocument",
"value",
"ws",
"register_encoder",
"unregister_encoder",
]
22 changes: 22 additions & 0 deletions tomlkit/api.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from __future__ import annotations

import contextlib
import datetime as _datetime

from collections.abc import Mapping
from typing import IO
from typing import Iterable
from typing import TypeVar

from tomlkit._utils import parse_rfc3339
from tomlkit.container import Container
from tomlkit.exceptions import UnexpectedCharError
from tomlkit.items import CUSTOM_ENCODERS
from tomlkit.items import AoT
from tomlkit.items import Array
from tomlkit.items import Bool
from tomlkit.items import Comment
from tomlkit.items import Date
from tomlkit.items import DateTime
from tomlkit.items import DottedKey
from tomlkit.items import Encoder
from tomlkit.items import Float
from tomlkit.items import InlineTable
from tomlkit.items import Integer
Expand Down Expand Up @@ -284,3 +288,21 @@ def nl() -> Whitespace:
def comment(string: str) -> Comment:
"""Create a comment item."""
return Comment(Trivia(comment_ws=" ", comment="# " + string))


E = TypeVar("E", bound=Encoder)


def register_encoder(encoder: E) -> E:
"""Add a custom encoder, which should be a function that will be called
if the value can't otherwise be converted. It should takes a single value
and return a TOMLKit item or raise a ``TypeError``.
"""
CUSTOM_ENCODERS.append(encoder)
return encoder


def unregister_encoder(encoder: Encoder) -> None:
"""Unregister a custom encoder."""
with contextlib.suppress(ValueError):
CUSTOM_ENCODERS.remove(encoder)
24 changes: 23 additions & 1 deletion tomlkit/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from enum import Enum
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Collection
from typing import Iterable
from typing import Iterator
Expand Down Expand Up @@ -57,6 +58,15 @@ class _CustomDict(MutableMapping, dict):


ItemT = TypeVar("ItemT", bound="Item")
Encoder = Callable[[Any], "Item"]
CUSTOM_ENCODERS: list[Encoder] = []


class _ConvertError(TypeError, ValueError):
"""An internal error raised when item() fails to convert a value.
It should be a TypeError, but due to historical reasons
it needs to subclass ValueError as well.
"""


@overload
Expand Down Expand Up @@ -218,8 +228,20 @@ def item(value: Any, _parent: Item | None = None, _sort_keys: bool = False) -> I
Trivia(),
value.isoformat(),
)
else:
for encoder in CUSTOM_ENCODERS:
try:
rv = encoder(value)
except TypeError:
pass
else:
if not isinstance(rv, Item):
raise _ConvertError(
f"Custom encoder returned {type(rv)}, not a subclass of Item"
)
return rv

raise ValueError(f"Invalid type {type(value)}")
raise _ConvertError(f"Invalid type {type(value)}")


class StringType(Enum):
Expand Down

0 comments on commit a3cb8a2

Please sign in to comment.