Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Jan 24, 2021
1 parent 96083fa commit acac78f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr
class _LGBMModel:
def __init__(self):
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('Dask, Pandas and Scikit-learn are required for this module')
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')

def _fit(self, model_factory, X, y=None, sample_weight=None, group=None, client=None, **kwargs):
"""Docstring is inherited from the LGBMModel."""
Expand Down
4 changes: 2 additions & 2 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,15 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
"xe_ndcg", "xe_ndcg_mart", "xendcg_mart"}
for obj_alias in _ConfigAliases.get("objective")):
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for ranking cv.')
raise LightGBMError('scikit-learn is required for ranking cv')
# ranking task, split according to groups
group_info = np.array(full_data.get_group(), dtype=np.int32, copy=False)
flatted_group = np.repeat(range(len(group_info)), repeats=group_info)
group_kfold = _LGBMGroupKFold(n_splits=nfold)
folds = group_kfold.split(X=np.zeros(num_data), groups=flatted_group)
elif stratified:
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for stratified cv.')
raise LightGBMError('scikit-learn is required for stratified cv')
skf = _LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed)
folds = skf.split(X=np.zeros(num_data), y=full_data.get_label())
else:
Expand Down
2 changes: 1 addition & 1 deletion python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
and you should group grad and hess in this way as well.
"""
if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for this module')
raise LightGBMError('scikit-learn is required for lightgbm.sklearn')

self.boosting_type = boosting_type
self.objective = objective
Expand Down

0 comments on commit acac78f

Please sign in to comment.