Skip to content

Commit

Permalink
add kwargs in fit method
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk committed Oct 5, 2024
1 parent 237dec9 commit 67d3f2b
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions mqboost/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,16 @@ def fit(
self,
dataset: MQDataset,
eval_set: MQDataset | None = None,
**kwargs,
) -> None:
"""
Fit the regressor to the dataset.
Args:
dataset (MQDataset): The dataset to fit the model on.
eval_set (Optional[MQDataset]):
The validation dataset. If None, the dataset is used for evaluation.
**kwargs:
train parameters.
"""
if eval_set:
_eval_set = eval_set.dtrain
Expand Down Expand Up @@ -92,6 +95,7 @@ def fit(
params=params,
feval=self._MQObj.feval,
valid_sets=[_eval_set],
**kwargs,
)
elif self.__is_xgb:
self.model = xgb.train(
Expand All @@ -101,6 +105,7 @@ def fit(
obj=self._MQObj.fobj,
custom_metric=self._MQObj.feval,
evals=[(_eval_set, "eval")],
**kwargs,
)
self._colnames = dataset.columns.to_list()
self._fitted = True
Expand Down

0 comments on commit 67d3f2b

Please sign in to comment.