Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk committed Sep 10, 2024
1 parent e67d5e0 commit 48ccbc1
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,21 @@
<a href="https://github.com/RektPunk/MQBoost/releases/latest">
<img alt="release" src="https://img.shields.io/github/v/release/RektPunk/mqboost.svg">
</a>
<a href="https://pypi.org/project/MQBoost">
<img alt="Pythonv" src="https://img.shields.io/pypi/pyversions/MQBoost.svg?logo=python&logoColor=white">
</a>
<a href="https://github.com/RektPunk/MQBoost/blob/main/LICENSE">
<img alt="License" src="https://img.shields.io/github/license/RektPunk/MQboost.svg">
</a>
<a href="https://github.com/RektPunk/MQBoost/actions/workflows/lint.yaml">
<img alt="Lint" src="https://github.com/RektPunk/MQBoost/actions/workflows/lint.yaml/badge.svg?branch=main">
</a>
<a href="https://github.com/RektPunk/MQBoost/actions/workflows/test.yaml">
<img alt="Test" src="https://github.com/RektPunk/MQBoost/actions/workflows/test.yaml/badge.svg?branch=main">
</a>
</p>


<!-- <a href="LICENSE">
<img alt="license" src="https://img.shields.io/badge/license-MIT-indigo.sv">
</a> -->
Expand Down
9 changes: 4 additions & 5 deletions mqboost/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,16 @@ class MQDataset:
It supports both LightGBM and XGBoost models, handling data preparation, validation, and conversion for training and prediction.
Attributes:
alphas (list[float]):
alphas (list[float] | float):
List of quantile levels.
Must be in ascending order and contain no duplicates.
data (pd.DataFrame): The input features.
label (pd.DataFrame): The target labels (if provided).
model (ModelName): The model type (LightGBM or XGBoost).
data (pd.DataFrame | pd.Series | np.ndarray): The input features.
label (pd.Series | np.ndarray): The target labels (if provided).
model (str): The model type (LightGBM or XGBoost).
Property:
train_dtype: Returns the data type function for training data.
predict_dtype: Returns the data type function for prediction data.
model: Returns the model type.
columns: Returns the column names of the input features.
nrow: Returns the number of rows in the dataset.
data: Returns the input features.
Expand Down
13 changes: 6 additions & 7 deletions mqboost/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ class MQOptimizer:
The objective function for the quantile regression ('check', 'huber', or 'phuber'). Default is 'check'.
delta (float): Delta parameter for the 'huber' objective function. Default is 0.01.
epsilon (float): Epsilon parameter for the 'apptox' objective function. Default is 1e-5.
get_params (Callable): Function to get hyperparameters for the model.
Methods:
optimize_params(dataset, n_trials, get_params_func, valid_set):
Expand Down Expand Up @@ -116,20 +115,20 @@ def optimize_params(
n_trials (int): The number of trials for the hyperparameter optimization.
get_params_func (Callable, optional): A custom function to get the parameters for the model.
For example,
def get_params(trial: Trial, model: ModelName):
def get_params(trial: Trial):
return {
"learning_rate": trial.suggest_float("learning_rate", 1e-2, 1.0, log=True),
"learning_rate": trial.suggest_float("learning_rate", 1e-2, 1.0),
"max_depth": trial.suggest_int("max_depth", 1, 10),
"lambda_l1": trial.suggest_float("lambda_l1", 1e-8, 10.0, log=True),
"lambda_l2": trial.suggest_float("lambda_l2", 1e-8, 10.0, log=True),
"lambda_l1": trial.suggest_float("lambda_l1", 1e-8, 10.0),
"lambda_l2": trial.suggest_float("lambda_l2", 1e-8, 10.0),
"num_leaves": trial.suggest_int("num_leaves", 2, 256),
"feature_fraction": trial.suggest_float("feature_fraction", 0.4, 1.0),
"bagging_fraction": trial.suggest_float("bagging_fraction", 0.4, 1.0),
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
}
valid_set (Optional[MQDataset], optional): The validation dataset. Defaults to None.
valid_set (MQDataset, optional): The validation dataset. Defaults to None.
Returns:
Dict[str, Any]: The best hyperparameters found by the optimization process.
dict[str, Any]: The best hyperparameters found by the optimization process.
"""
self._dataset = dataset
self._MQObj = MQObjective(
Expand Down

0 comments on commit 48ccbc1

Please sign in to comment.