-
Notifications
You must be signed in to change notification settings - Fork 101
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fixed : #254
- Loading branch information
Showing
6 changed files
with
485 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright 2024 Ant Group Co., Ltd. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import time | ||
|
||
import jax.numpy as jnp | ||
from sklearn.datasets import load_iris | ||
from sklearn.ensemble import RandomForestClassifier | ||
|
||
import sml.utils.emulation as emulation | ||
from sml.ensemble.forest import RandomForestClassifier as sml_rfc | ||
|
||
MAX_DEPTH = 3 | ||
CONFIG_FILE = emulation.CLUSTER_ABY3_3PC | ||
|
||
|
||
def emul_forest(mode=emulation.Mode.MULTIPROCESS): | ||
def proc_wrapper( | ||
n_estimators, | ||
max_features, | ||
criterion, | ||
splitter, | ||
max_depth, | ||
bootstrap, | ||
max_samples, | ||
n_labels, | ||
): | ||
rf_custom = sml_rfc( | ||
n_estimators, | ||
max_features, | ||
criterion, | ||
splitter, | ||
max_depth, | ||
bootstrap, | ||
max_samples, | ||
n_labels, | ||
) | ||
|
||
def proc(X, y): | ||
rf_custom_fit = rf_custom.fit(X, y) | ||
result = rf_custom_fit.predict(X) | ||
return result | ||
|
||
return proc | ||
|
||
def load_data(): | ||
iris = load_iris() | ||
iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) | ||
# sorted_features: n_samples * n_features_in | ||
n_samples, n_features_in = iris_data.shape | ||
n_labels = len(jnp.unique(iris_label)) | ||
sorted_features = jnp.sort(iris_data, axis=0) | ||
new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 | ||
new_features = jnp.greater_equal( | ||
iris_data[:, :], new_threshold[:, jnp.newaxis, :] | ||
) | ||
new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) | ||
|
||
X, y = new_features[:, ::3], iris_label[:] | ||
return X, y | ||
|
||
try: | ||
# bandwidth and latency only work for docker mode | ||
emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20) | ||
emulator.up() | ||
|
||
# load mock data | ||
X, y = load_data() | ||
n_labels = jnp.unique(y).shape[0] | ||
|
||
# compare with sklearn | ||
rf = RandomForestClassifier( | ||
n_estimators=3, | ||
max_features=None, | ||
criterion='gini', | ||
max_depth=MAX_DEPTH, | ||
bootstrap=False, | ||
max_samples=None, | ||
) | ||
start = time.time() | ||
rf = rf.fit(X, y) | ||
score_plain = rf.score(X, y) | ||
end = time.time() | ||
print(f"Running time in SKlearn: {end - start:.2f}s") | ||
|
||
# mark these data to be protected in SPU | ||
X_spu, y_spu = emulator.seal(X, y) | ||
|
||
# run | ||
proc = proc_wrapper( | ||
n_estimators=3, | ||
max_features=0.7, | ||
criterion='gini', | ||
splitter='best', | ||
max_depth=3, | ||
bootstrap=False, | ||
max_samples=None, | ||
n_labels=n_labels, | ||
) | ||
start = time.time() | ||
result = emulator.run(proc)(X_spu, y_spu) | ||
end = time.time() | ||
score_encrpted = jnp.mean((result == y)) | ||
print(f"Running time in SPU: {end - start:.2f}s") | ||
|
||
# print acc | ||
print(f"Accuracy in SKlearn: {score_plain:.2f}") | ||
print(f"Accuracy in SPU: {score_encrpted:.2f}") | ||
|
||
finally: | ||
emulator.down() | ||
|
||
|
||
if __name__ == "__main__": | ||
emul_forest(emulation.Mode.MULTIPROCESS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
# Copyright 2024 Ant Group Co., Ltd. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import math | ||
import random | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax import lax | ||
|
||
from sml.tree.tree import DecisionTreeClassifier as sml_dtc | ||
|
||
|
||
class RandomForestClassifier: | ||
"""A random forest classifier based on DecisionTreeClassifier. | ||
Parameters | ||
---------- | ||
n_estimators : int | ||
The number of trees in the forest. Must specify an integer > 0. | ||
max_features : int, float, "auto", "sqrt", "log2", or None. | ||
The number of features to consider when looking for the best split. | ||
If it's an integer, must 0 < integer < n_features. | ||
If it's an float, must 0 < float <= 1. | ||
criterion : {"gini"}, default="gini" | ||
The function to measure the quality of a split. Supported criteria are | ||
"gini" for the Gini impurity. | ||
splitter : {"best"}, default="best" | ||
The strategy used to choose the split at each node. Supported | ||
strategies are "best" to choose the best split. | ||
max_depth : int | ||
The maximum depth of the tree. Must specify an integer > 0. | ||
bootstrap : bool | ||
Whether bootstrap samples are used when building trees. | ||
max_samples : int, float ,None, default=None | ||
The number of samples to draw from X to train each base estimator. | ||
This parameter is only valid if bootstrap is ture. | ||
If it's an integer, must 0 < integer < n_samples. | ||
If it's an float, must 0 < float <= 1. | ||
n_labels: int | ||
The max number of labels. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n_estimators, | ||
max_features, | ||
criterion, | ||
splitter, | ||
max_depth, | ||
bootstrap, | ||
max_samples, | ||
n_labels, | ||
): | ||
assert criterion == "gini", "criteria other than gini is not supported." | ||
assert splitter == "best", "splitter other than best is not supported." | ||
assert ( | ||
n_estimators is not None and n_estimators > 0 | ||
), "n_estimators should not be None and must > 0." | ||
assert ( | ||
max_depth is not None and max_depth > 0 | ||
), "max_depth should not be None and must > 0." | ||
assert isinstance( | ||
bootstrap, bool | ||
), "bootstrap should be a boolean value (True or False)" | ||
|
||
self.n_estimators = n_estimators | ||
self.max_features = max_features | ||
self.criterion = criterion | ||
self.splitter = splitter | ||
self.max_depth = max_depth | ||
self.bootstrap = bootstrap | ||
self.max_samples = max_samples | ||
self.n_labels = n_labels | ||
|
||
self.trees = [] | ||
self.features_indices = [] | ||
|
||
def _calculate_max_samples(self, max_samples, n_samples): | ||
if isinstance(max_samples, int): | ||
assert ( | ||
max_samples <= n_samples | ||
), "max_samples should not exceed n_samples when it's an integer" | ||
return max_samples | ||
elif isinstance(max_samples, float): | ||
assert ( | ||
0 < max_samples <= 1 | ||
), "max_samples should be in the range (0, 1] when it's a float" | ||
return int(max_samples * n_samples) | ||
else: | ||
return n_samples | ||
|
||
def _bootstrap_sample(self, X, y): | ||
n_samples = X.shape[0] | ||
max_samples = self._calculate_max_samples(self.max_samples, n_samples) | ||
|
||
if not self.bootstrap: | ||
return X, y | ||
|
||
# 实现bootstrap | ||
population = range(n_samples) | ||
indices = random.sample(population, max_samples) | ||
|
||
indices = jnp.array(indices) | ||
return X[indices], y[indices] | ||
|
||
def _select_features(self, n, k): | ||
indices = range(n) | ||
selected_elements = random.sample(indices, k) | ||
return selected_elements | ||
|
||
def _calculate_max_features(self, max_features, n_features): | ||
if isinstance(max_features, int): | ||
assert ( | ||
0 < max_features <= n_features | ||
), "0 < max_features <= n_features when it's an integer" | ||
return max_features | ||
|
||
elif isinstance(max_features, float): | ||
assert ( | ||
0 < max_features <= 1 | ||
), "max_features should be in the range (0, 1] when it's a float" | ||
return int(max_features * n_features) | ||
|
||
elif isinstance(max_features, str): | ||
if max_features == 'sqrt': | ||
return int(math.sqrt(n_features)) | ||
elif max_features == 'log2': | ||
return int(math.log2(n_features)) | ||
else: | ||
return n_features | ||
else: | ||
return n_features | ||
|
||
def fit(self, X, y): | ||
n_samples, n_features = X.shape | ||
self.n_features = n_features | ||
self.max_features = self._calculate_max_features( | ||
self.max_features, self.n_features | ||
) | ||
self.label_list = jnp.arange(self.n_labels) | ||
|
||
self.trees = [] | ||
self.features_indices = [] | ||
|
||
for _ in range(self.n_estimators): | ||
X_sample, y_sample = self._bootstrap_sample(X, y) | ||
features = self._select_features(self.n_features, self.max_features) | ||
|
||
tree = sml_dtc(self.criterion, self.splitter, self.max_depth, self.n_labels) | ||
tree.fit(X_sample[:, features], y_sample) | ||
self.trees.append(tree) | ||
self.features_indices.append(features) | ||
|
||
return self | ||
|
||
def jax_mode_row_vectorized(self, data): | ||
label_list = jnp.array(self.label_list) | ||
|
||
data_expanded = jnp.expand_dims(data, axis=-1) | ||
label_expanded = jnp.expand_dims(label_list, axis=0) | ||
|
||
mask = (data_expanded == label_expanded).astype(jnp.int32) | ||
|
||
counts = jnp.sum(mask, axis=1) | ||
mode_indices = jnp.argmax(counts, axis=1) | ||
|
||
modes = label_list[mode_indices] | ||
return modes | ||
|
||
def predict(self, X): | ||
predictions_list = [] | ||
for i, tree in enumerate(self.trees): | ||
features = self.features_indices[i] | ||
predictions = tree.predict(X[:, features]) | ||
predictions_list.append(predictions) | ||
|
||
tree_predictions = jnp.array(predictions_list).T | ||
|
||
y_pred = self.jax_mode_row_vectorized(tree_predictions) | ||
|
||
return y_pred.ravel() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.