From b6167cd2ff2284d69af82601d863e0e69b849abf Mon Sep 17 00:00:00 2001 From: capybara <61287624+jose-moralez@users.noreply.github.com> Date: Thu, 25 Feb 2021 02:40:38 -0600 Subject: [PATCH] [dask] Use client to persist collections (#6722) Co-authored-by: fis --- python-package/xgboost/dask.py | 4 ++-- tests/python/test_with_dask.py | 23 ++++++++++++++++++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 26bec393f613..60c7ae29057d 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -303,11 +303,11 @@ def check_columns(parts: Any) -> None: ' of columns for your dask Array explicitly. e.g.' \ ' chunks=(partition_size, X.shape[1])' - data = data.persist() + data = client.persist(data) for meta in [label, weights, base_margin, label_lower_bound, label_upper_bound]: if meta is not None: - meta = meta.persist() + meta = client.persist(meta) # Breaking data into partitions, a trick borrowed from dask_xgboost. # `to_delayed` downgrades high-level objects into numpy or pandas diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 1d531b08537d..dec7bb75c52a 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -563,9 +563,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> xgb.dask.TrainR await client.compute(with_X)) np.testing.assert_allclose(await client.compute(with_m), await client.compute(inplace)) - - client.shutdown() - return output + return output async def run_dask_regressor_asyncio(scheduler_address: str) -> None: @@ -647,6 +645,25 @@ def test_with_asyncio() -> None: asyncio.run(run_dask_classifier_asyncio(address)) +async def generate_concurrent_trainings() -> None: + async def train(): + async with LocalCluster(n_workers=2, + threads_per_worker=1, + asynchronous=True, + dashboard_address=0) as cluster: + async with Client(cluster, asynchronous=True) as client: + X, y, w = generate_array(with_weights=True) + dtrain = await DaskDMatrix(client, X, y, weight=w) + dvalid = await DaskDMatrix(client, X, y, weight=w) + output = await xgb.dask.train(client, {}, dtrain=dtrain) + await xgb.dask.predict(client, output, data=dvalid) + await asyncio.gather(train(), train()) + + +def test_concurrent_trainings() -> None: + asyncio.run(generate_concurrent_trainings()) + + def test_predict(client: "Client") -> None: X, y, _ = generate_array() dtrain = DaskDMatrix(client, X, y)