diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 1b42411f6652..e5f8a68abd1b 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -3,7 +3,8 @@ # pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name # pylint: disable=too-few-public-methods, too-many-lines, too-many-branches import json -from typing import Iterator, Optional, Tuple +from collections import namedtuple +from typing import Iterator, List, Optional, Tuple import numpy as np import pandas as pd @@ -21,6 +22,7 @@ HasWeightCol, ) from pyspark.ml.util import MLReadable, MLWritable +from pyspark.sql import DataFrame from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct from pyspark.sql.types import ( ArrayType, @@ -471,6 +473,12 @@ def _get_unwrapped_vec_cols(feature_col): ] +FeatureProp = namedtuple( + "FeatureProp", + ("enable_sparse_data_optim", "has_validation_col", "features_cols_names"), +) + + class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable): def __init__(self): super().__init__() @@ -641,9 +649,9 @@ def _get_xgb_train_call_args(cls, train_params): } return booster_params, kwargs_params - def _fit(self, dataset): - # pylint: disable=too-many-statements, too-many-locals - self._validate_params() + def _prepare_input_columns_and_feature_prop( + self, dataset: DataFrame + ) -> Tuple[List[str], FeatureProp]: label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label) select_cols = [label_col] @@ -698,6 +706,18 @@ def _fit(self, dataset): if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col): select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid)) + feature_prop = FeatureProp( + enable_sparse_data_optim, has_validation_col, features_cols_names + ) + return select_cols, feature_prop + + def _prepare_input(self, dataset: DataFrame) -> Tuple[DataFrame, FeatureProp]: + """Prepare the input including column pruning, repartition and so on""" + + select_cols, feature_prop = self._prepare_input_columns_and_feature_prop( + dataset + ) + dataset = dataset.select(*select_cols) num_workers = self.getOrDefault(self.num_workers) @@ -732,11 +752,13 @@ def _fit(self, dataset): # XGBoost requires qid to be sorted for each partition dataset = dataset.sortWithinPartitions(alias.qid, ascending=True) + return dataset, feature_prop + + def _get_xgb_parameters(self, dataset: DataFrame): train_params = self._get_distributed_train_params(dataset) booster_params, train_call_kwargs_params = self._get_xgb_train_call_args( train_params ) - cpu_per_task = int( _get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1") ) @@ -749,9 +771,6 @@ def _fit(self, dataset): "missing": float(self.getOrDefault(self.missing)), } booster_params["nthread"] = cpu_per_task - use_gpu = self.getOrDefault(self.use_gpu) - - is_local = _is_local(_get_spark_session().sparkContext) # Remove the parameters whose value is None booster_params = {k: v for k, v in booster_params.items() if v is not None} @@ -760,7 +779,25 @@ def _fit(self, dataset): } dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None} - use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist") + return booster_params, train_call_kwargs_params, dmatrix_kwargs + + def _fit(self, dataset): + # pylint: disable=too-many-statements, too-many-locals + self._validate_params() + + dataset, feature_prop = self._prepare_input(dataset) + + ( + booster_params, + train_call_kwargs_params, + dmatrix_kwargs, + ) = self._get_xgb_parameters(dataset) + + use_gpu = self.getOrDefault(self.use_gpu) + + is_local = _is_local(_get_spark_session().sparkContext) + + num_workers = self.getOrDefault(self.num_workers) def _train_booster(pandas_df_iter): """Takes in an RDD partition and outputs a booster for that partition after @@ -772,6 +809,8 @@ def _train_booster(pandas_df_iter): context = BarrierTaskContext.get() gpu_id = None + use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist") + if use_gpu: gpu_id = context.partitionId() if is_local else _get_gpu_id(context) booster_params["gpu_id"] = gpu_id @@ -814,12 +853,12 @@ def _train_booster(pandas_df_iter): with CommunicatorContext(context, **_rabit_args): dtrain, dvalid = create_dmatrix_from_partitions( pandas_df_iter, - features_cols_names, + feature_prop.features_cols_names, gpu_id, use_qdm, dmatrix_kwargs, - enable_sparse_data_optim=enable_sparse_data_optim, - has_validation_col=has_validation_col, + enable_sparse_data_optim=feature_prop.enable_sparse_data_optim, + has_validation_col=feature_prop.has_validation_col, ) if dvalid is not None: dval = [(dtrain, "training"), (dvalid, "validation")]