Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stft to keras_core.ops #697

Closed
james77777778 opened this issue Aug 11, 2023 · 2 comments · Fixed by #717
Closed

Add stft to keras_core.ops #697

james77777778 opened this issue Aug 11, 2023 · 2 comments · Fixed by #717

Comments

@james77777778
Copy link
Contributor

james77777778 commented Aug 11, 2023

Hi keras-core team

I'm trying to use keras-core to build an audio model. However, I've found that there are certain operations that are missing for audio processing

  • tf.signal.frame
  • tf.signal.rfft
  • tf.signal.stft
    etc...

I have already implemented stft for all backends. Is it good to add to keras-core?
If that's the case, would it be better to separate rfft and frame from stft?

The working example with the the results that are passed with np.testing.assert_allclose
(I have implemented stft with the signature from librosa because I think it is a more mature library for audio processing)

import math

import jax
import librosa as lr
import numpy as np
import scipy.signal
import torch
import tensorflow as tf

from keras_core.backend.jax.core import (
    convert_to_tensor as jax_convert_to_tensor,
)
from keras_core.backend.numpy.core import (
    convert_to_tensor as np_convert_to_tensor,
)
from keras_core.backend.tensorflow.core import (
    convert_to_tensor as tf_convert_to_tensor,
)
from keras_core.backend.torch.core import (
    convert_to_tensor as torch_convert_to_tensor,
)


# librosa
def librosa_stft(signal, n_fft, hop_length, win_length, center, window="hann"):
    return lr.stft(
        y=signal,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
        center=center,
    )


