Skip to content

Commit

Permalink
Tiny bit of battle testing function dict inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Nov 25, 2024
1 parent 1e8426b commit 72e9b9d
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 15 deletions.
44 changes: 43 additions & 1 deletion keras/src/models/functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def test_named_input_dict_io(self):
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))

# Two inputs
# ----
# Two inputs, input is list
input_a = Input(shape=(3,), batch_size=2, name="a")
input_b = Input(shape=(4,), batch_size=2, name="b")
a = layers.Dense(5)(input_a)
Expand All @@ -175,6 +176,46 @@ def test_named_input_dict_io(self):
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))

# ----
# Two inputs, input is dict
model = Functional({"a": input_a, "b": input_b}, outputs)

# Eager call
in_val = {"a": np.random.random((2, 3)), "b": np.random.random((2, 4))}
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))

# Symbolic call
input_a_2 = Input(shape=(3,), batch_size=2)
input_b_2 = Input(shape=(4,), batch_size=2)
in_val = {"a": input_a_2, "b": input_b_2}
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))

# ----
# Two inputs, input is dict with incorrect names
model = Functional({"c": input_a, "d": input_b}, outputs)

# Eager call
in_val = {"c": np.random.random((2, 3)), "d": np.random.random((2, 4))}
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))

# Symbolic call
input_a_2 = Input(shape=(3,), batch_size=2)
input_b_2 = Input(shape=(4,), batch_size=2)
in_val = {"c": input_a_2, "d": input_b_2}
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))

# Now we can't use the input names:
with self.assertRaises(ValueError):
in_val = {
"a": np.random.random((2, 3)),
"b": np.random.random((2, 4)),
}
out_val = model(in_val)

@pytest.mark.requires_trainable_backend
def test_input_dict_with_extra_field(self):
input_a = Input(shape=(3,), batch_size=2, name="a")
Expand Down Expand Up @@ -560,6 +601,7 @@ def test_add_loss(self):
# TODO
pass

@pytest.mark.requires_trainable_backend
def test_layers_setter(self):
inputs = Input(shape=(3,), batch_size=2, name="input")
outputs = layers.Dense(5)(inputs)
Expand Down
9 changes: 5 additions & 4 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,11 @@ def test_functional_list_outputs_dict_losses_invalid_keys(self):
"output_c": "binary_crossentropy",
},
)

# Fit the model to make sure compile_metrics are built
with self.assertRaisesRegex(
KeyError,
"in the `loss` argument, can't be found in the model's output",
ValueError,
"Expected keys",
):
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)

Expand All @@ -680,8 +681,8 @@ def test_functional_list_outputs_dict_losses_no_output_names(self):
)
# Fit the model to make sure compile_metrics are built
with self.assertRaisesRegex(
KeyError,
"in the `loss` argument, can't be found in the model's output",
ValueError,
"Expected keys",
):
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)

Expand Down
24 changes: 14 additions & 10 deletions keras/src/trainers/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,22 +565,26 @@ def key_check_fn(key, objs):
)

def build(self, y_true, y_pred):
loss = self._user_loss
loss_weights = self._user_loss_weights
flat_output_names = self.output_names
if (
self.output_names
and isinstance(self._user_loss, dict)
and not isinstance(y_pred, dict)
):
loss = [self._user_loss[name] for name in self.output_names]
if isinstance(self._user_loss_weights, dict):
loss_weights = [
self._user_loss_weights[name] for name in self.output_names
]
if set(self.output_names) == set(self._user_loss.keys()):
loss = [self._user_loss[name] for name in self.output_names]
if isinstance(self._user_loss_weights, dict):
loss_weights = [
self._user_loss_weights[name]
for name in self.output_names
]
else:
loss_weights = self._user_loss_weights
else:
loss = self._user_loss
loss_weights = self._user_loss_weights
flat_output_names = self.output_names
raise ValueError(
f"Expected keys {self.output_names} in loss dict, but found "
f"loss.keys()={list(self._user_loss.keys())}"
)

# Pytree leaf container
class WeightedLoss:
Expand Down

0 comments on commit 72e9b9d

Please sign in to comment.