Skip to content

Commit

Permalink
Add tracker info to training api.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 2, 2020
1 parent eb067c1 commit 94569e6
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 12 deletions.
2 changes: 2 additions & 0 deletions demo/dask/cpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from dask.distributed import Client
from dask.distributed import LocalCluster
from dask import array as da
import logging


def main(client):
# generate some random data for demonstration
logging.basicConfig(level=logging.INFO)
m = 100000
n = 100
X = da.random.random(size=(m, n), chunks=100)
Expand Down
2 changes: 1 addition & 1 deletion doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ Parameters for Tree Booster
See tutorial for more information

Additional parameters for `hist` and 'gpu_hist' tree method
================================================
===========================================================

* ``single_precision_histogram``, [default=``false``]

Expand Down
4 changes: 2 additions & 2 deletions doc/tutorials/saving_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ or in R:
Will print out something similiar to (not actual output as it's too long for demonstration):

.. code-block:: json
.. code-block:: js
{
"Learner": {
Expand Down Expand Up @@ -201,7 +201,7 @@ Difference between saving model and dumping model
XGBoost has a function called ``dump_model`` in Booster object, which lets you to export
the model in a readable format like ``text``, ``json`` or ``dot`` (graphviz). The primary
use case for it is for model interpretation or visualization, and is not supposed to be
loaded back to XGBoost. The JSON version has a `schema
loaded back to XGBoost. The JSON version has a `Schema
<https://github.com/dmlc/xgboost/blob/master/doc/dump.schema>`_. See next section for
more info.

Expand Down
44 changes: 35 additions & 9 deletions python-package/xgboost/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,14 @@
LOGGER = logging.getLogger('[xgboost.dask]')


def _start_tracker(host, n_workers):
def _start_tracker(host, port, n_workers):
"""Start Rabit tracker """
env = {'DMLC_NUM_WORKER': n_workers}
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
if port:
rabit_context = RabitTracker(hostIP=host, port=port, port_end=port+1,
nslave=n_workers)
else:
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
env.update(rabit_context.slave_envs())

rabit_context.start(n_workers)
Expand Down Expand Up @@ -356,13 +360,21 @@ def get_worker_data_shape(self, worker):
cols = c
return (rows, cols)


def _get_rabit_args(worker_map, client):
from distributed import Client
def _get_rabit_args(worker_map, client: Client, host_ip=None, port=None):
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
host = distributed_comm.get_address_host(client.scheduler.address)
msg = 'Please provide both IP and port'
assert (host_ip and port) or (host_ip is None and port is None), msg

env = client.run_on_scheduler(_start_tracker, host.strip('/:'),
len(worker_map))
if host_ip:
LOGGER.info('Running tracker on: %s, %s', host_ip, str(port))
env = client.run_on_scheduler(_start_tracker, host_ip, port,
len(worker_map))
else:
host = distributed_comm.get_address_host(client.scheduler.address)
LOGGER.info('Running tracker on: %s', host.strip('/:'))
env = client.run_on_scheduler(_start_tracker, host.strip('/:'), port,
len(worker_map))
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
return rabit_args

Expand All @@ -373,7 +385,8 @@ def _get_rabit_args(worker_map, client):
# evaluation history is instead returned.


def train(client, params, dtrain, *args, evals=(), **kwargs):
def train(client, params, dtrain, *args, evals=(), tracker_ip=None,
tracker_port=None, **kwargs):
'''Train XGBoost model.
.. versionadded:: 1.0.0
Expand All @@ -383,6 +396,19 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
client: dask.distributed.Client
Specify the dask client used for training. Use default client
returned from dask if it's set to None.
tracker_ip:
Address for rabit tracker that runs on dask scheduler. Use
`client.scheduler.address` if unspecified.
.. versionadded:: 1.2.0
tracker_port:
Port for the tracker. Search for available ports automatically if
unspecified.
.. versionadded:: 1.2.0
\\*\\*kwargs:
Other parameters are the same as `xgboost.train` except for
`evals_result`, which is returned as part of function return value
Expand Down Expand Up @@ -410,7 +436,7 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):

workers = list(_get_client_workers(client).keys())

rabit_args = _get_rabit_args(workers, client)
rabit_args = _get_rabit_args(workers, client, tracker_ip, tracker_port)

def dispatched_train(worker_addr):
'''Perform training on a single worker.'''
Expand Down
24 changes: 24 additions & 0 deletions tests/python/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dask.dataframe as dd
import dask.array as da
from xgboost.dask import DaskDMatrix
from dask.distributed import comm
except ImportError:
LocalCluster = None
Client = None
Expand Down Expand Up @@ -286,3 +287,26 @@ def test_empty_dmatrix_approx():
with Client(cluster) as client:
parameters = {'tree_method': 'approx'}
run_empty_dmatrix(client, parameters)


def test_explicit_rabit_tracker():
with LocalCluster() as cluster:
with Client(cluster) as client:
X, y = generate_array()
host = comm.get_address_host(client.scheduler.address)
port = 9091
dtrain = xgb.dask.DaskDMatrix(client, X, y)

out = xgb.dask.train(client, {'tree_method': 'hist'}, dtrain,
tracker_ip=host, tracker_port=port)
prediction = xgb.dask.predict(client, out, dtrain)
assert prediction.shape[0] == kRows

assert isinstance(prediction, da.Array)
prediction = prediction.compute()

booster = out['booster']
single_node_predt = booster.predict(
xgb.DMatrix(X.compute())
)
np.testing.assert_allclose(prediction, single_node_predt)

0 comments on commit 94569e6

Please sign in to comment.