Skip to content

Commit

Permalink
Misc fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 19, 2023
1 parent 1f9dd43 commit 93c5981
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 24 deletions.
2 changes: 0 additions & 2 deletions keras_core/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,8 +576,6 @@ def vstack(xs):


def where(condition, x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.where(condition, x1, x2)


Expand Down
30 changes: 9 additions & 21 deletions keras_core/losses/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@ def call(self, y_true, y_pred):
```
"""

def __init__(self, name=None, reduction="sum_over_batch_size"):
def __init__(self, name=None, reduction="sum_over_batch_size", dtype=None):
self.name = name or auto_name(self.__class__.__name__)
self.reduction = standardize_reduction(reduction)
self.dtype = dtype or backend.floatx()

def __call__(self, y_true, y_pred, sample_weight=None):
in_mask = getattr(y_pred, "_keras_mask", None)

with ops.name_scope(self.name):
dtype = backend.floatx()
y_pred = tree.map_structure(
lambda x: ops.convert_to_tensor(x, dtype=dtype), y_pred
lambda x: ops.convert_to_tensor(x, dtype=self.dtype), y_pred
)
y_true = tree.map_structure(
lambda x: ops.convert_to_tensor(x, dtype=dtype), y_true
lambda x: ops.convert_to_tensor(x, dtype=self.dtype), y_true
)

losses = self.call(y_true, y_pred)
Expand All @@ -58,6 +58,7 @@ def __call__(self, y_true, y_pred, sample_weight=None):
sample_weight=sample_weight,
mask=mask,
reduction=self.reduction,
dtype=self.dtype,
)

def call(self, y_true, y_pred):
Expand Down Expand Up @@ -119,30 +120,21 @@ def reduce_weighted_values(
sample_weight=None,
mask=None,
reduction="sum_over_batch_size",
dtype=None,
):
reduction = standardize_reduction(reduction)

values = ops.convert_to_tensor(values)
values = ops.convert_to_tensor(values, dtype=dtype)
if sample_weight is not None:
sample_weight = ops.convert_to_tensor(sample_weight, dtype=values.dtype)
sample_weight = ops.convert_to_tensor(sample_weight, dtype=dtype)
if mask is not None:
mask = ops.convert_to_tensor(mask, dtype=values.dtype)
mask = ops.convert_to_tensor(mask, dtype=dtype)

# Merge mask and sample weight into sample weight.
sample_weight = apply_mask(
sample_weight, mask, dtype=values.dtype, reduction=reduction
)

# Convert any non float dtypes to floats, to avoid loss of precision
# for dtype like int or bool.
dtype = backend.standardize_dtype(values.dtype)
if not dtype_utils.is_float(dtype):
input_dtype = values.dtype
values = ops.cast(values, "float32")
input_casted = True
else:
input_casted = False

if sample_weight is not None:
sample_weight = ops.cast(sample_weight, values.dtype)
# Update dimensions of `sample_weight` to match `losses`.
Expand All @@ -151,10 +143,6 @@ def reduce_weighted_values(

# Apply reduction function to the individual weighted losses.
loss = reduce_values(values, reduction)

if input_casted:
# Convert the result back to the input type.
loss = ops.cast(loss, input_dtype)
return loss


Expand Down
10 changes: 10 additions & 0 deletions keras_core/losses/loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,13 @@ def test_get_method(self):

with self.assertRaises(ValueError):
losses_module.get("typo")

def test_dtype_arg(self):
y_true = np.array([1.0, 0.0, 1.0, 0.0], dtype="float32")
y_pred = np.array([0.1, 0.2, 0.3, 0.4], dtype="float32")

# Note: we use float16 and not float64 to test this because
# JAX will map float64 to float32.
loss_fn = ExampleLoss(dtype="float16")
loss = loss_fn(y_true, y_pred)
self.assertEqual(backend.standardize_dtype(loss.dtype), "float16")
2 changes: 1 addition & 1 deletion keras_core/ops/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def test_affine_transform(self, interpolation, fill_mode, data_format):
if data_format == "channels_first":
ref_out = np.transpose(ref_out, (0, 3, 1, 2))
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
self.assertAllClose(ref_out, out, atol=0.3)
self.assertAllClose(ref_out, out, atol=1e-3, rtol=1e-3)

@parameterized.parameters(
[
Expand Down

0 comments on commit 93c5981

Please sign in to comment.