Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Operator torch._ops.aten.unfold.default is not Aten Canonical #5381

Open
ari-ruokamo opened this issue Sep 15, 2024 · 9 comments
Open

Operator torch._ops.aten.unfold.default is not Aten Canonical #5381

ari-ruokamo opened this issue Sep 15, 2024 · 9 comments
Labels
module: exir Issues related to Export IR

Comments

@ari-ruokamo
Copy link

ari-ruokamo commented Sep 15, 2024

🐛 Describe the bug

I'm experimenting in exporting various MSS models to Executorch. Following the example export scenario in Executorch Documentation, the export terminates in error: 'Operator torch._ops.aten.unfold.default is not Aten Canonical'.

Sample reference code input:

# Instatiate model & load weights
model = MyMSSModel(**config.model)
model = load(model, checkpoint_path)
model.eval()

# Dummy one second tensor
model_args = (torch.rand(1, 2, 44100),)

# 1. torch.export: Defines the program with the ATen operator set.
aten_dialect = export(model, model_args)

# 2. to_edge: Make optimizations for Edge devices
edge_program = to_edge(aten_dialect) ### <-- FAIL

# 3. to_executorch: Convert the graph to an ExecuTorch program
executorch_program = edge_program.to_executorch()

Export output:

Traceback (most recent call last):
  File "/home/arijr/miniconda3/envs/executorch/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/arijr/miniconda3/envs/executorch/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/arijr/projects/SCNet/scnet/export-executorch.py", line 29, in <module>
    edge_program = to_edge(aten_dialect)
  File "/home/arijr/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 1019, in to_edge
    edge_programs[name] = _generate_edge_program(name, config, program)
  File "/home/arijr/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 680, in _generate_edge_program
    EXIRATenDialectVerifier(ops_set_to_not_decompose)(program.graph_module)
  File "/home/arijr/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/verification/verifier.py", line 74, in __call__
    return self._check_graph_module(*args, **kwargs)
  File "/home/arijr/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_export/verifier.py", line 222, in _check_graph_module
    _check_valid_op(node.target)
  File "/home/arijr/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_export/verifier.py", line 205, in _check_valid_op
    self.check_valid_op(op)
  File "/home/arijr/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/verification/verifier.py", line 116, in check_valid_op
    raise SpecViolationError(
**torch._export.verifier.SpecViolationError: Operator torch._ops.aten.unfold.default is not Aten Canonical.**

Versions

Collecting environment information...
PyTorch version: 2.4.0+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04 LTS (x86_64)
GCC version: (Ubuntu 13.2.0-23ubuntu4) 13.2.0
Clang version: Could not collect
CMake version: version 3.30.3
Libc version: glibc-2.39

Python version: 3.10.0 (default, Mar 3 2022, 09:58:08) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-6.8.0-40-generic-x86_64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: 12.2.91
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090

Nvidia driver version: 535.183.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.3.3
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.3.3
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.3.3
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.3.3
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.3.3
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.3.3
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.3.3
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 32
On-line CPU(s) list: 0-31
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 9 5950X 16-Core Processor
CPU family: 25
Model: 33
Thread(s) per core: 2
Core(s) per socket: 16
Socket(s): 1
Stepping: 2
Frequency boost: enabled
CPU(s) scaling MHz: 58%
CPU max MHz: 5083,3979
CPU min MHz: 2200,0000
BogoMIPS: 6799,81
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm debug_swap
L1d cache: 512 KiB (16 instances)
L1i cache: 512 KiB (16 instances)
L2 cache: 8 MiB (16 instances)
L3 cache: 64 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-31
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] executorch==0.3.0a0+7d77d78
[pip3] numpy==2.1.1
[pip3] torch==2.4.0+cpu
[pip3] torchaudio==2.4.0+cpu
[pip3] torchsr==1.0.4
[pip3] torchvision==0.19.0
[conda] executorch 0.3.0a0+7d77d78 pypi_0 pypi
[conda] numpy 2.1.1 pypi_0 pypi
[conda] torch 2.4.0+cpu pypi_0 pypi
[conda] torchaudio 2.4.0+cpu pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchvision 0.19.0 pypi_0 pypi
(

@guangy10
Copy link
Contributor

guangy10 commented Sep 17, 2024

All existing core ATen ops can be found here: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml

cc: @SS-JIA for core aten opset

@guangy10 guangy10 added the module: exir Issues related to Export IR label Sep 17, 2024
@ari-ruokamo
Copy link
Author

Knowing this might not be the solution but I tried disabling the dialect op check using,

edge_program = to_edge(aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False))
the to_edge() phase completes, but the next phase terminates in an error:

Traceback (most recent call last):
  File "/home/arijr/miniconda3/envs/executorch/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/arijr/miniconda3/envs/executorch/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/arijr/projects/SCNet/scnet/export-executorch.py", line 33, in <module>
    executorch_program = edge_program.to_executorch()
  File "/home/arijr/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 1196, in to_executorch
    new_gm_res = p(new_gm)
  File "/home/arijr/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/fx/passes/infra/pass_base.py", line 41, in __call__
    res = self.call(graph_module)
  File "/home/arijr/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/passes/__init__.py", line 423, in call
    raise RuntimeError(f"Missing out variants: {missing_out_vars}")
RuntimeError: Missing out variants: {'aten::_unsafe_index_put'}

Any tips & assistance how to mitigate these issues is highly appreciated - thanks!

@ari-ruokamo
Copy link
Author

ari-ruokamo commented Sep 28, 2024

An example/dummy forward-function:

    def forward(self, x):
        L = x.shape[-1]
        x = x.reshape(-1, L)
        x = torch.stft(x, **self.stft_config, return_complex=True)

        # NOP

        x = torch.istft(x, **self.stft_config)

        return x

Export trials:

  1. Export with config (_check_ir_validity=True) makes the to_edge() compilation fail at the first stft-funtion,

--> **torch._export.verifier.SpecViolationError: Operator torch._ops.aten.unfold.default is not Aten Canonical.**

  1. Export with config (_check_ir_validity=False) make the edge_program.to_executorch() fail at the latter istft-function with an error,

--> RuntimeError: Missing out variants: {'aten::_unsafe_index_put'}

@divideconcept
Copy link

Does this mean it's impossible to export any model that uses stft/istft ? Is that planned ?

@ari-ruokamo
Copy link
Author

ari-ruokamo commented Oct 9, 2024

Here's the graph output of the dummy forward-function. Where does the export compiler find a torch._ops.aten.unfold.default in there?

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[1, 2, 895088]"):
            # File: /home/arijr/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/functional.py:4552 in pad, code: return torch._C._nn.pad(input, pad, mode, value)
            pad: "f32[1, 2, 896000]" = torch.ops.aten.pad.default(x, [0, 912]);  x = None
            
            # File: /home/arijr/projects/Mss/Mss.py:331 in forward, code: x = x.reshape(-1, L)
            view: "f32[2, 896000]" = torch.ops.aten.view.default(pad, [-1, 896000]);  pad = None
            
            # File: /home/arijr/projects/Mss/Mss.py:333 in forward, code: x = torch.stft(x, **self.stft_config, return_complex=True)
            view_1: "f32[1, 2, 896000]" = torch.ops.aten.view.default(view, [1, 2, 896000]);  view = None
            pad_1: "f32[1, 2, 900096]" = torch.ops.aten.pad.default(view_1, [2048, 2048], 'reflect');  view_1 = None
            view_2: "f32[2, 900096]" = torch.ops.aten.view.default(pad_1, [2, 900096]);  pad_1 = None
            stft: "c64[2, 2049, 876]" = torch.ops.aten.stft.default(view_2, 4096, 1024, 4096, None, True, None, True);  view_2 = None
            
            # File: /home/arijr/projects/Mss/Mss.py:338 in forward, code: x = torch.istft(x, **self.stft_config)
            istft: "f32[2, 896000]" = torch.ops.aten.istft.default(stft, 4096, 1024, 4096, None, True, True);  stft = None
            return (istft,)

@ari-ruokamo
Copy link
Author

ari-ruokamo commented Oct 27, 2024

RuntimeError: view_as_complex_copy does not support automatic differentiation for outputs with complex dtype.

Moved to a macOS with configuration:

OS: macOS 14.6.1 (arm64)
Clang version: 16.0.0 (clang-1600.0.26.3)
CMake version: version 3.30.5

