Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Tensorflow SparseTensors: core classes. #839

Merged
merged 1 commit into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions keras_core/backend/common/keras_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,19 @@ class KerasTensor:
dtype is called "static shape inference".
"""

def __init__(self, shape, dtype="float32", record_history=True, name=None):
def __init__(
self,
shape,
dtype="float32",
sparse=False,
record_history=True,
name=None,
):
from keras_core import backend

self.shape = backend.standardize_shape(shape)
self.dtype = backend.standardize_dtype(dtype)
self.sparse = sparse
self.name = name or auto_name(self.__class__.__name__)
self.record_history = record_history

Expand Down Expand Up @@ -106,7 +114,7 @@ def __tf_tensor__(self, dtype=None, name=None):
def __repr__(self):
return (
f"<KerasTensor shape={self.shape}, dtype={self.dtype}, "
f"name={self.name}>"
f"sparse={self.sparse}, name={self.name}>"
)

def __iter__(self):
Expand Down
3 changes: 2 additions & 1 deletion keras_core/backend/common/keras_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

class KerasTensorTest(testing.TestCase):
def test_attributes(self):
x = keras_tensor.KerasTensor(shape=(3,), dtype="float32")
x = keras_tensor.KerasTensor(shape=(3,), dtype="float32", sparse=True)
self.assertEqual(x.dtype, "float32")
self.assertEqual(x.shape, (3,))
self.assertEqual(x.sparse, True)

def test_numpy_methods(self):
x = keras_tensor.KerasTensor(shape=(3, 2), dtype="float32")
Expand Down
1 change: 1 addition & 0 deletions keras_core/backend/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras_core.backend.jax import nn
from keras_core.backend.jax import numpy
from keras_core.backend.jax import random
from keras_core.backend.jax.core import SUPPORTS_SPARSE_TENSORS
from keras_core.backend.jax.core import Variable
from keras_core.backend.jax.core import cast
from keras_core.backend.jax.core import compute_output_spec
Expand Down
6 changes: 5 additions & 1 deletion keras_core/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from keras_core.backend.jax import distribution_lib
from keras_core.utils.nest import pack_sequence_as

SUPPORTS_SPARSE_TENSORS = False


class Variable(KerasVariable):
def _initialize(self, value):
Expand Down Expand Up @@ -44,7 +46,9 @@ def __jax_array__(self):
return self.value


def convert_to_tensor(x, dtype=None):
def convert_to_tensor(x, dtype=None, sparse=False):
if sparse:
raise ValueError("`sparse=True` is not supported with jax backend")
if dtype is not None:
dtype = standardize_dtype(dtype)
if isinstance(x, Variable):
Expand Down
1 change: 1 addition & 0 deletions keras_core/backend/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from keras_core.backend.numpy import nn
from keras_core.backend.numpy import numpy
from keras_core.backend.numpy import random
from keras_core.backend.numpy.core import SUPPORTS_SPARSE_TENSORS
from keras_core.backend.numpy.core import Variable
from keras_core.backend.numpy.core import cast
from keras_core.backend.numpy.core import compute_output_spec
Expand Down
6 changes: 5 additions & 1 deletion keras_core/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.utils.nest import pack_sequence_as

SUPPORTS_SPARSE_TENSORS = False


class Variable(KerasVariable):
def _initialize(self, value):
Expand All @@ -23,7 +25,9 @@ def __array__(self):
return self.value


def convert_to_tensor(x, dtype=None):
def convert_to_tensor(x, dtype=None, sparse=False):
if sparse:
raise ValueError("`sparse=True` is not supported with numpy backend")
if dtype is not None:
dtype = standardize_dtype(dtype)
if isinstance(x, Variable):
Expand Down
1 change: 1 addition & 0 deletions keras_core/backend/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras_core.backend.tensorflow import numpy
from keras_core.backend.tensorflow import random
from keras_core.backend.tensorflow import tensorboard
from keras_core.backend.tensorflow.core import SUPPORTS_SPARSE_TENSORS
from keras_core.backend.tensorflow.core import Variable
from keras_core.backend.tensorflow.core import cast
from keras_core.backend.tensorflow.core import compute_output_spec
Expand Down
26 changes: 21 additions & 5 deletions keras_core/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.utils.naming import auto_name

SUPPORTS_SPARSE_TENSORS = True


class Variable(
KerasVariable,
Expand Down Expand Up @@ -70,15 +72,28 @@ def _write_object_proto(self, proto, options):
return self.value._write_object_proto(proto, options)


def convert_to_tensor(x, dtype=None):
def convert_to_tensor(x, dtype=None, sparse=True):
"""Convert to a TensorFlow tensor.

