Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename symbols to support torch.einsum #41

Merged
merged 3 commits into from
Aug 18, 2018

Conversation

fritzo
Copy link
Contributor

@fritzo fritzo commented Aug 17, 2018

Description

This renames symbols to support PyTorch 0.4.1 which supports only symbols a-z. This is useful when the overall computation has more than 26 dimensions but any single contraction requires only 26 or fewer dimensions.

Questions

  • Do other backends work? I have only tested PyTorch. If other backends don't work, I'll simply limit the test to the PyTorch backend.

Status

  • Ready to merge as soon as tests pass

Tested

Copy link
Owner

@dgasmith dgasmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. It looks like CI missed a trigger (very odd), can you make another commit to trigger CI when you get the chance.

@@ -31,6 +31,12 @@ def transpose(a, axes):
def einsum(equation, *operands):
"""Variadic version of torch.einsum to match numpy api.
"""
# rename symbols to support PyTorch 0.4.1 and earlier,
# which allow only symbols a-z.
symbols = sorted(set(equation) - set(',->'))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extreme, but can we throw if len(symbols) > 26.

Copy link
Contributor Author

@fritzo fritzo Aug 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The advantage of not throwing is that this will continue to work when PyTorch fixes their einsum to allow more symbols. I'd rather let their current implementation throw.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I am not intimately familiar with the new torch symbols set, but would opt_einsum.parser.convert_to_valid_einsum_chars work here?

# which allow only symbols a-z.
symbols = sorted(set(equation) - set(',->'))
rename = {s: get_symbol(i) for i, s in enumerate(symbols)}
equation = ''.join(rename.get(s, s) for s in equation)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jcmgray We do this for einsum as well I believe. Is it time to have a single expression which converts from global to local symbols?

@@ -39,8 +45,6 @@ def tensordot(x, y, axes=2):
"""Simple translation of tensordot syntax to einsum.
"""
# XXX: tensordot should be directly implemented in torch soon
torch, _ = _get_torch_and_device()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jcmgray Can you look over this as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Note the only reason I've removed this line is so I can reuse the einsum() wrapper in this file, rather than calling torch.einsum in two different places)

@dgasmith dgasmith requested a review from jcmgray August 18, 2018 00:25
@jcmgray
Copy link
Collaborator

jcmgray commented Aug 18, 2018

Do other backends work? I have only tested PyTorch. If other backends don't work, I'll simply limit the test to the PyTorch backend.

einsum in numpy, cupy, dask and tensorflow (from only recently) all support upper case letters so tests should be fine. The ideal situation would be for pytorch to add support for upper case, but the implementation is c++ so its not super obvious to me how easy that would be (@t-vi?). Probably easier to add it here anyway!


It might be cleaner, and more useful for potential future backends, to modify the existing machinery:

  • opt_einsum.parser.convert_to_valid_einsum_chars
  • opt_einsum.parser.has_valid_einsum_chars_only
  • opt_einsum.parser.is_valid_einsum_char

probably with a allow_uppercase keyword, which could just switch is_valid_einsum_char between checking against:

import string
string.ascii_lowercase  # or
string.ascii_letters

Also, for what its worth, your strategy of just replacing all characters might be faster than replacing just the invalid ones, in which case do update convert_to_valid_einsum_chars with your snippet.

@t-vi
Copy link

t-vi commented Aug 18, 2018 via email

@fritzo
Copy link
Contributor Author

fritzo commented Aug 18, 2018

@t-vi lowercase a-z is all anyone ever uses in the wild

We're using opt_einsum.contract() for hundreds of variables in Pyro. Each local contraction requires only a few variables (hence this PR), but we have one variable per time step in a Hidden Markov Model.

@jcmgray
Copy link
Collaborator

jcmgray commented Aug 18, 2018

Yes just to clarify, opt_einsum already maps pairwise contractions into the [a-zA-Z] range if necessary (since large contractions can have thousands of indices). The real problem is if you want 26+ indices in a single pairwise contraction - i.e. a tensor with 26+ dimensions. Sounds unusual but this is actually quite a likely/necessary situation if, for example, you are simulating large quantum circuits.

Now, opt_einsum tries to call tensordot as much as possible, so will probably avoid this niche case (torch backend + non-tensordot-able contractions + >26 dimensions), but it's certainly possible!

@t-vi
Copy link

t-vi commented Aug 18, 2018 via email

@dgasmith dgasmith closed this Aug 18, 2018
@dgasmith dgasmith reopened this Aug 18, 2018
@dgasmith
Copy link
Owner

Closed/opened to trigger Travis. Once that passes and the opt_einsum.parser.convert_to_valid_einsum_chars is either used or updated this is ready to go.

@codecov-io
Copy link

codecov-io commented Aug 18, 2018

Codecov Report

Merging #41 into master will decrease coverage by 0.01%.
The diff coverage is 100%.

@fritzo
Copy link
Contributor Author

fritzo commented Aug 18, 2018

Thanks for the quick review @dgasmith! I've moved the convert-all-chars implementation up to convert_to_valid_einsum_chars and used it in the torch backend.

Copy link
Collaborator

@jcmgray jcmgray left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this update, all looks good to me,!

@dgasmith
Copy link
Owner

@fritzo Thanks for the PR, everything looks great and we will get this into the next release.

@dgasmith dgasmith merged commit 49e2a91 into dgasmith:master Aug 18, 2018
@dgasmith dgasmith mentioned this pull request Aug 25, 2018
3 tasks
@dgasmith dgasmith modified the milestones: v2.1, v2.2 Aug 25, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants