Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch export can't support dynamics shapes for lstm #5176

Open
ZORO-Q opened this issue Sep 9, 2024 · 1 comment
Open

Torch export can't support dynamics shapes for lstm #5176

ZORO-Q opened this issue Sep 9, 2024 · 1 comment

Comments

@ZORO-Q
Copy link

ZORO-Q commented Sep 9, 2024

🐛 Describe the bug

class Basic(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(320,64,num_layers=1, bidirectional=False, batch_first=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _  = self.lstm(x)
        return x


example_args = (torch.randn(1,100,320),)
dim1_x = Dim("dim1_x", min=6, max=100000)
dynamic_shapes = {"x": {1: dim1_x}}
aten_dialect: ExportedProgram = export(
    Basic(), example_args, dynamic_shapes=dynamic_shapes
)
print(aten_dialect)
Traceback (most recent call last):
  File "<string>", line 14, in <module>
  File "/home/yaqiong.he/.conda/envs/executorch/lib/python3.10/site-packages/torch/export/__init__.py", line 174, in export
    return _export(
  File "/home/yaqiong.he/.conda/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 945, in wrapper
    raise e
  File "/home/yaqiong.he/.conda/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 928, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/yaqiong.he/.conda/envs/executorch/lib/python3.10/site-packages/torch/export/exported_program.py", line 89, in wrapper
    return fn(*args, **kwargs)
  File "/home/yaqiong.he/.conda/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 1455, in _export
    aten_export_artifact = export_func(
  File "/home/yaqiong.he/.conda/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 1060, in _strict_export
    gm_torch_level = _export_to_torch_ir(
  File "/home/yaqiong.he/.conda/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 529, in _export_to_torch_ir
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (dim1_x)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of dim1_x = L['x'].size()[1] in the specified range 6 <= dim1_x <= 100000 are valid because dim1_x was inferred to be a constant (100).

Suggested fixes:
  dim1_x = 100```

### Versions

PyTorch version: 2.5.0.dev20240901
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.29.6
Libc version: N/A

Python version: 3.10.0 (default, Mar 3 2022, 03:54:28) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3

Versions of relevant libraries:
[pip3] executorch==0.4.0a0+99fbca3
[pip3] numpy==1.21.3
[pip3] torch==2.5.0.dev20240901
[pip3] torchaudio==2.5.0.dev20240901
[pip3] torchsr==1.0.4
[pip3] torchvision==0.20.0.dev20240901
[conda] executorch 0.4.0a0+99fbca3 pypi_0 pypi
[conda] numpy 1.21.3 pypi_0 pypi
[conda] torch 2.5.0.dev20240901 pypi_0 pypi
[conda] torchaudio 2.5.0.dev20240901 pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchvision 0.20.0.dev20240901 pypi_0 pypi
@davidlin54
Copy link

davidlin54 commented Sep 9, 2024

@angelayi this seems to be related to pytorch/pytorch#115092, are there any updates on fixing this bug?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants