Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backport] [pyspark] rework transform to reuse same code (#9292) #9558

Merged
merged 1 commit into from
Sep 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 123 additions & 129 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from xgboost.sklearn import DEFAULT_N_ESTIMATORS, XGBModel, _can_use_qdm
from xgboost.training import train as worker_train

from .._typing import ArrayLike
from .data import (
_read_csr_matrix_from_unwrapped_spark_vec,
alias,
Expand Down Expand Up @@ -1117,12 +1118,86 @@ def _get_feature_col(
)
return features_col, feature_col_names

def _get_pred_contrib_col_name(self) -> Optional[str]:
"""Return the pred_contrib_col col name"""
pred_contrib_col_name = None
if (
self.isDefined(self.pred_contrib_col)
and self.getOrDefault(self.pred_contrib_col) != ""
):
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)

return pred_contrib_col_name

def _out_schema(self) -> Tuple[bool, str]:
"""Return the bool to indicate if it's a single prediction, true is single prediction,
and the returned type of the user-defined function. The value must
be a DDL-formatted type string."""

if self._get_pred_contrib_col_name() is not None:
return False, f"{pred.prediction} double, {pred.pred_contrib} array<double>"

return True, "double"

def _get_predict_func(self) -> Callable:
"""Return the true prediction function which will be running on the executor side"""

predict_params = self._gen_predict_params_dict()
pred_contrib_col_name = self._get_pred_contrib_col_name()

def _predict(
model: XGBModel, X: ArrayLike, base_margin: Optional[ArrayLike]
) -> Union[pd.DataFrame, pd.Series]:
data = {}
preds = model.predict(
X,
base_margin=base_margin,
validate_features=False,
**predict_params,
)
data[pred.prediction] = pd.Series(preds)

if pred_contrib_col_name is not None:
contribs = pred_contribs(model, X, base_margin)
data[pred.pred_contrib] = pd.Series(list(contribs))
return pd.DataFrame(data=data)

return data[pred.prediction]

return _predict

def _post_transform(self, dataset: DataFrame, pred_col: Column) -> DataFrame:
"""Post process of transform"""
prediction_col_name = self.getOrDefault(self.predictionCol)
single_pred, _ = self._out_schema()

if single_pred:
if prediction_col_name:
dataset = dataset.withColumn(prediction_col_name, pred_col)
else:
pred_struct_col = "_prediction_struct"
dataset = dataset.withColumn(pred_struct_col, pred_col)

if prediction_col_name:
dataset = dataset.withColumn(
prediction_col_name, getattr(col(pred_struct_col), pred.prediction)
)

pred_contrib_col_name = self._get_pred_contrib_col_name()
if pred_contrib_col_name is not None:
dataset = dataset.withColumn(
pred_contrib_col_name,
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)),
)

dataset = dataset.drop(pred_struct_col)
return dataset

def _transform(self, dataset: DataFrame) -> DataFrame:
# pylint: disable=too-many-statements, too-many-locals
# Save xgb_sklearn_model and predict_params to be local variable
# to avoid the `self` object to be pickled to remote.
xgb_sklearn_model = self._xgb_sklearn_model
predict_params = self._gen_predict_params_dict()

has_base_margin = False
if (
Expand All @@ -1137,18 +1212,9 @@ def _transform(self, dataset: DataFrame) -> DataFrame:
features_col, feature_col_names = self._get_feature_col(dataset)
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)

pred_contrib_col_name = None
if (
self.isDefined(self.pred_contrib_col)
and self.getOrDefault(self.pred_contrib_col) != ""
):
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)
predict_func = self._get_predict_func()

single_pred = True
schema = "double"
if pred_contrib_col_name:
single_pred = False
schema = f"{pred.prediction} double, {pred.pred_contrib} array<double>"
_, schema = self._out_schema()

@pandas_udf(schema) # type: ignore
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
Expand All @@ -1168,48 +1234,14 @@ def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
else:
base_margin = None

data = {}
preds = model.predict(
X,
base_margin=base_margin,
validate_features=False,
**predict_params,
)
data[pred.prediction] = pd.Series(preds)

if pred_contrib_col_name:
contribs = pred_contribs(model, X, base_margin)
data[pred.pred_contrib] = pd.Series(list(contribs))
yield pd.DataFrame(data=data)
else:
yield data[pred.prediction]
yield predict_func(model, X, base_margin)

if has_base_margin:
pred_col = predict_udf(struct(*features_col, base_margin_col))
else:
pred_col = predict_udf(struct(*features_col))

prediction_col_name = self.getOrDefault(self.predictionCol)

if single_pred:
dataset = dataset.withColumn(prediction_col_name, pred_col)
else:
pred_struct_col = "_prediction_struct"
dataset = dataset.withColumn(pred_struct_col, pred_col)

dataset = dataset.withColumn(
prediction_col_name, getattr(col(pred_struct_col), pred.prediction)
)

if pred_contrib_col_name:
dataset = dataset.withColumn(
pred_contrib_col_name,
array_to_vector(getattr(col(pred_struct_col), pred.pred_contrib)),
)

dataset = dataset.drop(pred_struct_col)

return dataset
return self._post_transform(dataset, pred_col)


