Skip to content

Commit

Permalink
[OSCP]使用SPU实现随机森林算法 (#792)
Browse files Browse the repository at this point in the history
fixed : #254
  • Loading branch information
xbw886 authored Aug 1, 2024
1 parent 3601e85 commit 854f3ef
Show file tree
Hide file tree
Showing 6 changed files with 485 additions and 3 deletions.
14 changes: 13 additions & 1 deletion sml/ensemble/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Ant Group Co., Ltd.
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -11,3 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_library")

package(default_visibility = ["//visibility:public"])

py_library(
name = "forest",
srcs = ["forest.py"],
deps = [
"//sml/tree",
],
)
15 changes: 14 additions & 1 deletion sml/ensemble/emulations/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Ant Group Co., Ltd.
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -11,3 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_binary")

package(default_visibility = ["//visibility:public"])

py_binary(
name = "forest_emul",
srcs = ["forest_emul.py"],
deps = [
"//sml/ensemble:forest",
"//sml/utils:emulation",
],
)
125 changes: 125 additions & 0 deletions sml/ensemble/emulations/forest_emul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time

import jax.numpy as jnp
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier

import sml.utils.emulation as emulation
from sml.ensemble.forest import RandomForestClassifier as sml_rfc

MAX_DEPTH = 3
CONFIG_FILE = emulation.CLUSTER_ABY3_3PC


def emul_forest(mode=emulation.Mode.MULTIPROCESS):
def proc_wrapper(
n_estimators,
max_features,
criterion,
splitter,
max_depth,
bootstrap,
max_samples,
n_labels,
):
rf_custom = sml_rfc(
n_estimators,
max_features,
criterion,
splitter,
max_depth,
bootstrap,
max_samples,
n_labels,
)

def proc(X, y):
rf_custom_fit = rf_custom.fit(X, y)
result = rf_custom_fit.predict(X)
return result

return proc

def load_data():
iris = load_iris()
iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target)
# sorted_features: n_samples * n_features_in
n_samples, n_features_in = iris_data.shape
n_labels = len(jnp.unique(iris_label))
sorted_features = jnp.sort(iris_data, axis=0)
new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2
new_features = jnp.greater_equal(
iris_data[:, :], new_threshold[:, jnp.newaxis, :]
)
new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1)

X, y = new_features[:, ::3], iris_label[:]
return X, y

try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20)
emulator.up()

# load mock data
X, y = load_data()
n_labels = jnp.unique(y).shape[0]

# compare with sklearn
rf = RandomForestClassifier(
n_estimators=3,
max_features=None,
criterion='gini',
max_depth=MAX_DEPTH,
bootstrap=False,
max_samples=None,
)
start = time.time()
rf = rf.fit(X, y)
score_plain = rf.score(X, y)
end = time.time()
print(f"Running time in SKlearn: {end - start:.2f}s")

# mark these data to be protected in SPU
X_spu, y_spu = emulator.seal(X, y)

# run
proc = proc_wrapper(
n_estimators=3,
max_features=0.7,
criterion='gini',
splitter='best',
max_depth=3,
bootstrap=False,
max_samples=None,
n_labels=n_labels,
)
start = time.time()
result = emulator.run(proc)(X_spu, y_spu)
end = time.time()
score_encrpted = jnp.mean((result == y))
print(f"Running time in SPU: {end - start:.2f}s")

# print acc
print(f"Accuracy in SKlearn: {score_plain:.2f}")
print(f"Accuracy in SPU: {score_encrpted:.2f}")

finally:
emulator.down()


if __name__ == "__main__":
emul_forest(emulation.Mode.MULTIPROCESS)
201 changes: 201 additions & 0 deletions sml/ensemble/forest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import random

import jax
import jax.numpy as jnp
from jax import lax

from sml.tree.tree import DecisionTreeClassifier as sml_dtc


class RandomForestClassifier:
"""A random forest classifier based on DecisionTreeClassifier.
Parameters
----------
n_estimators : int
The number of trees in the forest. Must specify an integer > 0.
max_features : int, float, "auto", "sqrt", "log2", or None.
The number of features to consider when looking for the best split.
If it's an integer, must 0 < integer < n_features.
If it's an float, must 0 < float <= 1.
criterion : {"gini"}, default="gini"
The function to measure the quality of a split. Supported criteria are
"gini" for the Gini impurity.
splitter : {"best"}, default="best"
The strategy used to choose the split at each node. Supported
strategies are "best" to choose the best split.
max_depth : int
The maximum depth of the tree. Must specify an integer > 0.
bootstrap : bool
Whether bootstrap samples are used when building trees.
max_samples : int, float ,None, default=None
The number of samples to draw from X to train each base estimator.
This parameter is only valid if bootstrap is ture.
If it's an integer, must 0 < integer < n_samples.
If it's an float, must 0 < float <= 1.
n_labels: int
The max number of labels.
"""