Python version: 3.10.0 (default, Mar  3 2022, 03:54:28) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-14.6.1-arm64-arm-64bit
Is XNNPACK available: True

CPU:
Apple M2 Max

Versions of relevant libraries:
[pip3] executorch==0.4.0a0+6a085ff
[pip3] numpy==1.21.3
[pip3] torch==2.5.0
[pip3] torchaudio==2.5.0
[pip3] torchsr==1.0.4
[pip3] torchvision==0.20.0
[conda] executorch                0.4.0a0+6a085ff          pypi_0    pypi
[conda] numpy                     1.21.3                   pypi_0    pypi
[conda] torch                     2.5.0                    pypi_0    pypi
[conda] torchaudio                2.5.0                    pypi_0    pypi
[conda] torchsr                   1.0.4                    pypi_0    pypi
[conda] torchvision               0.20.0                   pypi_0    pypi

Running a dummy model export on macOS the export completes successfully with,

  • Fixed model input tensor
  • EdgeCompileConfig(_check_ir_validity=False)

However, when exporting the real model with the exact same export-script, export terminates at to_edge():

export(...):

ExportedProgram:
    class GraphModule(torch.nn.Module):

<clip>

             # File: /Users/ariruokamo/projects/Mss/Mss.py:392 in forward, code: x = torch.view_as_complex(x.contiguous())
            clone_19: "f32[4, 2049, 880, 2]" = torch.ops.aten.clone.default(permute_19, memory_format = torch.contiguous_format);  permute_19 = None
            view_as_complex: "c64[4, 2049, 880]" = torch.ops.aten.view_as_complex.default(clone_19);  clone_19 = None

to_edge(...):

Traceback (most recent call last):
  File "/opt/miniconda3/envs/executorch/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/miniconda3/envs/executorch/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/ariruokamo/projects/Mss/export-executorch.py", line 63, in <module>
    executorch_program = edge_program.to_executorch()
  File "/opt/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 1322, in to_executorch
    new_gm_res = p(new_gm)
  File "/opt/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/fx/passes/infra/pass_base.py", line 41, in __call__
    res = self.call(graph_module)
  File "/opt/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/pass_base.py", line 572, in call
    result = self.call_submodule(graph_module, tuple(inputs))
  File "/opt/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/pass_base.py", line 658, in call_submodule
    res = super().call_submodule(graph_module, inputs)
  File "/opt/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/pass_base.py", line 535, in call_submodule
    interpreter.run(*inputs_data)
  File "/opt/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
  File "/opt/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/pass_base.py", line 375, in run_node
    return super().run_node(n)
  File "/opt/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/opt/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/pass_base.py", line 607, in call_function
    return self.callback.call_operator(
  File "/opt/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/passes/spec_prop_pass.py", line 96, in call_operator
    meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data))
  File "/opt/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/dialects/edge/_ops.py", line 333, in __call__
    return self._op(*args, **kwargs)
  File "/opt/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_ops.py", line 716, in __call__
    return self._op(*args, **kwargs)
RuntimeError: view_as_complex_copy does not support automatic differentiation for outputs with complex dtype.

While executing %aten_view_as_complex_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_as_complex_copy.default](args = (%aten_clone_default_79,), kwargs = {})
Original traceback:
  File "/Users/ariruokamo/projects/Mss/Mss.py", line 392, in forward
    x = torch.view_as_complex(x.contiguous())

This is a little confusing, view_as_complex() get the exact same tensor as in the dummy version. Any help appreciated!

@valkjsaaa
Copy link

I have the same issue. Thanks for posting this.

@slycheese
Copy link

Adding to the list here... I'm having the exact same issue.

aten_dialect = export() works fine, but
edge_program = to_edge(aten_dialect)

gives the verification error,
_'Operator torch.ops.aten.unfold.default is not Aten Canonical'.

The model I'm trying to export also uses torch.stft

As noted above by @ari-ruokamo, it's possible (at least for my model) to force my way to a lowered model by disabling the verification when lowering to edge, but the subsequent model crashes when I try to execute it via executor_runner (unsurprisingly!).

This is probably not useful, but just in case, this is the output when the model fails to run:

