-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rename symbols to support torch.einsum #41
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. It looks like CI missed a trigger (very odd), can you make another commit to trigger CI when you get the chance.
opt_einsum/backends/torch.py
Outdated
@@ -31,6 +31,12 @@ def transpose(a, axes): | |||
def einsum(equation, *operands): | |||
"""Variadic version of torch.einsum to match numpy api. | |||
""" | |||
# rename symbols to support PyTorch 0.4.1 and earlier, | |||
# which allow only symbols a-z. | |||
symbols = sorted(set(equation) - set(',->')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extreme, but can we throw if len(symbols) > 26
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The advantage of not throwing is that this will continue to work when PyTorch fixes their einsum
to allow more symbols. I'd rather let their current implementation throw.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I am not intimately familiar with the new torch
symbols set, but would opt_einsum.parser.convert_to_valid_einsum_chars
work here?
opt_einsum/backends/torch.py
Outdated
# which allow only symbols a-z. | ||
symbols = sorted(set(equation) - set(',->')) | ||
rename = {s: get_symbol(i) for i, s in enumerate(symbols)} | ||
equation = ''.join(rename.get(s, s) for s in equation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jcmgray We do this for einsum
as well I believe. Is it time to have a single expression which converts from global to local symbols?
@@ -39,8 +45,6 @@ def tensordot(x, y, axes=2): | |||
"""Simple translation of tensordot syntax to einsum. | |||
""" | |||
# XXX: tensordot should be directly implemented in torch soon | |||
torch, _ = _get_torch_and_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jcmgray Can you look over this as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Note the only reason I've removed this line is so I can reuse the einsum()
wrapper in this file, rather than calling torch.einsum
in two different places)
It might be cleaner, and more useful for potential future backends, to modify the existing machinery:
probably with a import string
string.ascii_lowercase # or
string.ascii_letters Also, for what its worth, your strategy of just replacing all characters might be faster than replacing just the invalid ones, in which case do update |
PyTorch's einsum makes some use of the fact that valid letters are consecutive and from a small fixed set (based on review comments when I previously dynamically allocated indices). It's not hard to extend the set and deal with several ranges but my impression is that lowercase a-z is all anyone ever uses in the wild.
That said, I still plan on having an extended greedy optimisation in ATen c++, maybe in a month or so.
|
We're using |
Yes just to clarify, Now, |
If there is a use case, supporting A-Z is easy. One could also consider a "post-parsing interface" (i.e. taking preprocessed equations) if that is useful. I must admit I always thought of einsum as a convenient as hoc interface, so I learn something new here (thanks!).
|
Closed/opened to trigger Travis. Once that passes and the |
Thanks for the quick review @dgasmith! I've moved the convert-all-chars implementation up to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this update, all looks good to me,!
@fritzo Thanks for the PR, everything looks great and we will get this into the next release. |
Description
This renames symbols to support PyTorch 0.4.1 which supports only symbols a-z. This is useful when the overall computation has more than 26 dimensions but any single contraction requires only 26 or fewer dimensions.
Questions
Status
Tested