From e8830c6548c702b64a2cef7ac7e17d739cc1fb30 Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Sat, 27 Mar 2021 23:12:52 +0300 Subject: [PATCH 1/2] Update test_dask.py --- tests/python_package_test/test_dask.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index b2d915e55987..3343e4563677 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -994,16 +994,12 @@ def collection_to_single_partition(collection): @pytest.mark.parametrize('task', tasks) -@pytest.mark.parametrize('output', data_output) -def test_network_params_not_required_but_respected_if_given(client, task, output, listen_port): - if task == 'ranking' and output == 'scipy_csr_matrix': - pytest.skip('LGBMRanker is not currently tested on sparse matrices') - +def test_network_params_not_required_but_respected_if_given(client, task, listen_port): client.wait_for_workers(2) _, _, _, _, dX, dy, _, dg = _create_data( objective=task, - output=output, + output='array', chunk_size=10, group=None ) @@ -1011,8 +1007,7 @@ def test_network_params_not_required_but_respected_if_given(client, task, output dask_model_factory = task_to_dask_factory[task] # rebalance data to be sure that each worker has a piece of the data - if output == 'array': - client.rebalance() + client.rebalance() # model 1 - no network parameters given dask_model1 = dask_model_factory( From 04bb1786ba1bd3e70a0845e6cff883988405f64b Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Sat, 27 Mar 2021 23:14:30 +0300 Subject: [PATCH 2/2] Update test_dask.py --- tests/python_package_test/test_dask.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 3343e4563677..b1f7ff5605a9 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1054,15 +1054,11 @@ def test_network_params_not_required_but_respected_if_given(client, task, listen @pytest.mark.parametrize('task', tasks) -@pytest.mark.parametrize('output', data_output) -def test_machines_should_be_used_if_provided(task, output): - if task == 'ranking' and output == 'scipy_csr_matrix': - pytest.skip('LGBMRanker is not currently tested on sparse matrices') - +def test_machines_should_be_used_if_provided(task): with LocalCluster(n_workers=2) as cluster, Client(cluster) as client: _, _, _, _, dX, dy, _, dg = _create_data( objective=task, - output=output, + output='array', chunk_size=10, group=None ) @@ -1070,8 +1066,7 @@ def test_machines_should_be_used_if_provided(task, output): dask_model_factory = task_to_dask_factory[task] # rebalance data to be sure that each worker has a piece of the data - if output == 'array': - client.rebalance() + client.rebalance() n_workers = len(client.scheduler_info()['workers']) assert n_workers > 1