From 4bf23c239139aaa90b6bd0e3d36c7c2168c516ce Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 26 Jan 2021 02:08:22 +0800 Subject: [PATCH] Specify shape in prediction contrib and interaction. (#6614) --- doc/tutorials/dask.rst | 13 +++- python-package/xgboost/core.py | 4 +- python-package/xgboost/dask.py | 115 ++++++++++++++++++++++--------- tests/python/test_with_dask.py | 120 +++++++++++++++++---------------- 4 files changed, 160 insertions(+), 92 deletions(-) diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst index 80754baaca53..4254d698007b 100644 --- a/doc/tutorials/dask.rst +++ b/doc/tutorials/dask.rst @@ -95,14 +95,21 @@ For prediction, pass the ``output`` returned by ``train`` into ``xgb.dask.predic .. code-block:: python prediction = xgb.dask.predict(client, output, dtrain) + # Or equivalently, pass ``output['booster']``: + prediction = xgb.dask.predict(client, output['booster'], dtrain) -Or equivalently, pass ``output['booster']``: +Eliminating the construction of DaskDMatrix is also possible, this can make the +computation a bit faster when meta information like ``base_margin`` is not needed: .. code-block:: python - prediction = xgb.dask.predict(client, output['booster'], dtrain) + prediction = xgb.dask.predict(client, output, X) + # Use inplace version. + prediction = xgb.dask.inplace_predict(client, output, X) -Here ``prediction`` is a dask ``Array`` object containing predictions from model. +Here ``prediction`` is a dask ``Array`` object containing predictions from model if input +is a ``DaskDMatrix`` or ``da.Array``. For ``dd.DataFrame``, the return value is a +``dd.Series``. Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier`` and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples. diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 20a4728e3bf4..b37a908a9f99 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -190,7 +190,7 @@ def _check_call(ret): raise XGBoostError(py_str(_LIB.XGBGetLastError())) -def ctypes2numpy(cptr, length, dtype): +def ctypes2numpy(cptr, length, dtype) -> np.ndarray: """Convert a ctypes pointer array to a numpy array.""" NUMPY_TO_CTYPES_MAPPING = { np.float32: ctypes.c_float, @@ -1553,7 +1553,7 @@ def predict(self, ctypes.byref(preds))) preds = ctypes2numpy(preds, length.value, np.float32) if pred_leaf: - preds = preds.astype(np.int32) + preds = preds.astype(np.int32, copy=False) nrow = data.num_row() if preds.size != nrow and preds.size % nrow == 0: chunk_size = int(preds.size / nrow) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 64d13bd800ed..ee16467afba0 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -964,22 +964,21 @@ async def _predict_async( pred_contribs: bool, approx_contribs: bool, pred_interactions: bool, - validate_features: bool + validate_features: bool, ) -> _DaskCollection: if isinstance(model, Booster): booster = model elif isinstance(model, dict): - booster = model['booster'] + booster = model["booster"] else: raise TypeError(_expect([Booster, dict], type(model))) if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)): - raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], - type(data))) + raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data))) def mapped_predict(partition: Any, is_df: bool) -> Any: worker = distributed.get_worker() with config.config_context(**global_config): - booster.set_param({'nthread': worker.nthreads}) + booster.set_param({"nthread": worker.nthreads}) m = DMatrix(data=partition, missing=missing, nthread=worker.nthreads) predt = booster.predict( data=m, @@ -988,15 +987,16 @@ def mapped_predict(partition: Any, is_df: bool) -> Any: pred_contribs=pred_contribs, approx_contribs=approx_contribs, pred_interactions=pred_interactions, - validate_features=validate_features + validate_features=validate_features, ) if is_df: - if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'): + if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"): import cudf - predt = cudf.DataFrame(predt, columns=['prediction']) + predt = cudf.DataFrame(predt, columns=["prediction"]) else: - predt = DataFrame(predt, columns=['prediction']) + predt = DataFrame(predt, columns=["prediction"]) return predt + # Predict on dask collection directly. if isinstance(data, (da.Array, dd.DataFrame)): return await _direct_predict_impl(client, data, mapped_predict) @@ -1010,16 +1010,16 @@ def mapped_predict(partition: Any, is_df: bool) -> Any: meta_names = data.meta_names def dispatched_predict( - worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts - ) -> List[Tuple[Tuple["dask.delayed.Delayed", int], int]]: - '''Perform prediction on each worker.''' - LOGGER.debug('Predicting on %d', worker_id) + worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts + ) -> List[Tuple[List[Union["dask.delayed.Delayed", int]], int]]: + """Perform prediction on each worker.""" + LOGGER.debug("Predicting on %d", worker_id) with config.config_context(**global_config): worker = distributed.get_worker() list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts) predictions = [] - booster.set_param({'nthread': worker.nthreads}) + booster.set_param({"nthread": worker.nthreads}) for i, parts in enumerate(list_of_parts): (data, _, _, base_margin, _, _, _) = parts order = list_of_orders[i] @@ -1029,7 +1029,7 @@ def dispatched_predict( feature_names=feature_names, feature_types=feature_types, missing=missing, - nthread=worker.nthreads + nthread=worker.nthreads, ) predt = booster.predict( data=local_part, @@ -1038,10 +1038,42 @@ def dispatched_predict( pred_contribs=pred_contribs, approx_contribs=approx_contribs, pred_interactions=pred_interactions, - validate_features=validate_features + validate_features=validate_features, ) - columns = 1 if len(predt.shape) == 1 else predt.shape[1] - ret = ((dask.delayed(predt), columns), order) # pylint: disable=no-member + if pred_contribs and predt.size != local_part.num_row(): + assert len(predt.shape) in (2, 3) + if len(predt.shape) == 2: + groups = 1 + columns = predt.shape[1] + else: + groups = predt.shape[1] + columns = predt.shape[2] + # pylint: disable=no-member + ret = ( + [dask.delayed(predt), groups, columns], + order, + ) + elif pred_interactions and predt.size != local_part.num_row(): + assert len(predt.shape) in (3, 4) + if len(predt.shape) == 3: + groups = 1 + columns = predt.shape[1] + else: + groups = predt.shape[1] + columns = predt.shape[2] + # pylint: disable=no-member + ret = ( + [dask.delayed(predt), groups, columns], + order, + ) + else: + assert len(predt.shape) == 1 or len(predt.shape) == 2 + columns = 1 if len(predt.shape) == 1 else predt.shape[1] + # pylint: disable=no-member + ret = ( + [dask.delayed(predt), columns], + order, + ) predictions.append(ret) return predictions @@ -1049,8 +1081,8 @@ def dispatched_predict( def dispatched_get_shape( worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts ) -> List[Tuple[int, int]]: - '''Get shape of data in each worker.''' - LOGGER.debug('Get shape on %d', worker_id) + """Get shape of data in each worker.""" + LOGGER.debug("Get shape on %d", worker_id) list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts) shapes = [] for i, parts in enumerate(list_of_parts): @@ -1061,7 +1093,7 @@ def dispatched_get_shape( async def map_function( func: Callable[[int, List[int], _DataParts], Any] ) -> List[Any]: - '''Run function for each part of the data.''' + """Run function for each part of the data.""" futures = [] workers_address = list(worker_map.keys()) for wid, worker_addr in enumerate(workers_address): @@ -1069,10 +1101,14 @@ async def map_function( list_of_parts = worker_map[worker_addr] list_of_orders = [partition_order[part.key] for part in list_of_parts] - f = client.submit(func, worker_id=wid, - list_of_orders=list_of_orders, - list_of_parts=list_of_parts, - pure=True, workers=[worker_addr]) + f = client.submit( + func, + worker_id=wid, + list_of_orders=list_of_orders, + list_of_parts=list_of_parts, + pure=True, + workers=[worker_addr], + ) assert isinstance(f, distributed.client.Future) futures.append(f) # Get delayed objects @@ -1091,10 +1127,24 @@ async def map_function( # See https://docs.dask.org/en/latest/array-creation.html arrays = [] for i, shape in enumerate(shapes): - arrays.append(da.from_delayed( - results[i][0], shape=(shape[0],) - if results[i][1] == 1 else (shape[0], results[i][1]), - dtype=numpy.float32)) + if pred_contribs: + out_shape = ( + (shape[0], results[i][2]) + if results[i][1] == 1 + else (shape[0], results[i][1], results[i][2]) + ) + elif pred_interactions: + out_shape = ( + (shape[0], results[i][2], results[i][2]) + if results[i][1] == 1 + else (shape[0], results[i][1], results[i][2]) + ) + else: + out_shape = (shape[0],) if results[i][1] == 1 else (shape[0], results[i][1]) + arrays.append( + da.from_delayed(results[i][0], shape=out_shape, dtype=numpy.float32) + ) + predictions = await da.concatenate(arrays, axis=0) return predictions @@ -1115,7 +1165,9 @@ def predict( .. note:: - Only default prediction mode is supported right now. + Using ``inplace_predict `` might be faster when meta information like + ``base_margin`` is not needed. For other parameters, please see + ``Booster.predict``. .. versionadded:: 1.0.0 @@ -1136,6 +1188,9 @@ def predict( Returns ------- prediction: dask.array.Array/dask.dataframe.Series + When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an + array, when input data is ``dask.dataframe.DataFrame``, return value is + ``dask.dataframe.Series`` ''' _assert_dask_support() diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 89f0997fab07..dedf6bfe7bff 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -24,6 +24,7 @@ if tm.no_dask()['condition']: pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True) +import distributed from distributed import LocalCluster, Client from distributed.utils_test import client, loop, cluster_fixture import dask.dataframe as dd @@ -51,11 +52,12 @@ def generate_array( with_weights: bool = False ) -> Tuple[xgb.dask._DaskCollection, xgb.dask._DaskCollection, Optional[xgb.dask._DaskCollection]]: - partition_size = 20 - X = da.random.random((kRows, kCols), partition_size) - y = da.random.random(kRows, partition_size) + chunk_size = 20 + rng = da.random.RandomState(1994) + X = rng.random_sample((kRows, kCols), chunks=(chunk_size, -1)) + y = rng.random_sample(kRows, chunks=chunk_size) if with_weights: - w = da.random.random(kRows, partition_size) + w = rng.random_sample(kRows, chunks=chunk_size) return X, y, w return X, y, None @@ -175,55 +177,51 @@ def test_boost_from_prediction(tree_method: str, client: "Client") -> None: assert np.all(predictions_1.compute() == predictions_2.compute()) -def test_dask_missing_value_reg() -> None: - with LocalCluster(n_workers=kWorkers) as cluster: - with Client(cluster) as client: - X_0 = np.ones((20 // 2, kCols)) - X_1 = np.zeros((20 // 2, kCols)) - X = np.concatenate([X_0, X_1], axis=0) - np.random.shuffle(X) - X = da.from_array(X) - X = X.rechunk(20, 1) - y = da.random.randint(0, 3, size=20) - y.rechunk(20) - regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2, - missing=0.0) - regressor.client = client - regressor.set_params(tree_method='hist') - regressor.fit(X, y, eval_set=[(X, y)]) - dd_predt = regressor.predict(X).compute() - - np_X = X.compute() - np_predt = regressor.get_booster().predict( - xgb.DMatrix(np_X, missing=0.0)) - np.testing.assert_allclose(np_predt, dd_predt) - - -def test_dask_missing_value_cls() -> None: - with LocalCluster() as cluster: - with Client(cluster) as client: - X_0 = np.ones((kRows // 2, kCols)) - X_1 = np.zeros((kRows // 2, kCols)) - X = np.concatenate([X_0, X_1], axis=0) - np.random.shuffle(X) - X = da.from_array(X) - X = X.rechunk(20, None) - y = da.random.randint(0, 3, size=kRows) - y = y.rechunk(20, 1) - cls = xgb.dask.DaskXGBClassifier(verbosity=1, n_estimators=2, - tree_method='hist', - missing=0.0) - cls.client = client - cls.fit(X, y, eval_set=[(X, y)]) - dd_pred_proba = cls.predict_proba(X).compute() - - np_X = X.compute() - np_pred_proba = cls.get_booster().predict( - xgb.DMatrix(np_X, missing=0.0)) - np.testing.assert_allclose(np_pred_proba, dd_pred_proba) - - cls = xgb.dask.DaskXGBClassifier() - assert hasattr(cls, 'missing') +def test_dask_missing_value_reg(client: "Client") -> None: + X_0 = np.ones((20 // 2, kCols)) + X_1 = np.zeros((20 // 2, kCols)) + X = np.concatenate([X_0, X_1], axis=0) + np.random.shuffle(X) + X = da.from_array(X) + X = X.rechunk(20, 1) + y = da.random.randint(0, 3, size=20) + y.rechunk(20) + regressor = xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2, + missing=0.0) + regressor.client = client + regressor.set_params(tree_method='hist') + regressor.fit(X, y, eval_set=[(X, y)]) + dd_predt = regressor.predict(X).compute() + + np_X = X.compute() + np_predt = regressor.get_booster().predict( + xgb.DMatrix(np_X, missing=0.0)) + np.testing.assert_allclose(np_predt, dd_predt) + + +def test_dask_missing_value_cls(client: "Client") -> None: + X_0 = np.ones((kRows // 2, kCols)) + X_1 = np.zeros((kRows // 2, kCols)) + X = np.concatenate([X_0, X_1], axis=0) + np.random.shuffle(X) + X = da.from_array(X) + X = X.rechunk(20, None) + y = da.random.randint(0, 3, size=kRows) + y = y.rechunk(20, 1) + cls = xgb.dask.DaskXGBClassifier(verbosity=1, n_estimators=2, + tree_method='hist', + missing=0.0) + cls.client = client + cls.fit(X, y, eval_set=[(X, y)]) + dd_pred_proba = cls.predict_proba(X).compute() + + np_X = X.compute() + np_pred_proba = cls.get_booster().predict( + xgb.DMatrix(np_X, missing=0.0)) + np.testing.assert_allclose(np_pred_proba, dd_pred_proba) + + cls = xgb.dask.DaskXGBClassifier() + assert hasattr(cls, 'missing') @pytest.mark.parametrize("model", ["boosting", "rf"]) @@ -998,8 +996,7 @@ def worker_fn(worker_addr: str, data_ref: Dict) -> None: assert cnt - n_workers == n_partitions def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None: - X, y = da.from_array(X), da.from_array(y) - + X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32) Xy = xgb.dask.DaskDMatrix(client, X, y) booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster'] @@ -1009,8 +1006,12 @@ def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute() assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5) + shap = xgb.dask.predict(client, booster, X, pred_contribs=True).compute() + margin = xgb.dask.predict(client, booster, X, output_margin=True).compute() + assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5) + def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None: - X, y = da.from_array(X), da.from_array(y) + X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32) cls = xgb.dask.DaskXGBClassifier() cls.client = client cls.fit(X, y) @@ -1022,6 +1023,10 @@ def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None: margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute() assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5) + shap = xgb.dask.predict(client, booster, X, pred_contribs=True).compute() + margin = xgb.dask.predict(client, booster, X, output_margin=True).compute() + assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5) + def test_shap(self, client: "Client") -> None: from sklearn.datasets import load_boston, load_digits X, y = load_boston(return_X_y=True) @@ -1031,6 +1036,7 @@ def test_shap(self, client: "Client") -> None: X, y = load_digits(return_X_y=True) params = {'objective': 'multi:softmax', 'num_class': 10} self.run_shap(X, y, params, client) + params = {'objective': 'multi:softprob', 'num_class': 10} self.run_shap(X, y, params, client) @@ -1043,7 +1049,7 @@ def run_shap_interactions( params: Dict[str, Any], client: "Client" ) -> None: - X, y = da.from_array(X), da.from_array(y) + X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32) Xy = xgb.dask.DaskDMatrix(client, X, y) booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster']