-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add extract_sequences
, rfft
and stft
to ops.math
#717
Conversation
Furthermore, we can include following ops in ops.math:
These might be considered as a follow-up? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
keras_core/ops/math.py
Outdated
@@ -1,5 +1,7 @@ | |||
"""Commonly used math operations not included in NumPy.""" | |||
|
|||
import scipy.signal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not depend on scipy if at all possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have refactored the implementation of get_window
The biggest change involves transferring the logic of stft
from keras_core.ops.math
to keras_core.backend.*.math
- tf:
tf.signal.hann_window
andtf.signal.hamming_window
- torch:
torch.hann_window
andtorch.hamming_window
- jax & numpy:
scipy.signal.get_window
keras_core/ops/math.py
Outdated
@keras_core_export("keras_core.ops.rfft") | ||
def rfft(x, n=None): | ||
"""Computes the real-valued fast Fourier transform along the last axis of | ||
the input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The one-line summaries should fit on a single line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. I have enhanced the docstring for the 3 operators (taken from tf)
keras_core/ops/math.py
Outdated
|
||
>>> x = keras_core.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0]) | ||
>>> rfft(x) | ||
(array([10. , -2.5, -2.5]), array([0. , 3.4409548 , 0.81229924])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trim whitespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
keras_core/ops/math.py
Outdated
specified, uses the smallest power of 2 enclosing `frame_length`. | ||
window: A string or the tensor of the window. If `window` is a string, | ||
it is passed to `scipy.signal.get_window` to generate the window | ||
values, which are DFT-even by default. See |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring should be standalone and should not refer to other sources to explain what argument values are valid. They should be listed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the docstring to guide users to select either "hann"
or "hamming"
- tf & torch: only provides hann and hamming window
- numpy & jax: https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html
keras_core/ops/math.py
Outdated
|
||
|
||
@keras_core_export("keras_core.ops.rfft") | ||
def rfft(x, n=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Surely n
should be something like fft_length
? Do we have any equivalent argument anywhere else in the API today?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
I used n
due to the naming convention from numpy but I think fft_length
should be more clearer.
keras_core/ops/math.py
Outdated
|
||
|
||
@keras_core_export("keras_core.ops.frame") | ||
def frame(x, frame_length, frame_step): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% sure we should export frame
to the public API as an op (the function here + the Frame
class). How important is it compared to the other too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have discovered that it is used to build LongFormer (citation> 2k) in huggingface implementation
https://github.com/huggingface/transformers/blob/eec5841e9f440c795fb9292d009675d97a14f983/src/transformers/models/longformer/modeling_tf_longformer.py#L1107
There are about 1.6k search results for tf.signal.frame
https://github.com/search?q=tf.signal.frame&type=code
I believe that exporting it holds value. However, I agree that we should find a more suitable place for it.
Refs:
- tf:
tf.signal
- torch:
torch.Tensor.unfold
- jax & numpy: no direct api
Is a new namespace required for it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the info. OK to export. It is like a 1D equivalent of extract_patches
, right? In this case perhaps we can standardize the API on extract_frames
or extract_windows
(which is also more self-explanatory to a user who does not know what it does).
We might also want to standardize the terminology, e.g. step
-> stride
. We also need to decide whether to refer to frames as "frames", "windows", or "sequences". A very similar utility exists for tf.data (timeseries_dataset_from_array
) which uses "sequence".
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can adopt sequences
and stride
as they have been commonly used in keras before.
So I'm going to update the PR:
def extract_sequences(x, sequence_length, sequence_stride):
...
def stft(x, sequence_length, sequence_stride, fft_length, window="hann", center=True):
...
Is it ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
keras_core/backend/numpy/math.py
Outdated
@@ -1,7 +1,10 @@ | |||
import numpy as np | |||
import scipy.signal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When using the numpy backend, scipy might not be installed. Prefer using keras.utils.module_utils
to create a lazily imported package.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
keras_core/backend/jax/math.py
Outdated
import jax | ||
import jax.numpy as jnp | ||
import scipy.signal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
frame
, rfft
and stft
to ops.mathextract_sequences
, rfft
and stft
to ops.math
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you for the contribution!
Fixes #697
This PR adds
frame
,rfft
andstft
to ops.math with the corresponding docstrings and unit testsThe CPU version of tf needs
atol=1e-5, rtol=1e-5
to passrfft
andstft
tests(note: the GPU version is good with the default
1e-6
)The signatures are following:
The visualization of STFT using keras-core:
The
output.png
is almost the same as the example shown in librosa's dochttps://librosa.org/doc/main/auto_examples/plot_display.html#changing-axis-scales