Skip to content

Commit

Permalink
[ci] remove output parametrization from two Dask tests (#4123)
Browse files Browse the repository at this point in the history
* Update test_dask.py

* Update test_dask.py
  • Loading branch information
StrikerRUS authored Mar 27, 2021
1 parent e98da99 commit d32ee23
Showing 1 changed file with 6 additions and 16 deletions.
22 changes: 6 additions & 16 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,25 +994,20 @@ def collection_to_single_partition(collection):


@pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output)
def test_network_params_not_required_but_respected_if_given(client, task, output, listen_port):
if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices')

def test_network_params_not_required_but_respected_if_given(client, task, listen_port):
client.wait_for_workers(2)

_, _, _, _, dX, dy, _, dg = _create_data(
objective=task,
output=output,
output='array',
chunk_size=10,
group=None
)

dask_model_factory = task_to_dask_factory[task]

# rebalance data to be sure that each worker has a piece of the data
if output == 'array':
client.rebalance()
client.rebalance()

# model 1 - no network parameters given
dask_model1 = dask_model_factory(
Expand Down Expand Up @@ -1059,24 +1054,19 @@ def test_network_params_not_required_but_respected_if_given(client, task, output


@pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output)
def test_machines_should_be_used_if_provided(task, output):
if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices')

def test_machines_should_be_used_if_provided(task):
with LocalCluster(n_workers=2) as cluster, Client(cluster) as client:
_, _, _, _, dX, dy, _, dg = _create_data(
objective=task,
output=output,
output='array',
chunk_size=10,
group=None
)

dask_model_factory = task_to_dask_factory[task]

# rebalance data to be sure that each worker has a piece of the data
if output == 'array':
client.rebalance()
client.rebalance()

n_workers = len(client.scheduler_info()['workers'])
assert n_workers > 1
Expand Down

0 comments on commit d32ee23

Please sign in to comment.