Skip to content

Commit

Permalink
fixed apply_mask bug in losses (#802)
Browse files Browse the repository at this point in the history
* fixed apply_mask bug in losses

* added rank 2 weight + mask unit test

* formatting and test fix

* excluded mask-based test using numpy
  • Loading branch information
jackd authored Aug 28, 2023
1 parent 174fee0 commit aea55a9
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
5 changes: 4 additions & 1 deletion keras_core/losses/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ def apply_mask(sample_weight, mask, dtype, reduction):
# = sum(loss * sample_weight * total / valid) / total
# = sum(loss * sample_weight) / total * total / valid
# = sum(loss * sample_weight) / valid
total = ops.cast(ops.shape(mask)[0], dtype=dtype)
total = ops.cast(
ops.prod(ops.convert_to_tensor(ops.shape(mask), dtype="int32")),
dtype,
)
valid = ops.sum(mask) # May be 0!
mask *= total / (valid + backend.epsilon())

Expand Down
28 changes: 28 additions & 0 deletions keras_core/losses/loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,34 @@ def test_mask_and_sample_weight(self):
loss,
)

@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support masking.",
)
def test_mask_and_sample_weight_rank2(self):
# check loss of inputs with duplicate rows doesn't change
sample_weight = np.array([0.4, 0.3, 0.2, 0.1])
y_true = np.array([1.0, 0.0, 1.0, 0.0])
y_pred = np.array([0.1, 0.2, 0.3, 0.4])
mask = np.array([True, False, True, True])

mask = ops.convert_to_tensor(mask)
y_true = ops.convert_to_tensor(y_true)
y_pred = ops.convert_to_tensor(y_pred)
y_pred._keras_mask = mask

loss_fn = ExampleLoss()
rank1_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)

# duplicate rows
mask = ops.tile(ops.expand_dims(mask, axis=0), (2, 1))
y_true = ops.tile(ops.expand_dims(y_true, axis=0), (2, 1))
y_pred = ops.tile(ops.expand_dims(y_pred, axis=0), (2, 1))
sample_weight = ops.tile(ops.expand_dims(sample_weight, axis=0), (2, 1))
y_pred._keras_mask = mask
rank2_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(rank1_loss, rank2_loss)

# @testing.parametrize(
# "uprank", ["mask", "sample_weight", "y_true", "y_pred"])
# TODO: use parameterization decorator
Expand Down

0 comments on commit aea55a9

Please sign in to comment.