Skip to content

Commit

Permalink
[dask] Include support for raw_score in predict (fixes #3793) (#4024)
Browse files Browse the repository at this point in the history
* include test for prediction with raw_score

* close client

* initial comments

* update data creation and include ranking task

* linting

* update _create_data

* compare unique raw_predictions with values in leaves_df
  • Loading branch information
jmoralez authored Mar 27, 2021
1 parent 8cc6eef commit fe1b80a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
8 changes: 2 additions & 6 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def _predict_part(

# dask.DataFrame.map_partitions() expects each call to return a pandas DataFrame or Series
if isinstance(part, pd_DataFrame):
if pred_proba or pred_contrib or pred_leaf:
if len(result.shape) == 2:
result = pd_DataFrame(result, index=part.index)
else:
result = pd_Series(result, index=part.index, name='predictions')
Expand Down Expand Up @@ -510,10 +510,6 @@ def _predict(
**kwargs
).values
elif isinstance(data, dask_Array):
if pred_proba:
kwargs['chunks'] = (data.chunks[0], (model.n_classes_,))
else:
kwargs['drop_axis'] = 1
return data.map_blocks(
_predict_part,
model=model,
Expand All @@ -522,7 +518,7 @@ def _predict(
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
dtype=dtype,
**kwargs
drop_axis=1
)
else:
raise TypeError('Data must be either Dask Array or Dask DataFrame. Got %s.' % str(type(data)))
Expand Down
40 changes: 40 additions & 0 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,3 +1265,43 @@ def test_parameters_default_constructible(estimator):
else:
Estimator = estimator.__class__
sklearn_checks.check_parameters_default_constructible(name, Estimator)


@pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output)
def test_predict_with_raw_score(task, output, client):
if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices')

_, _, _, _, dX, dy, _, dg = _create_data(
objective=task,
output=output,
group=None
)

model_factory = task_to_dask_factory[task]
params = {
'client': client,
'n_estimators': 1,
'num_leaves': 2,
'time_out': 5,
'min_sum_hessian': 0
}
model = model_factory(**params)
model.fit(dX, dy, group=dg)
raw_predictions = model.predict(dX, raw_score=True).compute()

trees_df = model.booster_.trees_to_dataframe()
leaves_df = trees_df[trees_df.node_depth == 2]
if task == 'multiclass-classification':
for i in range(model.n_classes_):
class_df = leaves_df[leaves_df.tree_index == i]
assert set(raw_predictions[:, i]) == set(class_df['value'])
else:
assert set(raw_predictions) == set(leaves_df['value'])

if task.endswith('classification'):
pred_proba_raw = model.predict_proba(dX, raw_score=True).compute()
assert_eq(raw_predictions, pred_proba_raw)

client.close(timeout=CLIENT_CLOSE_TIMEOUT)

0 comments on commit fe1b80a

Please sign in to comment.