From 7847cede440f9b3f45779b91fd48fa4cc431d58e Mon Sep 17 00:00:00 2001 From: "Xuye (Chris) Qin" Date: Sat, 9 Oct 2021 13:57:56 +0800 Subject: [PATCH] Refine MarsDMatrix & support more parameters for XGB classifier and regressor (#2498) (cherry picked from commit e367ea3c782660fbecec488c1c5551ebfbeeaa73) --- .github/workflows/core-ci.yml | 2 +- mars/learn/contrib/xgboost/classifier.py | 13 +- mars/learn/contrib/xgboost/core.py | 108 ++++-- mars/learn/contrib/xgboost/dmatrix.py | 343 +++++++++--------- mars/learn/contrib/xgboost/regressor.py | 13 +- .../contrib/xgboost/tests/test_classifier.py | 4 +- mars/learn/contrib/xgboost/tests/test_core.py | 44 +++ .../contrib/xgboost/tests/test_regressor.py | 2 +- .../learn/contrib/xgboost/tests/test_train.py | 22 ++ mars/learn/contrib/xgboost/train.py | 57 ++- 10 files changed, 372 insertions(+), 236 deletions(-) create mode 100644 mars/learn/contrib/xgboost/tests/test_core.py diff --git a/.github/workflows/core-ci.yml b/.github/workflows/core-ci.yml index d13ca4037d..f795219d54 100644 --- a/.github/workflows/core-ci.yml +++ b/.github/workflows/core-ci.yml @@ -39,7 +39,7 @@ jobs: source ./ci/reload-env.sh export DEFAULT_VENV=$VIRTUAL_ENV - if [[ ! "$PYTHON" =~ "3.9" ]]; then + if [[ ! "$PYTHON" =~ "3.6" ]]; then conda install -n test --quiet --yes -c conda-forge python=$PYTHON numba fi diff --git a/mars/learn/contrib/xgboost/classifier.py b/mars/learn/contrib/xgboost/classifier.py index 501ac47b2b..c291041dbf 100644 --- a/mars/learn/contrib/xgboost/classifier.py +++ b/mars/learn/contrib/xgboost/classifier.py @@ -21,8 +21,7 @@ from xgboost.sklearn import XGBClassifierBase from .... import tensor as mt - from .dmatrix import MarsDMatrix - from .core import evaluation_matrices + from .core import wrap_evaluation_matrices from .train import train from .predict import predict @@ -31,14 +30,16 @@ class XGBClassifier(XGBScikitLearnBase, XGBClassifierBase): Implementation of the scikit-learn API for XGBoost classification. """ - def fit(self, X, y, sample_weights=None, eval_set=None, sample_weight_eval_set=None, **kw): + def fit(self, X, y, sample_weight=None, base_margin=None, + eval_set=None, sample_weight_eval_set=None, base_margin_eval_set=None, **kw): session = kw.pop('session', None) run_kwargs = kw.pop('run_kwargs', dict()) if kw: raise TypeError(f"fit got an unexpected keyword argument '{next(iter(kw))}'") - dtrain = MarsDMatrix(X, label=y, weight=sample_weights, - session=session, run_kwargs=run_kwargs) + dtrain, evals = wrap_evaluation_matrices( + None, X, y, sample_weight, base_margin, eval_set, + sample_weight_eval_set, base_margin_eval_set) params = self.get_xgb_params() self.classes_ = mt.unique(y, aggregate_size=1).to_numpy(session=session, **run_kwargs) @@ -50,8 +51,6 @@ def fit(self, X, y, sample_weights=None, eval_set=None, sample_weight_eval_set=N else: params['objective'] = 'binary:logistic' - evals = evaluation_matrices(eval_set, sample_weight_eval_set, - session=session, run_kwargs=run_kwargs) self.evals_result_ = dict() result = train(params, dtrain, num_boost_round=self.get_num_boosting_rounds(), evals=evals, evals_result=self.evals_result_, diff --git a/mars/learn/contrib/xgboost/core.py b/mars/learn/contrib/xgboost/core.py index a105656918..8223b6bdea 100644 --- a/mars/learn/contrib/xgboost/core.py +++ b/mars/learn/contrib/xgboost/core.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Callable, List, Optional, Tuple + try: import xgboost except ImportError: @@ -61,34 +63,82 @@ def predict(self, data, **kw): """ raise NotImplementedError - def evaluation_matrices(validation_set, sample_weights, session=None, run_kwargs=None): - """ - Parameters - ---------- - validation_set: list of tuples - Each tuple contains a validation dataset including input X and label y. - E.g.: - .. code-block:: python - [(X_0, y_0), (X_1, y_1), ... ] - sample_weights: list of arrays - The weight vector for validation data. - session: - Session to run - run_kwargs: - kwargs for session.run - Returns - ------- - evals: list of validation MarsDMatrix + def wrap_evaluation_matrices( + missing: float, + X: Any, + y: Any, + sample_weight: Optional[Any], + base_margin: Optional[Any], + eval_set: Optional[List[Tuple[Any, Any]]], + sample_weight_eval_set: Optional[List[Any]], + base_margin_eval_set: Optional[List[Any]], + label_transform: Callable = lambda x: x, + ) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]: + """Convert array_like evaluation matrices into DMatrix. Perform validation on the way. """ - evals = [] - if validation_set is not None: - assert isinstance(validation_set, list) - for i, e in enumerate(validation_set): - w = (sample_weights[i] - if sample_weights is not None else None) - dmat = MarsDMatrix(e[0], label=e[1], weight=w, - session=session, run_kwargs=run_kwargs) - evals.append((dmat, f'validation_{i}')) + train_dmatrix = MarsDMatrix( + data=X, + label=label_transform(y), + weight=sample_weight, + base_margin=base_margin, + missing=missing, + ) + + n_validation = 0 if eval_set is None else len(eval_set) + + def validate_or_none(meta: Optional[List], name: str) -> List: + if meta is None: + return [None] * n_validation + if len(meta) != n_validation: + raise ValueError( + f"{name}'s length does not equal `eval_set`'s length, " + + f"expecting {n_validation}, got {len(meta)}" + ) + return meta + + if eval_set is not None: + sample_weight_eval_set = validate_or_none( + sample_weight_eval_set, "sample_weight_eval_set" + ) + base_margin_eval_set = validate_or_none( + base_margin_eval_set, "base_margin_eval_set" + ) + + evals = [] + for i, (valid_X, valid_y) in enumerate(eval_set): + # Skip the duplicated entry. + if all( + ( + valid_X is X, valid_y is y, + sample_weight_eval_set[i] is sample_weight, + base_margin_eval_set[i] is base_margin, + ) + ): + evals.append(train_dmatrix) + else: + m = MarsDMatrix( + data=valid_X, + label=label_transform(valid_y), + weight=sample_weight_eval_set[i], + base_margin=base_margin_eval_set[i], + missing=missing, + ) + evals.append(m) + nevals = len(evals) + eval_names = [f"validation_{i}" for i in range(nevals)] + evals = list(zip(evals, eval_names)) else: - evals = None - return evals + if any( + meta is not None + for meta in [ + sample_weight_eval_set, + base_margin_eval_set, + ] + ): + raise ValueError( + "`eval_set` is not set but one of the other evaluation meta info is " + "not None." + ) + evals = [] + + return train_dmatrix, evals diff --git a/mars/learn/contrib/xgboost/dmatrix.py b/mars/learn/contrib/xgboost/dmatrix.py index e598bf5b15..956dadae7d 100644 --- a/mars/learn/contrib/xgboost/dmatrix.py +++ b/mars/learn/contrib/xgboost/dmatrix.py @@ -13,14 +13,19 @@ # limitations under the License. import itertools +from typing import List, Union + +import numpy as np from .... import opcodes as OperandDef -from ....core import ExecutableTuple, get_output_types, recursive_tile +from ....core import get_output_types, recursive_tile +from ....core.context import get_context, Context from ....dataframe.core import DATAFRAME_TYPE from ....serialization.serializables import KeyField, Float64Field, ListField, BoolField from ....tensor.core import TENSOR_TYPE, TENSOR_CHUNK_TYPE from ....tensor import tensor as astensor -from ....utils import has_unknown_shape, ensure_own_data +from ....typing import TileableType, ChunkType +from ....utils import has_unknown_shape, ensure_own_data, build_fetch from ...operands import LearnOperand, LearnOperandMixin from ...utils import convert_to_tensor_or_dataframe, concat_chunks @@ -28,64 +33,37 @@ class ToDMatrix(LearnOperand, LearnOperandMixin): _op_type_ = OperandDef.TO_DMATRIX - _data = KeyField('data') - _label = KeyField('label') - _missing = Float64Field('missing') - _weight = KeyField('weight') - _feature_names = ListField('feature_names') - _feature_types = ListField('feature_types') - _multi_output = BoolField('multi_output') - - def __init__(self, data=None, label=None, missing=None, weight=None, feature_names=None, - feature_types=None, multi_output=None, gpu=None, output_types=None, **kw): - super().__init__(_data=data, _label=label, _missing=missing, _weight=weight, - _feature_names=feature_names, _feature_types=feature_types, - _gpu=gpu, _multi_output=multi_output, _output_types=output_types, - **kw) + data = KeyField('data') + label = KeyField('label') + missing = Float64Field('missing') + weight = KeyField('weight') + base_margin = KeyField('base_margin') + feature_names = ListField('feature_names') + feature_types = ListField('feature_types') + # if to collocate the data, label and weight + _collocate = BoolField('collocate', default=False) @property def output_limit(self): - if self._multi_output: - return 1 + (self._label is not None) + (self._weight is not None) + if self._collocate: + return 1 + \ + (self.label is not None) + \ + (self.weight is not None) + \ + (self.base_margin is not None) return 1 - @property - def data(self): - return self._data - - @property - def label(self): - return self._label - - @property - def missing(self): - return self._missing - - @property - def weight(self): - return self._weight - - @property - def feature_names(self): - return self._feature_names - - @property - def feature_types(self): - return self._feature_types - - @property - def multi_output(self): - return self._multi_output - def _set_inputs(self, inputs): super()._set_inputs(inputs) - self._data = self._inputs[0] - has_label = self._label is not None + if self.data is not None: + self.data = self._inputs[0] + has_label = self.label is not None if has_label: - self._label = self._inputs[1] - if self._weight is not None: + self.label = self._inputs[1] + if self.weight is not None: i = 1 if not has_label else 2 - self._weight = self._inputs[i] + self.weight = self._inputs[i] + if self.base_margin is not None: + self.base_margin = self._inputs[-1] @staticmethod def _get_kw(obj): @@ -100,95 +78,118 @@ def _get_kw(obj): 'columns_value': obj.columns_value} def __call__(self): - inputs = [self._data] - kws = [] - kw = self._get_kw(self._data) - if not self._multi_output: - kw['type'] = 'data' - kws.append(kw) - if self._label is not None: - inputs.append(self._label) - if self._multi_output: - kw = self._get_kw(self._label) - kw['type'] = 'label' - kws.append(kw) - if self._weight is not None: - inputs.append(self._weight) - if self._multi_output: - kw = self._get_kw(self._weight) - kw['type'] = 'weight' - kws.append(kw) - if not self.output_types: - self.output_types = get_output_types(*inputs) + inputs = [self.data] + kw = self._get_kw(self.data) + if self.label is not None: + inputs.append(self.label) + if self.weight is not None: + inputs.append(self.weight) + if self.base_margin is not None: + inputs.append(self.base_margin) - tileables = self.new_tileables(inputs, kws=kws) - if not self._multi_output: - return tileables[0] - return tileables + return self.new_tileable(inputs, **kw) @classmethod - def _tile_multi_output(cls, op): - data, label, weight = op.data, op.label, op.weight - - if has_unknown_shape(data): - yield - - if data.chunk_shape[1] > 1: - # make sure data's second dimension has only 1 chunk - data = yield from recursive_tile(data.rechunk({1: data.shape[1]})) - + def _get_collocated(cls, + op: "ToDMatrix", + data: TileableType, + label: TileableType, + weight: TileableType, + base_margin: TileableType) -> List[TileableType]: + types = ['data', 'label', 'weight', 'base_margin'] nsplit = data.nsplits[0] - # rechunk label - if label is not None: - label = yield from recursive_tile(label.rechunk({0: nsplit})) - # rechunk weight - if weight is not None: - weight = yield from recursive_tile(weight.rechunk({0: nsplit})) - - out_chunkss = [[] for _ in range(op.output_limit)] + out_chunkss = [[] for _ in op.inputs] for i in range(len(nsplit)): data_chunk = data.cix[i, 0] inps = [data_chunk] kws = [] chunk_op = op.copy().reset_key() - chunk_op._data = data_chunk + chunk_op._collocate = True + chunk_op.data = data_chunk + output_types = [get_output_types(data)[0]] data_kw = cls._get_kw(data_chunk) data_kw['index'] = data_chunk.index kws.append(data_kw) - if label is not None: - label_chunk = chunk_op._label = label.cix[i, ] - inps.append(label_chunk) - kw = cls._get_kw(label_chunk) - kw['index'] = label_chunk.index - kw['type'] = 'label' - kws.append(kw) - if weight is not None: - weight_chunk = chunk_op._weight = weight.cix[i, ] - inps.append(weight_chunk) - kw = cls._get_kw(weight_chunk) - kw['index'] = weight_chunk.index - kw['type'] = 'weight' + for type_name, inp in zip(types[1:], [label, weight, base_margin]): + if inp is None: + continue + inp_chunk = inp.cix[i, ] + setattr(chunk_op, type_name, inp_chunk) + inps.append(inp_chunk) + kw = cls._get_kw(inp_chunk) + kw['index'] = inp_chunk.index + kw['type'] = type_name kws.append(kw) + output_types.append(get_output_types(inp)[0]) + chunk_op.output_types = output_types out_chunks = chunk_op.new_chunks(inps, kws=kws) for i, out_chunk in enumerate(out_chunks): out_chunkss[i].append(out_chunk) new_op = op.copy() - params = [out.params.copy() for out in op.outputs] - types = ['data', 'label', 'weight'] - for i, inp in enumerate([data, label, weight]): - if inp is None: + new_op._collocate = True + outs = [data, label, weight, base_margin] + params = [out.params.copy() for out in outs + if out is not None] + output_types = [] + j = 0 + for i, out in enumerate(outs): + if out is None: continue - params[i]['nsplits'] = inp.nsplits - params[i]['chunks'] = out_chunkss[i] - params[i]['type'] = types[i] + params[j]['nsplits'] = out.nsplits + params[j]['chunks'] = out_chunkss[j] + params[j]['type'] = types[i] + output_types.append(get_output_types(out)[0]) + j += 1 + new_op.output_types = output_types return new_op.new_tileables(op.inputs, kws=params) + @staticmethod + def _order_chunk_index(chunks: List[ChunkType]): + ndim = chunks[0].ndim + for i, c in enumerate(chunks): + if ndim == 2: + c._index = (i, 0) + else: + c._index = (i,) + return chunks + @classmethod - def _tile_single_output(cls, op): - from ....core.context import get_context + def tile(cls, op: "MarsDMatrix"): + data, label, weight, base_margin = op.data, op.label, op.weight, op.base_margin - data, label, weight = op.data, op.label, op.weight + if has_unknown_shape(data): + yield + if data.chunk_shape[1] > 1: + # make sure data's second dimension has only 1 chunk + data = yield from recursive_tile(data.rechunk({1: data.shape[1]})) + nsplit = data.nsplits[0] + # rechunk label + if label is not None: + label = yield from recursive_tile(label.rechunk({0: nsplit})) + # rechunk weight + if weight is not None: + weight = yield from recursive_tile(weight.rechunk({0: nsplit})) + # rechunk base_margin + if base_margin is not None: + base_margin = yield from recursive_tile(base_margin.rechunk({0: nsplit})) + + collocated = cls._get_collocated(op, data, label, weight, base_margin) + collocated_chunks = list(itertools.chain.from_iterable( + c.chunks for c in collocated)) + yield collocated_chunks + collocated + + data = build_fetch(collocated[0]) + has_label = False + if label is not None: + has_label = True + label = build_fetch(collocated[1]) + i_weight = -1 + if weight is not None: + i_weight = 1 if not has_label else 2 + weight = build_fetch(collocated[i_weight]) + if base_margin is not None: + base_margin = build_fetch(collocated[-1]) ctx = get_context() @@ -198,31 +199,37 @@ def _tile_single_output(cls, op): data_chunk_workers = [m['bands'][0][0] for m in data_chunk_metas] worker_to_chunks = dict() for i, worker in enumerate(data_chunk_workers): - size = 1 + (label is not None) + (weight is not None) + size = 1 + sum(it is not None for it in [label, weight, base_margin]) if worker not in worker_to_chunks: worker_to_chunks[worker] = [[] for _ in range(size)] worker_to_chunks[worker][0].append(data.chunks[i]) if label is not None: worker_to_chunks[worker][1].append(label.chunks[i]) if weight is not None: - worker_to_chunks[worker][-1].append(weight.chunks[i]) + worker_to_chunks[worker][i_weight].append(weight.chunks[i]) + if base_margin is not None: + worker_to_chunks[worker][-1].append(base_margin.chunks[i]) ind = itertools.count(0) out_chunks = [] for worker, chunks in worker_to_chunks.items(): - data_chunk = concat_chunks(chunks[0]) + data_chunk = concat_chunks(cls._order_chunk_index(chunks[0])) inps = [data_chunk] label_chunk = None if label is not None: - label_chunk = concat_chunks(chunks[1]) + label_chunk = concat_chunks(cls._order_chunk_index(chunks[1])) inps.append(label_chunk) weight_chunk = None if weight is not None: - weight_chunk = concat_chunks(chunks[2]) + weight_chunk = concat_chunks(cls._order_chunk_index(chunks[i_weight])) inps.append(weight_chunk) + base_margin_chunk = None + if base_margin is not None: + base_margin_chunk = concat_chunks(cls._order_chunk_index(chunks[-1])) + inps.append(base_margin_chunk) chunk_op = ToDMatrix(data=data_chunk, label=label_chunk, missing=op.missing, - weight=weight_chunk, feature_names=op.feature_names, - feature_types=op.feature_types, multi_output=False, - output_types=op.output_types) + weight=weight_chunk, base_margin=base_margin_chunk, + feature_names=op.feature_names, feature_types=op.feature_types, + _output_types=op.output_types) kws = data_chunk.params kws['index'] = (next(ind), 0) out_chunks.append(chunk_op.new_chunk(inps, **kws)) @@ -234,21 +241,14 @@ def _tile_single_output(cls, op): kw['nsplits'] = nsplits return new_op.new_tileables(op.inputs, kws=[kw]) - @classmethod - def tile(cls, op): - if op.multi_output: - return (yield from cls._tile_multi_output(op)) - else: - return cls._tile_single_output(op) - @staticmethod def get_xgb_dmatrix(tup): from xgboost import DMatrix - data, label, weight, missing, feature_names, feature_types = tup + data, label, weight, base_margin, missing, feature_names, feature_types = tup data = data.spmatrix if hasattr(data, 'spmatrix') else data - return DMatrix(ensure_own_data(data), label=ensure_own_data(label), - missing=missing, weight=ensure_own_data(weight), + return DMatrix(ensure_own_data(data), label=ensure_own_data(label), missing=missing, + weight=ensure_own_data(weight), base_margin=base_margin, feature_names=feature_names, feature_types=feature_types, nthread=-1) @@ -259,20 +259,31 @@ def _from_ctx_if_not_none(ctx, chunk): return ctx[chunk.key] @classmethod - def execute(cls, ctx, op): - if op.multi_output: + def execute(cls, + ctx: Union[dict, Context], + op: "ToDMatrix"): + if op._collocate: outs = op.outputs ctx[outs[0].key] = ctx[op.inputs[0].key] + has_label = False if op.label is not None: + has_label = True ctx[outs[1].key] = ctx[op.inputs[1].key] if op.weight is not None: + i_weight = 1 if not has_label else 2 + ctx[outs[i_weight].key] = ctx[op.inputs[i_weight].key] + if op.base_margin is not None: ctx[outs[-1].key] = ctx[op.inputs[-1].key] - return else: - ctx[op.outputs[0].key] = ( - cls._from_ctx_if_not_none(ctx, op.data), + out = op.outputs[0] + data = cls._from_ctx_if_not_none(ctx, op.data) + if data is None: + data = np.empty((0, out.shape[1])) + ctx[out.key] = ( + data, cls._from_ctx_if_not_none(ctx, op.label), cls._from_ctx_if_not_none(ctx, op.weight), + cls._from_ctx_if_not_none(ctx, op.base_margin), op.missing, op.feature_names, op.feature_types @@ -287,41 +298,31 @@ def check_data(data): return data -def to_dmatrix(data, label=None, missing=None, weight=None, - feature_names=None, feature_types=None, session=None, run_kwargs=None): +def check_array_like(y: TileableType, name: str) -> TileableType: + if y is None: + return + y = convert_to_tensor_or_dataframe(y) + if isinstance(y, DATAFRAME_TYPE): + y = y.iloc[:, 0] + y = astensor(y) + if y.ndim != 1: + raise ValueError(f'Expecting 1-d {name}, got: {y.ndim}-d') + return y + + +def to_dmatrix(data, label=None, missing=None, weight=None, base_margin=None, + feature_names=None, feature_types=None): data = check_data(data) - if label is not None: - label = convert_to_tensor_or_dataframe(label) - if isinstance(label, DATAFRAME_TYPE): - label = label.iloc[:, 0] - label = astensor(label) - if label.ndim != 1: - raise ValueError(f'Expecting 1-d label, got: {label.ndim}-d') - if weight is not None: - weight = convert_to_tensor_or_dataframe(weight) - if isinstance(weight, DATAFRAME_TYPE): - weight = weight.iloc[:, 0] - weight = astensor(weight) - if weight.ndim != 1: - raise ValueError(f'Expecting 1-d weight, got {weight.ndim}-d') - - op = ToDMatrix(data=data, label=label, missing=missing, weight=weight, - feature_names=feature_names, feature_types=feature_types, - gpu=data.op.gpu, multi_output=True, - output_types=get_output_types(data, label, weight)) - outs = ExecutableTuple(op()) - # Execute first, to make sure the counterpart chunks of data, label and weight are co-allocated - outs.execute(session=session, **(run_kwargs or dict())) - - data = outs[0] - label = None if op.label is None else outs[1] - weight = None if op.weight is None else outs[-1] + label = check_array_like(label, 'label') + weight = check_array_like(weight, 'weight') + base_margin = check_array_like(base_margin, 'base_margin') + # If not multiple outputs, try to collect the chunks on same worker into one # to feed the data into XGBoost for training. - op = ToDMatrix(data=data, label=label, missing=missing, weight=weight, + op = ToDMatrix(data=data, label=label, missing=missing, + weight=weight, base_margin=base_margin, feature_names=feature_names, feature_types=feature_types, - gpu=data.op.gpu, multi_output=False, - output_types=get_output_types(data)) + gpu=data.op.gpu, _output_types=get_output_types(data)) return op() diff --git a/mars/learn/contrib/xgboost/regressor.py b/mars/learn/contrib/xgboost/regressor.py index bb2d99e63a..51d0ab6623 100644 --- a/mars/learn/contrib/xgboost/regressor.py +++ b/mars/learn/contrib/xgboost/regressor.py @@ -19,8 +19,7 @@ XGBRegressor = make_import_error_func('xgboost') if xgboost: - from .dmatrix import MarsDMatrix - from .core import evaluation_matrices + from .core import wrap_evaluation_matrices from .train import train from .predict import predict @@ -29,17 +28,17 @@ class XGBRegressor(XGBScikitLearnBase): Implementation of the scikit-learn API for XGBoost regressor. """ - def fit(self, X, y, sample_weights=None, eval_set=None, sample_weight_eval_set=None, **kw): + def fit(self, X, y, sample_weight=None, base_margin=None, + eval_set=None, sample_weight_eval_set=None, base_margin_eval_set=None, **kw): session = kw.pop('session', None) run_kwargs = kw.pop('run_kwargs', dict()) if kw: raise TypeError(f"fit got an unexpected keyword argument '{next(iter(kw))}'") - dtrain = MarsDMatrix(X, label=y, weight=sample_weights, - session=session, run_kwargs=run_kwargs) + dtrain, evals = wrap_evaluation_matrices( + None, X, y, sample_weight, base_margin, eval_set, + sample_weight_eval_set, base_margin_eval_set) params = self.get_xgb_params() - evals = evaluation_matrices(eval_set, sample_weight_eval_set, - session=session, run_kwargs=run_kwargs) self.evals_result_ = dict() result = train(params, dtrain, num_boost_round=self.get_num_boosting_rounds(), evals=evals, evals_result=self.evals_result_, diff --git a/mars/learn/contrib/xgboost/tests/test_classifier.py b/mars/learn/contrib/xgboost/tests/test_classifier.py index 5d37508067..6ef21af5b7 100644 --- a/mars/learn/contrib/xgboost/tests/test_classifier.py +++ b/mars/learn/contrib/xgboost/tests/test_classifier.py @@ -76,7 +76,7 @@ def test_local_classifier(setup): y_df = md.DataFrame(y) for weight in weights: classifier = XGBClassifier(verbosity=1, n_estimators=2) - classifier.fit(X_raw, y_df, sample_weights=weight) + classifier.fit(X_raw, y_df, sample_weight=weight) prediction = classifier.predict(X_raw) assert prediction.ndim == 1 @@ -85,7 +85,7 @@ def test_local_classifier(setup): # should raise error if weight.ndim > 1 with pytest.raises(ValueError): XGBClassifier(verbosity=1, n_estimators=2).fit( - X_raw, y_df, sample_weights=mt.random.rand(1, 1)) + X_raw, y_df, sample_weight=mt.random.rand(1, 1)) # test binary classifier new_y = (y > 0.5).astype(mt.int32) diff --git a/mars/learn/contrib/xgboost/tests/test_core.py b/mars/learn/contrib/xgboost/tests/test_core.py new file mode 100644 index 0000000000..e1e124b81c --- /dev/null +++ b/mars/learn/contrib/xgboost/tests/test_core.py @@ -0,0 +1,44 @@ +# Copyright 1999-2021 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +try: + import xgboost +except ImportError: + xgboost = None + + +from ..... import tensor as mt +if xgboost: + from ..core import wrap_evaluation_matrices + + +@pytest.mark.skipif(xgboost is None, reason='XGBoost not installed') +def test_wrap_evaluation_matrices(): + X = mt.random.rand(100, 3) + y = mt.random.randint(3, size=(100,)) + + eval_set = [(mt.random.rand(10, 3), mt.random.randint(3, size=10))] + with pytest.raises(ValueError): + # sample_weight_eval_set size wrong + wrap_evaluation_matrices(0.0, X, y, None, None, + eval_set, [], None) + + with pytest.raises(ValueError): + wrap_evaluation_matrices(0.0, X, y, None, None, + None, eval_set, None) + + evals = wrap_evaluation_matrices(0.0, X, y, None, None, + eval_set, None, None)[1] + assert len(evals) > 0 diff --git a/mars/learn/contrib/xgboost/tests/test_regressor.py b/mars/learn/contrib/xgboost/tests/test_regressor.py index 343df2577c..93a1a850d0 100644 --- a/mars/learn/contrib/xgboost/tests/test_regressor.py +++ b/mars/learn/contrib/xgboost/tests/test_regressor.py @@ -51,7 +51,7 @@ def test_local_regressor(setup): weight = mt.random.rand(X.shape[0]) classifier = XGBRegressor(verbosity=1, n_estimators=2) regressor.set_params(tree_method='hist') - classifier.fit(X, y, sample_weights=weight) + classifier.fit(X, y, sample_weight=weight) prediction = classifier.predict(X) assert prediction.ndim == 1 diff --git a/mars/learn/contrib/xgboost/tests/test_train.py b/mars/learn/contrib/xgboost/tests/test_train.py index 71a47fa1d5..8166264f71 100644 --- a/mars/learn/contrib/xgboost/tests/test_train.py +++ b/mars/learn/contrib/xgboost/tests/test_train.py @@ -56,3 +56,25 @@ def test_local_train_dataframe(setup): dtrain = MarsDMatrix(X_df, y_series) booster = train({}, dtrain, num_boost_round=2) assert isinstance(booster, Booster) + + +@pytest.mark.skipif(xgboost is None, reason='XGBoost not installed') +def test_train_evals(setup_cluster): + rs = mt.random.RandomState(0) + # keep 1 chunk for X and y + X = rs.rand(n_rows, n_columns, chunk_size=(n_rows, n_columns // 2)) + y = rs.rand(n_rows, chunk_size=n_rows) + base_margin = rs.rand(n_rows, chunk_size=n_rows) + dtrain = MarsDMatrix(X, y, base_margin=base_margin) + eval_x = MarsDMatrix(rs.rand(n_rows, n_columns, chunk_size=n_rows // 5), + rs.rand(n_rows, chunk_size=n_rows // 5)) + evals = [(eval_x, 'eval_x')] + eval_result = dict() + booster = train({}, dtrain, num_boost_round=2, evals=evals, + evals_result=eval_result) + assert isinstance(booster, Booster) + assert len(eval_result) > 0 + + with pytest.raises(TypeError): + train({}, dtrain, num_boost_round=2, evals=[('eval_x', eval_x)], + evals_result=eval_result) diff --git a/mars/learn/contrib/xgboost/train.py b/mars/learn/contrib/xgboost/train.py index 8437b84adc..0527d68a5d 100644 --- a/mars/learn/contrib/xgboost/train.py +++ b/mars/learn/contrib/xgboost/train.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import pickle from collections import OrderedDict, defaultdict @@ -24,7 +25,7 @@ from ....serialization.serializables import FieldTypes, DictField, KeyField, ListField from ....utils import ensure_own_data from .start_tracker import StartTracker -from .dmatrix import ToDMatrix +from .dmatrix import ToDMatrix, to_dmatrix def _on_serialize_evals(evals_val): @@ -96,16 +97,6 @@ def _get_dmatrix_chunks_workers(ctx, dmatrix): [c.inputs[0].inputs[0].key for c in dmatrix.chunks], fields=['bands']) return [m['bands'][0][0] for m in metas] - @staticmethod - def _get_dmatrix_worker_to_chunk(dmatrix, workers, ctx): - worker_to_chunk = dict() - expect_workers = set(workers) - workers = XGBTrain._get_dmatrix_chunks_workers(ctx, dmatrix) - for w, c in zip(workers, dmatrix.chunks): - if w in expect_workers: - worker_to_chunk[w] = c - return worker_to_chunk - @classmethod def tile(cls, op): ctx = get_context() @@ -113,20 +104,39 @@ def tile(cls, op): inp = op.inputs[0] in_chunks = inp.chunks workers = cls._get_dmatrix_chunks_workers(ctx, inp) + worker_to_in_chunks = dict(zip(workers, in_chunks)) n_chunk = len(in_chunks) - tracker_chunk = StartTracker(n_workers=n_chunk, pure_depends=[True] * n_chunk)\ - .new_chunk(in_chunks, shape=()) out_chunks = [] worker_to_evals = defaultdict(list) if op.evals is not None: for dm, ev in op.evals: - worker_to_chunk = cls._get_dmatrix_worker_to_chunk(dm, workers, ctx) - for worker, chunk in worker_to_chunk.items(): - worker_to_evals[worker].append((chunk, ev)) - for in_chunk, worker in zip(in_chunks, workers): + ev_workers = cls._get_dmatrix_chunks_workers(ctx, dm) + for ev_worker, ev_chunk in zip(ev_workers, dm.chunks): + worker_to_evals[ev_worker].append((ev_chunk, ev)) + + all_workers = set(workers) + all_workers.update(worker_to_evals) + + i = itertools.count(n_chunk) + tracker_chunk = StartTracker( + n_workers=len(all_workers), + pure_depends=[True] * n_chunk).new_chunk(in_chunks, shape=()) + for worker in all_workers: chunk_op = op.copy().reset_key() chunk_op.expect_worker = worker chunk_op._tracker = tracker_chunk + if worker in worker_to_in_chunks: + in_chunk = worker_to_in_chunks[worker] + else: + in_chunk_op = ToDMatrix(data=None, label=None, weight=None, + base_margin=None, missing=inp.op.missing, + feature_names=inp.op.feature_names, + feature_types=inp.op.feature_types, + _output_types=inp.op.output_types) + params = inp.params.copy() + params['index'] = (next(i),) + params['shape'] = (0, inp.shape[1]) + in_chunk = in_chunk_op.new_chunk(None, kws=[params]) chunk_evals = list(worker_to_evals.get(worker, list())) chunk_op._evals = chunk_evals input_chunks = [in_chunk] + [pair[0] for pair in chunk_evals] + [tracker_chunk] @@ -194,7 +204,18 @@ def train(params, dtrain, evals=(), **kwargs): evals_result = kwargs.pop('evals_result', dict()) session = kwargs.pop('session', None) run_kwargs = kwargs.pop('run_kwargs', dict()) - op = XGBTrain(params=params, dtrain=dtrain, evals=evals, kwargs=kwargs) + + processed_evals = [] + if evals: + for eval_dmatrix, name in evals: + if not isinstance(name, str): + raise TypeError('evals must a list of pairs (DMatrix, string)') + if hasattr(eval_dmatrix, 'op') and isinstance(eval_dmatrix.op, ToDMatrix): + processed_evals.append((eval_dmatrix, name)) + else: + processed_evals.append((to_dmatrix(eval_dmatrix), name)) + + op = XGBTrain(params=params, dtrain=dtrain, evals=processed_evals, kwargs=kwargs) t = op() ret = t.execute(session=session, **run_kwargs).fetch(session=session) evals_result.update(ret['history'])