I 00:00:00.001028 executorch:executor_runner.cpp:82] Model file ../tmp/compute_features.pte is loaded.
I 00:00:00.001050 executorch:executor_runner.cpp:91] Using method forward
I 00:00:00.001061 executorch:executor_runner.cpp:138] Setting up planned buffer 0, size 5165040.
E 00:00:00.002096 executorch:tensor_parser_portable.cpp:49] Invalid or unsupported ScalarType 9
E 00:00:00.002103 executorch:method.cpp:386] Failed parsing tensor at index 29: 0x23
F 00:00:00.002111 executorch:executor_runner.cpp:160] In function main(), assert failed (method.ok()): Loading of method forward failed with status 0x23

@slycheese
Copy link

slycheese commented Nov 14, 2024

#5381 (comment)

Here's the graph output of the dummy forward-function. Where does the export compiler find a torch._ops.aten.unfold.default in there?

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[1, 2, 895088]"):
            # File: /home/arijr/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/functional.py:4552 in pad, code: return torch._C._nn.pad(input, pad, mode, value)
            pad: "f32[1, 2, 896000]" = torch.ops.aten.pad.default(x, [0, 912]);  x = None
            
            # File: /home/arijr/projects/Mss/Mss.py:331 in forward, code: x = x.reshape(-1, L)
            view: "f32[2, 896000]" = torch.ops.aten.view.default(pad, [-1, 896000]);  pad = None
            
            # File: /home/arijr/projects/Mss/Mss.py:333 in forward, code: x = torch.stft(x, **self.stft_config, return_complex=True)
            view_1: "f32[1, 2, 896000]" = torch.ops.aten.view.default(view, [1, 2, 896000]);  view = None
            pad_1: "f32[1, 2, 900096]" = torch.ops.aten.pad.default(view_1, [2048, 2048], 'reflect');  view_1 = None
            view_2: "f32[2, 900096]" = torch.ops.aten.view.default(pad_1, [2, 900096]);  pad_1 = None
            stft: "c64[2, 2049, 876]" = torch.ops.aten.stft.default(view_2, 4096, 1024, 4096, None, True, None, True);  view_2 = None
            
            # File: /home/arijr/projects/Mss/Mss.py:338 in forward, code: x = torch.istft(x, **self.stft_config)
            istft: "f32[2, 896000]" = torch.ops.aten.istft.default(stft, 4096, 1024, 4096, None, True, True);  stft = None
            return (istft,)

FWIW, my forced-to-edge graph does have aten.unfold.default operations. Here's the edge graph I got:

graph():
    %c_l__args___0__modules__compute_stft___window : [num_users=1] = placeholder[target=c_l__args___0__modules__compute_stft___window]
    %c_l__args___0__modules__compute_fbanks___f_central : [num_users=1] = placeholder[target=c_l__args___0__modules__compute_fbanks___f_central]
    %c_l__args___0__modules__compute_fbanks___all_freqs_mat : [num_users=1] = placeholder[target=c_l__args___0__modules__compute_fbanks___all_freqs_mat]
    %c_l__args___0__modules__compute_fbanks___band : [num_users=1] = placeholder[target=c_l__args___0__modules__compute_fbanks___band]
    %_lifted_tensor_constant0 : [num_users=1] = placeholder[target=_lifted_tensor_constant0]
    %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
    %_lifted_tensor_constant2 : [num_users=1] = placeholder[target=_lifted_tensor_constant2]
    %_lifted_tensor_constant3 : [num_users=1] = placeholder[target=_lifted_tensor_constant3]
    %_lifted_tensor_constant4 : [num_users=1] = placeholder[target=_lifted_tensor_constant4]
    %wav : [num_users=1] = placeholder[target=wav]
    %aten_view_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%wav, [1, 1, 207952]), kwargs = {})
    %aten_constant_pad_nd_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.constant_pad_nd.default](args = (%aten_view_copy_default, [200, 200], 0.0), kwargs = {})
    %aten_view_copy_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%aten_constant_pad_nd_default, [1, 208352]), kwargs = {})
    **%aten_unfold_copy_default** : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.**unfold_copy.default**](args = (%aten_view_copy_default_1, -1, 400, 160), kwargs = {})
    %aten_mul_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mul.Tensor](args = (%**aten_unfold_copy_default**, %c_l__args___0__modules__compute_stft___window), kwargs = {})
    %aten__fft_r2c_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._fft_r2c.default](args = (%aten_mul_tensor, [2], 0, True), kwargs = {})
    %aten_permute_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.permute_copy.default](args = (%aten__fft_r2c_default, [0, 2, 1]), kwargs = {})
    %aten_view_as_real_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_as_real_copy.default](args = (%aten_permute_copy_default,), kwargs = {})
    %aten_permute_copy_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.permute_copy.default](args = (%aten_view_as_real_copy_default, [0, 2, 1, 3]), kwargs = {})
    %aten_pow_tensor_scalar : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.pow.Tensor_Scalar](args = (%aten_permute_copy_default_1, 2), kwargs = {})
    %aten_sum_dim_int_list : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.sum.dim_IntList](args = (%aten_pow_tensor_scalar, [-1]), kwargs = {})
    %aten_pow_tensor_scalar_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.pow.Tensor_Scalar](args = (%aten_sum_dim_int_list, 1), kwargs = {})
    %aten_repeat_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.repeat.default](args = (%c_l__args___0__modules__compute_fbanks___f_central, [201, 1]), kwargs = {})
    %aten_permute_copy_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.permute_copy.default](args = (%aten_repeat_default, [1, 0]), kwargs = {})
    %aten_repeat_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.repeat.default](args = (%c_l__args___0__modules__compute_fbanks___band, [201, 1]), kwargs = {})
    %aten_permute_copy_default_3 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.permute_copy.default](args = (%aten_repeat_default_1, [1, 0]), kwargs = {})
    %aten_sub_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.sub.Tensor](args = (%c_l__args___0__modules__compute_fbanks___all_freqs_mat, %aten_permute_copy_default_2), kwargs = {})
    %aten_div_tensor : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.div.Tensor](args = (%aten_sub_tensor, %aten_permute_copy_default_3), kwargs = {})
    %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_div_tensor, %_lifted_tensor_constant0), kwargs = {})
    %aten_neg_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.neg.default](args = (%aten_div_tensor,), kwargs = {})
    %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_neg_default, %_lifted_tensor_constant1), kwargs = {})
    %aten_full_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.full.default](args = ([1], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
    %aten_minimum_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.minimum.default](args = (%aten_add_tensor, %aten_add_tensor_1), kwargs = {})
    %aten_maximum_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.maximum.default](args = (%aten_full_default, %aten_minimum_default), kwargs = {})
    %aten_permute_copy_default_4 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.permute_copy.default](args = (%aten_maximum_default, [1, 0]), kwargs = {})
    %aten_view_copy_default_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%aten_pow_tensor_scalar_1, [1300, 201]), kwargs = {})
    %aten_mm_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mm.default](args = (%aten_view_copy_default_2, %aten_permute_copy_default_4), kwargs = {})
    %aten_view_copy_default_3 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%aten_mm_default, [1, 1300, 60]), kwargs = {})
    %aten_clamp_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.clamp.default](args = (%aten_view_copy_default_3, 1e-10), kwargs = {})
    %aten_log10_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.log10.default](args = (%aten_clamp_default,), kwargs = {})
    %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant2,), kwargs = {dtype: torch.float32})
    %aten_mul_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mul.Tensor](args = (%aten_log10_default, %aten__to_copy_default), kwargs = {})
    %aten_sub_tensor_1 : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.sub.Tensor](args = (%aten_mul_tensor_1, %_lifted_tensor_constant3), kwargs = {})
    %aten_amax_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.amax.default](args = (%aten_sub_tensor_1, [-2, -1]), kwargs = {})
    %aten_sub_tensor_2 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.sub.Tensor](args = (%aten_amax_default, %_lifted_tensor_constant4), kwargs = {})
    %aten_view_copy_default_4 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%aten_sub_tensor_2, [1, 1, 1]), kwargs = {})
    %aten_maximum_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.maximum.default](args = (%aten_sub_tensor_1, %aten_view_copy_default_4), kwargs = {})
    return (aten_maximum_default_1,)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: exir Issues related to Export IR
Projects
None yet
Development

No branches or pull requests

5 participants