Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't Serialize because of functools.partial #6167

Closed
bacalfa opened this issue Sep 30, 2022 · 3 comments
Closed

Can't Serialize because of functools.partial #6167

bacalfa opened this issue Sep 30, 2022 · 3 comments

Comments

@bacalfa
Copy link

bacalfa commented Sep 30, 2022

Description of your problem

After upgrading from PyMC3 to PyMC (4.2.0), I can't use dill to pickle an object that has a pm.Model member.

Error:

AttributeError: 'functools.partial' object has no attribute '__name__'

After debugging, I realized this seems to be caused by member _repr_latex_ of class Distribution (https://www.pymc.io/projects/docs/en/latest/_modules/pymc/distributions/distribution.html).

rv_out._repr_latex_ = types.MethodType(
            functools.partial(str_for_dist, formatting="latex"), rv_out
        )

Please provide a minimal, self-contained, and reproducible example.

I'm pickling the object that contains a pm.Model member as follows:

with bz2.BZ2File(file_path, "wb") as f:
    dill.dump(self, f)

Please provide the full traceback.

Several levels down dill.dump until it reaches a functools.partial object that can't be pickled.

Please provide any additional information below.

The (ugly) workaround for me was to overwrite member _repr_latex_ with a lambda function instead. Example:

my_rv = pm.Uniform("my_rv", lower=lower, upper=upper, testval=testval)
my_rv._repr_latex_ = types.MethodType(lambda v: str_for_dist(v, formatting="latex"), my_rv)

Or more generally:

partial_attrs = ("str_repr", "_repr_latex_")
for partial_attr in partial_attrs:
    formatting = "latex" if "latex" in partial_attr else "plain"
    if hasattr(self._model, partial_attr):
        if not hasattr(getattr(self._model, partial_attr), "__name__"):
            setattr(self._model, partial_attr,
                    types.MethodType(lambda v: str_for_model(v, formatting=formatting), self._model))
for named_var in self._model.named_vars.keys():
    for partial_attr in partial_attrs:
        if hasattr(self._model.named_vars[named_var], partial_attr):
            if not hasattr(getattr(self._model.named_vars[named_var], partial_attr), "__name__"):
                formatting = "latex" if "latex" in partial_attr else "plain"
                tag_str = self._model.named_vars[named_var].tag.trace.__str__()
                if "Potential" in tag_str or "Deterministic" in tag_str:
                    dist_name = "Potential" if "Potential" in tag_str else "Deterministic"
                    setattr(self._model.named_vars[named_var], partial_attr,
                            types.MethodType(
                                lambda v: str_for_potential_or_deterministic(v, dist_name=dist_name,
                                                                             formatting=formatting),
                                self._model.named_vars[named_var]))
                else:
                    setattr(self._model.named_vars[named_var], partial_attr,
                            types.MethodType(lambda v: str_for_dist(v, formatting=formatting),
                                             self._model.named_vars[named_var]))

I had to do something similar to the pm.Model object, which also has member _repr_latex_. Note that functools.partial is used in other modules as well.

A possible alternative could be to explicitly define attribute __name__ to the functools.partial object (see elastic/apm-agent-python#293). But I didn't try that since it'd require changing PyMC's source code.

The workaround above made it possible to dill.dump the object, but I'm getting an error when dill.load the file:

AttributeError: 'TensorVariable' object has no attribute 'str_for_dist'

with

with bz2.BZ2File(file_path, "rb") as f:
    data = dill.load(f)

So my problem is still not solved.

Versions and main components

  • PyMC/PyMC3 Version: pymc==4.2.0
  • Aesara/Theano Version: aesara==2.8.2
  • Python Version: 3.8.10
  • Operating system: Windows 10
  • How did you install PyMC/PyMC3: (conda/pip) pip
@ricardoV94
Copy link
Member

ricardoV94 commented Sep 30, 2022

We have moved to cloudpickle which can handle this sort of issues. It's a drop-in replacement for dill

@bacalfa
Copy link
Author

bacalfa commented Sep 30, 2022

Let me try that!

@bacalfa
Copy link
Author

bacalfa commented Sep 30, 2022

It seems to have worked! But I had to use protocol=4 to avoid a pickling error (lucianopaz/compress_pickle#23).

with bz2.BZ2File(file_path, "wb") as f:
    dill.dump(self, f, protocol=4)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants