Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【OSCP】使用SPU实现随机森林算法 #752

Closed
wants to merge 3 commits into from

Conversation

xbw886
Copy link
Contributor

@xbw886 xbw886 commented Jul 1, 2024

Pull Request

What problem does this PR solve?

Issue Number: Fixed 254 #

Possible side effects?

  • Performance:

  • Backward compatibility:

Copy link

github-actions bot commented Jul 1, 2024

All contributors have signed the CLA ✍️ ✅
Posted by the CLA Assistant Lite bot.

@anakinxc anakinxc requested a review from deadlywing July 1, 2024 11:32
@xbw886 xbw886 changed the title Xbw random forest(issue 254) Jul 1, 2024
@xbw886 xbw886 marked this pull request as draft July 1, 2024 13:06
@xbw886 xbw886 marked this pull request as ready for review July 1, 2024 13:15
@xbw886
Copy link
Contributor Author

xbw886 commented Jul 1, 2024

I have read the CLA Document and I hereby sign the CLA

@xbw886 xbw886 changed the title random forest(issue 254) 【OSCP】使用SPU实现随机森林算法 Jul 2, 2024
@@ -0,0 +1,22 @@
# Copyright 2023 Ant Group Co., Ltd.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

麻烦把forest.py移动到ensemble目录内,对应的tests和emulation和build文件等也移动一下哈~

self,
n_estimators,
max_features,
n_features,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_features 建议在 fit 里从数据集获取,max_features的校验或者具体值的计算也可以放到fit里

self.splitter = splitter
self.max_depth = max_depth
self.bootstrap = bootstrap
self.max_samples = max_samples
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同n_features和max_features,具体的校验和值的计算延迟到fit中

bootstrap,
max_samples,
n_labels,
seed,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seed参数不需要

X_sample, y_sample = self._bootstrap_sample(X, y)
features = self._select_features()
# selected_indices = self._shuffle_indices(n_features)
print(y_sample)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要print

for i, tree in enumerate(self.trees):
features = self.features_indices[i]
print(features)
tree_predictions = tree_predictions.at[:, i].set(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

尽量少用.set,这种jax的update方法,会重新copy所有数据,可以先计算好,然后jnp.array一次性得到array

return y_pred.ravel()


def jax_mode_row(data):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用这么复杂,,label取值为0,1,2,... (decision tree的要求),比如二分类,直接统计所有的tree里==0和==1的个数,返回其中大的即可。。(请尽量避免循环,,善用向量化的)


class RandomForestClassifier:
"""A random forest classifier."""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__init__里的超参数需要解释;
以及对数据格式的要求也需要说明(可以参考决策树模型里的一些注释说明信息)

from sml.tree.tree import DecisionTreeClassifier as sml_dtc

# from functools import partial
# from jax import jit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无关注释可以清理一下

Description : 基本完成函数编写的工作,目前测试结果基本正确,后面需要完成emul和test
bootstrap有问题,bootstrap后predict不输出1,bootstrap无1(因为不支持jax.random的api)

!最终:bootstrap这个参数,不可用:在明文下bootstrap取样正确,但在forest_test.py时,无法取到标签1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分说明内容也可以清理一下

@xbw886 xbw886 closed this Jul 27, 2024
@github-actions github-actions bot locked and limited conversation to collaborators Jul 27, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants