From a270e38bb6f797b57172db50454a6099d7f696a8 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 5 Nov 2020 15:35:23 +0800 Subject: [PATCH] Use submit instead. --- python-package/xgboost/dask.py | 150 ++++++++++++++++----------------- 1 file changed, 74 insertions(+), 76 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 9791f9f30d35..f7b1c026d6ed 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -330,7 +330,7 @@ def append_meta(m_parts, name: str): return self - def create_fn_args(self): + def create_fn_args(self, worker_addr: str): '''Create a dictionary of objects that can be pickled for function arguments. @@ -339,57 +339,56 @@ def create_fn_args(self): 'feature_types': self.feature_types, 'meta_names': self.meta_names, 'missing': self.missing, - 'worker_map': self.worker_map, + 'worker_map': self.worker_map.get(worker_addr, None), 'is_quantile': self.is_quantile} -def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker): - list_of_parts: List[tuple] = worker_map[worker.address] +def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition_order): # List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved. assert isinstance(list_of_parts, list) - with distributed.worker_client() as client: - list_of_parts_value = client.gather(list_of_parts) - - result = [] - - for i, part in enumerate(list_of_parts): - data = list_of_parts_value[i][0] - labels = None - weights = None - base_margin = None - label_lower_bound = None - label_upper_bound = None - # Iterate through all possible meta info, brings small overhead as in xgboost - # there are constant number of meta info available. - for j, blob in enumerate(list_of_parts_value[i][1:]): - if meta_names[j] == 'labels': - labels = blob - elif meta_names[j] == 'weights': - weights = blob - elif meta_names[j] == 'base_margin': - base_margin = blob - elif meta_names[j] == 'label_lower_bound': - label_lower_bound = blob - elif meta_names[j] == 'label_upper_bound': - label_upper_bound = blob - else: - raise ValueError('Unknown metainfo:', meta_names[j]) - - if partition_order: - result.append((data, labels, weights, base_margin, label_lower_bound, - label_upper_bound, partition_order[part.key])) + list_of_parts_value = list_of_parts + + result = [] + + for i, part in enumerate(list_of_parts): + data = list_of_parts_value[i][0] + labels = None + weights = None + base_margin = None + label_lower_bound = None + label_upper_bound = None + # Iterate through all possible meta info, brings small overhead as in xgboost + # there are constant number of meta info available. + for j, blob in enumerate(list_of_parts_value[i][1:]): + if meta_names[j] == 'labels': + labels = blob + elif meta_names[j] == 'weights': + weights = blob + elif meta_names[j] == 'base_margin': + base_margin = blob + elif meta_names[j] == 'label_lower_bound': + label_lower_bound = blob + elif meta_names[j] == 'label_upper_bound': + label_upper_bound = blob else: - result.append((data, labels, weights, base_margin, label_lower_bound, - label_upper_bound)) - return result + raise ValueError('Unknown metainfo:', meta_names[j]) + + if partition_order: + result.append((data, labels, weights, base_margin, label_lower_bound, + label_upper_bound, partition_order[list_of_keys[i]])) + else: + result.append((data, labels, weights, base_margin, label_lower_bound, + label_upper_bound)) + return result def _unzip(list_of_parts): return list(zip(*list_of_parts)) -def _get_worker_parts(worker_map, meta_names, worker): - partitions = _get_worker_parts_ordered(meta_names, worker_map, None, worker) +def _get_worker_parts(worker_map, meta_names): + list_of_parts: List[tuple] = worker_map + partitions = _get_worker_parts_ordered(meta_names, None, list_of_parts, None) partitions = _unzip(partitions) return partitions @@ -519,8 +518,8 @@ def __init__(self, client, self.max_bin = max_bin self.is_quantile = True - def create_fn_args(self): - args = super().create_fn_args() + def create_fn_args(self, worker_addr: str): + args = super().create_fn_args(worker_addr) args['max_bin'] = self.max_bin return args @@ -529,11 +528,9 @@ def _create_device_quantile_dmatrix(feature_names, feature_types, meta_names, missing, worker_map, max_bin): worker = distributed.get_worker() - if worker.address not in set(worker_map.keys()): - msg = 'worker {address} has an empty DMatrix. ' \ - 'All workers associated with this DMatrix: {workers}'.format( - address=worker.address, - workers=set(worker_map.keys())) + if worker_map is None: + msg = 'worker {address} has an empty DMatrix. '.format( + address=worker.address) LOGGER.warning(msg) import cupy # pylint: disable=import-error d = DeviceQuantileDMatrix(cupy.zeros((0, 0)), @@ -544,7 +541,7 @@ def _create_device_quantile_dmatrix(feature_names, feature_types, (data, labels, weights, base_margin, label_lower_bound, label_upper_bound) = _get_worker_parts( - worker_map, meta_names, worker) + worker_map, meta_names) it = DaskPartitionIter(data=data, label=labels, weight=weights, base_margin=base_margin, label_lower_bound=label_lower_bound, @@ -569,11 +566,9 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing, ''' worker = distributed.get_worker() - if worker.address not in set(worker_map.keys()): - msg = 'worker {address} has an empty DMatrix. ' \ - 'All workers associated with this DMatrix: {workers}'.format( - address=worker.address, - workers=set(worker_map.keys())) + list_of_parts = worker_map + if list_of_parts is None: + msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address) LOGGER.warning(msg) d = DMatrix(numpy.empty((0, 0)), feature_names=feature_names, @@ -586,8 +581,7 @@ def concat_or_none(data): return concat(data) (data, labels, weights, base_margin, - label_lower_bound, label_upper_bound) = _get_worker_parts( - worker_map, meta_names, worker) + label_lower_bound, label_upper_bound) = _get_worker_parts(list_of_parts, meta_names) labels = concat_or_none(labels) weights = concat_or_none(weights) @@ -640,8 +634,6 @@ def _get_workers_from_data(dtrain: DaskDMatrix, evals=()): async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(), early_stopping_rounds=None, **kwargs): - _assert_dask_support() - client: distributed.Client = _xgb_get_client(client) if 'evals_result' in kwargs.keys(): raise ValueError( 'evals_result is not supported in dask interface.', @@ -700,13 +692,13 @@ def dispatched_train(worker_addr, rabit_args, dtrain_ref, evals_ref): # XGBoost is deterministic in most of the cases, which means train function is # supposed to be idempotent. One known exception is gblinear with shotgun updater. # We haven't been able to do a full verification so here we keep pure to be False. - futures = client.map(dispatched_train, - workers, - [_rabit_args] * len(workers), - [dtrain.create_fn_args()] * len(workers), - [evals] * len(workers), - pure=False, - workers=workers) + futures = [] + for i in range(len(workers)): + evals = [(e.create_fn_args(workers[i])) for e, name in evals] + f = client.submit(dispatched_train, workers[i], _rabit_args, + dtrain.create_fn_args(workers[i]), evals) + futures.append(f) + results = await client.gather(futures) return list(filter(lambda ret: ret is not None, results))[0] @@ -802,14 +794,15 @@ def mapped_predict(partition, is_df): missing = data.missing meta_names = data.meta_names - def dispatched_predict(worker_id): + def dispatched_predict(worker_id, list_of_keys, list_of_parts): '''Perform prediction on each worker.''' LOGGER.info('Predicting on %d', worker_id) - + list_of_keys = list_of_keys.compute() worker = distributed.get_worker() list_of_parts = _get_worker_parts_ordered( - meta_names, worker_map, partition_order, worker) + meta_names, list_of_keys, list_of_parts, partition_order) predictions = [] + booster.set_param({'nthread': worker.nthreads}) for parts in list_of_parts: (data, _, _, base_margin, _, _, order) = parts @@ -828,17 +821,20 @@ def dispatched_predict(worker_id): columns = 1 if len(predt.shape) == 1 else predt.shape[1] ret = ((dask.delayed(predt), columns), order) predictions.append(ret) + return predictions - def dispatched_get_shape(worker_id): + def dispatched_get_shape(worker_id, list_of_keys, list_of_parts): '''Get shape of data in each worker.''' LOGGER.info('Get shape on %d', worker_id) - worker = distributed.get_worker() + list_of_keys = list_of_keys.compute() + # worker = distributed.get_worker() + # list_of_parts = worker_map[worker.address] list_of_parts = _get_worker_parts_ordered( meta_names, - worker_map, + list_of_keys, + list_of_parts, partition_order, - worker ) shapes = [] for parts in list_of_parts: @@ -850,12 +846,14 @@ async def map_function(func): '''Run function for each part of the data.''' futures = [] for wid in range(len(worker_map)): - list_of_workers = [list(worker_map.keys())[wid]] - f = await client.submit(func, wid, - pure=False, - workers=list_of_workers) + worker_addr = list(worker_map.keys())[wid] + list_of_parts = worker_map[worker_addr] + list_of_keys = [part.key for part in list_of_parts] + f = await client.submit(func, worker_id=wid, list_of_keys=dask.delayed(list_of_keys), + list_of_parts=list_of_parts, + pure=False, workers=[worker_addr]) futures.append(f) - # Get delayed objects + # # Get delayed objects results = await client.gather(futures) results = [t for l in results for t in l] # flatten into 1 dim list # sort by order, l[0] is the delayed object, l[1] is its order