diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 3510dea7fe92..0153a6370343 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -449,7 +449,7 @@ def _predict_part( # dask.DataFrame.map_partitions() expects each call to return a pandas DataFrame or Series if isinstance(part, pd_DataFrame): - if pred_proba or pred_contrib or pred_leaf: + if len(result.shape) == 2: result = pd_DataFrame(result, index=part.index) else: result = pd_Series(result, index=part.index, name='predictions') @@ -510,10 +510,6 @@ def _predict( **kwargs ).values elif isinstance(data, dask_Array): - if pred_proba: - kwargs['chunks'] = (data.chunks[0], (model.n_classes_,)) - else: - kwargs['drop_axis'] = 1 return data.map_blocks( _predict_part, model=model, @@ -522,7 +518,7 @@ def _predict( pred_leaf=pred_leaf, pred_contrib=pred_contrib, dtype=dtype, - **kwargs + drop_axis=1 ) else: raise TypeError('Data must be either Dask Array or Dask DataFrame. Got %s.' % str(type(data))) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index d4434d68d503..b2d915e55987 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1265,3 +1265,43 @@ def test_parameters_default_constructible(estimator): else: Estimator = estimator.__class__ sklearn_checks.check_parameters_default_constructible(name, Estimator) + + +@pytest.mark.parametrize('task', tasks) +@pytest.mark.parametrize('output', data_output) +def test_predict_with_raw_score(task, output, client): + if task == 'ranking' and output == 'scipy_csr_matrix': + pytest.skip('LGBMRanker is not currently tested on sparse matrices') + + _, _, _, _, dX, dy, _, dg = _create_data( + objective=task, + output=output, + group=None + ) + + model_factory = task_to_dask_factory[task] + params = { + 'client': client, + 'n_estimators': 1, + 'num_leaves': 2, + 'time_out': 5, + 'min_sum_hessian': 0 + } + model = model_factory(**params) + model.fit(dX, dy, group=dg) + raw_predictions = model.predict(dX, raw_score=True).compute() + + trees_df = model.booster_.trees_to_dataframe() + leaves_df = trees_df[trees_df.node_depth == 2] + if task == 'multiclass-classification': + for i in range(model.n_classes_): + class_df = leaves_df[leaves_df.tree_index == i] + assert set(raw_predictions[:, i]) == set(class_df['value']) + else: + assert set(raw_predictions) == set(leaves_df['value']) + + if task.endswith('classification'): + pred_proba_raw = model.predict_proba(dX, raw_score=True).compute() + assert_eq(raw_predictions, pred_proba_raw) + + client.close(timeout=CLIENT_CLOSE_TIMEOUT)