Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more manual implementations for TF's einsum #19368

Merged
merged 3 commits into from
Mar 24, 2024

Conversation

james77777778
Copy link
Contributor

We need these new subscripts to utilize int8,int8->int32 einsum for backprop in #19356

                inputs_grad = ops.einsum(
                    # self._custom_gradient_equation is '{output_spec},{weight_spec}->{input_spec}'
                    self._custom_gradient_equation, upstream, float_kernel
                )

To find out all of these subscripts:

import re

from keras.backend.tensorflow.numpy import _normalize_einsum_subscripts

currently_supported_subscripts = [
    "a,b->ab",
    "ab,bc->ac",
    "abc,cd->abd",
    "abcd,abed->abce",
    "abcd,adbe->acbe",
    "abcd,aecd->acbe",
    "abcd,aecd->aceb",
    "abc,cde->abde",
    "abc,dce->abde",
    "abcd,cde->abe",
    "abcde,aebf->adbcf",
    "abcde,afce->acdbf",
]


def custom_gradient_subscripts(equation):
    dot_replaced_string = re.sub(r"\.\.\.", "0", equation)
    split_string = re.match(
        "([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)", dot_replaced_string
    )
    if split_string is not None:
        input_spec = split_string.group(1)
        weight_spec = split_string.group(2)
        output_spec = split_string.group(3)
    return f"{output_spec},{weight_spec}->{input_spec}"


possible_subscripts = []
for s in currently_supported_subscripts:
    t = _normalize_einsum_subscripts(custom_gradient_subscripts(s))
    possible_subscripts.append(t)

supported_subscripts = set(currently_supported_subscripts)
possible_subscripts = set(possible_subscripts)
missing_subscripts = possible_subscripts - supported_subscripts
print(sorted(missing_subscripts))

# ['ab,b->a', 'ab,cb->ac', 'abc,dc->abd', 'abc,dec->abde', 'abcd,abde->abce',
# 'abcd,acbe->adbe', 'abcd,ced->abe', 'abcd,ecd->abe']

@james77777778 james77777778 changed the title Add more manual implementations for tf's einsum Add more manual implementations for TF's einsum Mar 24, 2024
@codecov-commenter
Copy link

codecov-commenter commented Mar 24, 2024

Codecov Report

Attention: Patch coverage is 90.32258% with 6 lines in your changes are missing coverage. Please review.

Project coverage is 75.86%. Comparing base (aa3a61b) to head (edcaa97).
Report is 2 commits behind head on master.

Files Patch % Lines
keras/backend/tensorflow/numpy.py 90.32% 3 Missing and 3 partials ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master   #19368   +/-   ##
=======================================
  Coverage   75.86%   75.86%           
=======================================
  Files         366      366           
  Lines       40417    40479   +62     
  Branches     7855     7869   +14     
=======================================
+ Hits        30661    30711   +50     
- Misses       8061     8068    +7     
- Partials     1695     1700    +5     
Flag Coverage Δ
keras 75.72% <90.32%> (+<0.01%) ⬆️
keras-jax 60.02% <0.00%> (-0.10%) ⬇️
keras-numpy 54.29% <0.00%> (-0.09%) ⬇️
keras-tensorflow 61.32% <90.32%> (+0.04%) ⬆️
keras-torch 60.31% <0.00%> (-0.11%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Mar 24, 2024
@fchollet fchollet merged commit 7d7917c into keras-team:master Mar 24, 2024
6 checks passed
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Mar 24, 2024
@james77777778 james77777778 deleted the improve-tf-int8-einsum branch March 25, 2024 04:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

4 participants