Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix _convert_simple_rnn #15723

Merged
merged 3 commits into from
Sep 13, 2023
Merged

fix _convert_simple_rnn #15723

merged 3 commits into from
Sep 13, 2023

Conversation

haoyang9804
Copy link
Contributor

@haoyang9804 haoyang9804 commented Sep 11, 2023

Fix this issue

Just as @echuraev guessed, _convert_simple_rnn has some logical errors. I'm not very sure if I fix it correctly. All in all, after this fix, running the following bug-triggered script will feedback a good compilation result, and InferType() can successfully infer all types/shapes in the model and the inference is correct.

import tvm
import tvm.relay as relay
from tensorflow import keras
from tensorflow.keras import layers, models

input_shape = (2, 2, 2)
x = layers.Input(shape=input_shape[1:], dtype='float32')

layer = keras.layers.SimpleRNN(units=2)
layer.set_weights(layer.get_weights())

y = layer(x)
model = models.Model(x, y)
model.summary()
mod, params = relay.frontend.from_keras(model, {'input_1': input_shape})
mod = relay.transform.InferType()(mod)

print(mod)
with tvm.transform.PassContext(opt_level=3):
    model = relay.build_module.create_executor("vm", mod, tvm.cpu(0), 'llvm', params).evaluate()

The compilation result is

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 2, 2)]            0         
                                                                 
 simple_rnn (SimpleRNN)      (None, 2)                 10        
                                                                 
=================================================================
Total params: 10 (40.00 Byte)
Trainable params: 10 (40.00 Byte)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
def @main(%input_1: Tensor[(2, 2, 2), float32] /* ty=Tensor[(2, 2, 2), float32] */, %v_param_2: Tensor[(2, 4), float32] /* ty=Tensor[(2, 4), float32] */, %v_param_4: Tensor[(2), float32] /* ty=Tensor[(2), float32] */, %v_param_1: Tensor[(1, 2), float32] /* ty=Tensor[(1, 2), float32] */, %v_param_3: Tensor[(2, 2), float32] /* ty=Tensor[(2, 2), float32] */) -> Tensor[(2, 2), float32] {
  %0 = nn.batch_flatten(%input_1) /* ty=Tensor[(2, 4), float32] */;
  %1 = nn.dense(%0, %v_param_2, units=2) /* ty=Tensor[(2, 2), float32] */;
  %2 = nn.bias_add(%1, %v_param_4) /* ty=Tensor[(2, 2), float32] */;
  %3 = split(%2, indices_or_sections=[1], axis=1) /* ty=(Tensor[(2, 1), float32], Tensor[(2, 1), float32]) */;
  %4 = nn.batch_flatten(%v_param_1) /* ty=Tensor[(1, 2), float32] */;
  %5 = %3.0 /* ty=Tensor[(2, 1), float32] */;
  %6 = nn.dense(%4, %v_param_3, units=2) /* ty=Tensor[(1, 2), float32] */;
  %7 = add(%5, %6) /* ty=Tensor[(2, 2), float32] */;
  %8 = %3.0 /* ty=Tensor[(2, 1), float32] */;
  %9 = nn.dense(%7, %v_param_3, units=2) /* ty=Tensor[(2, 2), float32] */;
  add(%8, %9) /* ty=Tensor[(2, 2), float32] */
}

One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.

@haoyang9804
Copy link
Contributor Author

cc @vvchernov @echuraev

@vvchernov
Copy link
Contributor

vvchernov commented Sep 11, 2023

Hello @haoyang9804! I've looked through your code, but I'm not so familiar with it. It is good that you add test and it works successfully, but I should warn that numpy does not work with bfloat16, but TVM and other frameworks do.

@haoyang9804
Copy link
Contributor Author

haoyang9804 commented Sep 11, 2023

Hello @haoyang9804! I've looked through your code, but I'm not so familiar with it. It is good that you add test and it works succussfully, but I should warn that numpy does not work with bfloat16, but TVM and other frameworks do.

Thanks for the reply. But I think my patch is not related to bfloat16

weightList0 = weightList[0].transpose([1, 0])
assert len(in_data.type_annotation.shape) == 3
for i in range(in_data.type_annotation.shape[1].value - 1):
weightList0 = np.hstack((weightList0, weightList[0].transpose([1, 0])))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aware of this line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you plz elaborate? I still cannot see any data type issue here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check your code for bfloat16 weights? numpy.hstack has dtype arg and I guess it possibly checks it if so numpy fails when dtype is bfloat16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. I will check it later. Thx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true that numpy.hstack does not support bfloat16. But weightList[0] and weightList0 can never be bfloat16 to my best understanding. These two vars are from weightList = keras_layer.get_weights(), and they are NumPy arrays. If numpy does not support bfloat16, I think the dtype of weightList[0] should never be bfloat16. So this worry seems unnecessary here.

Copy link
Contributor

@vvchernov vvchernov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please fix lint and other CI issues

@vvchernov
Copy link
Contributor

cc @echuraev

@haoyang9804
Copy link
Contributor Author

Seems that everything is good except for gpu/docs ci tests. May I ask how to fix it? I never met this issue before @vvchernov.

@vvchernov
Copy link
Contributor

vvchernov commented Sep 12, 2023

Hello @echuraev! Looks like it is problem from jenkins:

[2023-09-12T17:54:08.315Z] + ./ci/scripts/jenkins/s3.py --action upload --bucket tvm-jenkins-artifacts-prod --prefix tvm/PR-15723/3/pytest-results/docs_GPU --items build/pytest-results

[2023-09-12T17:54:08.315Z] [ci/scripts/jenkins/s3.py:99 INFO] Namespace(action='upload', bucket='tvm-jenkins-artifacts-prod', items=['build/pytest-results'], prefix='tvm/PR-15723/3/pytest-results/docs_GPU')

[2023-09-12T17:54:08.315Z] [ci/scripts/jenkins/s3.py:109 INFO] Using s3 path: s3://tvm-jenkins-artifacts-prod/tvm/PR-15723/3/pytest-results/docs_GPU

[2023-09-12T17:54:08.315Z] Traceback (most recent call last):

[2023-09-12T17:54:08.315Z]   File "./ci/scripts/jenkins/s3.py", line 150, in <module>

[2023-09-12T17:54:08.315Z]     raise RuntimeError(f"Cannot upload empty folder with name: {item}")

[2023-09-12T17:54:08.315Z] RuntimeError: Cannot upload empty folder with name: build/pytest-results

Do you know who can help us?

@vvchernov
Copy link
Contributor

vvchernov commented Sep 12, 2023

Hello @haoyang9804! I think it is not your issue. I've rechecked PRs: #15714 and #15709 have the same issue, but the next PRs do not have. Possibly it is the best way to restart this CI.

@echuraev
Copy link
Contributor

@tvm-bot rerun

@echuraev
Copy link
Contributor

@haoyang9804, I have restarted CI. In case of further errors, please try to rebase your branch to the latest mainline.

Copy link
Contributor

@echuraev echuraev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants