From dd449c17a5bd5deaccdf317d25d456f8918e66a7 Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 21 Nov 2024 11:19:59 +0800 Subject: [PATCH] pnnx convert torchaudio functional spectrogram --- tools/pnnx/src/CMakeLists.txt | 3 + tools/pnnx/src/ir.cpp | 1 + .../pass_level2/torchaudio_F_spectrogram.cpp | 397 ++++++++++++++++++ .../pass_ncnn/torchaudio_F_spectrogram.cpp | 233 ++++++++++ .../ncnn/test_torchaudio_F_spectrogram.py | 65 +++ .../tests/test_torchaudio_F_spectrogram.py | 64 +++ 6 files changed, 763 insertions(+) create mode 100644 tools/pnnx/src/pass_level2/torchaudio_F_spectrogram.cpp create mode 100644 tools/pnnx/src/pass_ncnn/torchaudio_F_spectrogram.cpp create mode 100644 tools/pnnx/tests/ncnn/test_torchaudio_F_spectrogram.py create mode 100644 tools/pnnx/tests/test_torchaudio_F_spectrogram.py diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 615db2d05084..9f22ed90128b 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -306,6 +306,8 @@ set(pnnx_pass_level2_SRCS pass_level2/nn_quantized_FloatFunctional.cpp + pass_level2/torchaudio_F_spectrogram.cpp + pass_level2/nn_GRU.cpp pass_level2/nn_LSTM.cpp pass_level2/nn_RNN.cpp @@ -587,6 +589,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/torch_t.cpp pass_ncnn/torch_transpose.cpp pass_ncnn/torch_unsqueeze.cpp + pass_ncnn/torchaudio_F_spectrogram.cpp pass_ncnn/torchvision_DeformConv2d.cpp ) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 9e616699a9e5..394754273b72 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1458,6 +1458,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) fprintf(pyfp, "import torch.nn.functional as F\n"); fprintf(pyfp, "try:\n"); fprintf(pyfp, " import torchvision\n"); + fprintf(pyfp, " import torchaudio\n"); fprintf(pyfp, "except:\n"); fprintf(pyfp, " pass\n"); diff --git a/tools/pnnx/src/pass_level2/torchaudio_F_spectrogram.cpp b/tools/pnnx/src/pass_level2/torchaudio_F_spectrogram.cpp new file mode 100644 index 000000000000..a803de5a7075 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torchaudio_F_spectrogram.cpp @@ -0,0 +1,397 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torchaudio_F_spectrogram : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +27 26 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +prim::Constant op_0 0 1 11 value=0 +aten::size op_1 2 1 waveform 11 12 +prim::NumToTensor op_2 1 1 12 13 +aten::Int op_3 1 1 13 18 +prim::Constant op_4 0 1 15 value=-1 +prim::ListConstruct op_5 2 1 15 18 19 +aten::reshape op_6 2 1 waveform 19 waveform.1 +prim::Constant op_7 0 1 normalized value=%normalized +prim::Constant op_8 0 1 return_complex value=True +aten::stft op_9 8 1 waveform.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 +prim::Constant op_10 0 1 29 value=1 +aten::size op_11 2 1 spec_f.1 29 30 +prim::NumToTensor op_12 1 1 30 31 +aten::Int op_13 1 1 31 34 +prim::Constant op_14 0 1 36 value=2 +aten::size op_15 2 1 spec_f.1 36 37 +prim::NumToTensor op_16 1 1 37 38 +aten::Int op_17 1 1 38 43 +prim::ListConstruct op_18 2 1 34 43 44 +aten::reshape op_19 2 1 spec_f.1 44 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torchaudio.functional.spectrogram"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["pad"] = 0; + op->params["pad_mode"] = "reflect"; + op->params["center"] = false; + op->params["power"] = Parameter(); + if (captured_params.at("normalized").b) + { + op->params["normalized"] = "frame_length"; + } + else + { + op->params["normalized"] = false; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram, 6) + +class torchaudio_F_spectrogram_0 : public torchaudio_F_spectrogram +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +31 30 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +prim::Constant op_0 0 1 11 value=0 +aten::size op_1 2 1 waveform 11 12 +prim::NumToTensor op_2 1 1 12 13 +aten::Int op_3 1 1 13 16 +prim::Constant op_4 0 1 18 value=1 +aten::size op_5 2 1 waveform 18 19 +prim::NumToTensor op_6 1 1 19 20 +aten::Int op_7 1 1 20 25 +prim::Constant op_8 0 1 22 value=-1 +prim::ListConstruct op_9 2 1 22 25 26 +aten::reshape op_10 2 1 waveform 26 waveform.1 +prim::Constant op_11 0 1 normalized value=%normalized +prim::Constant op_12 0 1 return_complex value=True +aten::stft op_13 8 1 waveform.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 +prim::Constant op_14 0 1 72 value=1 +aten::size op_15 2 1 spec_f.1 72 36 +prim::NumToTensor op_16 1 1 36 37 +aten::Int op_17 1 1 37 40 +prim::Constant op_18 0 1 42 value=2 +aten::size op_19 2 1 spec_f.1 42 43 +prim::NumToTensor op_20 1 1 43 44 +aten::Int op_21 1 1 44 50 +prim::ListConstruct op_22 3 1 16 40 50 51 +aten::reshape op_23 2 1 spec_f.1 51 out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_0, 6) + +class torchaudio_F_spectrogram_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +58 57 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +prim::Constant op_0 0 1 18 value=1 +aten::size op_1 2 1 waveform 18 19 +prim::NumToTensor op_2 1 1 19 20 +aten::Int op_3 1 1 20 25 +prim::Constant op_4 0 1 22 value=-1 +prim::ListConstruct op_5 2 1 22 25 26 +aten::reshape op_6 2 1 waveform 26 waveform.1 +prim::Constant op_7 0 1 106 value=0 +aten::size op_8 2 1 waveform.1 106 29 +prim::NumToTensor op_9 1 1 29 30 +aten::Int op_10 1 1 30 33 +prim::Constant op_11 0 1 107 value=1 +aten::size op_12 2 1 waveform.1 107 35 +prim::NumToTensor op_13 1 1 35 36 +aten::Int op_14 1 1 36 41 +prim::Constant op_15 0 1 108 value=1 +prim::ListConstruct op_16 3 1 108 33 41 42 +aten::view op_17 2 1 waveform.1 42 input0.1 +prim::Constant op_18 0 1 45 value=%pad_left +prim::Constant op_19 0 1 109 value=%pad_right +prim::ListConstruct op_20 2 1 45 109 46 +prim::Constant op_21 0 1 47 value=%pad_mode +prim::Constant op_22 0 1 110 value=None +aten::pad op_23 4 1 input0.1 46 47 110 input1.1 +prim::Constant op_24 0 1 111 value=1 +aten::size op_25 2 1 input1.1 111 51 +prim::NumToTensor op_26 1 1 51 52 +aten::Int op_27 1 1 52 55 +prim::Constant op_28 0 1 57 value=2 +aten::size op_29 2 1 input1.1 57 58 +prim::NumToTensor op_30 1 1 58 59 +aten::Int op_31 1 1 59 64 +prim::ListConstruct op_32 2 1 55 64 65 +aten::view op_33 2 1 input1.1 65 input2.1 +prim::Constant op_34 0 1 normalized value=%normalized +prim::Constant op_35 0 1 return_complex value=True +aten::stft op_36 8 1 input2.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 +prim::Constant op_37 0 1 11 value=0 +aten::size op_38 2 1 waveform 11 12 +prim::NumToTensor op_39 1 1 12 13 +aten::Int op_40 1 1 13 16 +prim::Constant op_41 0 1 116 value=1 +aten::size op_42 2 1 spec_f.1 116 75 +prim::NumToTensor op_43 1 1 75 76 +aten::Int op_44 1 1 76 79 +prim::Constant op_45 0 1 117 value=2 +aten::size op_46 2 1 spec_f.1 117 81 +prim::NumToTensor op_47 1 1 81 82 +aten::Int op_48 1 1 82 88 +prim::ListConstruct op_49 3 1 16 79 88 89 +aten::reshape op_50 2 1 spec_f.1 89 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torchaudio.functional.spectrogram"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["pad"] = 0; + op->params["pad_mode"] = captured_params.at("pad_mode"); + op->params["center"] = true; + op->params["power"] = Parameter(); + if (captured_params.at("normalized").b) + { + op->params["normalized"] = "frame_length"; + } + else + { + op->params["normalized"] = false; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1, 6) + +class torchaudio_F_spectrogram_1_1 : public torchaudio_F_spectrogram_1 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +63 62 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +prim::Constant op_0 0 1 11 value=0 +aten::size op_1 2 1 waveform 11 12 +prim::NumToTensor op_2 1 1 12 13 +aten::Int op_3 1 1 13 18 +prim::Constant op_4 0 1 15 value=-1 +prim::ListConstruct op_5 2 1 15 18 19 +aten::reshape op_6 2 1 waveform 19 waveform.1 +prim::Constant op_7 0 1 108 value=0 +aten::size op_8 2 1 waveform.1 108 22 +prim::NumToTensor op_9 1 1 22 23 +aten::Int op_10 1 1 23 26 +prim::Constant op_11 0 1 28 value=1 +aten::size op_12 2 1 waveform.1 28 29 +prim::NumToTensor op_13 1 1 29 30 +aten::Int op_14 1 1 30 35 +prim::Constant op_15 0 1 109 value=1 +prim::ListConstruct op_16 3 1 109 26 35 36 +aten::view op_17 2 1 waveform.1 36 input0.1 +prim::Constant op_18 0 1 39 value=%pad_left +prim::Constant op_19 0 1 110 value=%pad_right +prim::ListConstruct op_20 2 1 39 110 40 +prim::Constant op_21 0 1 41 value=%pad_mode +prim::Constant op_22 0 1 111 value=None +aten::pad op_23 4 1 input0.1 40 41 111 input1.1 +prim::Constant op_24 0 1 112 value=1 +aten::size op_25 2 1 input1.1 112 45 +prim::NumToTensor op_26 1 1 45 46 +aten::Int op_27 1 1 46 49 +prim::Constant op_28 0 1 51 value=2 +aten::size op_29 2 1 input1.1 51 52 +prim::NumToTensor op_30 1 1 52 53 +aten::Int op_31 1 1 53 58 +prim::ListConstruct op_32 2 1 49 58 59 +aten::view op_33 2 1 input1.1 59 input2.1 +prim::Constant op_34 0 1 normalized value=%normalized +prim::Constant op_35 0 1 return_complex value=True +aten::stft op_36 8 1 input2.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 +prim::Constant op_37 0 1 117 value=1 +aten::size op_38 2 1 spec_f.1 117 69 +prim::NumToTensor op_39 1 1 69 70 +aten::Int op_40 1 1 70 73 +prim::Constant op_50 0 1 118 value=2 +aten::size op_51 2 1 spec_f.1 118 75 +prim::NumToTensor op_52 1 1 75 76 +aten::Int op_53 1 1 76 81 +prim::ListConstruct op_54 2 1 73 81 82 +aten::reshape op_55 2 1 spec_f.1 82 out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_1, 6) + +class torchaudio_F_spectrogram_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +14 13 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +torchaudio.functional.spectrogram op_0 6 1 waveform n_fft hop_length win_length window onesided spec power=None normalized=False center=%center pad=%pad pad_mode=%pad_mode +prim::Constant op_1 0 1 92 value=2.000000e+00 +aten::pow op_2 2 1 window 92 93 +prim::Constant op_3 0 1 127 value=None +aten::sum op_4 2 1 93 127 95 +aten::sqrt op_5 1 1 95 96 +aten::div op_6 2 1 spec 96 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torchaudio.functional.spectrogram"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["pad"] = captured_params.at("pad"); + op->params["pad_mode"] = captured_params.at("pad_mode"); + op->params["center"] = captured_params.at("center"); + op->params["power"] = Parameter(); + op->params["normalized"] = "window"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_2, 7) + +class torchaudio_F_spectrogram_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +torchaudio.functional.spectrogram op_0 6 1 waveform n_fft hop_length win_length window onesided spec power=None normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode +aten::abs op_1 1 1 spec out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torchaudio.functional.spectrogram"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["pad"] = captured_params.at("pad"); + op->params["pad_mode"] = captured_params.at("pad_mode"); + op->params["center"] = captured_params.at("center"); + op->params["normalized"] = captured_params.at("normalized"); + op->params["power"] = 1; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_3, 8) + +class torchaudio_F_spectrogram_4 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +10 9 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +torchaudio.functional.spectrogram op_0 6 1 waveform n_fft hop_length win_length window onesided spec power=1 normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode +prim::Constant op_1 0 1 391 value=2 +aten::pow op_2 2 1 spec 391 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torchaudio.functional.spectrogram"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["pad"] = captured_params.at("pad"); + op->params["pad_mode"] = captured_params.at("pad_mode"); + op->params["center"] = captured_params.at("center"); + op->params["normalized"] = captured_params.at("normalized"); + op->params["power"] = 2; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_4, 9) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torchaudio_F_spectrogram.cpp b/tools/pnnx/src/pass_ncnn/torchaudio_F_spectrogram.cpp new file mode 100644 index 000000000000..5c42dc191704 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torchaudio_F_spectrogram.cpp @@ -0,0 +1,233 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +static bool NearlyEqual(float a, float b, float epsilon) +{ + if (a == b) + return true; + + float diff = (float)fabs(a - b); + if (diff <= epsilon) + return true; + + // relative error + return diff < epsilon * std::max(fabs(a), fabs(b)); +} + +static int detect_window_type(const std::vector& window_data) +{ + const int winlen = (int)window_data.size(); + + bool is_one = true; + bool is_hann = true; + bool is_hamming = true; + for (int i = 0; i < winlen; i++) + { + if (!NearlyEqual(window_data[i], 1.f, 0.001)) + is_one = false; + + if (!NearlyEqual(window_data[i], 0.5f * (1 - cos(2 * M_PI * i / winlen)), 0.001)) + is_hann = false; + + if (!NearlyEqual(window_data[i], 0.54f - 0.46f * cos(2 * M_PI * i / winlen), 0.001)) + is_hamming = false; + } + + if (is_one) + return 0; + if (is_hann) + return 1; + if (is_hamming) + return 2; + + return -1; +} + +class torchaudio_F_spectrogram : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 window @data +torchaudio.functional.spectrogram op_1 2 1 input window a n_fft=%n_fft hop_length=%hop_length win_length=%win_length onesided=%onesided power=%power normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode +torch.view_as_real op_2 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Spectrogram"; + } + + const char* name_str() const + { + return "spectrogram"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (captured_params.at("power").type != 0) + return false; + + const std::vector window_data = captured_attrs.at("op_0.data").get_float32_data(); + const int window_type = detect_window_type(window_data); + return window_type != -1; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const std::vector window_data = captured_attrs.at("op_0.data").get_float32_data(); + const int window_type = detect_window_type(window_data); + + const std::string& pad_mode = captured_params.at("pad_mode").s; + int pad_type = 2; + if (pad_mode == "constant") + pad_type = 0; + if (pad_mode == "replicate") + pad_type = 1; + if (pad_mode == "reflect") + pad_type = 2; + const int onesided = captured_params.at("onesided").type == 1 && captured_params.at("onesided").b == false ? 0 : 1; + int normalized = 0; + if (captured_params.at("normalized").type == 1) + { + normalized = captured_params.at("normalized").b ? 2 : 0; + } + if (captured_params.at("normalized").type == 4) + { + if (captured_params.at("normalized").s == "frame_length") + normalized = 1; + if (captured_params.at("normalized").s == "window") + normalized = 2; + } + + op->params["0"] = captured_params.at("n_fft"); + op->params["1"] = 0; // power + op->params["2"] = captured_params.at("hop_length"); + op->params["3"] = captured_params.at("win_length"); + op->params["4"] = window_type; + op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; + op->params["6"] = pad_type; + op->params["7"] = normalized; + op->params["8"] = onesided; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram, 20) + +class torchaudio_F_spectrogram_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 window @data +torchaudio.functional.spectrogram op_1 2 1 input window out n_fft=%n_fft hop_length=%hop_length win_length=%win_length onesided=%onesided power=%power normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Spectrogram"; + } + + const char* name_str() const + { + return "spectrogram"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (captured_params.at("power").type == 0) + return false; + + const std::vector window_data = captured_attrs.at("op_0.data").get_float32_data(); + const int window_type = detect_window_type(window_data); + return window_type != -1; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const std::vector window_data = captured_attrs.at("op_0.data").get_float32_data(); + const int window_type = detect_window_type(window_data); + + const std::string& pad_mode = captured_params.at("pad_mode").s; + int pad_type = 2; + if (pad_mode == "constant") + pad_type = 0; + if (pad_mode == "replicate") + pad_type = 1; + if (pad_mode == "reflect") + pad_type = 2; + const int onesided = captured_params.at("onesided").type == 1 && captured_params.at("onesided").b == false ? 0 : 1; + int normalized = 0; + if (captured_params.at("normalized").type == 1) + { + normalized = captured_params.at("normalized").b ? 2 : 0; + } + if (captured_params.at("normalized").type == 4) + { + if (captured_params.at("normalized").s == "frame_length") + normalized = 1; + if (captured_params.at("normalized").s == "window") + normalized = 2; + } + + int power = 0; + if (captured_params.at("power").type == 2) + { + power = captured_params.at("power").i; + if (power != 1 && power != 2) + fprintf(stderr, "unsupported spectrogram power %d\n", power); + } + if (captured_params.at("power").type == 3) + { + if (NearlyEqual(captured_params.at("power").f, 1.0, 0.0001)) + power = 1; + else if (NearlyEqual(captured_params.at("power").f, 2.0, 0.0001)) + power = 2; + else + fprintf(stderr, "unsupported spectrogram power %f\n", captured_params.at("power").f); + } + + op->params["0"] = captured_params.at("n_fft"); + op->params["1"] = power; + op->params["2"] = captured_params.at("hop_length"); + op->params["3"] = captured_params.at("win_length"); + op->params["4"] = window_type; + op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; + op->params["6"] = pad_type; + op->params["7"] = normalized; + op->params["8"] = onesided; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/tests/ncnn/test_torchaudio_F_spectrogram.py b/tools/pnnx/tests/ncnn/test_torchaudio_F_spectrogram.py new file mode 100644 index 000000000000..1acd2009914b --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torchaudio_F_spectrogram.py @@ -0,0 +1,65 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + out0 = torchaudio.functional.spectrogram(x, n_fft=64, window=torch.hann_window(64), win_length=64, hop_length=16, pad=0, center=True, normalized='window', power=1) + out1 = torchaudio.functional.spectrogram(x, n_fft=128, window=torch.hann_window(128), win_length=128, hop_length=3, pad=0, center=False, onesided=True, normalized=False, power=None) + out2 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, pad=0, center=True, pad_mode='constant', onesided=True, normalized='frame_length', power=2) + out3 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(512), win_length=512, hop_length=128, pad=32, center=True, onesided=False, normalized=False, power=2) + out1 = torch.view_as_real(out1) + return out0, out1, out2, out3 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(2560) + y = torch.rand(1000) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torchaudio_F_spectrogram.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torchaudio_F_spectrogram.pt inputshape=[2560],[1000]") + + # ncnn inference + import test_torchaudio_F_spectrogram_ncnn + b = test_torchaudio_F_spectrogram_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + print(a0) + print(b0) + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torchaudio_F_spectrogram.py b/tools/pnnx/tests/test_torchaudio_F_spectrogram.py new file mode 100644 index 000000000000..5bf05a61374b --- /dev/null +++ b/tools/pnnx/tests/test_torchaudio_F_spectrogram.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + out0 = torchaudio.functional.spectrogram(x, n_fft=64, window=torch.hann_window(64), win_length=64, hop_length=16, pad=0, center=True, normalized='window', power=1) + out1 = torchaudio.functional.spectrogram(x, n_fft=128, window=torch.hann_window(128), win_length=128, hop_length=3, pad=0, center=False, onesided=True, normalized=False, power=None) + out2 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, pad=0, center=True, pad_mode='constant', onesided=True, normalized='frame_length', power=2) + out3 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(512), win_length=512, hop_length=128, pad=32, center=True, onesided=False, normalized=False, power=2) + return out0, out1, out2, out3 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 2560) + y = torch.rand(1000) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torchaudio_F_spectrogram.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torchaudio_F_spectrogram.pt inputshape=[3,2560],[1000]") + + # pnnx inference + import test_torchaudio_F_spectrogram_pnnx + b = test_torchaudio_F_spectrogram_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + print(a0) + print(b0) + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)