`sparse=True` means that `tf.SparseTensor`s are returned as-is, which is the
default with the TensorFlow backend. An explicit `sparse=False` densifies
`tf.SparseTensor`s.
"""
if isinstance(x, tf.SparseTensor) and not sparse:
x = tf.sparse.to_dense(x)
if dtype is not None:
dtype = standardize_dtype(dtype)
if tf.is_tensor(x):
return tf.cast(x, dtype=dtype)
return tf.convert_to_tensor(x, dtype=dtype)
if not tf.is_tensor(x):
return tf.convert_to_tensor(x, dtype=dtype)
elif dtype is not None:
return tf.cast(x, dtype=dtype)
else:
return x


def convert_to_numpy(x):
if isinstance(x, tf.SparseTensor):
x = tf.sparse.to_dense(x)
return np.array(x)


Expand All @@ -95,7 +110,8 @@ def shape(x):
tensor values when the shape is unknown (this is tf specific, as dynamic
shapes do not apply in other backends).
"""
x = tf.convert_to_tensor(x)
if not tf.is_tensor(x):
x = tf.convert_to_tensor(x)
dynamic = tf.shape(x)
if x.shape == tf.TensorShape(None):
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions keras_core/backend/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from keras_core.backend.torch import nn
from keras_core.backend.torch import numpy
from keras_core.backend.torch import random
from keras_core.backend.torch.core import SUPPORTS_SPARSE_TENSORS
from keras_core.backend.torch.core import Variable
from keras_core.backend.torch.core import cast
from keras_core.backend.torch.core import compute_output_spec
Expand Down
6 changes: 5 additions & 1 deletion keras_core/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.utils.nest import pack_sequence_as

SUPPORTS_SPARSE_TENSORS = False

# Some operators such as 'aten::_foreach_mul_.Scalar'
# are not currently implemented for the MPS device.
# check https://github.com/pytorch/pytorch/issues/77764.
Expand Down Expand Up @@ -118,7 +120,9 @@ def __eq__(self, other):
return False


def convert_to_tensor(x, dtype=None):
def convert_to_tensor(x, dtype=None, sparse=False):
if sparse:
raise ValueError("`sparse=True` is not supported with torch backend")
if is_tensor(x):
device = get_device()
if x.device != device:
Expand Down
20 changes: 18 additions & 2 deletions keras_core/layers/core/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ def __init__(
shape=None,
batch_size=None,
dtype=None,
sparse=None,
batch_shape=None,
input_tensor=None,
name=None,
**kwargs,
):
# TODO: support for sparse, ragged.
# TODO: support for ragged.
super().__init__(name=name)
if "input_shape" in kwargs:
warnings.warn(
Expand All @@ -45,6 +46,13 @@ def __init__(
self.batch_shape = tuple(batch_shape)
self._dtype = backend.standardize_dtype(dtype)

self.sparse = bool(sparse)
if self.sparse and not backend.SUPPORTS_SPARSE_TENSORS:
raise ValueError(
"`sparse=True` is not supported with backend: "
f"{backend.backend()}"
)

if input_tensor is not None:
if not isinstance(input_tensor, backend.KerasTensor):
raise ValueError(
Expand All @@ -54,7 +62,7 @@ def __init__(
)
else:
input_tensor = backend.KerasTensor(
shape=batch_shape, dtype=dtype, name=name
shape=batch_shape, dtype=dtype, sparse=sparse, name=name
)
self._input_tensor = input_tensor
Node(operation=self, call_args=(), call_kwargs={}, outputs=input_tensor)
Expand All @@ -71,6 +79,7 @@ def get_config(self):
return {
"batch_shape": self.batch_shape,
"dtype": self.dtype,
"sparse": self.sparse,
"name": self.name,
}

Expand All @@ -80,6 +89,7 @@ def Input(
shape=None,
batch_size=None,
dtype=None,
sparse=None,
batch_shape=None,
name=None,
tensor=None,
Expand All @@ -104,6 +114,11 @@ def Input(
batch_size: Optional static batch size (integer).
dtype: The data type expected by the input, as a string
(e.g. `"float32"`, `"int32"`...)
sparse: A boolean specifying whether the expected input will be sparse
tensors. Note that, if `sparse` is `False`, sparse tensors can still
be passed into the input - they will be densified with a default
value of 0. This feature is only supported with the TensorFlow
backend. Defaults to `False`.
name: Optional name string for the layer.
Should be unique in a model (do not reuse the same name twice).
It will be autogenerated if it isn't provided.
Expand All @@ -127,6 +142,7 @@ def Input(
shape=shape,
batch_size=batch_size,
dtype=dtype,
sparse=sparse,
batch_shape=batch_shape,
name=name,
input_tensor=tensor,
Expand Down
32 changes: 27 additions & 5 deletions keras_core/layers/core/input_layer_test.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,51 @@
import numpy as np
from absl.testing import parameterized

from keras_core import backend
from keras_core import testing
from keras_core.backend import KerasTensor
from keras_core.layers import InputLayer


class InputLayerTest(testing.TestCase):
class InputLayerTest(testing.TestCase, parameterized.TestCase):
# Testing happy path for layer without input tensor
def test_input_basic(self):
@parameterized.named_parameters(
[
{"testcase_name": "dense", "sparse": False},
{"testcase_name": "sparse", "sparse": True},
]
)
def test_input_basic(self, sparse):
input_shape = (2, 3)
batch_size = 4
dtype = "float32"
ndim = len(tuple((batch_size,) + input_shape))

values = InputLayer(
shape=input_shape, batch_size=batch_size, dtype=dtype
)
init_kwargs = {
"shape": input_shape,
"batch_size": batch_size,
"dtype": dtype,
"sparse": sparse,
}

if sparse and not backend.SUPPORTS_SPARSE_TENSORS:
with self.assertRaisesRegex(
ValueError, "`sparse=True` is not supported"
):
InputLayer(**init_kwargs)
return

values = InputLayer(**init_kwargs)

self.assertEqual(values.dtype, dtype)
self.assertEqual(values.batch_shape[0], batch_size)
self.assertEqual(values.batch_shape[1:], input_shape)
self.assertEqual(values.sparse, sparse)
self.assertEqual(values.trainable, True)
self.assertIsInstance(values.output, KerasTensor)
self.assertEqual(values.output.ndim, ndim)
self.assertEqual(values.output.dtype, dtype)
self.assertEqual(values.output.sparse, sparse)

# Testing shape is not None and batch_shape is not None condition
def test_input_error1(self):
Expand Down
39 changes: 39 additions & 0 deletions keras_core/ops/core_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest

from keras_core import backend
from keras_core import layers
from keras_core import losses
from keras_core import models
Expand Down Expand Up @@ -270,6 +271,18 @@ def test_shape(self):
x = KerasTensor((None, 3, None, 1))
self.assertAllEqual(core.shape(x), (None, 3, None, 1))

@pytest.mark.skipif(
not backend.SUPPORTS_SPARSE_TENSORS,
reason="Backend does not support sparse tensors.",
)
def test_shape_sparse(self):
import tensorflow as tf

x = tf.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=(2, 3)
)
self.assertAllEqual(core.shape(x), (2, 3))

def test_convert_to_tensor(self):
x = np.ones((2,))
x = ops.convert_to_tensor(x)
Expand All @@ -284,6 +297,32 @@ def test_convert_to_tensor(self):
with self.assertRaises(ValueError):
ops.convert_to_numpy(KerasTensor((2,)))

@pytest.mark.skipif(
not backend.SUPPORTS_SPARSE_TENSORS,
reason="Backend does not support sparse tensors.",
)
def test_convert_to_tensor_sparse(self):
import tensorflow as tf

x = tf.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=(2, 3)
)

x_default = ops.convert_to_tensor(x)
self.assertIsInstance(x_default, tf.SparseTensor)
self.assertAllClose(x, x_default)
# Note that ops.convert_to_tensor does not expose the 'sparse' arg
x_sparse = backend.convert_to_tensor(x, sparse=True)
self.assertIsInstance(x_sparse, tf.SparseTensor)
self.assertAllClose(x, x_sparse)
x_dense = backend.convert_to_tensor(x, sparse=False)
self.assertNotIsInstance(x_dense, tf.SparseTensor)
self.assertAllClose(x, x_dense)

x_numpy = ops.convert_to_numpy(x)
self.assertIsInstance(x_numpy, np.ndarray)
self.assertAllClose(x_numpy, x_dense)

def test_cond(self):
t = ops.cond(True, lambda: 0, lambda: 1)
self.assertEqual(t, 0)
Expand Down