diff --git a/src/cmap/_color.py b/src/cmap/_color.py index 3591dfcce..7ab614279 100644 --- a/src/cmap/_color.py +++ b/src/cmap/_color.py @@ -9,6 +9,7 @@ Any, Callable, Iterable, + Literal, NamedTuple, Sequence, SupportsFloat, @@ -34,6 +35,8 @@ from pydantic_core import CoreSchema from typing_extensions import TypeAlias + rgba = Literal["r", "g", "b", "a"] + # not used internally... but available for typing RGBTuple: TypeAlias = "tuple[int, int, int] | tuple[float, float, float]" RGBATuple: TypeAlias = ( @@ -291,8 +294,100 @@ def _norm_name(name: str) -> str: return delim.sub("", name).lower() +def _ensure_format(format: str) -> Sequence[rgba]: + _format = "".join(format).lower() + if not all(c in "rgba" for c in _format): + raise ValueError("Format must be composed of 'r', 'g', 'b', and 'a'") + return _format # type: ignore [return-value] + + +def parse_int( + value: int, + format: str, + bits_per_component: int | Sequence[int] = 8, +) -> RGBA: + """Parse color from bit-shifted integer encoding. + + Parameters + ---------- + value : int + The integer value to parse. + format : str + The format of the integer value. Must be a string composed only of + the characters 'r', 'g', 'b', and 'a'. + bits_per_component : int | Sequence[int] | None + The number of bits used to represent each color component. If a single + integer is provided, it is used for all components. If a sequence of + integers is provided, the length must match the length of `format`. + """ + fmt = _ensure_format(format) + if isinstance(bits_per_component, int): + bits_per_component = [bits_per_component] * len(fmt) + elif len(bits_per_component) != len(fmt): # pragma: no cover + raise ValueError("Length of 'bits_per_component' must match 'format'") + + components: dict[str, float] = {"r": 0, "g": 0, "b": 0, "a": 1} + shift = 0 + + # Calculate the starting shift amount + for bits in reversed(bits_per_component): + shift += bits + + # Parse each component from the integer value + for i, comp in enumerate(fmt): + shift -= bits_per_component[i] + mask = (1 << bits_per_component[i]) - 1 + components[comp] = ((value >> shift) & mask) / mask + + return RGBA(**components) + + +def to_int( + color: RGBA, + format: str, + bits_per_component: int | Sequence[int] = 8, +) -> int: + """Convert color to bit-shifted integer encoding. + + Parameters + ---------- + color : RGBA + The color to convert. + format : str + The format of the integer value. Must be a string composed only of + the characters 'r', 'g', 'b', and 'a'. + bits_per_component : int | Sequence[int] | None + The number of bits used to represent each color component. If a single + integer is provided, it is used for all components. If a sequence of + integers is provided, the length must match the length of `format`. + """ + fmt = _ensure_format(format) + if isinstance(bits_per_component, int): + bits_per_component = [bits_per_component] * len(fmt) + elif len(bits_per_component) != len(fmt): # pragma: no cover + raise ValueError("Length of 'bits_per_component' must match 'format'") + + value = 0 + shift = 0 + + # Calculate the starting shift amount + for bits in reversed(bits_per_component): + shift += bits + + # Parse each component from the integer value + for i, comp in enumerate(fmt): + shift -= bits_per_component[i] + mask = (1 << bits_per_component[i]) - 1 + value |= int(getattr(color, comp) * mask) << shift + + return value + + def parse_rgba(value: Any) -> RGBA: """Parse a color.""" + if isinstance(value, RGBA): + return value + # parse hex, rgb, rgba, hsl, hsla, and color name strings if isinstance(value, str): key = _norm_name(value) @@ -337,11 +432,8 @@ def parse_rgba(value: Any) -> RGBA: return value._rgba if isinstance(value, int): - # convert 24-bit integer to RGBA8 with bit shifting - r = (value >> 16) & 0xFF - g = (value >> 8) & 0xFF - b = value & 0xFF - return RGBA8(r, g, b).to_float() + # assume RGB24, use parse_int to explicitly pass format and bits_per_component + return parse_int(value, "rgb") # support for pydantic.color.Color for mod in ("pydantic", "pydantic_extra_types"): @@ -388,6 +480,49 @@ def __new__(cls, value: Any) -> Color: _COLOR_CACHE[rgba] = obj return _COLOR_CACHE[rgba] + @classmethod + def from_int( + cls, + value: int, + format: str, + bits_per_component: int | Sequence[int] = 8, + ) -> Color: + """Parse color from bit-shifted integer encoding. + + Parameters + ---------- + value : int + The integer value to parse. + format : str + The format of the integer value. Must be a string composed only of + the characters 'r', 'g', 'b', and 'a'. + bits_per_component : int | Sequence[int] | None + The number of bits used to represent each color component. If a single + integer is provided, it is used for all components. If a sequence of + integers is provided, the length must match the length of `format`. + """ + rgba = parse_int(value, format=format, bits_per_component=bits_per_component) + return cls(rgba) + + def to_int( + self, + format: str, + bits_per_component: int | Sequence[int] = 8, + ) -> int: + """Convert color to bit-shifted integer encoding. + + Parameters + ---------- + format : str + The format of the integer value. Must be a string composed only of + the characters 'r', 'g', 'b', and 'a'. + bits_per_component : int | Sequence[int] | None + The number of bits used to represent each color component. If a single + integer is provided, it is used for all components. If a sequence of + integers is provided, the length must match the length of `format`. + """ + return to_int(self._rgba, format=format, bits_per_component=bits_per_component) + # for mkdocstrings def __init__(self, value: ColorLike) -> None: pass diff --git a/tests/test_color.py b/tests/test_color.py index 3d5d1f60b..9c566c9fe 100644 --- a/tests/test_color.py +++ b/tests/test_color.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from cmap._color import RGBA, RGBA8, Color +from cmap._color import RGBA, RGBA8, Color, parse_int try: import colour @@ -171,3 +171,31 @@ def test_to_array_in_list() -> None: def test_hashable(): assert hash(Color("red")) == hash(Color("red")) assert hash(Color("red")) != hash(Color("blue")) + + +def test_parse_int(): + # Test parsing a 24-bit integer value + assert parse_int(0xFF00FF, "rgb") == RGBA(1.0, 0.0, 1.0, 1.0) + + # Test parsing a 32-bit integer value with alpha + assert parse_int(0x00FF00FF, "rgba") == RGBA(0.0, 1.0, 0.0, 1.0) + + # Test parsing a 16-bit integer value with custom format + assert parse_int(0x0FF, "bgr", bits_per_component=4) == RGBA(1.0, 1.0, 0.0, 1.0) + + expect = RGBA8(123, 255, 0) + assert parse_int(0x7FE0, "rgb", bits_per_component=[5, 6, 5]).to_8bit() == expect + + # # Test parsing an invalid format + with pytest.raises(ValueError): + parse_int(0x7FE0, "rgbx") + + # Test parsing an invalid number of bits per component + with pytest.raises(ValueError): + parse_int(0x7FE0, "rgb", bits_per_component=[5, 5]) + + +@pytest.mark.parametrize("input", [0xFF00FF, 0x00FF00FF, 0x0FF, 0x7FE0]) +@pytest.mark.parametrize("fmt", ["rgb", "rgba", "bgr"]) +def test_round_trip(input: int, fmt: str): + assert Color.from_int(input, fmt).to_int(fmt) == input