Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sml] add LabelBinarizer, Binarizer, Normalizer in jax #470

Merged
merged 12 commits into from
Jan 15, 2024
22 changes: 22 additions & 0 deletions sml/preprocessing/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_library")

package(default_visibility = ["//visibility:public"])

py_library(
name = "preprocessing",
srcs = ["preprocessing.py"],
)
26 changes: 26 additions & 0 deletions sml/preprocessing/emulations/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_binary")

package(default_visibility = ["//visibility:public"])

py_binary(
name = "preprocessing_emul",
srcs = ["preprocessing_emul.py"],
deps = [
"//sml/preprocessing",
"//sml/utils:emulation",
],
)
166 changes: 166 additions & 0 deletions sml/preprocessing/emulations/preprocessing_emul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax.numpy as jnp
import numpy as np
from sklearn import preprocessing

import sml.utils.emulation as emulation
from sml.preprocessing.preprocessing import Binarizer, LabelBinarizer, Normalizer


def emul_labelbinarizer():
def labelbinarize(X, Y):
transformer = LabelBinarizer(neg_label=-2, pos_label=3)
transformer.fit(X, n_classes=4)
transformed = transformer.transform(Y)
inv_transformed = transformer.inverse_transform(transformed)
return transformed, inv_transformed

X = jnp.array([1, 2, 4, 6])
Y = jnp.array([1, 6])

transformer = preprocessing.LabelBinarizer(neg_label=-2, pos_label=3)
transformer.fit(X)
sk_transformed = transformer.transform(Y)
sk_inv_transformed = transformer.inverse_transform(sk_transformed)
# print("sklearn:\n", sk_transformed)
# print("sklearn:\n", sk_inv_transformed)

X, Y = emulator.seal(X, Y)
spu_transformed, spu_inv_transformed = emulator.run(labelbinarize)(X, Y)
# print("spu:\n", spu_transformed)
# print("spu:\n", spu_inv_transformed)

np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0)
np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0)


def emul_labelbinarizer_binary():
def labelbinarize(X):
transformer = LabelBinarizer()
transformed = transformer.fit_transform(X, n_classes=2, unique=False)
inv_transformed = transformer.inverse_transform(transformed)
return transformed, inv_transformed

X = jnp.array([1, -1, -1, 1])
transformer = preprocessing.LabelBinarizer()
sk_transformed = transformer.fit_transform(X)
sk_inv_transformed = transformer.inverse_transform(sk_transformed)
# print("sklearn:\n", sk_transformed)
# print("sklearn:\n", sk_inv_transformed)

X = emulator.seal(X)
spu_transformed, spu_inv_transformed = emulator.run(labelbinarize)(X)
# print("spu:\n", spu_transformed)
# print("spu:\n", spu_inv_transformed)

np.testing.assert_allclose(sk_transformed, spu_transformed, rtol=0, atol=0)
np.testing.assert_allclose(sk_inv_transformed, spu_inv_transformed, rtol=0, atol=0)


def emul_labelbinarizer_unseen():
def labelbinarize(X, Y):
transformer = LabelBinarizer()
transformer.fit(X, n_classes=3)
return transformer.transform(Y)

X = jnp.array([2, 4, 5])
Y = jnp.array([1, 2, 3, 4, 5, 6])

transformer = preprocessing.LabelBinarizer()
transformer.fit(X)
sk_result = transformer.transform(Y)
# print("sklearn:\n", sk_result)

X, Y = emulator.seal(X, Y)
spu_result = emulator.run(labelbinarize)(X, Y)
# print("spu:\n", spu_result)

np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0)


def emul_binarizer():
def binarize(X):
transformer = Binarizer()
return transformer.transform(X)

X = jnp.array([[1.0, -1.0, 2.0], [2.0, 0.0, 0.0], [0.0, 1.0, -1.0]])

transformer = preprocessing.Binarizer()
sk_result = transformer.transform(X)
# print("sklearn:\n", sk_result)

X = emulator.seal(X)
spu_result = emulator.run(binarize)(X)
# print("spu:\n", spu_result)

np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0)


def emul_normalizer():
def normalize_l1(X):
transformer = Normalizer(norm="l1")
return transformer.transform(X)

def normalize_l2(X):
transformer = Normalizer()
return transformer.transform(X)

def normalize_max(X):
transformer = Normalizer(norm="max")
return transformer.transform(X)

X = jnp.array([[4, 1, 2, 2], [1, 3, 9, 3], [5, 7, 5, 1]])

transformer_l1 = preprocessing.Normalizer(norm="l1")
sk_result_l1 = transformer_l1.transform(X)
transformer_l2 = preprocessing.Normalizer()
sk_result_l2 = transformer_l2.transform(X)
transformer_max = preprocessing.Normalizer(norm="max")
sk_result_max = transformer_max.transform(X)
# print("sklearn:\n", sk_result_l1)
# print("sklearn:\n", sk_result_l2)
# print("sklearn:\n", sk_result_max)

X = emulator.seal(X)
spu_result_l1 = emulator.run(normalize_l1)(X)
spu_result_l2 = emulator.run(normalize_l2)(X)
spu_result_max = emulator.run(normalize_max)(X)
# print("spu:\n", spu_result_l1)
# print("spu:\n", spu_result_l2)
# print("spu:\n", spu_result_max)

np.testing.assert_allclose(sk_result_l1, spu_result_l1, rtol=0, atol=1e-4)
np.testing.assert_allclose(sk_result_l2, spu_result_l2, rtol=0, atol=1e-4)
np.testing.assert_allclose(sk_result_max, spu_result_max, rtol=0, atol=1e-4)


if __name__ == "__main__":
try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(
emulation.CLUSTER_ABY3_3PC,
emulation.Mode.MULTIPROCESS,
bandwidth=300,
latency=20,
)
emulator.up()
emul_labelbinarizer()
emul_labelbinarizer_binary()
emul_labelbinarizer_unseen()
emul_binarizer()
emul_normalizer()
finally:
emulator.down()
Loading