From aaf92d216be62770dbde9038d6f2c4004594f071 Mon Sep 17 00:00:00 2001 From: Taehoon Lee Date: Wed, 11 Apr 2018 16:59:12 +0900 Subject: [PATCH] Fix dtype designation for `variable` of CNTK and Add its tests --- keras/backend/cntk_backend.py | 7 ++++++- tests/keras/backend/backend_test.py | 9 +++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/keras/backend/cntk_backend.py b/keras/backend/cntk_backend.py index 2dbfd56a0839..662aeaa49ca5 100644 --- a/keras/backend/cntk_backend.py +++ b/keras/backend/cntk_backend.py @@ -156,9 +156,14 @@ def variable(value, dtype=None, name=None, constraint=None): shape = value.shape if hasattr(value, 'shape') else () if hasattr(value, 'dtype') and value.dtype != dtype and len(shape) > 0: value = value.astype(dtype) - # cntk will init type based on the value type + + # TODO: remove the conversion when cntk supports int32, int64 + # https://docs.microsoft.com/en-us/python/api/cntk.variables.parameter + dtype = 'float32' if 'int' in str(dtype) else dtype + v = C.parameter(shape=shape, init=value, + dtype=dtype, name=_prepare_name(name, 'variable')) v._keras_shape = v.shape v._uses_learning_phase = False diff --git a/tests/keras/backend/backend_test.py b/tests/keras/backend/backend_test.py index 037628cffd54..4dd365b71bf6 100644 --- a/tests/keras/backend/backend_test.py +++ b/tests/keras/backend/backend_test.py @@ -1807,6 +1807,15 @@ def test_set_floatx(self): # Restore old value set_floatx(old_floatx) + def test_dtype(self): + assert K.dtype(K.variable(1, dtype='float64')) == 'float64' + assert K.dtype(K.variable(1, dtype='float32')) == 'float32' + if K.backend() == 'cntk': + with pytest.raises(ValueError): + K.variable(1, dtype='float16') + else: + assert K.dtype(K.variable(1, dtype='float16')) == 'float16' + def test_variable_support_bool_dtype(self): # Github issue: 7819 if K.backend() == 'tensorflow':