diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 6300b458e21c..1286b78b5274 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -56,6 +56,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@8a470fddafa5cbb6266ee11b37ef4d8aae19c571 # v3.24.6 + uses: github/codeql-action/upload-sarif@1b1aada464948af03b950897e5eb522f92603cc2 # v3.24.9 with: sarif_file: results.sarif diff --git a/check_fp8_mnist.py b/check_fp8_mnist.py new file mode 100644 index 000000000000..3b0e80f32f3d --- /dev/null +++ b/check_fp8_mnist.py @@ -0,0 +1,73 @@ +import argparse + +import numpy as np + +import keras +from keras import layers +from keras import models + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--fp8", action="store_true") + parser.add_argument("--einsum", action="store_true") + return parser.parse_args() + + +class Classifier(models.Model): + def __init__(self, use_fp8=False): + super().__init__() + inputs = layers.Input(shape=[28, 28, 1]) + x = layers.Flatten()(inputs) + x = layers.Dense( + 64, activation="relu", use_bias=False, use_fp8=use_fp8 + )(x) + x = layers.Dense( + 64, activation="relu", use_bias=False, use_fp8=use_fp8 + )(x) + outputs = layers.Dense( + 10, activation="softmax", use_bias=False, use_fp8=use_fp8 + )(x) + super().__init__(inputs, outputs) + + +class Classifier2(models.Model): + def __init__(self, use_fp8=False): + super().__init__() + inputs = layers.Input(shape=[28, 28, 1]) + x = layers.Flatten()(inputs) + x = layers.EinsumDense( + "ab,bc->ac", output_shape=[64], activation="relu", use_fp8=use_fp8 + )(x) + x = layers.EinsumDense( + "ab,bc->ac", output_shape=[64], activation="relu", use_fp8=use_fp8 + )(x) + outputs = layers.EinsumDense( + "ab,bc->ac", + output_shape=[10], + activation="softmax", + use_fp8=use_fp8, + )(x) + super().__init__(inputs, outputs) + + +args = get_args() +if args.einsum: + model = Classifier2(use_fp8=args.fp8) +else: + model = Classifier(use_fp8=args.fp8) +num_classes = 10 +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() +x_train = x_train.astype("float32") / 255 +x_test = x_test.astype("float32") / 255 +x_train = np.expand_dims(x_train, -1) +x_test = np.expand_dims(x_test, -1) +y_train = keras.utils.to_categorical(y_train, num_classes) +y_test = keras.utils.to_categorical(y_test, num_classes) + +model.compile( + loss="categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"], +) +model.fit(x_train, y_train, batch_size=128, epochs=10, validation_split=0.1) diff --git a/keras/backend/common/dtypes.py b/keras/backend/common/dtypes.py index 5a92821f21ec..dcb922db2156 100644 --- a/keras/backend/common/dtypes.py +++ b/keras/backend/common/dtypes.py @@ -1,17 +1,11 @@ import functools -from keras import backend from keras.api_export import keras_export -from keras.backend.common.variables import ALLOWED_DTYPES +from keras.backend import config from keras.backend.common.variables import standardize_dtype -""" -We adapted the type promotion lattice from JAX. Ref: -https://github.com/google/jax/blob/main/jax/_src/dtypes.py -""" - -BOOL_TYPES = ["bool"] -INT_TYPES = [ +BOOL_TYPES = ("bool",) +INT_TYPES = ( "uint8", "uint16", "uint32", @@ -20,9 +14,44 @@ "int16", "int32", "int64", -] -FLOAT_TYPES = ["bfloat16", "float16", "float32", "float64"] -WEAK_TYPES = ["int", "float"] +) +FLOAT_TYPES = ("bfloat16", "float16", "float32", "float64") +WEAK_TYPES = ("int", "float") +# We need to separate float8 from float because there are no implicit +# conversions from float8 dtypes to other dtypes. +# Ref: https://github.com/google/jax/issues/16705 +FLOAT8_TYPES = ("float8_e4m3fn", "float8_e5m2") + +# All supported dtypes in Keras +ALLOWED_DTYPES = ( + "float16", + "float32", + "float64", + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "bfloat16", + "bool", + "string", + "float8_e4m3fn", + "float8_e5m2", +) +PYTHON_DTYPES_MAP = { + bool: "bool", + int: "int64" if config.backend() == "tensorflow" else "int32", + float: "float32", + str: "string", + # special case for string value + "int": "int64" if config.backend() == "tensorflow" else "int32", +} + +# We adapted the type promotion lattice from JAX. Ref: +# https://github.com/google/jax/blob/main/jax/_src/dtypes.py def _type_promotion_lattice(): @@ -168,7 +197,7 @@ def _respect_weak_type(dtype, weak_type): @functools.lru_cache(maxsize=None) def _resolve_weak_type(dtype, precision="32"): """Resolve weak type by the precision of `backend.floatx()`.""" - extended_allowed_dtypes = ALLOWED_DTYPES.union(WEAK_TYPES) + extended_allowed_dtypes = set(ALLOWED_DTYPES).union(WEAK_TYPES) if dtype not in extended_allowed_dtypes: raise ValueError( "Invalid value for argument `dtype`. Expected one of " @@ -234,7 +263,7 @@ def _lattice_result_type(*args): out_weak_type = any(out_dtype is t for t in WEAK_TYPES) out_weak_type = (out_dtype != "bool") and out_weak_type - precision = backend.floatx()[-2:] + precision = config.floatx()[-2:] if out_weak_type: out_dtype = _resolve_weak_type(out_dtype, precision=precision) return out_dtype @@ -270,7 +299,13 @@ def result_type(*dtypes): if len(dtypes) == 0: # If no dtypes provided, default to floatx, this matches # `ops.convert_to_tensor([])` - return backend.floatx() + return config.floatx() + for dtype in dtypes: + if dtype in FLOAT8_TYPES: + raise ValueError( + "There is no implicit conversions from float8 dtypes to others." + f" You must cast it internally. Received: {dtypes}" + ) return _lattice_result_type( - *(backend.floatx() if arg is None else arg for arg in dtypes), + *(config.floatx() if arg is None else arg for arg in dtypes), ) diff --git a/keras/backend/common/dtypes_test.py b/keras/backend/common/dtypes_test.py index f2f91154da9f..200e4fa2b9e3 100644 --- a/keras/backend/common/dtypes_test.py +++ b/keras/backend/common/dtypes_test.py @@ -5,7 +5,6 @@ from keras import backend from keras import ops from keras.backend.common import dtypes -from keras.backend.common.variables import ALLOWED_DTYPES from keras.testing import test_case from keras.testing.test_utils import named_product @@ -18,14 +17,18 @@ class DtypesTest(test_case.TestCase, parameterized.TestCase): # TODO: torch doesn't support uint64. ALL_DTYPES = [] - for x in ALLOWED_DTYPES: + for x in dtypes.ALLOWED_DTYPES: if x not in ["string", "uint64"]: x = str(to_torch_dtype(x)).split(".")[-1] if x not in ALL_DTYPES: # skip duplicates created by remapping ALL_DTYPES.append(x) ALL_DTYPES += [None] else: - ALL_DTYPES = [x for x in ALLOWED_DTYPES if x != "string"] + [None] + ALL_DTYPES = [x for x in dtypes.ALLOWED_DTYPES if x != "string"] + [ + None + ] + # Remove float8 dtypes for the following tests + ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES] def setUp(self): from jax.experimental import enable_x64 @@ -217,3 +220,13 @@ def test_least_upper_bound_with_no_common_upper_bound(self): ValueError, "no available implicit dtype promotion path" ): dtypes._least_upper_bound("test_dtype1", "test_dtype2") + + def test_invalid_float8_dtype(self): + with self.assertRaisesRegex( + ValueError, "There is no implicit conversions from float8 dtypes" + ): + dtypes.result_type("float8_e4m3fn", "bfloat16") + with self.assertRaisesRegex( + ValueError, "There is no implicit conversions from float8 dtypes" + ): + dtypes.result_type("float8_e5m2", "bfloat16") diff --git a/keras/backend/common/variables.py b/keras/backend/common/variables.py index 9787eeb2f83b..793ff1d8a5ad 100644 --- a/keras/backend/common/variables.py +++ b/keras/backend/common/variables.py @@ -2,6 +2,7 @@ from keras.api_export import keras_export from keras.backend import config +from keras.backend.common import dtypes from keras.backend.common import global_state from keras.backend.common.name_scope import current_path from keras.backend.common.stateless_scope import get_stateless_scope @@ -397,40 +398,13 @@ def initialize_all_variables(): global_state.set_global_attribute("uninitialized_variables", []) -ALLOWED_DTYPES = { - "float16", - "float32", - "float64", - "uint8", - "uint16", - "uint32", - "uint64", - "int8", - "int16", - "int32", - "int64", - "bfloat16", - "bool", - "string", -} - -PYTHON_DTYPES_MAP = { - bool: "bool", - int: "int64" if config.backend() == "tensorflow" else "int32", - float: "float32", - str: "string", - # special case for string value - "int": "int64" if config.backend() == "tensorflow" else "int32", -} - - @keras_export( ["keras.utils.standardize_dtype", "keras.backend.standardize_dtype"] ) def standardize_dtype(dtype): if dtype is None: return config.floatx() - dtype = PYTHON_DTYPES_MAP.get(dtype, dtype) + dtype = dtypes.PYTHON_DTYPES_MAP.get(dtype, dtype) if hasattr(dtype, "name"): dtype = dtype.name elif hasattr(dtype, "__str__") and ( @@ -440,7 +414,7 @@ def standardize_dtype(dtype): elif hasattr(dtype, "__name__"): dtype = dtype.__name__ - if dtype not in ALLOWED_DTYPES: + if dtype not in dtypes.ALLOWED_DTYPES: raise ValueError(f"Invalid dtype: {dtype}") return dtype diff --git a/keras/backend/common/variables_test.py b/keras/backend/common/variables_test.py index 1df2597c3541..6bd21d37fbc2 100644 --- a/keras/backend/common/variables_test.py +++ b/keras/backend/common/variables_test.py @@ -4,7 +4,7 @@ from keras import backend from keras import initializers -from keras.backend.common.variables import ALLOWED_DTYPES +from keras.backend.common import dtypes from keras.backend.common.variables import AutocastScope from keras.backend.common.variables import KerasVariable from keras.backend.common.variables import shape_equal @@ -156,7 +156,7 @@ def test_autocasting(self): self.assertEqual(backend.standardize_dtype(v.value.dtype), "float32") @parameterized.parameters( - *((dtype for dtype in ALLOWED_DTYPES if dtype != "string")) + *((dtype for dtype in dtypes.ALLOWED_DTYPES if dtype != "string")) ) def test_standardize_dtype(self, dtype): """Tests standardize_dtype for all ALLOWED_DTYPES except string.""" diff --git a/keras/backend/tensorflow/nn.py b/keras/backend/tensorflow/nn.py index 4d6433aed51e..807f0206439a 100644 --- a/keras/backend/tensorflow/nn.py +++ b/keras/backend/tensorflow/nn.py @@ -40,8 +40,8 @@ def softsign(x): return tf.nn.softsign(x) -def silu(x, beta=1.0): - return tf.nn.silu(x, beta=beta) +def silu(x): + return tf.nn.silu(x) def log_sigmoid(x): diff --git a/keras/backend/torch/core.py b/keras/backend/torch/core.py index 4794fd9c9726..76dcc95db6f7 100644 --- a/keras/backend/torch/core.py +++ b/keras/backend/torch/core.py @@ -42,6 +42,8 @@ "int64": torch.int64, "bfloat16": torch.bfloat16, "bool": torch.bool, + "float8_e4m3fn": torch.float8_e4m3fn, + "float8_e5m2": torch.float8_e5m2, } diff --git a/keras/backend/torch/nn.py b/keras/backend/torch/nn.py index 5d0ea6cc271d..af7aba02ddd0 100644 --- a/keras/backend/torch/nn.py +++ b/keras/backend/torch/nn.py @@ -47,9 +47,9 @@ def softsign(x): return tnn.softsign(x) -def silu(x, beta=1.0): +def silu(x): x = convert_to_tensor(x) - return x * sigmoid(beta * x) + return tnn.silu(x) def log_sigmoid(x): diff --git a/keras/export/export_lib.py b/keras/export/export_lib.py index e0ca00e7d852..6a4df72202cd 100644 --- a/keras/export/export_lib.py +++ b/keras/export/export_lib.py @@ -320,7 +320,10 @@ def stateless_fn(variables, *args, **kwargs): def stateful_fn(*args, **kwargs): return jax2tf_stateless_fn( - self._tf_trackable.variables, *args, **kwargs + # Change the trackable `ListWrapper` to a plain `list` + list(self._tf_trackable.variables), + *args, + **kwargs, ) # Note: we truncate the number of parameters to what is diff --git a/keras/layers/core/einsum_dense.py b/keras/layers/core/einsum_dense.py index e8022d915b5f..5b8032527a20 100644 --- a/keras/layers/core/einsum_dense.py +++ b/keras/layers/core/einsum_dense.py @@ -1,6 +1,7 @@ import re import string +import ml_dtypes import numpy as np from keras import activations @@ -125,6 +126,8 @@ def __init__( kernel_constraint=None, bias_constraint=None, lora_rank=None, + amax_history_length=1024, + use_fp8=False, **kwargs, ): super().__init__(**kwargs) @@ -142,6 +145,8 @@ def __init__( self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.lora_rank = lora_rank + self.amax_history_length = amax_history_length + self.use_fp8 = use_fp8 self.lora_enabled = False def build(self, input_shape): @@ -184,6 +189,8 @@ def build(self, input_shape): self.built = True if self.lora_rank: self.enable_lora(self.lora_rank) + if self.use_fp8: + self.fp8_build(input_shape) @property def kernel(self): @@ -550,6 +557,175 @@ def _get_kernel_with_merged_lora(self): kernel_scale = None return kernel_value, kernel_scale + """FP8-related methods""" + + def fp8_build(self, input_shape): + amax_history_kwargs = { + "shape": (self.amax_history_length,), + "initializer": "zeros", + "trainable": False, + "autocast": False, + } + scale_kwargs = { + "shape": (), + "initializer": "ones", + "trainable": False, + "autocast": False, + } + self.inputs_amax_history = self.add_weight( + name="inputs_amax_history", **amax_history_kwargs + ) + self.inputs_scale = self.add_weight(name="inputs_scale", **scale_kwargs) + self.kernel_amax_history = self.add_weight( + name="kernel_amax_history", **amax_history_kwargs + ) + self.kernel_scale = self.add_weight(name="kernel_scale", **scale_kwargs) + self.outputs_grad_amax_history = self.add_weight( + name="outputs_grad_amax_history", **amax_history_kwargs + ) + self.outputs_grad_scale = self.add_weight( + name="outputs_grad_scale", **scale_kwargs + ) + if backend.backend() == "jax": + # For unknown reason, we need to set these weights to be trainable + # to enable assignment in `outputs_qdq` for jax + self.outputs_grad_amax_history.trainable = True + self.outputs_grad_scale.trainable = True + + def fp8_call(self, inputs): + if self.lora_enabled: + raise NotImplementedError( + "Currently, `fp8_call` doesn't support LoRA" + ) + + def compute_scale(amax, scale, dtype_max, margin=0): + """Default function to convert amax to scaling factor.""" + exp = ops.floor(ops.log2(ops.divide(dtype_max, amax))) - margin + sf = ops.round(ops.power(2.0, ops.abs(exp))) + sf = ops.where(amax > 0.0, sf, scale) + sf = ops.where(ops.isfinite(amax), sf, scale) + sf = ops.where(exp < 0.0, ops.reciprocal(sf), sf) + # The scaling factor we need equals to the notion of "scale_inv" in + # TransformerEngine. So, we convert the sf to its reciprocal. + return ops.reciprocal(sf) + + def compute_scale_and_amax_history( + x, scale, amax_history, quantized_dtype + ): + x = ops.convert_to_tensor(x) + scale = ops.convert_to_tensor(scale) + amax_history = ops.convert_to_tensor(amax_history) + quantized_dtype_max = float(ml_dtypes.finfo(quantized_dtype).max) + + amax_update = ops.cast(ops.max(ops.abs(x)), scale.dtype) + amax_history_update = ops.scatter_update( + ops.roll(amax_history, shift=-1), + [[0]], + ops.reshape(amax_update, [1]), + ) + + amax_from_history = ops.max(amax_history) + scale_update = compute_scale( + amax_from_history, scale, quantized_dtype_max + ) + return amax_history_update, scale_update + + @ops.custom_gradient + def inputs_qdq(inputs): + """Quantize-dequantize the inputs but not its gradient.""" + qdq_inputs = quantizers.quantize_and_dequantize( + inputs, + ops.convert_to_tensor(self.inputs_scale), + "float8_e4m3fn", + self.compute_dtype, + ) + + def grad(*args, upstream=None): + if upstream is None: + (upstream,) = args + return upstream + + return qdq_inputs, grad + + @ops.custom_gradient + def kernel_qdq(kernel): + """Quantize-dequantize the kernel but not its gradient.""" + qdq_kernel = quantizers.quantize_and_dequantize( + kernel, + ops.convert_to_tensor(self.kernel_scale), + "float8_e4m3fn", + self.compute_dtype, + ) + + def grad(*args, upstream=None, variables=None): + if upstream is None: + (upstream,) = args + return upstream + + return qdq_kernel, grad + + @ops.custom_gradient + def outputs_qdq(outputs, scale, amax_history): + """Quantize-dequantize the output gradient but not the output.""" + + def grad(*args, upstream=None): + if upstream is None: + (upstream,) = args + qdq_upstream = quantizers.quantize_and_dequantize( + upstream, scale, "float8_e5m2", self.compute_dtype + ) + amax_history_update, scale_update = ( + compute_scale_and_amax_history( + upstream, scale, amax_history, "float8_e5m2" + ) + ) + self.outputs_grad_scale.assign(scale_update) + self.outputs_grad_amax_history.assign(amax_history_update) + return qdq_upstream, None, None + + return outputs, grad + + x = ops.einsum( + self.equation, + inputs_qdq(inputs), + kernel_qdq(ops.convert_to_tensor(self._kernel)), + ) + # `outputs_qdq` is placed immediately after `ops.matmul` for the sake + # of pattern matching in gemm_rewrite. That way, the qdq will be + # adjacent to the corresponding matmul_bprop in the bprop. + x = outputs_qdq( + x, + ops.convert_to_tensor(self.outputs_grad_scale), + ops.convert_to_tensor(self.outputs_grad_amax_history), + ) + if self.bias is not None: + # Under non-mixed precision cases, F32 bias has to be converted to + # BF16 first to get the biasAdd fusion support. ref. PR + # https://github.com/tensorflow/tensorflow/pull/60306 + bias = self.bias + if self.dtype_policy.compute_dtype == "float32": + bias_bf16 = ops.cast(bias, "bfloat16") + bias = ops.cast(bias_bf16, bias.dtype) + x = ops.add(x, bias) + if self.activation is not None: + x = self.activation(x) + + # Update fp8 stats + amax_history_update, scale_update = compute_scale_and_amax_history( + inputs, self.inputs_scale, self.inputs_amax_history, "float8_e4m3fn" + ) + self.inputs_scale.assign(scale_update) + self.inputs_amax_history.assign(amax_history_update) + amax_history_update, scale_update = compute_scale_and_amax_history( + self.kernel, + self.kernel_scale, + self.kernel_amax_history, + "float8_e4m3fn", + ) + self.kernel_scale.assign(scale_update) + self.kernel_amax_history.assign(amax_history_update) + return x + def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape): """Analyzes an einsum string to determine the required weight shape.""" diff --git a/keras/losses/__init__.py b/keras/losses/__init__.py index 15bfc50f5067..df63ca37773d 100644 --- a/keras/losses/__init__.py +++ b/keras/losses/__init__.py @@ -4,6 +4,7 @@ from keras.losses.losses import CategoricalCrossentropy from keras.losses.losses import CategoricalHinge from keras.losses.losses import CosineSimilarity +from keras.losses.losses import Dice from keras.losses.losses import Hinge from keras.losses.losses import Huber from keras.losses.losses import KLDivergence @@ -21,6 +22,7 @@ from keras.losses.losses import categorical_hinge from keras.losses.losses import cosine_similarity from keras.losses.losses import ctc +from keras.losses.losses import dice from keras.losses.losses import hinge from keras.losses.losses import huber from keras.losses.losses import kl_divergence @@ -56,6 +58,8 @@ Hinge, SquaredHinge, CategoricalHinge, + # Image segmentation + Dice, # Probabilistic kl_divergence, poisson, @@ -74,6 +78,8 @@ hinge, squared_hinge, categorical_hinge, + # Image segmentation + dice, } ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} diff --git a/keras/losses/losses.py b/keras/losses/losses.py index 3319dd339127..af4e68a0846e 100644 --- a/keras/losses/losses.py +++ b/keras/losses/losses.py @@ -1934,3 +1934,69 @@ def ctc(y_true, y_pred): return ops.ctc_loss( y_true, y_pred, label_length, input_length, mask_index=0 ) + + +@keras_export("keras.losses.Dice") +class Dice(LossFunctionWrapper): + """Computes the Dice loss value between `y_true` and `y_pred`. + + Formula: + ```python + loss = 1 - (2 * sum(y_true * y_pred)) / (sum(y_true) + sum(y_pred)) + ``` + + Args: + y_true: tensor of true targets. + y_pred: tensor of predicted targets. + + Returns: + Dice loss value. + """ + + def __init__( + self, + reduction="sum_over_batch_size", + name="dice", + ): + super().__init__( + dice, + name=name, + reduction=reduction, + ) + + def get_config(self): + return { + "name": self.name, + "reduction": self.reduction, + } + + +@keras_export("keras.losses.dice") +def dice(y_true, y_pred): + """Computes the Dice loss value between `y_true` and `y_pred`. + + Formula: + ```python + loss = 1 - (2 * sum(y_true * y_pred)) / (sum(y_true) + sum(y_pred)) + ``` + + Args: + y_true: tensor of true targets. + y_pred: tensor of predicted targets. + + Returns: + Dice loss value. + """ + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.cast(y_true, y_pred.dtype) + + inputs = ops.reshape(y_true, [-1]) + targets = ops.reshape(y_pred, [-1]) + + intersection = ops.sum(inputs * targets) + dice = ops.divide( + 2.0 * intersection, + ops.sum(y_true) + ops.sum(y_pred) + backend.epsilon(), + ) + + return 1 - dice diff --git a/keras/losses/losses_test.py b/keras/losses/losses_test.py index 4d6c580cba82..c49b86b3bc4b 100644 --- a/keras/losses/losses_test.py +++ b/keras/losses/losses_test.py @@ -1388,3 +1388,24 @@ def test_correctness(self): y_true = np.array(([[1, 2, 1, 0], [1, 2, 0, 2]])) output = losses.CTC()(y_true, logits) self.assertAllClose(output, 4.389582) + + +class DiceTest(testing.TestCase): + def test_config(self): + self.run_class_serialization_test(losses.Dice(name="mydice")) + + def test_correctness(self): + y_true = np.array(([[1, 2], [1, 2]])) + y_pred = np.array(([[4, 1], [6, 1]])) + output = losses.Dice()(y_true, y_pred) + self.assertAllClose(output, -0.55555546) + + def test_binary_segmentation(self): + y_true = np.array( + ([[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]) + ) + y_pred = np.array( + ([[0, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1], [1, 0, 1, 1]]) + ) + output = losses.Dice()(y_true, y_pred) + self.assertAllClose(output, 0.77777773) diff --git a/keras/models/functional.py b/keras/models/functional.py index 1e24af9205c8..5fe9838cd11f 100644 --- a/keras/models/functional.py +++ b/keras/models/functional.py @@ -15,6 +15,7 @@ from keras.ops.function import Function from keras.ops.function import _build_map from keras.ops.function import make_node_key +from keras.ops.node import KerasHistory from keras.ops.node import Node from keras.saving import serialization_lib from keras.utils import tracking @@ -391,7 +392,7 @@ def get_config(self): if node_key in self._nodes: # The node is relevant to the model: # add to filtered_inbound_nodes. - node_data = serialize_node(node, node_reindexing_map) + node_data = serialize_node(node, own_nodes=self._nodes) if node_data is not None: filtered_inbound_nodes.append(node_data) @@ -594,13 +595,33 @@ def unpack_singleton(x): return x -def serialize_node(node, node_reindexing_map): +def serialize_node(node, own_nodes=()): if not node.input_tensors: # Does not need to be serialized. return + def serialize_keras_tensor(x): + # Serialize KerasTensor while converting + # node indices to only include nodes relevant to `own_nodes`. + if isinstance(x, backend.KerasTensor): + operation, node_index, tensor_index = x._keras_history + irrelevant_node_count = 0 + for node in operation._inbound_nodes[:node_index]: + if node not in own_nodes: + irrelevant_node_count += 1 + x._keras_history = KerasHistory( + operation, node_index - irrelevant_node_count, tensor_index + ) + serialized = serialization_lib.serialize_keras_object(x) + x._keras_history = KerasHistory(operation, node_index, tensor_index) + return serialized + return x + args = node.arguments.args kwargs = node.arguments.kwargs + + args = tree.map_structure(serialize_keras_tensor, args) + kwargs = tree.map_structure(serialize_keras_tensor, kwargs) return { "args": serialization_lib.serialize_keras_object(args), "kwargs": serialization_lib.serialize_keras_object(kwargs), diff --git a/keras/ops/core_test.py b/keras/ops/core_test.py index 552089141930..1c09bd2f385d 100644 --- a/keras/ops/core_test.py +++ b/keras/ops/core_test.py @@ -12,8 +12,8 @@ from keras import ops from keras import optimizers from keras import testing +from keras.backend.common import dtypes from keras.backend.common.keras_tensor import KerasTensor -from keras.backend.common.variables import ALLOWED_DTYPES from keras.ops import core from keras.utils import tree @@ -465,6 +465,27 @@ def test_cast(self): self.assertEqual(x.shape, y.shape) self.assertTrue(hasattr(y, "_keras_history")) + @parameterized.named_parameters( + ("float8_e4m3fn", "float8_e4m3fn"), ("float8_e5m2", "float8_e5m2") + ) + def test_cast_float8(self, float8_dtype): + # Cast to float8 and cast back + x = ops.ones((2,), dtype="float32") + y = ops.cast(x, float8_dtype) + self.assertIn(float8_dtype, str(y.dtype)) + x = ops.cast(y, "float32") + self.assertIn("float32", str(x.dtype)) + + x = ops.KerasTensor((2,), dtype="float32") + y = ops.cast(x, float8_dtype) + self.assertEqual(float8_dtype, y.dtype) + self.assertEqual(x.shape, y.shape) + self.assertTrue(hasattr(y, "_keras_history")) + x = ops.cast(y, "float32") + self.assertEqual("float32", x.dtype) + self.assertEqual(x.shape, y.shape) + self.assertTrue(hasattr(x, "_keras_history")) + def test_vectorized_map(self): def fn(x): return x + 1 @@ -555,7 +576,7 @@ class CoreOpsDtypeTest(testing.TestCase, parameterized.TestCase): # resulting in different behavior between JAX and Keras. Currently, we # are skipping the test for uint64 ALL_DTYPES = [ - x for x in ALLOWED_DTYPES if x not in ["string", "uint64"] + x for x in dtypes.ALLOWED_DTYPES if x not in ["string", "uint64"] ] + [None] if backend.backend() == "torch": @@ -563,6 +584,8 @@ class CoreOpsDtypeTest(testing.TestCase, parameterized.TestCase): ALL_DTYPES = [ x for x in ALL_DTYPES if x not in ["uint16", "uint32", "uint64"] ] + # Remove float8 dtypes for the following tests + ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES] @parameterized.parameters( ((), None, backend.floatx()), diff --git a/keras/ops/function.py b/keras/ops/function.py index 37cc8a2281c7..48a21a8bd3b8 100644 --- a/keras/ops/function.py +++ b/keras/ops/function.py @@ -203,7 +203,7 @@ def map_graph(inputs, outputs): Returns: A tuple `(nodes, nodes_by_depth, operations, operations_by_depth)`. - - network_nodes: dict mapping unique node keys to the Node instances + - nodes: set of Node instances - nodes_by_depth: dict mapping ints (depth) to lists of node instances. - operations: list of Operation instances. - operations_by_depth: dict mapping ints (depth) to lists of Operation diff --git a/keras/ops/math_test.py b/keras/ops/math_test.py index 937b10095744..be952620e36b 100644 --- a/keras/ops/math_test.py +++ b/keras/ops/math_test.py @@ -7,8 +7,8 @@ from keras import backend from keras import testing +from keras.backend.common import dtypes from keras.backend.common.keras_tensor import KerasTensor -from keras.backend.common.variables import ALLOWED_DTYPES from keras.ops import math as kmath @@ -869,10 +869,10 @@ class MathDtypeTest(testing.TestCase, parameterized.TestCase): # resulting in different behavior between JAX and Keras. Currently, we # are skipping the test for uint64 ALL_DTYPES = [ - x for x in ALLOWED_DTYPES if x not in ["string", "uint64"] + x for x in dtypes.ALLOWED_DTYPES if x not in ["string", "uint64"] ] + [None] - INT_DTYPES = [x for x in ALLOWED_DTYPES if "int" in x and x != "uint64"] - FLOAT_DTYPES = [x for x in ALLOWED_DTYPES if "float" in x] + INT_DTYPES = [x for x in dtypes.INT_TYPES if x != "uint64"] + FLOAT_DTYPES = dtypes.FLOAT_TYPES if backend.backend() == "torch": # TODO: torch doesn't support uint16, uint32 and uint64 diff --git a/keras/ops/nn_test.py b/keras/ops/nn_test.py index a048d242d5a8..4b3b71dd3895 100644 --- a/keras/ops/nn_test.py +++ b/keras/ops/nn_test.py @@ -9,9 +9,9 @@ from keras import models from keras import ops from keras import testing +from keras.backend.common import dtypes from keras.backend.common import standardize_dtype from keras.backend.common.keras_tensor import KerasTensor -from keras.backend.common.variables import ALLOWED_DTYPES from keras.layers.convolutional.conv_test import np_conv1d from keras.layers.convolutional.conv_test import np_conv2d from keras.layers.convolutional.conv_test import np_conv3d @@ -1949,7 +1949,7 @@ def test_logit_recovery_binary_crossentropy(self): class NNOpsDtypeTest(testing.TestCase, parameterized.TestCase): """Test the dtype to verify that the behavior matches JAX.""" - FLOAT_DTYPES = [x for x in ALLOWED_DTYPES if "float" in x] + FLOAT_DTYPES = dtypes.FLOAT_TYPES def setUp(self): from jax.experimental import enable_x64 diff --git a/keras/ops/numpy.py b/keras/ops/numpy.py index f98d2cc8815f..db3c12ebc8cf 100644 --- a/keras/ops/numpy.py +++ b/keras/ops/numpy.py @@ -2707,6 +2707,8 @@ def compute_output_spec(self, x, key): remaining_key = [key] elif isinstance(key, tuple): remaining_key = list(key) + elif isinstance(key, list): + remaining_key = key.copy() else: raise ValueError( f"Unsupported key type for array slice. Recieved: `{key}`" diff --git a/keras/ops/numpy_test.py b/keras/ops/numpy_test.py index 2ecfe49a9923..50785139a3dc 100644 --- a/keras/ops/numpy_test.py +++ b/keras/ops/numpy_test.py @@ -11,9 +11,9 @@ import keras from keras import backend from keras import testing +from keras.backend.common import dtypes from keras.backend.common import standardize_dtype from keras.backend.common.keras_tensor import KerasTensor -from keras.backend.common.variables import ALLOWED_DTYPES from keras.ops import numpy as knp from keras.testing.test_utils import named_product @@ -4612,10 +4612,10 @@ class NumpyDtypeTest(testing.TestCase, parameterized.TestCase): # resulting in different behavior between JAX and Keras. Currently, we # are skipping the test for uint64 ALL_DTYPES = [ - x for x in ALLOWED_DTYPES if x not in ["string", "uint64"] + x for x in dtypes.ALLOWED_DTYPES if x not in ["string", "uint64"] ] + [None] - INT_DTYPES = [x for x in ALLOWED_DTYPES if "int" in x and x != "uint64"] - FLOAT_DTYPES = [x for x in ALLOWED_DTYPES if "float" in x] + INT_DTYPES = [x for x in dtypes.INT_TYPES if x != "uint64"] + FLOAT_DTYPES = dtypes.FLOAT_TYPES if backend.backend() == "torch": # TODO: torch doesn't support uint16, uint32 and uint64 @@ -4625,6 +4625,8 @@ class NumpyDtypeTest(testing.TestCase, parameterized.TestCase): INT_DTYPES = [ x for x in INT_DTYPES if x not in ["uint16", "uint32", "uint64"] ] + # Remove float8 dtypes for the following tests + ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES] def setUp(self): from jax.experimental import enable_x64 @@ -6247,7 +6249,7 @@ def test_less_equal(self, dtypes): [np.array([0, 1], "float32"), np.array([10, 20], "float32")], ], num=[0, 1, 5], - dtype=FLOAT_DTYPES + [None], + dtype=FLOAT_DTYPES + (None,), ) ) def test_linspace(self, start_and_stop, num, dtype): @@ -6371,7 +6373,7 @@ def test_logaddexp(self, dtypes): [np.array([0, 1], "float32"), np.array([10, 20], "float32")], ], num=[0, 1, 5], - dtype=FLOAT_DTYPES + [None], + dtype=FLOAT_DTYPES + (None,), ) ) def test_logspace(self, start_and_stop, num, dtype): diff --git a/keras/quantizers/quantizers.py b/keras/quantizers/quantizers.py index cd15d2234bfe..22a01c79b7b8 100644 --- a/keras/quantizers/quantizers.py +++ b/keras/quantizers/quantizers.py @@ -110,23 +110,7 @@ def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype): quantized_dtype_max = float(ml_dtypes.finfo(quantized_dtype).max) x = ops.divide(inputs, ops.cast(scale, compute_dtype)) x = ops.clip(x, quantized_dtype_min, quantized_dtype_max) - - # TODO: Introduce float8 dtype - if backend.backend() == "tensorflow": - import tensorflow as tf - - x = tf.cast(x, quantized_dtype) - elif backend.backend() == "jax": - x = x.astype(quantized_dtype) - elif backend.backend() == "torch": - import torch - - if quantized_dtype == "float8_e5m2": - x = x.to(torch.float8_e5m2) - if quantized_dtype == "float8_e4m3fn": - x = x.to(torch.float8_e4m3fn) - - # x = ops.cast(x, quantized_dtype) + x = ops.cast(x, quantized_dtype) # Dequantize x = ops.multiply(ops.cast(x, compute_dtype), ops.cast(scale, compute_dtype)) diff --git a/keras/saving/saving_lib.py b/keras/saving/saving_lib.py index 52fcb6601af7..1054f69cab9d 100644 --- a/keras/saving/saving_lib.py +++ b/keras/saving/saving_lib.py @@ -237,7 +237,6 @@ def load_weights_only(model, filepath, skip_mismatch=False): Note: only supports h5 for now. """ - temp_dir = None archive = None filepath = str(filepath) if filepath.endswith(".weights.h5"): @@ -262,8 +261,6 @@ def load_weights_only(model, filepath, skip_mismatch=False): error_msgs=error_msgs, ) weights_store.close() - if temp_dir and file_utils.exists(temp_dir): - file_utils.rmtree(temp_dir) if archive: archive.close() diff --git a/keras/saving/serialization_lib_test.py b/keras/saving/serialization_lib_test.py index 5021df8e6b53..701a903d661f 100644 --- a/keras/saving/serialization_lib_test.py +++ b/keras/saving/serialization_lib_test.py @@ -94,6 +94,19 @@ def test_builtin_layers(self): self.assertEqual(layer.trainable, restored.trainable) self.assertEqual(layer.compute_dtype, restored.compute_dtype) + def test_numpy_get_item_layer(self): + def tuples_to_lists_str(x): + return str(x).replace("(", "[").replace(")", "]") + + input = keras.layers.Input(shape=(2,)) + layer = input[:, 1] + model = keras.Model(input, layer) + serialized, _, reserialized = self.roundtrip(model) + # Anticipate JSON roundtrip mapping tuples to lists: + serialized_str = tuples_to_lists_str(serialized) + reserialized_str = tuples_to_lists_str(reserialized) + self.assertEqual(serialized_str, reserialized_str) + def test_tensors_and_shapes(self): x = ops.random.normal((2, 2), dtype="float64") obj = {"x": x} @@ -314,6 +327,18 @@ def from_config(cls, config): self.assertIs(layers[0].activation, layers[1].activation) self.assertIs(new_layers[0].activation, new_layers[1].activation) + def test_layer_sharing(self): + seq = keras.Sequential( + [ + keras.Input(shape=(3,)), + keras.layers.Dense(5), + keras.layers.Softmax(), + ], + ) + func = keras.Model(inputs=seq.inputs, outputs=seq.outputs) + serialized, deserialized, reserialized = self.roundtrip(func) + self.assertLen(deserialized.layers, 3) + @keras.saving.register_keras_serializable() class MyDense(keras.layers.Layer): diff --git a/keras/utils/summary_utils.py b/keras/utils/summary_utils.py index 8b0d591d1546..2fa9c7a0919f 100644 --- a/keras/utils/summary_utils.py +++ b/keras/utils/summary_utils.py @@ -360,7 +360,10 @@ def print_layer(layer, nested_level=0): # Output captured summary for non-interactive logging. if print_fn: - print_fn(console.end_capture(), line_break=False) + if print_fn is io_utils.print_msg: + print_fn(console.end_capture(), line_break=False) + else: + print_fn(console.end_capture()) def get_layer_index_bound_by_layer_name(layers, layer_range=None):