Skip to content

Commit

Permalink
Remove tf.keras dependency from discretization layer
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Aug 20, 2023
1 parent 3ea1060 commit 7adbaf7
Show file tree
Hide file tree
Showing 4 changed files with 312 additions and 61 deletions.
2 changes: 1 addition & 1 deletion keras_core/backend/jax/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

import contextlib

from absl import logging
import jax
import numpy as np
from absl import logging

from keras_core.backend.common import global_state

Expand Down
29 changes: 28 additions & 1 deletion keras_core/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,23 @@ def add(x1, x2):
def bincount(x, weights=None, minlength=None):
if minlength is not None:
x = tf.cast(x, tf.int32)
if isinstance(x, tf.SparseTensor):
result = tf.sparse.bincount(
x,
weights=weights,
minlength=minlength,
axis=-1,
)
if x.shape.rank == 1:
output_shape = (minlength,)
else:
batch_size = tf.shape(result)[0]
output_shape = (batch_size, minlength)
return tf.SparseTensor(
indices=result.indices,
values=result.values,
dense_shape=output_shape,
)
return tf.math.bincount(x, weights=weights, minlength=minlength, axis=-1)


Expand Down Expand Up @@ -227,8 +244,18 @@ def diagonal(x, offset=0, axis1=0, axis2=1):


def digitize(x, bins):
x = convert_to_tensor(x)
bins = list(bins)
if isinstance(x, tf.RaggedTensor):
return tf.ragged.map_flat_values(
lambda y: tf.raw_ops.Bucketize(input=y, boundaries=bins), x
)
elif isinstance(x, tf.SparseTensor):
return tf.SparseTensor(
indices=tf.identity(x.indices),
values=tf.raw_ops.Bucketize(input=x.values, boundaries=bins),
dense_shape=tf.identity(x.dense_shape),
)
x = convert_to_tensor(x)
return tf.raw_ops.Bucketize(input=x, boundaries=bins)


Expand Down
258 changes: 207 additions & 51 deletions keras_core/layers/preprocessing/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import argument_validation
from keras_core.utils import backend_utils
from keras_core.utils import tf_utils
from keras_core.utils.module_utils import tensorflow as tf


Expand All @@ -15,7 +17,7 @@ class Discretization(Layer):
contiguous ranges and output an integer index indicating which range each
element was placed in.
**Note:** This layer wraps `tf.keras.layers.Discretization`. It cannot
**Note:** This layer uses TensorFlow internally. It cannot
be used as part of the compiled computation graph of a model with
any backend other than TensorFlow.
It can however be used with any backend when running eagerly.
Expand Down Expand Up @@ -99,50 +101,76 @@ def __init__(
epsilon=0.01,
output_mode="int",
sparse=False,
name=None,
dtype=None,
**kwargs,
name=None,
):
if not tf.available:
raise ImportError(
"Layer Discretization requires TensorFlow. "
"Install it via `pip install tensorflow`."
)
if dtype is None:
dtype = "int64" if output_mode == "int" else backend.floatx()
super().__init__(name=name)

super().__init__(name=name, dtype=dtype)
if sparse and backend.backend() != "tensorflow":
raise ValueError(
"`sparse` can only be set to True with the "
"TensorFlow backend."
)
self.layer = tf.keras.layers.Discretization(
bin_boundaries=bin_boundaries,
num_bins=num_bins,
epsilon=epsilon,
output_mode=output_mode,
sparse=sparse,
name=name,
dtype=dtype,
**kwargs,
if sparse and output_mode == "int":
raise ValueError(
"`sparse` may only be true if `output_mode` is "
"`'one_hot'`, `'multi_hot'`, or `'count'`. "
f"Received: sparse={sparse} and "
f"output_mode={output_mode}"
)

argument_validation.validate_string_arg(
output_mode,
allowable_strings=(
"int",
"one_hot",
"multi_hot",
"count",
),
caller_name=self.__class__.__name__,
arg_name="output_mode",
)
self.bin_boundaries = bin_boundaries
if self.bin_boundaries:
self.built = True
self._convert_input_args = False
self._allow_non_tensor_positional_args = True

if num_bins is not None and num_bins < 0:
raise ValueError(
"`num_bins` must be greater than or equal to 0. "
f"Received: `num_bins={num_bins}`"
)
if num_bins is not None and bin_boundaries is not None:
if len(bin_boundaries) != num_bins - 1:
raise ValueError(
"Both `num_bins` and `bin_boundaries` should not be "
f"set. Received: `num_bins={num_bins}` and "
f"`bin_boundaries={bin_boundaries}`"
)

self.input_bin_boundaries = bin_boundaries
self.bin_boundaries = (
bin_boundaries if bin_boundaries is not None else []
)
self.num_bins = num_bins
self.epsilon = epsilon
self.output_mode = output_mode
self.sparse = sparse
self.supports_jit = False

def build(self, input_shape):
self.layer.build(input_shape)
if self.bin_boundaries:
self.built = True
self.summary = None
else:
self.summary = np.array([[], []], dtype="float32")
self._convert_input_args = False
self._allow_non_tensor_positional_args = True

def build(self, input_shape=None):
self.built = True

# We override this method solely to generate a docstring.
def adapt(self, data, batch_size=None, steps=None):
@property
def input_dtype(self):
return backend.floatx()

def adapt(self, data, steps=None):
"""Computes bin boundaries from quantiles in a input dataset.
Calling `adapt()` on a `Discretization` layer is an alternative to
Expand All @@ -159,47 +187,81 @@ def adapt(self, data, batch_size=None, steps=None):
data: The data to train on. It can be passed either as a
batched `tf.data.Dataset`,
or as a NumPy array.
batch_size: Integer or `None`.
Number of samples per state update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` if your data is in the
form of a `tf.data.Dataset`
(it is expected to be already batched).
steps: Integer or `None`.
Total number of steps (batches of samples)
When training with input tensors such as
the default `None` is equal to
the number of samples in your dataset divided by
the batch size, or 1 if that cannot be determined.
Total number of steps (batches of samples) to process.
If `data` is a `tf.data.Dataset`, and `steps` is `None`,
`adapt()` will run until the input dataset is exhausted.
When passing an infinitely
repeating dataset, you must specify the `steps` argument. This
argument is not supported with array inputs or list inputs.
"""
self.layer.adapt(data, batch_size=batch_size, steps=steps)
if self.input_bin_boundaries is not None:
raise ValueError(
"Cannot adapt a Discretization layer that has been initialized "
"with `bin_boundaries`, use `num_bins` instead."
)
self.reset_state()
if isinstance(data, tf.data.Dataset):
if steps is not None:
data = data.take(steps)
for batch in data:
self.update_state(batch)
else:
self.update_state(data)
self.finalize_state()

def update_state(self, data):
self.layer.update_state(data)
data = np.array(data).astype("float32")
summary = summarize(data, self.epsilon)
self.summary = merge_summaries(summary, self.summary, self.epsilon)

def finalize_state(self):
self.layer.finalize_state()
if self.input_bin_boundaries is not None:
return
self.bin_boundaries = get_bin_boundaries(
self.summary, self.num_bins
).tolist()

def reset_state(self):
self.layer.reset_state()
if self.input_bin_boundaries is not None:
return
self.summary = np.array([[], []], dtype="float32")

def compute_output_spec(self, inputs):
return backend.KerasTensor(shape=inputs.shape, dtype="int32")
return backend.KerasTensor(shape=inputs.shape, dtype=self.compute_dtype)

def __call__(self, inputs):
if not isinstance(inputs, (tf.Tensor, np.ndarray, backend.KerasTensor)):
inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))
if not self.built:
self.build(inputs.shape)
return super().__call__(inputs)
def load_own_variables(self, store):
if len(store) == 1:
# Legacy format case
self.summary = store["0"]
return

def call(self, inputs):
outputs = self.layer.call(inputs)
if not isinstance(
inputs,
(
tf.Tensor,
tf.SparseTensor,
tf.RaggedTensor,
np.ndarray,
backend.KerasTensor,
),
):
inputs = tf.convert_to_tensor(
backend.convert_to_numpy(inputs), dtype=self.input_dtype
)

from keras_core.backend.tensorflow.numpy import digitize

indices = digitize(inputs, self.bin_boundaries)

outputs = tf_utils.encode_categorical_inputs(
indices,
output_mode=self.output_mode,
depth=len(self.bin_boundaries) + 1,
sparse=self.sparse,
dtype=self.compute_dtype,
)
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
Expand All @@ -217,3 +279,97 @@ def get_config(self):
"name": self.name,
"dtype": self.dtype,
}


def summarize(values, epsilon):
"""Reduce a 1D sequence of values to a summary.
This algorithm is based on numpy.quantiles but modified to allow for
intermediate steps between multiple data sets. It first finds the target
number of bins as the reciprocal of epsilon and then takes the individual
values spaced at appropriate intervals to arrive at that target.
The final step is to return the corresponding counts between those values
If the target num_bins is larger than the size of values, the whole array is
returned (with weights of 1).
Args:
values: 1D `np.ndarray` to be summarized.
epsilon: A `'float32'` that determines the approximate desired
precision.
Returns:
A 2D `np.ndarray` that is a summary of the inputs. First column is the
interpolated partition values, the second is the weights (counts).
"""
values = np.reshape(values, [-1])
values = np.sort(values)
elements = np.size(values)
num_buckets = 1.0 / epsilon
increment = elements / num_buckets
start = increment
step = max(increment, 1)
boundaries = values[int(start) :: int(step)]
weights = np.ones_like(boundaries)
weights = weights * step
return np.stack([boundaries, weights])


def merge_summaries(prev_summary, next_summary, epsilon):
"""Weighted merge sort of summaries.
Given two summaries of distinct data, this function merges (and compresses)
them to stay within `epsilon` error tolerance.
Args:
prev_summary: 2D `np.ndarray` summary to be merged with `next_summary`.
next_summary: 2D `np.ndarray` summary to be merged with `prev_summary`.
epsilon: A float that determines the approxmiate desired precision.
Returns:
A 2-D `np.ndarray` that is a merged summary. First column is the
interpolated partition values, the second is the weights (counts).
"""
merged = np.concatenate((prev_summary, next_summary), axis=1)
merged = np.take(merged, np.argsort(merged[0]), axis=1)
return compress_summary(merged, epsilon)


def get_bin_boundaries(summary, num_bins):
return compress_summary(summary, 1.0 / num_bins)[0, :-1]


def compress_summary(summary, epsilon):
"""Compress a summary to within `epsilon` accuracy.
The compression step is needed to keep the summary sizes small after
merging, and also used to return the final target boundaries. It finds the
new bins based on interpolating cumulative weight percentages from the large
summary. Taking the difference of the cumulative weights from the previous
bin's cumulative weight will give the new weight for that bin.
Args:
summary: 2D `np.ndarray` summary to be compressed.
epsilon: A `'float32'` that determines the approxmiate desired
precision.
Returns:
A 2D `np.ndarray` that is a compressed summary. First column is the
interpolated partition values, the second is the weights (counts).
"""
if summary.shape[1] * epsilon < 1:
return summary

percents = epsilon + np.arange(0.0, 1.0, epsilon)
cum_weights = summary[1].cumsum()
cum_weight_percents = cum_weights / cum_weights[-1]
new_bins = np.interp(percents, cum_weight_percents, summary[0])
cum_weights = np.interp(percents, cum_weight_percents, cum_weights)
new_weights = cum_weights - np.concatenate(
(np.array([0]), cum_weights[:-1])
)
summary = np.stack((new_bins, new_weights))
return summary.astype("float32")


def bucketize(inputs, boundaries):
return tf.raw_ops.Bucketize(input=inputs, boundaries=boundaries)
Loading

0 comments on commit 7adbaf7

Please sign in to comment.