From adbd5374e91b7c610923defb3853c4bf3934059c Mon Sep 17 00:00:00 2001 From: Bea Steers Date: Wed, 25 Sep 2019 18:55:16 -0400 Subject: [PATCH 1/2] added tensorflow.keras layers --- pumpp/feature/base.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pumpp/feature/base.py b/pumpp/feature/base.py index 58ebba3..f997bc9 100644 --- a/pumpp/feature/base.py +++ b/pumpp/feature/base.py @@ -116,8 +116,10 @@ def layers(self, api='keras'): field keys. ''' - if api == 'keras': + if api in ('k', 'keras'): return self.layers_keras() + elif api in ('tf.keras', 'tensorflow.keras', 'tfk'): + return self.layers_tfkeras() elif api in ('tf', 'tensorflow'): return self.layers_tensorflow() else: @@ -144,6 +146,17 @@ def layers_keras(self): return L + def layers_tfkeras(self): + from tensorflow.keras.layers import Input + + L = dict() + for key in self.fields: + L[key] = Input(name=key, + shape=self.fields[key].shape, + dtype=np.dtype(self.fields[key].dtype).name) + + return L + def n_frames(self, duration): '''Get the number of frames for a given duration From 8db32e7318c472eee9b90a35be8faa56bf680894 Mon Sep 17 00:00:00 2001 From: Bea Steers Date: Tue, 8 Oct 2019 14:46:17 -0400 Subject: [PATCH 2/2] added instantiation tests to get covereddd --- tests/test_core.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_core.py b/tests/test_core.py index 5840822..261ce5d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -213,6 +213,10 @@ def test_pump_layers(sr, hop_length): for d1, d2 in zip(L1[k].shape, L2[k].shape): assert str(d1) == str(d2) + # test other input layers + P.layers('tf.keras') + P.layers('tf') + def test_pump_str(sr, hop_length):