diff --git a/include/xgboost/global_config.h b/include/xgboost/global_config.h index b8a2ceb5438e..70eca5c23108 100644 --- a/include/xgboost/global_config.h +++ b/include/xgboost/global_config.h @@ -15,7 +15,7 @@ namespace xgboost { class Json; struct GlobalConfiguration : public XGBoostParameter { - int verbosity; + int verbosity { 1 }; DMLC_DECLARE_PARAMETER(GlobalConfiguration) { DMLC_DECLARE_FIELD(verbosity) .set_range(0, 3) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 67c4d6332c26..c449c0e75665 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -626,6 +626,7 @@ def _get_workers_from_data(dtrain: DaskDMatrix, evals=()): async def _train_async(client, + global_config, params, dtrain: DaskDMatrix, *args, @@ -639,7 +640,6 @@ async def _train_async(client, workers = list(_get_workers_from_data(dtrain, evals)) _rabit_args = await _get_rabit_args(len(workers), client) - _global_config = config.get_config() def dispatched_train(worker_addr, rabit_args, dtrain_ref, dtrain_idt, evals_ref): '''Perform training on a single worker. A local function prevents pickling. @@ -647,7 +647,7 @@ def dispatched_train(worker_addr, rabit_args, dtrain_ref, dtrain_idt, evals_ref) ''' LOGGER.info('Training on %s', str(worker_addr)) worker = distributed.get_worker() - with RabitContext(rabit_args), config.config_context(**_global_config): + with RabitContext(rabit_args), config.config_context(**global_config): local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref) local_evals = [] if evals_ref: @@ -735,8 +735,11 @@ def train(client, params, dtrain, *args, evals=(), early_stopping_rounds=None, ''' _assert_dask_support() client = _xgb_get_client(client) + # Get global configuration before transferring computation to another thread or + # process. + global_config = config.get_config() return client.sync( - _train_async, client, params, dtrain=dtrain, *args, evals=evals, + _train_async, client, global_config, params, dtrain=dtrain, *args, evals=evals, early_stopping_rounds=early_stopping_rounds, **kwargs) @@ -760,7 +763,7 @@ async def _direct_predict_impl(client, data, predict_fn): # pylint: disable=too-many-statements -async def _predict_async(client, model, data, missing=numpy.nan, **kwargs): +async def _predict_async(client, global_config, model, data, missing=numpy.nan, **kwargs): if isinstance(model, Booster): booster = model elif isinstance(model, dict): @@ -771,11 +774,9 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs): raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data))) - _global_config = config.get_config() - def mapped_predict(partition, is_df): worker = distributed.get_worker() - with config.config_context(**_global_config): + with config.config_context(**global_config): booster.set_param({'nthread': worker.nthreads}) m = DMatrix(partition, missing=missing, nthread=worker.nthreads) predt = booster.predict(m, validate_features=False, **kwargs) @@ -801,7 +802,7 @@ def mapped_predict(partition, is_df): def dispatched_predict(worker_id, list_of_orders, list_of_parts): '''Perform prediction on each worker.''' LOGGER.info('Predicting on %d', worker_id) - with config.config_context(**_global_config): + with config.config_context(**global_config): worker = distributed.get_worker() list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts) predictions = [] @@ -907,11 +908,12 @@ def predict(client, model, data, missing=numpy.nan, **kwargs): ''' _assert_dask_support() client = _xgb_get_client(client) - return client.sync(_predict_async, client, model, data, + global_config = config.get_config() + return client.sync(_predict_async, client, global_config, model, data, missing=missing, **kwargs) -async def _inplace_predict_async(client, model, data, +async def _inplace_predict_async(client, global_config, model, data, iteration_range=(0, 0), predict_type='value', missing=numpy.nan): @@ -927,6 +929,7 @@ async def _inplace_predict_async(client, model, data, def mapped_predict(data, is_df): worker = distributed.get_worker() + config.set_config(**global_config) booster.set_param({'nthread': worker.nthreads}) prediction = booster.inplace_predict( data, @@ -976,7 +979,9 @@ def inplace_predict(client, model, data, ''' _assert_dask_support() client = _xgb_get_client(client) - return client.sync(_inplace_predict_async, client, model=model, data=data, + global_config = config.get_config() + return client.sync(_inplace_predict_async, client, global_config, model=model, + data=data, iteration_range=iteration_range, predict_type=predict_type, missing=missing)