Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] Change document to avoid using default import. #9742

Merged
merged 1 commit into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions demo/dask/cpu_survival.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import dask.dataframe as dd
from dask.distributed import Client, LocalCluster

import xgboost as xgb
from xgboost import dask as dxgb
from xgboost.dask import DaskDMatrix


Expand Down Expand Up @@ -48,14 +48,14 @@ def main(client):
"lambda": 0.01,
"alpha": 0.02,
}
output = xgb.dask.train(
output = dxgb.train(
client, params, dtrain, num_boost_round=100, evals=[(dtrain, "train")]
)
bst = output["booster"]
history = output["history"]

# you can pass output directly into `predict` too.
prediction = xgb.dask.predict(client, bst, dtrain)
prediction = dxgb.predict(client, bst, dtrain)
print("Evaluation history: ", history)

# Uncomment the following line to save the model to the disk
Expand Down
6 changes: 3 additions & 3 deletions demo/dask/cpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dask import array as da
from dask.distributed import Client, LocalCluster

import xgboost as xgb
from xgboost import dask as dxgb
from xgboost.dask import DaskDMatrix


Expand All @@ -25,7 +25,7 @@ def main(client):
# distributed version of train returns a dictionary containing the
# resulting booster and evaluation history obtained from
# evaluation metrics.
output = xgb.dask.train(
output = dxgb.train(
client,
{"verbosity": 1, "tree_method": "hist"},
dtrain,
Expand All @@ -36,7 +36,7 @@ def main(client):
history = output["history"]

# you can pass output directly into `predict` too.
prediction = xgb.dask.predict(client, bst, dtrain)
prediction = dxgb.predict(client, bst, dtrain)
print("Evaluation history:", history)
return prediction

Expand Down
3 changes: 2 additions & 1 deletion demo/dask/dask_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dask_ml.model_selection import train_test_split

import xgboost as xgb
import xgboost.dask as dxgb
from xgboost.dask import DaskDMatrix


Expand Down Expand Up @@ -61,7 +62,7 @@ def main(client):
dtrain = DaskDMatrix(client, X_train, y_train)
dtest = DaskDMatrix(client, X_test, y_test)

output = xgb.dask.train(
output = dxgb.train(
client,
{
"verbosity": 1,
Expand Down
9 changes: 4 additions & 5 deletions demo/dask/gpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from dask.distributed import Client
from dask_cuda import LocalCUDACluster

import xgboost as xgb
from xgboost import dask as dxgb
from xgboost.dask import DaskDMatrix

Expand All @@ -21,7 +20,7 @@ def using_dask_matrix(client: Client, X: da.Array, y: da.Array) -> da.Array:
# Use train method from xgboost.dask instead of xgboost. This distributed version
# of train returns a dictionary containing the resulting booster and evaluation
# history obtained from evaluation metrics.
output = xgb.dask.train(
output = dxgb.train(
client,
{
"verbosity": 2,
Expand All @@ -37,7 +36,7 @@ def using_dask_matrix(client: Client, X: da.Array, y: da.Array) -> da.Array:
history = output["history"]

# you can pass output directly into `predict` too.
prediction = xgb.dask.predict(client, bst, dtrain)
prediction = dxgb.predict(client, bst, dtrain)
print("Evaluation history:", history)
return prediction

Expand All @@ -56,14 +55,14 @@ def using_quantile_device_dmatrix(client: Client, X: da.Array, y: da.Array) -> d
# be used for anything else other than training unless a reference is specified. See
# the `ref` argument of `DaskQuantileDMatrix`.
dtrain = dxgb.DaskQuantileDMatrix(client, X, y)
output = xgb.dask.train(
output = dxgb.train(
client,
{"verbosity": 2, "tree_method": "hist", "device": "cuda"},
dtrain,
num_boost_round=4,
)

prediction = xgb.dask.predict(client, output, X)
prediction = dxgb.predict(client, output, X)
return prediction


Expand Down
4 changes: 2 additions & 2 deletions demo/dask/sklearn_cpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dask import array as da
from dask.distributed import Client, LocalCluster

import xgboost
from xgboost import dask as dxgb


def main(client):
Expand All @@ -16,7 +16,7 @@ def main(client):
X = da.random.random((m, n), partition_size)
y = da.random.random(m, partition_size)

regressor = xgboost.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
regressor = dxgb.DaskXGBRegressor(verbosity=1, n_estimators=2)
regressor.set_params(tree_method="hist")
# assigning client here is optional
regressor.client = client
Expand Down
4 changes: 2 additions & 2 deletions demo/dask/sklearn_gpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# It's recommended to use dask_cuda for GPU assignment
from dask_cuda import LocalCUDACluster

import xgboost
from xgboost import dask as dxgb


def main(client):
Expand All @@ -20,7 +20,7 @@ def main(client):
X = da.random.random((m, n), partition_size)
y = da.random.random(m, partition_size)

regressor = xgboost.dask.DaskXGBRegressor(verbosity=1)
regressor = dxgb.DaskXGBRegressor(verbosity=1)
# set the device to CUDA
regressor.set_params(tree_method="hist", device="cuda")
# assigning client here is optional
Expand Down
67 changes: 35 additions & 32 deletions doc/tutorials/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ on a dask cluster:

.. code-block:: python

import xgboost as xgb
from xgboost import dask as dxgb

import dask.array as da
import dask.distributed

Expand All @@ -53,11 +54,11 @@ on a dask cluster:
X = da.random.random(size=(num_obs, num_features), chunks=(1000, num_features))
y = da.random.random(size=(num_obs, 1), chunks=(1000, 1))

dtrain = xgb.dask.DaskDMatrix(client, X, y)
dtrain = dxgb.DaskDMatrix(client, X, y)
# or
# dtrain = xgb.dask.DaskQuantileDMatrix(client, X, y)
# dtrain = dxgb.DaskQuantileDMatrix(client, X, y)

output = xgb.dask.train(
output = dxgb.train(
client,
{"verbosity": 2, "tree_method": "hist", "objective": "reg:squarederror"},
dtrain,
Expand Down Expand Up @@ -87,25 +88,27 @@ returns a model and the computation history as a Python dictionary:

.. code-block:: python

{'booster': Booster,
'history': dict}
{
"booster": Booster,
"history": dict,
}

For prediction, pass the ``output`` returned by ``train`` into :py:func:`xgboost.dask.predict`:

.. code-block:: python

prediction = xgb.dask.predict(client, output, dtrain)
prediction = dxgb.predict(client, output, dtrain)
# Or equivalently, pass ``output['booster']``:
prediction = xgb.dask.predict(client, output['booster'], dtrain)
prediction = dxgb.predict(client, output['booster'], dtrain)

Eliminating the construction of DaskDMatrix is also possible, this can make the
computation a bit faster when meta information like ``base_margin`` is not needed:

.. code-block:: python

prediction = xgb.dask.predict(client, output, X)
prediction = dxgb.predict(client, output, X)
# Use inplace version.
prediction = xgb.dask.inplace_predict(client, output, X)
prediction = dxgb.inplace_predict(client, output, X)

Here ``prediction`` is a dask ``Array`` object containing predictions from model if input
is a ``DaskDMatrix`` or ``da.Array``. When putting dask collection directly into the
Expand Down Expand Up @@ -134,22 +137,22 @@ both memory usage and prediction time.
.. code-block:: python

# dtrain is the DaskDMatrix defined above.
prediction = xgb.dask.predict(client, booster, dtrain)
prediction = dxgb.predict(client, booster, dtrain)

or equivalently:

.. code-block:: python

# where X is a dask DataFrame or dask Array.
prediction = xgb.dask.predict(client, booster, X)
prediction = dxgb.predict(client, booster, X)

Also for inplace prediction:

.. code-block:: python

# where X is a dask DataFrame or dask Array backed by cupy or cuDF.
booster.set_param({"device": "cuda"})
prediction = xgb.dask.inplace_predict(client, booster, X)
prediction = dxgb.inplace_predict(client, booster, X)

When input is ``da.Array`` object, output is always ``da.Array``. However, if the input
type is ``dd.DataFrame``, output can be ``dd.Series``, ``dd.DataFrame`` or ``da.Array``,
Expand All @@ -174,7 +177,7 @@ One simple optimization for running consecutive predictions is using
futures = []
for X in dataset:
# Here we pass in a future instead of concrete booster
shap_f = xgb.dask.predict(client, booster_f, X, pred_contribs=True)
shap_f = dxgb.predict(client, booster_f, X, pred_contribs=True)
futures.append(shap_f)

results = client.gather(futures)
Expand All @@ -186,7 +189,7 @@ Scikit-Learn wrapper object:

.. code-block:: python

cls = xgb.dask.DaskXGBClassifier()
cls = dxgb.DaskXGBClassifier()
cls.fit(X, y)

booster = cls.get_booster()
Expand All @@ -207,12 +210,12 @@ collection.
.. code-block:: python

from distributed import LocalCluster, Client
import xgboost as xgb
from xgboost import dask as dxgb


def main(client: Client) -> None:
X, y = load_data()
clf = xgb.dask.DaskXGBClassifier(n_estimators=100, tree_method="hist")
clf = dxgb.DaskXGBClassifier(n_estimators=100, tree_method="hist")
clf.client = client # assign the client
clf.fit(X, y, eval_set=[(X, y)])
proba = clf.predict_proba(X)
Expand Down Expand Up @@ -242,7 +245,7 @@ In the example below, a ``KubeCluster`` is used for `deploying Dask on Kubernete

from dask_kubernetes import KubeCluster # Need to install the ``dask-kubernetes`` package
from dask.distributed import Client
import xgboost as xgb
from xgboost import dask as dxgb
import dask
import dask.array as da

Expand All @@ -265,7 +268,7 @@ In the example below, a ``KubeCluster`` is used for `deploying Dask on Kubernete
X = da.random.random(size=(m, n), chunks=100)
y = da.random.random(size=(m, ), chunks=100)

regressor = xgb.dask.DaskXGBRegressor(n_estimators=10, missing=0.0)
regressor = dxgb.DaskXGBRegressor(n_estimators=10, missing=0.0)
regressor.client = client
regressor.set_params(tree_method='hist', device="cuda")
regressor.fit(X, y, eval_set=[(X, y)])
Expand Down Expand Up @@ -298,7 +301,7 @@ threads in each process for training. But if ``nthread`` parameter is set:

.. code-block:: python

output = xgb.dask.train(
output = dxgb.train(
client,
{"verbosity": 1, "nthread": 8, "tree_method": "hist"},
dtrain,
Expand Down Expand Up @@ -330,12 +333,12 @@ Functional interface:

async with dask.distributed.Client(scheduler_address, asynchronous=True) as client:
X, y = generate_array()
m = await xgb.dask.DaskDMatrix(client, X, y)
output = await xgb.dask.train(client, {}, dtrain=m)
m = await dxgb.DaskDMatrix(client, X, y)
output = await dxgb.train(client, {}, dtrain=m)

with_m = await xgb.dask.predict(client, output, m)
with_X = await xgb.dask.predict(client, output, X)
inplace = await xgb.dask.inplace_predict(client, output, X)
with_m = await dxgb.predict(client, output, m)
with_X = await dxgb.predict(client, output, X)
inplace = await dxgb.inplace_predict(client, output, X)

# Use ``client.compute`` instead of the ``compute`` method from dask collection
print(await client.compute(with_m))
Expand All @@ -349,7 +352,7 @@ actual computation will return a coroutine and hence require awaiting:

async with dask.distributed.Client(scheduler_address, asynchronous=True) as client:
X, y = generate_array()
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
regressor = await dxgb.DaskXGBRegressor(verbosity=1, n_estimators=2)
regressor.set_params(tree_method='hist') # trivial method, synchronous operation
regressor.client = client # accessing attribute, synchronous operation
regressor = await regressor.fit(X, y, eval_set=[(X, y)])
Expand All @@ -371,7 +374,7 @@ To enable early stopping, pass one or more validation sets containing ``DaskDMat
.. code-block:: python

import dask.array as da
import xgboost as xgb
from xgboost import dask as dxgb

num_rows = 1e6
num_features = 100
Expand All @@ -398,19 +401,19 @@ To enable early stopping, pass one or more validation sets containing ``DaskDMat
chunks=(rows_per_chunk, 1)
)

dtrain = xgb.dask.DaskDMatrix(
dtrain = dxgb.DaskDMatrix(
client=client,
data=data,
label=labels
)

dvalid = xgb.dask.DaskDMatrix(
dvalid = dxgb.DaskDMatrix(
client=client,
data=X_eval,
label=y_eval
)

result = xgb.dask.train(
result = dxgb.train(
client=client,
params={
"objective": "reg:squarederror",
Expand All @@ -421,7 +424,7 @@ To enable early stopping, pass one or more validation sets containing ``DaskDMat
early_stopping_rounds=3
)

When validation sets are provided to ``xgb.dask.train()`` in this way, the model object returned by ``xgb.dask.train()`` contains a history of evaluation metrics for each validation set, across all boosting rounds.
When validation sets are provided to :py:func:`xgboost.dask.train` in this way, the model object returned by :py:func:`xgboost.dask.train` contains a history of evaluation metrics for each validation set, across all boosting rounds.

.. code-block:: python

Expand Down Expand Up @@ -463,7 +466,7 @@ interface, including callback functions, custom evaluation metric and objective:
save_best=True,
)

booster = xgb.dask.train(
booster = dxgb.train(
client,
params={
"objective": "binary:logistic",
Expand Down
Loading