Skip to content

Commit

Permalink
[pyspark] Re-work _fit function (#8630)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 authored Jan 4, 2023
1 parent beefd28 commit d3ad052
Showing 1 changed file with 51 additions and 12 deletions.
63 changes: 51 additions & 12 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
import json
from typing import Iterator, Optional, Tuple
from collections import namedtuple
from typing import Iterator, List, Optional, Tuple

import numpy as np
import pandas as pd
Expand All @@ -21,6 +22,7 @@
HasWeightCol,
)
from pyspark.ml.util import MLReadable, MLWritable
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, countDistinct, pandas_udf, rand, struct
from pyspark.sql.types import (
ArrayType,
Expand Down Expand Up @@ -471,6 +473,12 @@ def _get_unwrapped_vec_cols(feature_col):
]


FeatureProp = namedtuple(
"FeatureProp",
("enable_sparse_data_optim", "has_validation_col", "features_cols_names"),
)


class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -641,9 +649,9 @@ def _get_xgb_train_call_args(cls, train_params):
}
return booster_params, kwargs_params

def _fit(self, dataset):
# pylint: disable=too-many-statements, too-many-locals
self._validate_params()
def _prepare_input_columns_and_feature_prop(
self, dataset: DataFrame
) -> Tuple[List[str], FeatureProp]:
label_col = col(self.getOrDefault(self.labelCol)).alias(alias.label)

select_cols = [label_col]
Expand Down Expand Up @@ -698,6 +706,18 @@ def _fit(self, dataset):
if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col):
select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid))

feature_prop = FeatureProp(
enable_sparse_data_optim, has_validation_col, features_cols_names
)
return select_cols, feature_prop

def _prepare_input(self, dataset: DataFrame) -> Tuple[DataFrame, FeatureProp]:
"""Prepare the input including column pruning, repartition and so on"""

select_cols, feature_prop = self._prepare_input_columns_and_feature_prop(
dataset
)

dataset = dataset.select(*select_cols)

num_workers = self.getOrDefault(self.num_workers)
Expand Down Expand Up @@ -732,11 +752,13 @@ def _fit(self, dataset):
# XGBoost requires qid to be sorted for each partition
dataset = dataset.sortWithinPartitions(alias.qid, ascending=True)

return dataset, feature_prop

def _get_xgb_parameters(self, dataset: DataFrame):
train_params = self._get_distributed_train_params(dataset)
booster_params, train_call_kwargs_params = self._get_xgb_train_call_args(
train_params
)

cpu_per_task = int(
_get_spark_session().sparkContext.getConf().get("spark.task.cpus", "1")
)
Expand All @@ -749,9 +771,6 @@ def _fit(self, dataset):
"missing": float(self.getOrDefault(self.missing)),
}
booster_params["nthread"] = cpu_per_task
use_gpu = self.getOrDefault(self.use_gpu)

is_local = _is_local(_get_spark_session().sparkContext)

# Remove the parameters whose value is None
booster_params = {k: v for k, v in booster_params.items() if v is not None}
Expand All @@ -760,7 +779,25 @@ def _fit(self, dataset):
}
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}

use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")
return booster_params, train_call_kwargs_params, dmatrix_kwargs

def _fit(self, dataset):
# pylint: disable=too-many-statements, too-many-locals
self._validate_params()

dataset, feature_prop = self._prepare_input(dataset)

(
booster_params,
train_call_kwargs_params,
dmatrix_kwargs,
) = self._get_xgb_parameters(dataset)

use_gpu = self.getOrDefault(self.use_gpu)

is_local = _is_local(_get_spark_session().sparkContext)

num_workers = self.getOrDefault(self.num_workers)

def _train_booster(pandas_df_iter):
"""Takes in an RDD partition and outputs a booster for that partition after
Expand All @@ -772,6 +809,8 @@ def _train_booster(pandas_df_iter):
context = BarrierTaskContext.get()

gpu_id = None
use_hist = booster_params.get("tree_method", None) in ("hist", "gpu_hist")

if use_gpu:
gpu_id = context.partitionId() if is_local else _get_gpu_id(context)
booster_params["gpu_id"] = gpu_id
Expand Down Expand Up @@ -814,12 +853,12 @@ def _train_booster(pandas_df_iter):
with CommunicatorContext(context, **_rabit_args):
dtrain, dvalid = create_dmatrix_from_partitions(
pandas_df_iter,
features_cols_names,
feature_prop.features_cols_names,
gpu_id,
use_qdm,
dmatrix_kwargs,
enable_sparse_data_optim=enable_sparse_data_optim,
has_validation_col=has_validation_col,
enable_sparse_data_optim=feature_prop.enable_sparse_data_optim,
has_validation_col=feature_prop.has_validation_col,
)
if dvalid is not None:
dval = [(dtrain, "training"), (dvalid, "validation")]
Expand Down

0 comments on commit d3ad052

Please sign in to comment.