Skip to content

Commit

Permalink
use toarray()
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Jul 6, 2021
1 parent 466917e commit 57edeaa
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,14 +409,14 @@ def test_classifier_pred_contrib(output, task, cluster):
assert computed_preds.shape == local_preds_with_contrib[i].shape
assert len(np.unique(computed_preds[:, -1])) == 1
# raw scores will probably be different, but at least check that all predicted classes are the same
pred_classes = np.argmax(np.asarray(computed_preds.todense()), axis=1)
local_pred_classes = np.argmax(np.asarray(local_preds_with_contrib[i].todense()), axis=1)
pred_classes = np.argmax(computed_preds.toarray(), axis=1)
local_pred_classes = np.argmax(local_preds_with_contrib[i].toarray(), axis=1)
np.testing.assert_array_equal(pred_classes, local_pred_classes)
return

preds_with_contrib = preds_with_contrib.compute()
if output.startswith('scipy'):
preds_with_contrib = np.asarray(preds_with_contrib.todense())
preds_with_contrib = preds_with_contrib.toarray()

# be sure LightGBM actually used at least one categorical column,
# and that it was correctly treated as a categorical feature
Expand All @@ -439,7 +439,7 @@ def test_classifier_pred_contrib(output, task, cluster):
assert preds_with_contrib.shape == local_preds_with_contrib.shape

if num_classes == 2:
assert np.unique(preds_with_contrib[:, num_features]).shape[0] == 1
assert len(np.unique(preds_with_contrib[:, num_features])) == 1
else:
for i in range(num_classes):
base_value_col = num_features * (i + 1) + i
Expand Down Expand Up @@ -621,7 +621,7 @@ def test_regressor_pred_contrib(output, cluster):
local_preds_with_contrib = local_regressor.predict(X, pred_contrib=True)

if output == "scipy_csr_matrix":
preds_with_contrib = np.asarray(preds_with_contrib.todense())
preds_with_contrib = preds_with_contrib.toarray()

# contrib outputs for distributed training are different than from local training, so we can just test
# that the output has the right shape and base values are in the right position
Expand Down

0 comments on commit 57edeaa

Please sign in to comment.