diff --git a/keras_core/layers/merging/base_merge.py b/keras_core/layers/merging/base_merge.py index 7c44c39e2..2837c2924 100644 --- a/keras_core/layers/merging/base_merge.py +++ b/keras_core/layers/merging/base_merge.py @@ -61,7 +61,7 @@ def _compute_elemwise_op_output_shape(self, shape1, shape2): def build(self, input_shape): # Used purely for shape validation. - if not isinstance(input_shape[0], tuple): + if not isinstance(input_shape[0], (tuple, list)): raise ValueError( "A merge layer should be called on a list of inputs. " f"Received: input_shape={input_shape} (not a list of shapes)" diff --git a/keras_core/layers/merging/concatenate.py b/keras_core/layers/merging/concatenate.py index 79e3ca0d8..5d32ebdbe 100644 --- a/keras_core/layers/merging/concatenate.py +++ b/keras_core/layers/merging/concatenate.py @@ -39,7 +39,9 @@ def __init__(self, axis=-1, **kwargs): def build(self, input_shape): # Used purely for shape validation. - if len(input_shape) < 1 or not isinstance(input_shape[0], tuple): + if len(input_shape) < 1 or not isinstance( + input_shape[0], (tuple, list) + ): raise ValueError( "A `Concatenate` layer should be called on a list of " f"at least 1 input. Received: input_shape={input_shape}"