class _ClassificationModel( # pylint: disable=abstract-method
Expand All @@ -1221,22 +1253,21 @@ class _ClassificationModel( # pylint: disable=abstract-method
.. Note:: This API is experimental.
"""

def _transform(self, dataset: DataFrame) -> DataFrame:
# pylint: disable=too-many-statements, too-many-locals
# Save xgb_sklearn_model and predict_params to be local variable
# to avoid the `self` object to be pickled to remote.
xgb_sklearn_model = self._xgb_sklearn_model
predict_params = self._gen_predict_params_dict()
def _out_schema(self) -> Tuple[bool, str]:
schema = (
f"{pred.raw_prediction} array<double>, {pred.prediction} double,"
f" {pred.probability} array<double>"
)
if self._get_pred_contrib_col_name() is not None:
# We will force setting strict_shape to True when predicting contribs,
# So, it will also output 3-D shape result.
schema = f"{schema}, {pred.pred_contrib} array<array<double>>"

has_base_margin = False
if (
self.isDefined(self.base_margin_col)
and self.getOrDefault(self.base_margin_col) != ""
):
has_base_margin = True
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
alias.margin
)
return False, schema

def _get_predict_func(self) -> Callable:
predict_params = self._gen_predict_params_dict()
pred_contrib_col_name = self._get_pred_contrib_col_name()

def transform_margin(margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
if margins.ndim == 1:
Expand All @@ -1251,76 +1282,38 @@ def transform_margin(margins: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
class_probs = softmax(raw_preds, axis=1)
return raw_preds, class_probs

features_col, feature_col_names = self._get_feature_col(dataset)
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)

pred_contrib_col_name = None
if (
self.isDefined(self.pred_contrib_col)
and self.getOrDefault(self.pred_contrib_col) != ""
):
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)

schema = (
f"{pred.raw_prediction} array<double>, {pred.prediction} double,"
f" {pred.probability} array<double>"
)
if pred_contrib_col_name:
# We will force setting strict_shape to True when predicting contribs,
# So, it will also output 3-D shape result.
schema = f"{schema}, {pred.pred_contrib} array<array<double>>"

@pandas_udf(schema) # type: ignore
def predict_udf(
iterator: Iterator[Tuple[pd.Series, ...]]
) -> Iterator[pd.DataFrame]:
assert xgb_sklearn_model is not None
model = xgb_sklearn_model
for data in iterator:
if enable_sparse_data_optim:
X = _read_csr_matrix_from_unwrapped_spark_vec(data)
else:
if feature_col_names is not None:
X = data[feature_col_names] # type: ignore
else:
X = stack_series(data[alias.data])

if has_base_margin:
base_margin = stack_series(data[alias.margin])
else:
base_margin = None

margins = model.predict(
X,
base_margin=base_margin,
output_margin=True,
validate_features=False,
**predict_params,
)
raw_preds, class_probs = transform_margin(margins)

# It seems that they use argmax of class probs,
# not of margin to get the prediction (Note: scala implementation)
preds = np.argmax(class_probs, axis=1)
result: Dict[str, pd.Series] = {
pred.raw_prediction: pd.Series(list(raw_preds)),
pred.prediction: pd.Series(preds),
pred.probability: pd.Series(list(class_probs)),
}
def _predict(
model: XGBModel, X: ArrayLike, base_margin: Optional[np.ndarray]
) -> Union[pd.DataFrame, pd.Series]:
margins = model.predict(
X,
base_margin=base_margin,
output_margin=True,
validate_features=False,
**predict_params,
)
raw_preds, class_probs = transform_margin(margins)

# It seems that they use argmax of class probs,
# not of margin to get the prediction (Note: scala implementation)
preds = np.argmax(class_probs, axis=1)
result: Dict[str, pd.Series] = {
pred.raw_prediction: pd.Series(list(raw_preds)),
pred.prediction: pd.Series(preds),
pred.probability: pd.Series(list(class_probs)),
}

if pred_contrib_col_name:
contribs = pred_contribs(model, X, base_margin, strict_shape=True)
result[pred.pred_contrib] = pd.Series(list(contribs.tolist()))
if pred_contrib_col_name is not None:
contribs = pred_contribs(model, X, base_margin, strict_shape=True)
result[pred.pred_contrib] = pd.Series(list(contribs.tolist()))

yield pd.DataFrame(data=result)
return pd.DataFrame(data=result)

if has_base_margin:
pred_struct = predict_udf(struct(*features_col, base_margin_col))
else:
pred_struct = predict_udf(struct(*features_col))
return _predict

def _post_transform(self, dataset: DataFrame, pred_col: Column) -> DataFrame:
pred_struct_col = "_prediction_struct"
dataset = dataset.withColumn(pred_struct_col, pred_struct)
dataset = dataset.withColumn(pred_struct_col, pred_col)

raw_prediction_col_name = self.getOrDefault(self.rawPredictionCol)
if raw_prediction_col_name:
Expand All @@ -1342,7 +1335,8 @@ def predict_udf(
array_to_vector(getattr(col(pred_struct_col), pred.probability)),
)

if pred_contrib_col_name:
pred_contrib_col_name = self._get_pred_contrib_col_name()
if pred_contrib_col_name is not None:
dataset = dataset.withColumn(
pred_contrib_col_name,
getattr(col(pred_struct_col), pred.pred_contrib),
Expand Down