def __init__(
self,
n_estimators,
max_features,
criterion,
splitter,
max_depth,
bootstrap,
max_samples,
n_labels,
):
assert criterion == "gini", "criteria other than gini is not supported."
assert splitter == "best", "splitter other than best is not supported."
assert (
n_estimators is not None and n_estimators > 0
), "n_estimators should not be None and must > 0."
assert (
max_depth is not None and max_depth > 0
), "max_depth should not be None and must > 0."
assert isinstance(
bootstrap, bool
), "bootstrap should be a boolean value (True or False)"

self.n_estimators = n_estimators
self.max_features = max_features
self.criterion = criterion
self.splitter = splitter
self.max_depth = max_depth
self.bootstrap = bootstrap
self.max_samples = max_samples
self.n_labels = n_labels

self.trees = []
self.features_indices = []

def _calculate_max_samples(self, max_samples, n_samples):
if isinstance(max_samples, int):
assert (
max_samples <= n_samples
), "max_samples should not exceed n_samples when it's an integer"
return max_samples
elif isinstance(max_samples, float):
assert (
0 < max_samples <= 1
), "max_samples should be in the range (0, 1] when it's a float"
return int(max_samples * n_samples)
else:
return n_samples

def _bootstrap_sample(self, X, y):
n_samples = X.shape[0]
max_samples = self._calculate_max_samples(self.max_samples, n_samples)

if not self.bootstrap:
return X, y

# 实现bootstrap
population = range(n_samples)
indices = random.sample(population, max_samples)

indices = jnp.array(indices)
return X[indices], y[indices]

def _select_features(self, n, k):
indices = range(n)
selected_elements = random.sample(indices, k)
return selected_elements

def _calculate_max_features(self, max_features, n_features):
if isinstance(max_features, int):
assert (
0 < max_features <= n_features
), "0 < max_features <= n_features when it's an integer"
return max_features

elif isinstance(max_features, float):
assert (
0 < max_features <= 1
), "max_features should be in the range (0, 1] when it's a float"
return int(max_features * n_features)

elif isinstance(max_features, str):
if max_features == 'sqrt':
return int(math.sqrt(n_features))
elif max_features == 'log2':
return int(math.log2(n_features))
else:
return n_features
else:
return n_features

def fit(self, X, y):
n_samples, n_features = X.shape
self.n_features = n_features
self.max_features = self._calculate_max_features(
self.max_features, self.n_features
)
self.label_list = jnp.arange(self.n_labels)

self.trees = []
self.features_indices = []

for _ in range(self.n_estimators):
X_sample, y_sample = self._bootstrap_sample(X, y)
features = self._select_features(self.n_features, self.max_features)

tree = sml_dtc(self.criterion, self.splitter, self.max_depth, self.n_labels)
tree.fit(X_sample[:, features], y_sample)
self.trees.append(tree)
self.features_indices.append(features)

return self

def jax_mode_row_vectorized(self, data):
label_list = jnp.array(self.label_list)

data_expanded = jnp.expand_dims(data, axis=-1)
label_expanded = jnp.expand_dims(label_list, axis=0)

mask = (data_expanded == label_expanded).astype(jnp.int32)

counts = jnp.sum(mask, axis=1)
mode_indices = jnp.argmax(counts, axis=1)

modes = label_list[mode_indices]
return modes

def predict(self, X):
predictions_list = []
for i, tree in enumerate(self.trees):
features = self.features_indices[i]
predictions = tree.predict(X[:, features])
predictions_list.append(predictions)

tree_predictions = jnp.array(predictions_list).T

y_pred = self.jax_mode_row_vectorized(tree_predictions)

return y_pred.ravel()
16 changes: 15 additions & 1 deletion sml/ensemble/tests/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Ant Group Co., Ltd.
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -11,3 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_test")

package(default_visibility = ["//visibility:public"])

py_test(
name = "forest_test",
srcs = ["forest_test.py"],
deps = [
"//sml/ensemble:forest",
"//spu:init",
"//spu/utils:simulation",
],
)
Loading

0 comments on commit 854f3ef

Please sign in to comment.