Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extract_sequences, rfft and stft to ops.math #717

Merged
merged 15 commits into from
Aug 16, 2023

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Aug 14, 2023

Fixes #697

This PR adds frame, rfft and stft to ops.math with the corresponding docstrings and unit tests
The CPU version of tf needs atol=1e-5, rtol=1e-5 to pass rfft and stft tests
(note: the GPU version is good with the default 1e-6)

The signatures are following:

def frame(x, frame_length, frame_step):
    ...

def rfft(x, fft_length=None):
    ...

# a mixture of tf and librosa/torch
def stft(x, frame_length, frame_step, fft_length, window="hann", center=True):
    ...

The visualization of STFT using keras-core:

import librosa as lr
import matplotlib.pyplot as plt
import numpy as np

from keras_core.ops.math import stft

# load waveform from librosa
waveform, samplerate = lr.load(lr.util.example("trumpet"))
# stft using keras_core
real_out, imag_out = stft(
    waveform, frame_length=2048, frame_step=512, fft_length=2048
)
# visualize with librosa
real_out = np.transpose(real_out)
s_db = lr.amplitude_to_db(np.abs(real_out), ref=np.max)
fig, ax = plt.subplots()
img = lr.display.specshow(s_db, x_axis="time", y_axis="log", ax=ax)
ax.set(title='Using a logarithmic frequency axis')
fig.colorbar(img, ax=ax, format="%+2.f dB")
plt.savefig("output.png")

output

The output.png is almost the same as the example shown in librosa's doc
https://librosa.org/doc/main/auto_examples/plot_display.html#changing-axis-scales

@james77777778
Copy link
Contributor Author

Furthermore, we can include following ops in ops.math:

  • ifft
  • ifft2
  • irfft
  • rfft2 / irfft2
  • istft

These might be considered as a follow-up?

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@@ -1,5 +1,7 @@
"""Commonly used math operations not included in NumPy."""

import scipy.signal
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not depend on scipy if at all possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have refactored the implementation of get_window

The biggest change involves transferring the logic of stft from keras_core.ops.math to keras_core.backend.*.math

  • tf: tf.signal.hann_window and tf.signal.hamming_window
  • torch: torch.hann_window and torch.hamming_window
  • jax & numpy: scipy.signal.get_window

@keras_core_export("keras_core.ops.rfft")
def rfft(x, n=None):
"""Computes the real-valued fast Fourier transform along the last axis of
the input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one-line summaries should fit on a single line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. I have enhanced the docstring for the 3 operators (taken from tf)


>>> x = keras_core.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])
>>> rfft(x)
(array([10. , -2.5, -2.5]), array([0. , 3.4409548 , 0.81229924]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trim whitespace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

specified, uses the smallest power of 2 enclosing `frame_length`.
window: A string or the tensor of the window. If `window` is a string,
it is passed to `scipy.signal.get_window` to generate the window
values, which are DFT-even by default. See
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring should be standalone and should not refer to other sources to explain what argument values are valid. They should be listed here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the docstring to guide users to select either "hann" or "hamming"



@keras_core_export("keras_core.ops.rfft")
def rfft(x, n=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surely n should be something like fft_length? Do we have any equivalent argument anywhere else in the API today?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.
I used n due to the naming convention from numpy but I think fft_length should be more clearer.

@james77777778 james77777778 requested a review from fchollet August 15, 2023 03:07


@keras_core_export("keras_core.ops.frame")
def frame(x, frame_length, frame_step):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure we should export frame to the public API as an op (the function here + the Frame class). How important is it compared to the other too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have discovered that it is used to build LongFormer (citation> 2k) in huggingface implementation
https://github.com/huggingface/transformers/blob/eec5841e9f440c795fb9292d009675d97a14f983/src/transformers/models/longformer/modeling_tf_longformer.py#L1107

There are about 1.6k search results for tf.signal.frame
https://github.com/search?q=tf.signal.frame&type=code

I believe that exporting it holds value. However, I agree that we should find a more suitable place for it.

Refs:

  • tf: tf.signal
  • torch: torch.Tensor.unfold
  • jax & numpy: no direct api

Is a new namespace required for it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the info. OK to export. It is like a 1D equivalent of extract_patches, right? In this case perhaps we can standardize the API on extract_frames or extract_windows (which is also more self-explanatory to a user who does not know what it does).

We might also want to standardize the terminology, e.g. step -> stride. We also need to decide whether to refer to frames as "frames", "windows", or "sequences". A very similar utility exists for tf.data (timeseries_dataset_from_array) which uses "sequence".

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can adopt sequences and stride as they have been commonly used in keras before.
So I'm going to update the PR:

def extract_sequences(x, sequence_length, sequence_stride):
    ...

def stft(x, sequence_length, sequence_stride, fft_length, window="hann", center=True):
    ...

Is it ok?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

@@ -1,7 +1,10 @@
import numpy as np
import scipy.signal
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using the numpy backend, scipy might not be installed. Prefer using keras.utils.module_utils to create a lazily imported package.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

import jax
import jax.numpy as jnp
import scipy.signal
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@james77777778 james77777778 requested a review from fchollet August 16, 2023 01:29
@james77777778 james77777778 changed the title Add frame, rfft and stft to ops.math Add extract_sequences, rfft and stft to ops.math Aug 16, 2023
Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for the contribution!

@fchollet fchollet merged commit ef50328 into keras-team:main Aug 16, 2023
@james77777778 james77777778 deleted the add-stft branch August 16, 2023 04:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add stft to keras_core.ops
2 participants