Skip to content

Commit

Permalink
Kapre 0.3.6 (#132)
Browse files Browse the repository at this point in the history
* Bugfix/pad end tflite (#131)

* update tensorflow requirements to have optional gpu version

* Update README.md

* Update README.md

* update CI

* remove extras requirement

* Add test for pad_end
Fix bug in tflite pad_end logic

* add missing arg

* comment

* better comment

* revert readme to master

* update travis

* Revert readme

* better readme

* test pad end

* Add tflite documentation

* with

* Add comment and formatting

Co-authored-by: Paul Kendrick <paul.kendrick@musictribe.com>

* bump version to 0.3.6

* reformat

* add release note

Co-authored-by: Paul Kendrick <kenders2000@gmail.com>
Co-authored-by: Paul Kendrick <paul.kendrick@musictribe.com>
  • Loading branch information
3 people committed Nov 14, 2021
1 parent ff6fe77 commit 894cc66
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 8 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,36 @@ model.fit(x, y)

* See the Jupyter notebook at the [example folder](https://github.com/keunwoochoi/kapre/tree/master/examples)

## Tflite compatbility

The `STFT` layer is not tflite compatible (due to `tf.signal.stft`). To create a tflite
compatible model, first train using the normal `kapre` layers then create a new
model replacing `STFT` and `Magnitude` with `STFTTflite`, `MagnitudeTflite`.
Tflite compatible layers are restricted to a batch size of 1 which prevents use
of them during training.

```python
# assumes you have run the one-shot example above.
from kapre import STFTTflite, MagnitudeTflite
model_tflite = Sequential()

model_tflite.add(STFTTflite(n_fft=2048, win_length=2018, hop_length=1024,
window_name=None, pad_end=False,
input_data_format='channels_last', output_data_format='channels_last',
input_shape=input_shape))
model_tflite.add(MagnitudeTflite())
model_tflite.add(MagnitudeToDecibel())
model_tflite.add(Conv2D(32, (3, 3), strides=(2, 2)))
model_tflite.add(BatchNormalization())
model_tflite.add(ReLU())
model_tflite.add(GlobalAveragePooling2D())
model_tflite.add(Dense(10))
model_tflite.add(Softmax())

# load the trained weights into the tflite compatible model.
model_tflite.set_weights(model.get_weights())
```

# Citation

Please cite this paper if you use Kapre for your work.
Expand Down
4 changes: 4 additions & 0 deletions docs/release_note.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Release Note
^^^^^^^^^^^^

* 13 Nov 2021
- 0.3.6
- bugfix/pad end tflite #131

* 18 March 2021
- 0.3.5
- Add `kapre.time_frequency_tflite` which uses tflite for a faster CPU inference.
Expand Down
2 changes: 1 addition & 1 deletion kapre/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.3.5'
__version__ = '0.3.6'
VERSION = __version__

from . import composed
Expand Down
5 changes: 3 additions & 2 deletions kapre/tflite_compatible_stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,10 @@ def stft_tflite(signal, frame_length, frame_step, fft_length, window_fn, pad_end
signal = tf.cast(signal, tf.float32)
if pad_end:
# the number of whole frames
# (NOTE: kenders2000), padding is pre-calculated and thus fixed in graph
length_samples = signal.shape[-1]
num_steps_round_up = tf.math.ceil(length_samples / frame_step)
pad_amount = int((num_steps_round_up * frame_step) - length_samples)
num_steps_round_up = int(np.ceil(length_samples / frame_step))
pad_amount = (num_steps_round_up * frame_step + frame_length - frame_step) - length_samples
signal = tf.pad(signal, tf.constant([[0, 0], [0, 0], [0, pad_amount]]))
# Make the window be shape (1, frame_length) instead of just frame_length
# in an effort to help the tflite broadcast logic.
Expand Down
8 changes: 7 additions & 1 deletion kapre/time_frequency_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ class STFTTflite(STFT):
Ues `stft_tflite` from tflite_compatible_stft.py, this contains a tflite
compatible stft (using a rdft), and `fixed_frame()` to window the audio.
Tflite does not cope with comple types so real and imaginary parts are stored in extra dim.
Ouput shape is now: (batch, channel, time, re/im) or (batch, time, channel, re/im)
Ouput shape is now: (batch, channel, time, re/im) or (batch, time, channel, re/im).
`MagnitudeTflite`, and `PhaseTflite` are versions of the `Magnitude` and `Phase`
layers that account for this extra dimensionality. Currently this layer is
restricted to a batch size of one, for training use the `STFT` layer, and
once complete transfer the weights to a new model, replacing the `STFT` layer
with the `STFTTflite` layer and `Magnitude` and `Phase` layers with
`MagnitudeTflite` and `PhaseTflite` layers.
Additionally, it reshapes the output to be a proper 2D batch.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='kapre',
version='0.3.5',
version='0.3.6',
description='Kapre: Keras Audio Preprocessors. Tensorflow.Keras layers for audio pre-processing in deep learning',
author='Keunwoo Choi',
url='http://github.com/keunwoochoi/kapre/',
Expand Down
7 changes: 4 additions & 3 deletions tests/test_time_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,9 @@ def _get_melgram_model(return_decibel, amin, dynamic_range, input_shape=None):
@pytest.mark.parametrize('data_format', ['default', 'channels_first', 'channels_last'])
@pytest.mark.parametrize('batch_size', [1, 2])
@pytest.mark.parametrize('win_length', [1000, 512])
@pytest.mark.parametrize('pad_end', [False, True])
def test_spectrogram_tflite_correctness(
n_fft, hop_length, n_ch, data_format, batch_size, win_length
n_fft, hop_length, n_ch, data_format, batch_size, win_length, pad_end
):
def _get_stft_model(following_layer=None, tflite_compatible=False):
# compute with kapre
Expand All @@ -282,7 +283,7 @@ def _get_stft_model(following_layer=None, tflite_compatible=False):
win_length=win_length,
hop_length=hop_length,
window_name=None,
pad_end=False,
pad_end=pad_end,
input_data_format=data_format,
output_data_format=data_format,
input_shape=input_shape,
Expand All @@ -296,7 +297,7 @@ def _get_stft_model(following_layer=None, tflite_compatible=False):
win_length=win_length,
hop_length=hop_length,
window_name=None,
pad_end=False,
pad_end=pad_end,
input_data_format=data_format,
output_data_format=data_format,
input_shape=input_shape,
Expand Down

0 comments on commit 894cc66

Please sign in to comment.