Skip to content

Commit

Permalink
Merge pull request #80 from jinlow/feature/more-sklearn
Browse files Browse the repository at this point in the history
Feature/more sklearn
  • Loading branch information
jinlow authored Oct 9, 2023
2 parents 91318f1 + 38cb5f0 commit 60c3030
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "forust-ml"
version = "0.3.1"
version = "0.3.2"
edition = "2021"
authors = ["James Inlow <james.d.inlow@gmail.com>"]
homepage = "https://github.com/jinlow/forust"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pip install forust

To use in a rust project add the following to your Cargo.toml file.
```toml
forust-ml = "0.3.1"
forust-ml = "0.3.2"
```

## Usage
Expand Down
4 changes: 2 additions & 2 deletions py-forust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "py-forust"
version = "0.3.1"
version = "0.3.2"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -10,7 +10,7 @@ crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.19.0", features = ["extension-module"] }
forust-ml = { version = "0.3.1", path = "../" }
forust-ml = { version = "0.3.2", path = "../" }
numpy = "0.19.0"
ndarray = "0.15.1"
serde_plain = { version = "1.0" }
Expand Down
24 changes: 22 additions & 2 deletions py-forust/forust/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def fit(
| list[
tuple[FrameLike, ArrayLike, ArrayLike] | tuple[FrameLike, ArrayLike]
] = None,
):
) -> GradientBooster:
"""Fit the gradient booster on a provided dataset.
Args:
Expand Down Expand Up @@ -507,6 +507,7 @@ def fit(
# Once it's been fit, reset the `base_score`
# this will account for the fact that's it's adjusted after fit.
self.base_score = self.booster.base_score
return self

def _validate_features(self, features: list[str]):
if len(features) > 0 and hasattr(self, "feature_names_in_"):
Expand Down Expand Up @@ -934,6 +935,25 @@ def get_best_iteration(self) -> int | None:

# Functions for scikit-learn compatibility, will feel out adding these manually,
# and then if that feels too unwieldy will add scikit-learn as a dependency.
def get_params(self, deep=True):
def get_params(self, deep=True) -> dict[str, Any]:
"""Get all of the parameters for the booster.
Args:
deep (bool, optional): This argument does nothing, and is simply here for scikit-learn compatibility.. Defaults to True.
Returns:
dict[str, Any]: The parameters of the booster.
"""
args = inspect.getfullargspec(GradientBooster).kwonlyargs
return {param: getattr(self, param) for param in args}

def set_params(self, **params: Any) -> GradientBooster:
"""Set the parameters of the booster, this has the same effect as reinstating the booster.
Returns:
GradientBooster: Booster with new parameters.
"""
old_params = self.get_params()
old_params.update(params)
GradientBooster.__init__(self, **old_params)
return self
33 changes: 33 additions & 0 deletions py-forust/tests/test_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
from sklearn.base import clone
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import GridSearchCV
from xgboost import XGBClassifier, XGBRegressor

import forust
Expand Down Expand Up @@ -1384,3 +1385,35 @@ def check_missing_is_average(tree: dict, n: int, learning_rate: float = 0.3):
with pytest.raises(AssertionError):
for tree in json.loads(fmod.json_dump())["trees"]:
check_missing_is_average(tree["nodes"], 0)


def test_get_params(X_y):
X, y = X_y
r = 0.00001
fmod = GradientBooster(learning_rate=r)
assert fmod.get_params()["learning_rate"] == r
fmod.fit(X, y)
assert fmod.get_params()["learning_rate"] == r


def test_set_params(X_y):
X, y = X_y
r = 0.00001
fmod = GradientBooster()
assert fmod.get_params()["learning_rate"] != r
assert fmod.set_params(learning_rate=r)
assert fmod.get_params()["learning_rate"] == r
fmod.fit(X, y)


def test_compat_gridsearch(X_y):
X, y = X_y
fmod = GradientBooster()
parameters = {"learning_rate": [0.1, 0.03], "subsample": [1.0, 0.8]}
clf = GridSearchCV(
fmod,
parameters,
scoring=lambda est, X, y: roc_auc_score(y, est.predict(X)),
)
clf.fit(X, y)
assert len(clf.cv_results_["mean_test_score"]) > 0
2 changes: 1 addition & 1 deletion rs-example.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
To run this example, add the following code to your `Cargo.toml` file.
```toml
[dependencies]
forust-ml = "0.3.1"
forust-ml = "0.3.2"
polars = "0.28"
reqwest = { version = "0.11", features = ["blocking"] }
```
Expand Down

0 comments on commit 60c3030

Please sign in to comment.