From e15efe2d6d103a32fc6c7eeff01fa41878e5235f Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Tue, 5 Sep 2023 15:30:27 -0700 Subject: [PATCH] Add support for Tensorflow SparseTensors: core classes. This adds: - Support for specifying sparse in `KerasTensor` and `Input`. - A boolean flag `backend.SUPPORTS_SPARSE_TENSORS`. - Support for `tf.SparseTensor` is Tensorflow core ops. --- keras_core/backend/common/keras_tensor.py | 12 +++++- .../backend/common/keras_tensor_test.py | 3 +- keras_core/backend/jax/__init__.py | 1 + keras_core/backend/jax/core.py | 6 ++- keras_core/backend/numpy/__init__.py | 1 + keras_core/backend/numpy/core.py | 6 ++- keras_core/backend/tensorflow/__init__.py | 1 + keras_core/backend/tensorflow/core.py | 26 ++++++++++--- keras_core/backend/torch/__init__.py | 1 + keras_core/backend/torch/core.py | 6 ++- keras_core/layers/core/input_layer.py | 20 +++++++++- keras_core/layers/core/input_layer_test.py | 32 ++++++++++++--- keras_core/ops/core_test.py | 39 +++++++++++++++++++ 13 files changed, 136 insertions(+), 18 deletions(-) diff --git a/keras_core/backend/common/keras_tensor.py b/keras_core/backend/common/keras_tensor.py index e400bf208..9c7131253 100644 --- a/keras_core/backend/common/keras_tensor.py +++ b/keras_core/backend/common/keras_tensor.py @@ -28,11 +28,19 @@ class KerasTensor: dtype is called "static shape inference". """ - def __init__(self, shape, dtype="float32", record_history=True, name=None): + def __init__( + self, + shape, + dtype="float32", + sparse=False, + record_history=True, + name=None, + ): from keras_core import backend self.shape = backend.standardize_shape(shape) self.dtype = backend.standardize_dtype(dtype) + self.sparse = sparse self.name = name or auto_name(self.__class__.__name__) self.record_history = record_history @@ -106,7 +114,7 @@ def __tf_tensor__(self, dtype=None, name=None): def __repr__(self): return ( f"" + f"sparse={self.sparse}, name={self.name}>" ) def __iter__(self): diff --git a/keras_core/backend/common/keras_tensor_test.py b/keras_core/backend/common/keras_tensor_test.py index 9a35898e7..c2bd7bec3 100644 --- a/keras_core/backend/common/keras_tensor_test.py +++ b/keras_core/backend/common/keras_tensor_test.py @@ -9,9 +9,10 @@ class KerasTensorTest(testing.TestCase): def test_attributes(self): - x = keras_tensor.KerasTensor(shape=(3,), dtype="float32") + x = keras_tensor.KerasTensor(shape=(3,), dtype="float32", sparse=True) self.assertEqual(x.dtype, "float32") self.assertEqual(x.shape, (3,)) + self.assertEqual(x.sparse, True) def test_numpy_methods(self): x = keras_tensor.KerasTensor(shape=(3, 2), dtype="float32") diff --git a/keras_core/backend/jax/__init__.py b/keras_core/backend/jax/__init__.py index 15d93f4a5..260081beb 100644 --- a/keras_core/backend/jax/__init__.py +++ b/keras_core/backend/jax/__init__.py @@ -5,6 +5,7 @@ from keras_core.backend.jax import nn from keras_core.backend.jax import numpy from keras_core.backend.jax import random +from keras_core.backend.jax.core import SUPPORTS_SPARSE_TENSORS from keras_core.backend.jax.core import Variable from keras_core.backend.jax.core import cast from keras_core.backend.jax.core import compute_output_spec diff --git a/keras_core/backend/jax/core.py b/keras_core/backend/jax/core.py index 429cde3ca..cb51029ee 100644 --- a/keras_core/backend/jax/core.py +++ b/keras_core/backend/jax/core.py @@ -14,6 +14,8 @@ from keras_core.backend.jax import distribution_lib from keras_core.utils.nest import pack_sequence_as +SUPPORTS_SPARSE_TENSORS = False + class Variable(KerasVariable): def _initialize(self, value): @@ -44,7 +46,9 @@ def __jax_array__(self): return self.value -def convert_to_tensor(x, dtype=None): +def convert_to_tensor(x, dtype=None, sparse=False): + if sparse: + raise ValueError("`sparse=True` is not supported with jax backend") if dtype is not None: dtype = standardize_dtype(dtype) if isinstance(x, Variable): diff --git a/keras_core/backend/numpy/__init__.py b/keras_core/backend/numpy/__init__.py index a57833bb2..82d22bb26 100644 --- a/keras_core/backend/numpy/__init__.py +++ b/keras_core/backend/numpy/__init__.py @@ -4,6 +4,7 @@ from keras_core.backend.numpy import nn from keras_core.backend.numpy import numpy from keras_core.backend.numpy import random +from keras_core.backend.numpy.core import SUPPORTS_SPARSE_TENSORS from keras_core.backend.numpy.core import Variable from keras_core.backend.numpy.core import cast from keras_core.backend.numpy.core import compute_output_spec diff --git a/keras_core/backend/numpy/core.py b/keras_core/backend/numpy/core.py index d4948ce37..b125a35e3 100644 --- a/keras_core/backend/numpy/core.py +++ b/keras_core/backend/numpy/core.py @@ -7,6 +7,8 @@ from keras_core.backend.common.stateless_scope import StatelessScope from keras_core.utils.nest import pack_sequence_as +SUPPORTS_SPARSE_TENSORS = False + class Variable(KerasVariable): def _initialize(self, value): @@ -23,7 +25,9 @@ def __array__(self): return self.value -def convert_to_tensor(x, dtype=None): +def convert_to_tensor(x, dtype=None, sparse=False): + if sparse: + raise ValueError("`sparse=True` is not supported with numpy backend") if dtype is not None: dtype = standardize_dtype(dtype) if isinstance(x, Variable): diff --git a/keras_core/backend/tensorflow/__init__.py b/keras_core/backend/tensorflow/__init__.py index 31012a14e..f70cfdb43 100644 --- a/keras_core/backend/tensorflow/__init__.py +++ b/keras_core/backend/tensorflow/__init__.py @@ -5,6 +5,7 @@ from keras_core.backend.tensorflow import numpy from keras_core.backend.tensorflow import random from keras_core.backend.tensorflow import tensorboard +from keras_core.backend.tensorflow.core import SUPPORTS_SPARSE_TENSORS from keras_core.backend.tensorflow.core import Variable from keras_core.backend.tensorflow.core import cast from keras_core.backend.tensorflow.core import compute_output_spec diff --git a/keras_core/backend/tensorflow/core.py b/keras_core/backend/tensorflow/core.py index 27a4462ed..62df7f78a 100644 --- a/keras_core/backend/tensorflow/core.py +++ b/keras_core/backend/tensorflow/core.py @@ -12,6 +12,8 @@ from keras_core.backend.common.stateless_scope import StatelessScope from keras_core.utils.naming import auto_name +SUPPORTS_SPARSE_TENSORS = True + class Variable( KerasVariable, @@ -70,15 +72,28 @@ def _write_object_proto(self, proto, options): return self.value._write_object_proto(proto, options) -def convert_to_tensor(x, dtype=None): +def convert_to_tensor(x, dtype=None, sparse=True): + """Convert to a TensorFlow tensor. + + `sparse=True` means that `tf.SparseTensor`s are returned as-is, which is the + default with the TensorFlow backend. An explicit `sparse=False` densifies + `tf.SparseTensor`s. + """ + if isinstance(x, tf.SparseTensor) and not sparse: + x = tf.sparse.to_dense(x) if dtype is not None: dtype = standardize_dtype(dtype) - if tf.is_tensor(x): - return tf.cast(x, dtype=dtype) - return tf.convert_to_tensor(x, dtype=dtype) + if not tf.is_tensor(x): + return tf.convert_to_tensor(x, dtype=dtype) + elif dtype is not None: + return tf.cast(x, dtype=dtype) + else: + return x def convert_to_numpy(x): + if isinstance(x, tf.SparseTensor): + x = tf.sparse.to_dense(x) return np.array(x) @@ -95,7 +110,8 @@ def shape(x): tensor values when the shape is unknown (this is tf specific, as dynamic shapes do not apply in other backends). """ - x = tf.convert_to_tensor(x) + if not tf.is_tensor(x): + x = tf.convert_to_tensor(x) dynamic = tf.shape(x) if x.shape == tf.TensorShape(None): raise ValueError( diff --git a/keras_core/backend/torch/__init__.py b/keras_core/backend/torch/__init__.py index a26f5f647..2be9709a8 100644 --- a/keras_core/backend/torch/__init__.py +++ b/keras_core/backend/torch/__init__.py @@ -20,6 +20,7 @@ from keras_core.backend.torch import nn from keras_core.backend.torch import numpy from keras_core.backend.torch import random +from keras_core.backend.torch.core import SUPPORTS_SPARSE_TENSORS from keras_core.backend.torch.core import Variable from keras_core.backend.torch.core import cast from keras_core.backend.torch.core import compute_output_spec diff --git a/keras_core/backend/torch/core.py b/keras_core/backend/torch/core.py index 9e5532179..f9675a301 100644 --- a/keras_core/backend/torch/core.py +++ b/keras_core/backend/torch/core.py @@ -12,6 +12,8 @@ from keras_core.backend.common.stateless_scope import StatelessScope from keras_core.utils.nest import pack_sequence_as +SUPPORTS_SPARSE_TENSORS = False + # Some operators such as 'aten::_foreach_mul_.Scalar' # are not currently implemented for the MPS device. # check https://github.com/pytorch/pytorch/issues/77764. @@ -118,7 +120,9 @@ def __eq__(self, other): return False -def convert_to_tensor(x, dtype=None): +def convert_to_tensor(x, dtype=None, sparse=False): + if sparse: + raise ValueError("`sparse=True` is not supported with torch backend") if is_tensor(x): device = get_device() if x.device != device: diff --git a/keras_core/layers/core/input_layer.py b/keras_core/layers/core/input_layer.py index 3f2748c3f..af7200b15 100644 --- a/keras_core/layers/core/input_layer.py +++ b/keras_core/layers/core/input_layer.py @@ -13,12 +13,13 @@ def __init__( shape=None, batch_size=None, dtype=None, + sparse=None, batch_shape=None, input_tensor=None, name=None, **kwargs, ): - # TODO: support for sparse, ragged. + # TODO: support for ragged. super().__init__(name=name) if "input_shape" in kwargs: warnings.warn( @@ -45,6 +46,13 @@ def __init__( self.batch_shape = tuple(batch_shape) self._dtype = backend.standardize_dtype(dtype) + self.sparse = bool(sparse) + if self.sparse and not backend.SUPPORTS_SPARSE_TENSORS: + raise ValueError( + "`sparse=True` is not supported with backend: " + f"{backend.backend()}" + ) + if input_tensor is not None: if not isinstance(input_tensor, backend.KerasTensor): raise ValueError( @@ -54,7 +62,7 @@ def __init__( ) else: input_tensor = backend.KerasTensor( - shape=batch_shape, dtype=dtype, name=name + shape=batch_shape, dtype=dtype, sparse=sparse, name=name ) self._input_tensor = input_tensor Node(operation=self, call_args=(), call_kwargs={}, outputs=input_tensor) @@ -71,6 +79,7 @@ def get_config(self): return { "batch_shape": self.batch_shape, "dtype": self.dtype, + "sparse": self.sparse, "name": self.name, } @@ -80,6 +89,7 @@ def Input( shape=None, batch_size=None, dtype=None, + sparse=None, batch_shape=None, name=None, tensor=None, @@ -104,6 +114,11 @@ def Input( batch_size: Optional static batch size (integer). dtype: The data type expected by the input, as a string (e.g. `"float32"`, `"int32"`...) + sparse: A boolean specifying whether the expected input will be sparse + tensors. Note that, if `sparse` is `False`, sparse tensors can still + be passed into the input - they will be densified with a default + value of 0. This feature is only supported with the TensorFlow + backend. Defaults to `False`. name: Optional name string for the layer. Should be unique in a model (do not reuse the same name twice). It will be autogenerated if it isn't provided. @@ -127,6 +142,7 @@ def Input( shape=shape, batch_size=batch_size, dtype=dtype, + sparse=sparse, batch_shape=batch_shape, name=name, input_tensor=tensor, diff --git a/keras_core/layers/core/input_layer_test.py b/keras_core/layers/core/input_layer_test.py index 437864782..b6640c90f 100644 --- a/keras_core/layers/core/input_layer_test.py +++ b/keras_core/layers/core/input_layer_test.py @@ -1,29 +1,51 @@ import numpy as np +from absl.testing import parameterized +from keras_core import backend from keras_core import testing from keras_core.backend import KerasTensor from keras_core.layers import InputLayer -class InputLayerTest(testing.TestCase): +class InputLayerTest(testing.TestCase, parameterized.TestCase): # Testing happy path for layer without input tensor - def test_input_basic(self): + @parameterized.named_parameters( + [ + {"testcase_name": "dense", "sparse": False}, + {"testcase_name": "sparse", "sparse": True}, + ] + ) + def test_input_basic(self, sparse): input_shape = (2, 3) batch_size = 4 dtype = "float32" ndim = len(tuple((batch_size,) + input_shape)) - values = InputLayer( - shape=input_shape, batch_size=batch_size, dtype=dtype - ) + init_kwargs = { + "shape": input_shape, + "batch_size": batch_size, + "dtype": dtype, + "sparse": sparse, + } + + if sparse and not backend.SUPPORTS_SPARSE_TENSORS: + with self.assertRaisesRegex( + ValueError, "`sparse=True` is not supported" + ): + InputLayer(**init_kwargs) + return + + values = InputLayer(**init_kwargs) self.assertEqual(values.dtype, dtype) self.assertEqual(values.batch_shape[0], batch_size) self.assertEqual(values.batch_shape[1:], input_shape) + self.assertEqual(values.sparse, sparse) self.assertEqual(values.trainable, True) self.assertIsInstance(values.output, KerasTensor) self.assertEqual(values.output.ndim, ndim) self.assertEqual(values.output.dtype, dtype) + self.assertEqual(values.output.sparse, sparse) # Testing shape is not None and batch_shape is not None condition def test_input_error1(self): diff --git a/keras_core/ops/core_test.py b/keras_core/ops/core_test.py index 82c576a6b..228f5c886 100644 --- a/keras_core/ops/core_test.py +++ b/keras_core/ops/core_test.py @@ -1,6 +1,7 @@ import numpy as np import pytest +from keras_core import backend from keras_core import layers from keras_core import losses from keras_core import models @@ -270,6 +271,18 @@ def test_shape(self): x = KerasTensor((None, 3, None, 1)) self.assertAllEqual(core.shape(x), (None, 3, None, 1)) + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", + ) + def test_shape_sparse(self): + import tensorflow as tf + + x = tf.SparseTensor( + indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=(2, 3) + ) + self.assertAllEqual(core.shape(x), (2, 3)) + def test_convert_to_tensor(self): x = np.ones((2,)) x = ops.convert_to_tensor(x) @@ -284,6 +297,32 @@ def test_convert_to_tensor(self): with self.assertRaises(ValueError): ops.convert_to_numpy(KerasTensor((2,))) + @pytest.mark.skipif( + not backend.SUPPORTS_SPARSE_TENSORS, + reason="Backend does not support sparse tensors.", + ) + def test_convert_to_tensor_sparse(self): + import tensorflow as tf + + x = tf.SparseTensor( + indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=(2, 3) + ) + + x_default = ops.convert_to_tensor(x) + self.assertIsInstance(x_default, tf.SparseTensor) + self.assertAllClose(x, x_default) + # Note that ops.convert_to_tensor does not expose the 'sparse' arg + x_sparse = backend.convert_to_tensor(x, sparse=True) + self.assertIsInstance(x_sparse, tf.SparseTensor) + self.assertAllClose(x, x_sparse) + x_dense = backend.convert_to_tensor(x, sparse=False) + self.assertNotIsInstance(x_dense, tf.SparseTensor) + self.assertAllClose(x, x_dense) + + x_numpy = ops.convert_to_numpy(x) + self.assertIsInstance(x_numpy, np.ndarray) + self.assertAllClose(x_numpy, x_dense) + def test_cond(self): t = ops.cond(True, lambda: 0, lambda: 1) self.assertEqual(t, 0)