From d25a4df314f751a4529846f2580a99ca499d502a Mon Sep 17 00:00:00 2001 From: Taehoon Lee Date: Wed, 11 Apr 2018 16:36:54 +0900 Subject: [PATCH] Fix `in_test_phase` of CNTK and Add its tests --- keras/backend/cntk_backend.py | 11 ++--------- tests/keras/backend/backend_test.py | 9 +++++++++ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/keras/backend/cntk_backend.py b/keras/backend/cntk_backend.py index 2dbfd56a083..932c7ab5097 100644 --- a/keras/backend/cntk_backend.py +++ b/keras/backend/cntk_backend.py @@ -88,15 +88,8 @@ def in_train_phase(x, alt, training=None): return result -def in_test_phase(x, alt): - global _LEARNING_PHASE - # Similar as in_train_phase, use element_select as workaround. - if callable(x) and isinstance(x, C.cntk_py.Function) is False: - x = x() - if callable(alt) and isinstance(alt, C.cntk_py.Function) is False: - alt = alt() - - return C.element_select(learning_phase(), x, alt) +def in_test_phase(x, alt, training=None): + return in_train_phase(alt, x, training=training) def _convert_string_dtype(dtype): diff --git a/tests/keras/backend/backend_test.py b/tests/keras/backend/backend_test.py index 037628cffd5..7d181ea09f3 100644 --- a/tests/keras/backend/backend_test.py +++ b/tests/keras/backend/backend_test.py @@ -1763,6 +1763,15 @@ def test_in_train_phase(self): for training in [True, False]: check_two_tensor_operation('in_train_phase', (3, 3), (2, 2), [KTH, KTF], training=training) + check_two_tensor_operation('in_train_phase', (2, 3), (2, 3), BACKENDS, + training=training) + + def test_in_test_phase(self): + for training in [True, False]: + check_two_tensor_operation('in_test_phase', (3, 3), (2, 2), [KTH, KTF], + training=training) + check_two_tensor_operation('in_test_phase', (2, 3), (2, 3), BACKENDS, + training=training) def test_setfloatx_incorrect_values(self): # Keep track of the old value