You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I constructed the struct, weights and input of conv2d and pretend to compile it with torch-mlir.But there has a problem if i set use_tracing=True .Then i debug it with torch-mlir-opt.It seems tensor_static_info_cast do something that I don't understand. I would like to know why using trace to manipulate the model fails but the script succeeds. The code and error are shown below.
code
import torch
import torch.nn as nn
import torch_mlir
class convtest(nn.Module):
def __init__(self, n_rnn=2, leakyRelu=False):
super(convtest, self).__init__()
self.cnn = nn.Conv2d(1, 64, 3, 1, 1)
def forward(self, input):
conv = self.cnn(input)
return conv
input = torch.ones(1, 1, 32, 100)
model = convtest()
weight = torch.ones(64, 1, 3, 3)
bias = torch.zeros(64)
for item in model.modules():
if isinstance(item,nn.Conv2d):
item.weight = nn.Parameter(weight)
item.bias = nn.Parameter(bias)
module = torch_mlir.compile(model, input, output_type=torch_mlir.OutputType.TOSA, use_tracing=True)
print("convert-to-tosa")
asm = module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True)
filename = "./crnn_tosa.mlir"
with open(filename, 'w') as f:
f.write(asm)
print("write tosa mlir to %s" % (filename))
error
Traceback (most recent call last):
File "/host/workspace/torch2tosa_examples/crnn2tosa/con2dtest.py", line 32, in <module>
module = torch_mlir.compile(model, input, output_type=torch_mlir.OutputType.TOSA, use_tracing=True)
File "/host/workspace/MTensorRT/third_party/torch-mlir/buildpurec/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 157, in compile
run_pipeline_with_repro_report(
File "/host/workspace/MTensorRT/third_party/torch-mlir/buildpurec/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 49, in run_pipeline_with_repro_report
raise Exception(f"""
Exception:
Lowering Torch Backend IR -> TOSA Backend IR failed with the following diagnostics:
error: unsupported by backend lowering: tensor with unknown rank or dtype
note: see current operation: %6 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[1,1,32,100],f32>) -> !torch.vtensor<*,f32>
note: this is likely due to a missing shape transfer function in shape_lib_gen.py
Error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='torch-backend-to-tosa-backend-pipeline' /tmp/convtest.mlir
Add '-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
I constructed the struct, weights and input of conv2d and pretend to compile it with torch-mlir.But there has a problem if i set use_tracing=True .Then i debug it with torch-mlir-opt.It seems tensor_static_info_cast do something that I don't understand. I would like to know why using trace to manipulate the model fails but the script succeeds. The code and error are shown below.
code
error
torch-mlir-opt debug results
The text was updated successfully, but these errors were encountered: