Skip to content

Commit

Permalink
fix(pickle): make Parameter instances pickleable (#9798)
Browse files Browse the repository at this point in the history
The pickles expect a specific constructor for `Parameter`, so remove our
custom constructor and provide a `classmethod` with the previous
behavior for convenience. Closes #9793.
  • Loading branch information
cpcloud authored Aug 8, 2024
1 parent 79cef68 commit d772c80
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 39 deletions.
32 changes: 18 additions & 14 deletions ibis/common/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import TYPE_CHECKING
from typing import Any as AnyType

from typing_extensions import Self

from ibis.common.bases import Immutable, Slotted
from ibis.common.patterns import (
Any,
Expand Down Expand Up @@ -258,18 +260,6 @@ class Parameter(inspect.Parameter):

__slots__ = ()

def __init__(self, name, annotation):
if not isinstance(annotation, Argument):
raise TypeError(
f"annotation must be an instance of Argument, got {annotation}"
)
super().__init__(
name,
kind=annotation.kind,
default=annotation.default,
annotation=annotation,
)

def __str__(self):
formatted = self._name

Expand All @@ -290,6 +280,20 @@ def __str__(self):

return formatted

@classmethod
def from_argument(cls, name: str, annotation: Argument) -> Self:
"""Construct a Parameter from an Argument annotation."""
if not isinstance(annotation, Argument):
raise TypeError(
f"annotation must be an instance of Argument, got {annotation}"
)
return cls(
name,
kind=annotation.kind,
default=annotation.default,
annotation=annotation,
)


class Signature(inspect.Signature):
"""Validatable signature.
Expand Down Expand Up @@ -324,7 +328,7 @@ def merge(cls, *signatures, **annotations):

inherited = set(params.keys())
for name, annot in annotations.items():
params[name] = Parameter(name, annotation=annot)
params[name] = Parameter.from_argument(name, annotation=annot)

# mandatory fields without default values must precede the optional
# ones in the function signature, the partial ordering will be kept
Expand Down Expand Up @@ -406,7 +410,7 @@ def from_callable(cls, fn, patterns=None, return_pattern=None):
else:
annot = Argument(pattern, kind=kind, default=default, typehint=typehint)

parameters.append(Parameter(param.name, annot))
parameters.append(Parameter.from_argument(param.name, annot))

if return_pattern is not None:
return_annotation = return_pattern
Expand Down
38 changes: 26 additions & 12 deletions ibis/common/tests/test_annotations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import inspect
import pickle
from typing import Annotated, Union

import pytest
Expand Down Expand Up @@ -140,20 +141,20 @@ def fn(x, this):
return int(x) + this["other"]

annot = argument(fn)
p = Parameter("test", annotation=annot)
p = Parameter.from_argument("test", annotation=annot)

assert p.annotation is annot
assert p.default is inspect.Parameter.empty
assert p.annotation.pattern.match("2", {"other": 1}) == 3

ofn = optional(fn)
op = Parameter("test", annotation=ofn)
op = Parameter.from_argument("test", annotation=ofn)
assert op.annotation.pattern == Option(fn, default=None)
assert op.default is None
assert op.annotation.pattern.match(None, {"other": 1}) is None

with pytest.raises(TypeError, match="annotation must be an instance of Argument"):
Parameter("wrong", annotation=Attribute(lambda x, context: x))
Parameter.from_argument("wrong", annotation=Attribute(lambda x, context: x))


def test_signature():
Expand All @@ -163,8 +164,8 @@ def to_int(x, this):
def add_other(x, this):
return int(x) + this["other"]

other = Parameter("other", annotation=Argument(to_int))
this = Parameter("this", annotation=Argument(add_other))
other = Parameter.from_argument("other", annotation=Argument(to_int))
this = Parameter.from_argument("this", annotation=Argument(add_other))

sig = Signature(parameters=[other, this])
assert sig.validate(None, args=(1, 2), kwargs={}) == {"other": 1, "this": 3}
Expand Down Expand Up @@ -275,8 +276,8 @@ def to_int(x, this):
def add_other(x, this):
return int(x) + this["other"]

other = Parameter("other", annotation=Argument(to_int))
this = Parameter("this", annotation=Argument(add_other))
other = Parameter.from_argument("other", annotation=Argument(to_int))
this = Parameter.from_argument("this", annotation=Argument(add_other))

sig = Signature(parameters=[other, this])
params = sig.validate(None, args=(1,), kwargs=dict(this=2))
Expand All @@ -286,14 +287,16 @@ def add_other(x, this):
assert kwargs == {}


a = Parameter("a", annotation=Argument(CoercedTo(float)))
b = Parameter("b", annotation=Argument(CoercedTo(float)))
c = Parameter("c", annotation=Argument(CoercedTo(float), default=0))
d = Parameter(
a = Parameter.from_argument("a", annotation=Argument(CoercedTo(float)))
b = Parameter.from_argument("b", annotation=Argument(CoercedTo(float)))
c = Parameter.from_argument("c", annotation=Argument(CoercedTo(float), default=0))
d = Parameter.from_argument(
"d",
annotation=Argument(TupleOf(CoercedTo(float)), default=()),
)
e = Parameter("e", annotation=Argument(Option(CoercedTo(float)), default=None))
e = Parameter.from_argument(
"e", annotation=Argument(Option(CoercedTo(float)), default=None)
)
sig = Signature(parameters=[a, b, c, d, e])


Expand Down Expand Up @@ -480,3 +483,14 @@ def test(a: float, b: float, *args: int, **kwargs: int): ...
test(1.0, 2.0, 3.0, 4, c=5.0, d=6)

assert len(excinfo.value.errors) == 2


def test_pickle():
a = Parameter.from_argument("a", annotation=Argument(int))
assert pickle.loads(pickle.dumps(a)) == a


def test_cloudpickle():
cloudpickle = pytest.importorskip("cloudpickle")
a = Parameter.from_argument("a", annotation=Argument(int))
assert cloudpickle.loads(cloudpickle.dumps(a)) == a
32 changes: 20 additions & 12 deletions ibis/common/tests/test_grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,33 +449,41 @@ class IntAddClip(FloatAddClip, IntBinop):

assert IntBinop.__signature__ == Signature(
[
Parameter("left", annotation=Argument(is_int)),
Parameter("right", annotation=Argument(is_int)),
Parameter.from_argument("left", annotation=Argument(is_int)),
Parameter.from_argument("right", annotation=Argument(is_int)),
]
)

assert FloatAddRhs.__signature__ == Signature(
[
Parameter("left", annotation=Argument(is_int)),
Parameter("right", annotation=Argument(is_float)),
Parameter.from_argument("left", annotation=Argument(is_int)),
Parameter.from_argument("right", annotation=Argument(is_float)),
]
)

assert FloatAddClip.__signature__ == Signature(
[
Parameter("left", annotation=Argument(is_float)),
Parameter("right", annotation=Argument(is_float)),
Parameter("clip_lower", annotation=optional(is_int, default=0)),
Parameter("clip_upper", annotation=optional(is_int, default=10)),
Parameter.from_argument("left", annotation=Argument(is_float)),
Parameter.from_argument("right", annotation=Argument(is_float)),
Parameter.from_argument(
"clip_lower", annotation=optional(is_int, default=0)
),
Parameter.from_argument(
"clip_upper", annotation=optional(is_int, default=10)
),
]
)

assert IntAddClip.__signature__ == Signature(
[
Parameter("left", annotation=Argument(is_int)),
Parameter("right", annotation=Argument(is_int)),
Parameter("clip_lower", annotation=optional(is_int, default=0)),
Parameter("clip_upper", annotation=optional(is_int, default=10)),
Parameter.from_argument("left", annotation=Argument(is_int)),
Parameter.from_argument("right", annotation=Argument(is_int)),
Parameter.from_argument(
"clip_lower", annotation=optional(is_int, default=0)
),
Parameter.from_argument(
"clip_upper", annotation=optional(is_int, default=10)
),
]
)

Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ ruff = ">=0.1.8"
tqdm = ">=4.66.1,<5"

[tool.poetry.group.test.dependencies]
cloudpickle = ">=3,<4"
filelock = ">=3.7.0,<4"
hypothesis = ">=6.58.0,<7"
packaging = ">=21.3,<25"
Expand Down

0 comments on commit d772c80

Please sign in to comment.