Skip to content

Commit

Permalink
Merge branch 'divyashreepathihalli-tf_discretization'
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 20, 2023
2 parents 1704ecf + 01566ef commit 5a4ee8e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 46 deletions.
53 changes: 8 additions & 45 deletions keras_core/layers/preprocessing/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,20 @@

from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
from keras_core.utils import argument_validation
from keras_core.utils import backend_utils
from keras_core.utils import tf_utils
from keras_core.utils import numerical_utils
from keras_core.utils.module_utils import tensorflow as tf


@keras_core_export("keras_core.layers.Discretization")
class Discretization(Layer):
class Discretization(TFDataLayer):
"""A preprocessing layer which buckets continuous features by ranges.
This layer will place each element of its input data into one of several
contiguous ranges and output an integer index indicating which range each
element was placed in.
**Note:** This layer uses TensorFlow internally. It cannot
be used as part of the compiled computation graph of a model with
any backend other than TensorFlow.
It can however be used with any backend when running eagerly.
It can also always be used as part of an input preprocessing pipeline
with any backend (outside the model itself), which is how we recommend
to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Expand Down Expand Up @@ -78,14 +69,14 @@ class Discretization(Layer):
Examples:
Bucketize float values based on provided buckets.
Discretize float values based on provided buckets.
>>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]])
>>> layer = Discretization(bin_boundaries=[0., 1., 2.])
>>> layer(input)
array([[0, 2, 3, 1],
[1, 3, 2, 1]])
Bucketize float values based on a number of buckets to compute.
Discretize float values based on a number of buckets to compute.
>>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]])
>>> layer = Discretization(num_bins=4, epsilon=0.01)
>>> layer.adapt(input)
Expand Down Expand Up @@ -161,8 +152,6 @@ def __init__(
self.summary = None
else:
self.summary = np.array([[], []], dtype="float32")
self._convert_input_args = False
self._allow_non_tensor_positional_args = True

def build(self, input_shape=None):
self.built = True
Expand Down Expand Up @@ -238,36 +227,14 @@ def load_own_variables(self, store):
return

def call(self, inputs):
if not isinstance(
inputs,
(
tf.Tensor,
tf.SparseTensor,
tf.RaggedTensor,
np.ndarray,
backend.KerasTensor,
),
):
inputs = tf.convert_to_tensor(
backend.convert_to_numpy(inputs), dtype=self.input_dtype
)

from keras_core.backend.tensorflow.numpy import digitize

indices = digitize(inputs, self.bin_boundaries)

outputs = tf_utils.encode_categorical_inputs(
indices = self.backend.numpy.digitize(inputs, self.bin_boundaries)
outputs = numerical_utils.encode_categorical_inputs(
indices,
output_mode=self.output_mode,
depth=len(self.bin_boundaries) + 1,
sparse=self.sparse,
dtype=self.compute_dtype,
backend_module=self.backend,
)
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
return outputs

def get_config(self):
Expand Down Expand Up @@ -370,7 +337,3 @@ def compress_summary(summary, epsilon):
)
summary = np.stack((new_bins, new_weights))
return summary.astype("float32")


def bucketize(inputs, boundaries):
return tf.raw_ops.Bucketize(input=inputs, boundaries=boundaries)
4 changes: 3 additions & 1 deletion keras_core/layers/preprocessing/discretization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def test_correctness(self):

def test_tf_data_compatibility(self):
# With fixed bins
layer = layers.Discretization(bin_boundaries=[0.0, 0.35, 0.5, 1.0])
layer = layers.Discretization(
bin_boundaries=[0.0, 0.35, 0.5, 1.0], dtype="float32"
)
x = np.array([[-1.0, 0.0, 0.1, 0.2, 0.4, 0.5, 1.0, 1.2, 0.98]])
self.assertAllClose(layer(x), np.array([[0, 1, 1, 1, 2, 3, 4, 4, 3]]))
ds = tf_data.Dataset.from_tensor_slices(x).batch(1).map(layer)
Expand Down
47 changes: 47 additions & 0 deletions keras_core/utils/numerical_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,50 @@ def to_categorical(x, num_classes=None):
output_shape = input_shape + (num_classes,)
categorical = np.reshape(categorical, output_shape)
return categorical


def encode_categorical_inputs(
inputs,
output_mode,
depth,
dtype="float32",
count_weights=None,
backend_module=None,
):
"""Encodes categoical inputs according to output_mode."""
backend_module = backend_module or backend

if output_mode == "int":
return backend_module.cast(inputs, dtype=dtype)

original_shape = inputs.shape
# In all cases, we should uprank scalar input to a single sample.
if len(backend_module.shape(inputs)) == 0:
inputs = backend_module.numpy.expand_dims(inputs, -1)
# One hot will unprank only if the final output dimension is not already 1.
if output_mode == "one_hot":
if backend_module.shape(inputs)[-1] != 1:
inputs = backend_module.numpy.expand_dims(inputs, -1)

if len(backend_module.shape(inputs)) > 2:
raise ValueError(
"When output_mode is not `'int'`, maximum supported output rank "
f"is 2. Received output_mode {output_mode} and input shape "
f"{original_shape}, "
f"which would result in output rank {inputs.shape.rank}."
)

binary_output = output_mode in ("multi_hot", "one_hot")
bincounts = backend_module.numpy.bincount(
inputs,
weights=count_weights,
minlength=depth,
)
if binary_output:
one_hot_input = backend_module.nn.one_hot(inputs, depth)
bincounts = backend_module.numpy.where(
backend_module.numpy.any(one_hot_input, axis=-2), 1, 0
)
bincounts = backend_module.cast(bincounts, dtype)

return bincounts

0 comments on commit 5a4ee8e

Please sign in to comment.