# tensorflow
def tf_stft(signal, n_fft, hop_length, win_length, center, window="hann"):
    signal = tf_convert_to_tensor(signal)
    rank = tf.rank(signal)
    if center:
        if rank == 1:
            padding = [[n_fft // 2, n_fft // 2]]
        else:
            padding = [[0, 0], [n_fft // 2, n_fft // 2]]
        signal = tf.pad(signal, padding, mode="REFLECT")
    if isinstance(window, str):
        if window == "hamming":
            win = tf.signal.hamming_window(win_length, periodic=True)
        elif window == "hann":
            win = tf.signal.hann_window(win_length, periodic=True)
        else:
            raise NotImplementedError()
    else:
        win = tf_convert_to_tensor(window, dtype=signal.dtype)
    left_pad = (n_fft - win_length) // 2
    right_pad = n_fft - win_length - left_pad
    win = tf.pad(win, [[left_pad, right_pad]])
    framed_signals = tf.signal.frame(
        signal, frame_length=n_fft, frame_step=hop_length
    )
    framed_signals *= win
    framed_signals = tf.signal.rfft(framed_signals, [n_fft])
    if rank == 1:
        framed_signals = tf.transpose(framed_signals, perm=[1, 0])
    else:
        framed_signals = tf.transpose(framed_signals, perm=[0, 2, 1])
    return framed_signals


# torch
def torch_stft(signal, n_fft, hop_length, win_length, center, window="hann"):
    signal = torch_convert_to_tensor(signal)
    if isinstance(window, str):
        if window == "hamming":
            win = torch.hamming_window(win_length, periodic=True)
        elif window == "hann":
            win = torch.hann_window(win_length, periodic=True)
        else:
            raise NotImplementedError()
    else:
        win = torch_convert_to_tensor(window)
    return torch.stft(
        signal,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=win,
        center=center,
        return_complex=True,
    )


# jax
def jax_stft(signal, n_fft, hop_length, win_length, center, window="hann"):
    signal = jax_convert_to_tensor(signal)
    rank = len(signal.shape)
    if center:
        padding = [(0, 0) for _ in range(rank)]
        padding[-1] = (n_fft // 2, n_fft // 2)
        signal = jax.numpy.pad(signal, padding, mode="reflect")
    if isinstance(window, str):
        win = jax.numpy.array(
            scipy.signal.get_window(window, win_length), dtype=signal.dtype
        )
    else:
        win = jax_convert_to_tensor(window)
    left_pad = (n_fft - win_length) // 2
    right_pad = n_fft - win_length - left_pad
    win = jax.numpy.pad(win, [[left_pad, right_pad]])

    # equivalent to tf.signal.frame
    # https://gist.github.com/sourabh2k15/80adbf1c5861e727f7698fd66e51be39#file-collect_trace-py-L226
    *batch_shape, signal_length = signal.shape
    batch_shape = list(batch_shape)
    signal = jax.numpy.reshape(
        signal, (math.prod(batch_shape), signal_length, 1)
    )
    framed_signals = jax.lax.conv_general_dilated_patches(
        signal,
        (n_fft,),
        (hop_length,),
        "VALID",
        dimension_numbers=("NTC", "OIT", "NTC"),
    )
    framed_signals = jax.numpy.reshape(
        framed_signals, (*batch_shape, *framed_signals.shape[-2:])
    )
    framed_signals *= win
    framed_signals = jax.numpy.fft.rfft(framed_signals, n_fft)
    if rank == 1:
        framed_signals = jax.numpy.transpose(framed_signals, axes=[1, 0])
    else:
        framed_signals = jax.numpy.transpose(framed_signals, axes=[0, 2, 1])
    return framed_signals


# numpy
def numpy_stft(signal, n_fft, hop_length, win_length, center, window="hann"):
    signal = np_convert_to_tensor(signal)
    rank = len(signal.shape)
    if center:
        padding = [(0, 0) for _ in range(rank)]
        padding[-1] = (n_fft // 2, n_fft // 2)
        signal = np.pad(signal, padding, mode="reflect")
    if isinstance(window, str):
        win = np.array(
            scipy.signal.get_window(window, win_length), dtype=signal.dtype
        )
    else:
        win = np_convert_to_tensor(window)
    left_pad = (n_fft - win_length) // 2
    right_pad = n_fft - win_length - left_pad
    win = np.pad(win, [[left_pad, right_pad]])

    # equivalent to tf.signal.frame
    # https://github.com/scipy/scipy/blob/v1.11.1/scipy/signal/_spectral_py.py#L1928
    *batch_shape, _ = signal.shape
    batch_shape = list(batch_shape)
    shape = signal.shape[:-1] + (
        (signal.shape[-1] - (n_fft - hop_length)) // hop_length,
        n_fft,
    )
    strides = signal.strides[:-1] + (
        hop_length * signal.strides[-1],
        signal.strides[-1],
    )
    framed_signals = np.lib.stride_tricks.as_strided(
        signal, shape=shape, strides=strides
    )
    framed_signals = np.reshape(
        framed_signals, (*batch_shape, *framed_signals.shape[-2:])
    )
    framed_signals = np.multiply(framed_signals, win)
    framed_signals = np.fft.rfft(framed_signals, n_fft)
    if rank == 1:
        framed_signals = np.transpose(framed_signals, axes=[1, 0])
    else:
        framed_signals = np.transpose(framed_signals, axes=[0, 2, 1])
    return framed_signals


waveform, samplerate = lr.load(
    lr.util.example("brahms", hq=True), duration=10.0
)
for n_fft, hop_length, win_length in [
    (2048, 256, 1024),
    (1024, 128, 512),
    (123, 50, 80),
]:
    for center in [True, False]:
        for window_name in ["hann", "hamming"]:
            window = scipy.signal.get_window(window_name, win_length)
            window = window.astype(waveform.dtype)
            args = (waveform, n_fft, hop_length, win_length, center, window)

            # run stft with different backends
            lr_out = librosa_stft(*args)
            tf_out = tf_stft(*args).numpy()
            torch_out = torch_stft(*args).numpy()
            jax_out = np.asarray(jax_stft(*args))
            np_out = numpy_stft(*args)

            # np.testing.assert_allclose(np_out, lr_out, rtol=1e-5, atol=1e-5)
            np.testing.assert_allclose(np_out, tf_out, rtol=1e-5, atol=1e-5)
            np.testing.assert_allclose(np_out, torch_out, rtol=1e-5, atol=1e-5)
            np.testing.assert_allclose(np_out, jax_out, rtol=1e-5, atol=1e-5)
            print(
                f"Passed! n_fft={n_fft}, hop_length={hop_length}, "
                f"win_length={win_length}, center={center}, "
                f"window={window_name}"
            )
@fchollet
Copy link
Contributor

I have already implemented stft for all backends. Is it good to add to keras-core?

Sure, that sounds good. We already have fft and fft2 so it seems fine.

If that's the case, would it be better to separate rfft and frame from stft?

What are our options?

I have implemented stft with the signature from librosa

It's fine to follow the same argument semantics and order. However we will still need to standardize argument names so that they are consistent with Keras argument naming conventions.

@james77777778
Copy link
Contributor Author

What are our options?

Actually stft is a composite of frame and rfft.
So, there are two options:

  1. Implement stft with the logic of frame and rfft inside
  2. First implement frame and rfft, and then leverage them for stft

It's fine to follow the same argument semantics and order. However we will still need to standardize argument names so that they are consistent with Keras argument naming conventions.

Which one do your prefer?

  1. def stft(signals, n_fft, hop_length, win_length, center=True, window="hann") (same as librosa, torch)
  2. def stft(signals, frame_length, frame_step, fft_length, window="hann", center=True) (tf-like)
  3. def stft(signals, window="hann", nperseg=256, noverlap=None, nfft=None, boundary="zeros", padded=True) (scipy-like)

There are certain challenges to implement the same functionality of scipy version for all backends such as detrend, return_onesided and axis.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants