Skip to content

Commit

Permalink
spectrogram and inverse spectrogram (#5779)
Browse files Browse the repository at this point in the history
* only supports hann, hamming and all-one window
* inverse spectrogram does not support length parameter
* spectrogram always returns torch.view_as_real(out) as ncnn does not support complex typed mat yet
* inverse spectrogram always accepts torch.view_as_complex(in) as ncnn does not support complex typed mat yet
  • Loading branch information
nihui authored Nov 22, 2024
1 parent c043612 commit 0734b65
Show file tree
Hide file tree
Showing 33 changed files with 3,155 additions and 22 deletions.
14 changes: 13 additions & 1 deletion .ci/pnnx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,39 +31,51 @@ jobs:
include:
- torch-version: 1.8.1
torchvision-version: 0.9.1
torchaudio-version: 0.8.1

- torch-version: 1.9.1
torchvision-version: 0.10.1
torchaudio-version: 0.9.1

- torch-version: 1.10.0
torchvision-version: 0.11.1
torchaudio-version: '0.10.0+cpu'

- torch-version: 1.11.0
torchvision-version: 0.12.0
torchaudio-version: '0.11.0+cpu'

- torch-version: 1.12.0
torchvision-version: 0.13.0
torchaudio-version: '0.12.0+cpu'

- torch-version: 1.13.0
torchvision-version: 0.14.0
torchaudio-version: '0.13.0+cpu'

- torch-version: 2.0.0
torchvision-version: 0.15.1
torchaudio-version: '2.0.0+cpu'

- torch-version: 2.1.0
torchvision-version: 0.16.0
torchaudio-version: '2.1.0+cpu'

- torch-version: 2.2.1
torchvision-version: 0.17.1
torchaudio-version: '2.2.1+cpu'

- torch-version: 2.3.0
torchvision-version: 0.18.0
torchaudio-version: '2.3.0+cpu'

- torch-version: 2.4.0
torchvision-version: 0.19.0
torchaudio-version: '2.4.0+cpu'

- torch-version: 2.5.0
torchvision-version: 0.20.0
torchaudio-version: '2.5.0+cpu'

runs-on:
pool-name: docker
Expand Down Expand Up @@ -169,7 +181,7 @@ jobs:
- name: setup-pytorch
run: |
export PYTHONUSERBASE=${{ci.workspace}}/torch-${{matrix.torch-version}}
pip3 install --user torch==${{matrix.torch-version}}+cpu torchvision==${{matrix.torchvision-version}}+cpu --index-url https://download.pytorch.org/whl/cpu
pip3 install --user torch==${{matrix.torch-version}}+cpu torchvision==${{matrix.torchvision-version}}+cpu torchaudio==${{matrix.torchaudio-version}} --index-url https://download.pytorch.org/whl/cpu
pip3 install --user onnx
pip3 install --user onnxscript
Expand Down
51 changes: 51 additions & 0 deletions docs/developer-guide/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
* [Input](#input)
* [InstanceNorm](#instancenorm)
* [Interp](#interp)
* [InverseSpectrogram](#inversespectrogram)
* [LayerNorm](#layernorm)
* [Log](#log)
* [LRN](#lrn)
Expand Down Expand Up @@ -81,6 +82,7 @@
* [Slice](#slice)
* [Softmax](#softmax)
* [Softplus](#softplus)
* [Spectrogram](#spectrogram)
* [Split](#split)
* [Swish](#swish)
* [TanH](#tanh)
Expand Down Expand Up @@ -1141,6 +1143,30 @@ Resize type:
- 2 = Bilinear
- 3 = Bicubic

# InverseSpectrogram
```
x1 = x as complex
x1 = x1 * sqrt(norm) if normalized
y = istft(x1)
y1 = unpad(y) if center
if returns == 0 return y1 as complex
if returns == 1 return y1 real
if returns == 2 return y1 imag
```

* one_blob_only

| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | n_fft | int | 0 | |
| 1 | returns | int | 1 | |
| 2 | hoplen | int | n_fft / 4 | |
| 3 | winlen | int | n_fft | |
| 4 | window_type | int | 0 | 0=ones 1=hann 2=hamming |
| 5 | center | int | 1 | |
| 7 | normalized | int | 0 | 0=no 1=n_fft 2=window-l2-energy |

# LayerNorm
```
split x along outmost axis into part x0, x1 ...
Expand Down Expand Up @@ -1829,6 +1855,31 @@ y = log(exp(x) + 1)
* one_blob_only
* support_inplace

# Spectrogram
```
x1 = pad(x) if center
y = stft(x1)
y = y / sqrt(norm) if normalized
if power == 0 return y as real
if power == 1 return magnitude
if power == 2 return square of magnitude
```

* one_blob_only

| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | n_fft | int | 0 | |
| 1 | power | int | 0 | |
| 2 | hoplen | int | n_fft / 4 | |
| 3 | winlen | int | n_fft | |
| 4 | window_type | int | 0 | 0=ones 1=hann 2=hamming |
| 5 | center | int | 1 | |
| 6 | pad_type | int | 2 | 0=CONSTANT 1=REPLICATE 2=REFLECT |
| 7 | normalized | int | 0 | 0=no 1=n_fft 2=window-l2-energy |
| 8 | onesided | int | 1 | |

# Split
```
y0, y1 ... = x
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ ncnn_add_layer(Diag)
ncnn_add_layer(CELU)
ncnn_add_layer(Shrink)
ncnn_add_layer(RMSNorm)
ncnn_add_layer(Spectrogram)
ncnn_add_layer(InverseSpectrogram)

if(NCNN_VULKAN)
ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp)
Expand Down
238 changes: 238 additions & 0 deletions src/layer/inversespectrogram.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "inversespectrogram.h"

namespace ncnn {

InverseSpectrogram::InverseSpectrogram()
{
one_blob_only = true;
support_inplace = false;
}

int InverseSpectrogram::load_param(const ParamDict& pd)
{
n_fft = pd.get(0, 0);
returns = pd.get(1, 0);
hoplen = pd.get(2, n_fft / 4);
winlen = pd.get(3, n_fft);
window_type = pd.get(4, 0);
center = pd.get(5, 1);
normalized = pd.get(7, 0);

// assert winlen <= n_fft
// generate window
window_data.create(normalized == 2 ? n_fft + 1 : n_fft);
{
float* p = window_data;
for (int i = 0; i < (n_fft - winlen) / 2; i++)
{
*p++ = 0.f;
}
if (window_type == 0)
{
// all ones
for (int i = 0; i < winlen; i++)
{
*p++ = 1.f;
}
}
if (window_type == 1)
{
// hann window
for (int i = 0; i < winlen; i++)
{
*p++ = 0.5f * (1 - cosf(2 * 3.14159265358979323846 * i / winlen));
}
}
if (window_type == 2)
{
// hamming window
for (int i = 0; i < winlen; i++)
{
*p++ = 0.54f - 0.46f * cosf(2 * 3.14159265358979323846 * i / winlen);
}
}
for (int i = 0; i < n_fft - winlen - (n_fft - winlen) / 2; i++)
{
*p++ = 0.f;
}

// pre-calculated window norm factor
if (normalized == 2)
{
float sqsum = 0.f;
for (int i = 0; i < n_fft; i++)
{
sqsum += window_data[i] * window_data[i];
}
window_data[n_fft] = sqrt(sqsum);
}
}

return 0;
}

int InverseSpectrogram::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
{
// https://github.com/librosa/librosa/blob/main/librosa/core/spectrum.py#L630

// TODO custom window
// TODO output length

const int frames = bottom_blob.h;
const int freqs = bottom_blob.c;
// assert freqs == n_fft or freqs == n_fft / 2 + 1

const int onesided = freqs == n_fft / 2 + 1 ? 1 : 0;

const int outsize = center ? (frames - 1) * hoplen + (n_fft - n_fft / 2 * 2) : (frames - 1) * hoplen + n_fft;

const size_t elemsize = bottom_blob.elemsize;

if (returns == 0)
{
top_blob.create(2, outsize, elemsize, opt.blob_allocator);
}
else
{
top_blob.create(outsize, elemsize, opt.blob_allocator);
}
if (top_blob.empty())
return -100;

Mat window_sumsquare(outsize + n_fft, elemsize, opt.workspace_allocator);
if (window_sumsquare.empty())
return -100;

top_blob.fill(0.f);
window_sumsquare.fill(0.f);

for (int j = 0; j < frames; j++)
{
// collect complex
Mat sp(2, n_fft);
if (onesided == 1)
{
for (int k = 0; k < n_fft / 2 + 1; k++)
{
sp.row(k)[0] = bottom_blob.channel(k).row(j)[0];
sp.row(k)[1] = bottom_blob.channel(k).row(j)[1];
}
for (int k = n_fft / 2 + 1; k < n_fft; k++)
{
sp.row(k)[0] = bottom_blob.channel(n_fft - k).row(j)[0];
sp.row(k)[1] = -bottom_blob.channel(n_fft - k).row(j)[1];
}
}
else
{
for (int k = 0; k < n_fft; k++)
{
sp.row(k)[0] = bottom_blob.channel(k).row(j)[0];
sp.row(k)[1] = bottom_blob.channel(k).row(j)[1];
}
}

if (normalized == 1)
{
float norm = sqrt(n_fft);
for (int i = 0; i < 2 * n_fft; i++)
{
sp[i] *= norm;
}
}
if (normalized == 2)
{
float norm = window_data[n_fft];
for (int i = 0; i < 2 * n_fft; i++)
{
sp[i] *= norm;
}
}

#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < n_fft; i++)
{
// inverse dft
float re = 0.f;
float im = 0.f;
for (int k = 0; k < n_fft; k++)
{
double angle = 2 * 3.14159265358979323846 * i * k / n_fft;

re += sp.row(k)[0] * cosf(angle) - sp.row(k)[1] * sinf(angle);
im += sp.row(k)[0] * sinf(angle) + sp.row(k)[1] * cosf(angle);
}

re /= n_fft;
im /= n_fft;

// apply window
re *= window_data[i];
im *= window_data[i];

int output_index = j * hoplen + i;
if (center == 1)
{
output_index -= n_fft / 2;
}
if (output_index >= 0 && output_index < outsize)
{
// square window
window_sumsquare[output_index] += window_data[i] * window_data[i];

if (returns == 0)
{
top_blob.row(output_index)[0] += re;
top_blob.row(output_index)[1] += im;
}
if (returns == 1)
{
top_blob[output_index] += re;
}
if (returns == 2)
{
top_blob[output_index] += im;
}
}
}
}

// square window norm
if (returns == 0)
{
for (int i = 0; i < outsize; i++)
{
if (window_sumsquare[i] != 0.f)
{
top_blob.row(i)[0] /= window_sumsquare[i];
top_blob.row(i)[1] /= window_sumsquare[i];
}
}
}
else
{
for (int i = 0; i < outsize; i++)
{
if (window_sumsquare[i] != 0.f)
top_blob[i] /= window_sumsquare[i];
}
}

return 0;
}

} // namespace ncnn
Loading

0 comments on commit 0734b65

Please sign in to comment.