Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] Accept user defined tracker address. #5408

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions demo/dask/cpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from dask.distributed import Client
from dask.distributed import LocalCluster
from dask import array as da
import logging


def main(client):
# generate some random data for demonstration
logging.basicConfig(level=logging.INFO)
m = 100000
n = 100
X = da.random.random(size=(m, n), chunks=100)
Expand Down
2 changes: 1 addition & 1 deletion doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ Parameters for Tree Booster
See tutorial for more information

Additional parameters for `hist` and 'gpu_hist' tree method
================================================
===========================================================

* ``single_precision_histogram``, [default=``false``]

Expand Down
4 changes: 2 additions & 2 deletions doc/tutorials/saving_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ or in R:

Will print out something similiar to (not actual output as it's too long for demonstration):

.. code-block:: json
.. code-block:: js

{
"Learner": {
Expand Down Expand Up @@ -201,7 +201,7 @@ Difference between saving model and dumping model
XGBoost has a function called ``dump_model`` in Booster object, which lets you to export
the model in a readable format like ``text``, ``json`` or ``dot`` (graphviz). The primary
use case for it is for model interpretation or visualization, and is not supposed to be
loaded back to XGBoost. The JSON version has a `schema
loaded back to XGBoost. The JSON version has a `Schema
<https://github.com/dmlc/xgboost/blob/master/doc/dump.schema>`_. See next section for
more info.

Expand Down
47 changes: 39 additions & 8 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
from .sklearn import xgboost_model_doc

try:
from distributed import Client
except ImportError:
Client = None

# Current status is considered as initial support, many features are
# not properly supported yet.
#
Expand All @@ -47,10 +52,14 @@
LOGGER = logging.getLogger('[xgboost.dask]')


def _start_tracker(host, n_workers):
def _start_tracker(host, port, n_workers):
"""Start Rabit tracker """
env = {'DMLC_NUM_WORKER': n_workers}
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
if port:
rabit_context = RabitTracker(hostIP=host, port=port, port_end=port+1,
nslave=n_workers)
else:
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
env.update(rabit_context.slave_envs())

rabit_context.start(n_workers)
Expand Down Expand Up @@ -357,12 +366,20 @@ def get_worker_data_shape(self, worker):
return (rows, cols)


def _get_rabit_args(worker_map, client):
def _get_rabit_args(worker_map, client: Client, host_ip=None, port=None):
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
host = distributed_comm.get_address_host(client.scheduler.address)
msg = 'Please provide both IP and port'
assert (host_ip and port) or (host_ip is None and port is None), msg

env = client.run_on_scheduler(_start_tracker, host.strip('/:'),
len(worker_map))
if host_ip:
LOGGER.info('Running tracker on: %s, %s', host_ip, str(port))
env = client.run_on_scheduler(_start_tracker, host_ip, port,
len(worker_map))
else:
host = distributed_comm.get_address_host(client.scheduler.address)
LOGGER.info('Running tracker on: %s', host.strip('/:'))
env = client.run_on_scheduler(_start_tracker, host.strip('/:'), port,
len(worker_map))
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
return rabit_args

Expand All @@ -373,7 +390,8 @@ def _get_rabit_args(worker_map, client):
# evaluation history is instead returned.


def train(client, params, dtrain, *args, evals=(), **kwargs):
def train(client, params, dtrain, *args, evals=(), tracker_ip=None,
tracker_port=None, **kwargs):
'''Train XGBoost model.

.. versionadded:: 1.0.0
Expand All @@ -383,6 +401,19 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
client: dask.distributed.Client
Specify the dask client used for training. Use default client
returned from dask if it's set to None.

tracker_ip:
Address for rabit tracker that runs on dask scheduler. Use
`client.scheduler.address` if unspecified.

.. versionadded:: 1.2.0

tracker_port:
Port for the tracker. Search for available ports automatically if
unspecified.

.. versionadded:: 1.2.0

\\*\\*kwargs:
Other parameters are the same as `xgboost.train` except for
`evals_result`, which is returned as part of function return value
Expand Down Expand Up @@ -410,7 +441,7 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):

workers = list(_get_client_workers(client).keys())

rabit_args = _get_rabit_args(workers, client)
rabit_args = _get_rabit_args(workers, client, tracker_ip, tracker_port)

def dispatched_train(worker_addr):
'''Perform training on a single worker.'''
Expand Down
24 changes: 24 additions & 0 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dask.dataframe as dd
import dask.array as da
from xgboost.dask import DaskDMatrix
from dask.distributed import comm
except ImportError:
LocalCluster = None
Client = None
Expand Down Expand Up @@ -286,3 +287,26 @@ def test_empty_dmatrix_approx():
with Client(cluster) as client:
parameters = {'tree_method': 'approx'}
run_empty_dmatrix(client, parameters)


def test_explicit_rabit_tracker():
with LocalCluster() as cluster:
with Client(cluster) as client:
X, y = generate_array()
host = comm.get_address_host(client.scheduler.address)
port = 9091
dtrain = xgb.dask.DaskDMatrix(client, X, y)

out = xgb.dask.train(client, {'tree_method': 'hist'}, dtrain,
tracker_ip=host, tracker_port=port)
prediction = xgb.dask.predict(client, out, dtrain)
assert prediction.shape[0] == kRows

assert isinstance(prediction, da.Array)
prediction = prediction.compute()

booster = out['booster']
single_node_predt = booster.predict(
xgb.DMatrix(X.compute())
)
np.testing.assert_allclose(prediction, single_node_predt)