Skip to content

Commit

Permalink
[dask] Use client to persist collections (#6722)
Browse files Browse the repository at this point in the history

Co-authored-by: fis <jm.yuan@outlook.com>
  • Loading branch information
jose-moralez and trivialfis authored Feb 25, 2021
1 parent 9b530e5 commit b6167cd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
4 changes: 2 additions & 2 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,11 @@ def check_columns(parts: Any) -> None:
' of columns for your dask Array explicitly. e.g.' \
' chunks=(partition_size, X.shape[1])'

data = data.persist()
data = client.persist(data)
for meta in [label, weights, base_margin, label_lower_bound,
label_upper_bound]:
if meta is not None:
meta = meta.persist()
meta = client.persist(meta)
# Breaking data into partitions, a trick borrowed from dask_xgboost.

# `to_delayed` downgrades high-level objects into numpy or pandas
Expand Down
23 changes: 20 additions & 3 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> xgb.dask.TrainR
await client.compute(with_X))
np.testing.assert_allclose(await client.compute(with_m),
await client.compute(inplace))

client.shutdown()
return output
return output


async def run_dask_regressor_asyncio(scheduler_address: str) -> None:
Expand Down Expand Up @@ -647,6 +645,25 @@ def test_with_asyncio() -> None:
asyncio.run(run_dask_classifier_asyncio(address))


async def generate_concurrent_trainings() -> None:
async def train():
async with LocalCluster(n_workers=2,
threads_per_worker=1,
asynchronous=True,
dashboard_address=0) as cluster:
async with Client(cluster, asynchronous=True) as client:
X, y, w = generate_array(with_weights=True)
dtrain = await DaskDMatrix(client, X, y, weight=w)
dvalid = await DaskDMatrix(client, X, y, weight=w)
output = await xgb.dask.train(client, {}, dtrain=dtrain)
await xgb.dask.predict(client, output, data=dvalid)
await asyncio.gather(train(), train())


def test_concurrent_trainings() -> None:
asyncio.run(generate_concurrent_trainings())


def test_predict(client: "Client") -> None:
X, y, _ = generate_array()
dtrain = DaskDMatrix(client, X, y)
Expand Down

0 comments on commit b6167cd

Please sign in to comment.