You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Now that PyTorch supports tensor subtyping and function overloading with __torch_function__, should we add __array_function__ and __torch_function__ methods to funsor.terms.Funsor to allow evaluation of (some) PyTorch/Numpy code on Funsors?
Here is the meat of a Funsor.__torch_function__ implementation, modulo handling of edge cases; __array_function__ for the Numpy backend would be very similar:
classFunsor:
...
def__torch_function__(self, func, types, args=(), kwargs=None):
# exploit our op registry: ops should know how to handle and convert their argumentstry:
op=getattr(funsor.ops, func.__name__)
exceptAttributeError:
op=funsor.ops.make_op(func). # handle e.g. nn.Module or dist.Transform instancesreturnop(*args, **kwargs)
The motivating application is as a much simpler and more general alternative to the dimension tracking via effectful to_data/to_funsor primitives in pyro.contrib.funsor, which is somewhat confusing. This would also simplify @ordabayevy's work in #543 and elsewhere by removing the need for special torch.Tensor subclasses that duplicate Funsor broadcasting semantics.
The text was updated successfully, but these errors were encountered:
I like the idea quite a lot! It might simplify things in funsor and make look cleaner. My current understanding is that __torch_function__ will replace all Funsor.ops (such as Funsor.__add__, Funsor.sum, etc)? And contrib.funsor will calculate everything as Funsors during model execution instead of delegating it to TraceMessenger and converting it the last moment?
I don't think it would replace the basic Python operator overloads, but array-specific methods like sum() could probably be removed in favor of these generic methods.
Now that PyTorch supports tensor subtyping and function overloading with
__torch_function__
, should we add__array_function__
and__torch_function__
methods tofunsor.terms.Funsor
to allow evaluation of (some) PyTorch/Numpy code on Funsors?Here is the meat of a
Funsor.__torch_function__
implementation, modulo handling of edge cases;__array_function__
for the Numpy backend would be very similar:The motivating application is as a much simpler and more general alternative to the dimension tracking via effectful
to_data
/to_funsor
primitives inpyro.contrib.funsor
, which is somewhat confusing. This would also simplify @ordabayevy's work in #543 and elsewhere by removing the need for specialtorch.Tensor
subclasses that duplicate Funsor broadcasting semantics.The text was updated successfully, but these errors were encountered: