Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626305807
  • Loading branch information
achoum authored and copybara-github committed Apr 19, 2024
1 parent e38179c commit c82f647
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 1 deletion.
4 changes: 4 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -292,5 +292,9 @@ py_test(
# absl/testing:parameterized dep,
# jax dep,
# numpy dep,
"@ydf_cc//yggdrasil_decision_forests/dataset:data_spec_py_proto",
"//ydf/dataset:dataspec",
"//ydf/learner:generic_learner",
"//ydf/learner:specialized_learners",
],
)
77 changes: 76 additions & 1 deletion yggdrasil_decision_forests/port/python/ydf/model/export_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

"""Utilities to export JAX models."""

from typing import Any, Sequence
import dataclasses
from typing import Any, Sequence, Dict, Optional

from yggdrasil_decision_forests.dataset import data_spec_pb2 as ds_pb
from ydf.dataset import dataspec as dataspec_lib
from ydf.model import generic_model

# pytype: disable=import-error
# pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -61,3 +66,73 @@ def to_compact_jax_array(values: Sequence[int]) -> jax.Array:
"""Converts a list of integers to a compact Jax array."""

return jnp.asarray(values, dtype=compact_dtype(values))


@dataclasses.dataclass
class FeatureEncoding:
"""Utility to prepare feature values before being fed into the Jax model.
Does the following:
- Encodes categorical strings into categorical integers.
Attributes:
categorical: Mapping between categorical-string feature to the dictionary of
categorical-string value to categorical-integer value.
categorical_out_of_vocab_item: Integer value representing an out of
vocabulary item.
"""

categorical: Dict[str, Dict[str, int]]
categorical_out_of_vocab_item: int = 0

@classmethod
def build(
cls,
input_features: Sequence[generic_model.InputFeature],
dataspec: ds_pb.DataSpecification,
) -> Optional["FeatureEncoding"]:
"""Creates a FeatureEncoding object.
If the input feature does not require feature encoding, returns None.
Args:
input_features: All the input features of a model.
dataspec: Dataspec of the model.
Returns:
A FeatureEncoding or None.
"""

categorical = {}
for input_feature in input_features:
column_spec = dataspec.columns[input_feature.column_idx]
if (
input_feature.semantic
in [
dataspec_lib.Semantic.CATEGORICAL,
dataspec_lib.Semantic.CATEGORICAL_SET,
]
and not column_spec.categorical.is_already_integerized
):
categorical[input_feature.name] = {
key: item.index
for key, item in column_spec.categorical.items.items()
}
if not categorical:
return None
return FeatureEncoding(categorical=categorical)

def encode(self, feature_values: Dict[str, Any]) -> Dict[str, jax.Array]:
"""Encodes feature values for a model."""

def encode_item(key: str, value: Any) -> jax.Array:
categorical_map = self.categorical.get(key)
if categorical_map is not None:
# Categorical string encoding.
value = [
categorical_map.get(x, self.categorical_out_of_vocab_item)
for x in value
]
return jax.numpy.asarray(value)

return {k: encode_item(k, v) for k, v in feature_values.items()}
129 changes: 129 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/model/jax_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import numpy as np
from yggdrasil_decision_forests.dataset import data_spec_pb2 as ds_pb
from ydf.dataset import dataspec as dataspec_lib
from ydf.learner import specialized_learners
from ydf.model import export_jax as to_jax
from ydf.model import generic_model


class JaxModelTest(parameterized.TestCase):

def create_dataset(self, columns: List[str]) -> Dict[str, Any]:
"""Creates a dataset with random values."""
data = {
# Single-dim features
"f1": np.random.random(size=100),
"f2": np.random.random(size=100),
"i1": np.random.randint(100, size=100),
"i2": np.random.randint(100, size=100),
"c1": np.random.choice(["x", "y", "z"], size=100, p=[0.6, 0.3, 0.1]),
"b1": np.random.randint(2, size=100).astype(np.bool_),
"b2": np.random.randint(2, size=100).astype(np.bool_),
# Cat-set features
"cs1": [[], ["a", "b", "c"], ["b", "c"], ["a"]] * 25,
# Multi-dim features
"multi_f1": np.random.random(size=(100, 5)),
"multi_f2": np.random.random(size=(100, 5)),
"multi_i1": np.random.randint(100, size=(100, 5)),
"multi_c1": np.random.choice(["x", "y", "z"], size=(100, 5)),
"multi_b1": np.random.randint(2, size=(100, 5)).astype(np.bool_),
# Labels
"label_class_binary": np.random.choice(["l1", "l2"], size=100),
"label_class_multi": np.random.choice(["l1", "l2", "l3"], size=100),
"label_regress": np.random.random(size=100),
}
return {k: data[k] for k in columns}

@parameterized.parameters(
((0,), jnp.int8),
((0, 1, -1), jnp.int8),
Expand All @@ -45,6 +77,103 @@ def test_compact_dtype_non_supported(self):
with self.assertRaisesRegex(ValueError, "No supported compact dtype"):
to_jax.compact_dtype((0x80000000,))

def test_feature_encoding_basic(self):
feature_encoding = to_jax.FeatureEncoding.build(
[
generic_model.InputFeature(
"f1", dataspec_lib.Semantic.NUMERICAL, 0
),
generic_model.InputFeature(
"f2", dataspec_lib.Semantic.CATEGORICAL, 1
),
generic_model.InputFeature(
"f3", dataspec_lib.Semantic.CATEGORICAL, 2
),
],
ds_pb.DataSpecification(
created_num_rows=3,
columns=(
ds_pb.Column(
name="f1",
type=ds_pb.ColumnType.NUMERICAL,
),
ds_pb.Column(
name="f2",
type=ds_pb.ColumnType.CATEGORICAL,
categorical=ds_pb.CategoricalSpec(
items={
"<OOD>": ds_pb.CategoricalSpec.VocabValue(index=0),
"A": ds_pb.CategoricalSpec.VocabValue(index=1),
"B": ds_pb.CategoricalSpec.VocabValue(index=2),
},
),
),
ds_pb.Column(
name="f3",
type=ds_pb.ColumnType.CATEGORICAL,
categorical=ds_pb.CategoricalSpec(
is_already_integerized=True,
),
),
ds_pb.Column(
name="f4",
type=ds_pb.ColumnType.CATEGORICAL,
categorical=ds_pb.CategoricalSpec(
items={
"<OOD>": ds_pb.CategoricalSpec.VocabValue(index=0),
"X": ds_pb.CategoricalSpec.VocabValue(index=1),
"Y": ds_pb.CategoricalSpec.VocabValue(index=2),
},
),
),
),
),
)
self.assertIsNotNone(feature_encoding)
self.assertDictEqual(
feature_encoding.categorical, {"f2": {"<OOD>": 0, "A": 1, "B": 2}}
)

def test_feature_encoding_on_model(self):
columns = ["f1", "i1", "c1", "b1", "cs1", "label_class_binary"]
model = specialized_learners.RandomForestLearner(
label="label_class_binary",
num_trees=2,
features=[("cs1", dataspec_lib.Semantic.CATEGORICAL_SET)],
include_all_columns=True,
).train(self.create_dataset(columns))
feature_encoding = to_jax.FeatureEncoding.build(
model.input_features(), model.data_spec()
)
self.assertIsNotNone(feature_encoding)
self.assertDictEqual(
feature_encoding.categorical,
{
"c1": {"<OOD>": 0, "x": 1, "y": 2, "z": 3},
"cs1": {"<OOD>": 0, "a": 1, "b": 2, "c": 3},
},
)

encoded_features = feature_encoding.encode(
{"f1": [1, 2, 3], "c1": ["x", "y", "other"]}
)
np.testing.assert_array_equal(
encoded_features["f1"], jax.numpy.asarray([1, 2, 3])
)
np.testing.assert_array_equal(
encoded_features["c1"], jax.numpy.asarray([1, 2, 0])
)

def test_feature_encoding_is_none(self):
columns = ["f1", "i1", "label_class_binary"]
model = specialized_learners.RandomForestLearner(
label="label_class_binary", num_trees=2
).train(self.create_dataset(columns))
feature_encoding = to_jax.FeatureEncoding.build(
model.input_features(), model.data_spec()
)
self.assertIsNone(feature_encoding)


if __name__ == "__main__":
absltest.main()

0 comments on commit c82f647

Please sign in to comment.