Skip to content

Commit

Permalink
Fix dtype designation for variable of CNTK and Add its tests
Browse files Browse the repository at this point in the history
  • Loading branch information
taehoonlee committed Apr 11, 2018
1 parent e73199d commit aaf92d2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
7 changes: 6 additions & 1 deletion keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,14 @@ def variable(value, dtype=None, name=None, constraint=None):
shape = value.shape if hasattr(value, 'shape') else ()
if hasattr(value, 'dtype') and value.dtype != dtype and len(shape) > 0:
value = value.astype(dtype)
# cntk will init type based on the value type

# TODO: remove the conversion when cntk supports int32, int64
# https://docs.microsoft.com/en-us/python/api/cntk.variables.parameter
dtype = 'float32' if 'int' in str(dtype) else dtype

v = C.parameter(shape=shape,
init=value,
dtype=dtype,
name=_prepare_name(name, 'variable'))
v._keras_shape = v.shape
v._uses_learning_phase = False
Expand Down
9 changes: 9 additions & 0 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,6 +1807,15 @@ def test_set_floatx(self):
# Restore old value
set_floatx(old_floatx)

def test_dtype(self):
assert K.dtype(K.variable(1, dtype='float64')) == 'float64'
assert K.dtype(K.variable(1, dtype='float32')) == 'float32'
if K.backend() == 'cntk':
with pytest.raises(ValueError):
K.variable(1, dtype='float16')
else:
assert K.dtype(K.variable(1, dtype='float16')) == 'float16'

def test_variable_support_bool_dtype(self):
# Github issue: 7819
if K.backend() == 'tensorflow':
Expand Down

0 comments on commit aaf92d2

Please sign in to comment.