diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index ba84edca87c2..4251af78f8bc 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -21,7 +21,7 @@ from dask.distributed import Client, default_client, get_worker, wait from .basic import _ConfigAliases, _LIB, _safe_call -from .sklearn import LGBMClassifier, LGBMRegressor +from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker logger = logging.getLogger(__name__) @@ -133,15 +133,24 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re } params.update(network_params) + is_ranker = issubclass(model_factory, LGBMRanker) + # Concatenate many parts into one parts = tuple(zip(*list_of_parts)) data = _concat(parts[0]) label = _concat(parts[1]) - weight = _concat(parts[2]) if len(parts) == 3 else None try: model = model_factory(**params) - model.fit(data, label, sample_weight=weight, **kwargs) + + if is_ranker: + group = _concat(parts[-1]) + weight = _concat(parts[2]) if len(parts) == 4 else None + model.fit(data, y=label, sample_weight=weight, group=group, **kwargs) + else: + weight = _concat(parts[2]) if len(parts) == 3 else None + model.fit(data, y=label, sample_weight=weight, **kwargs) + finally: _safe_call(_LIB.LGBM_NetworkFree()) @@ -156,7 +165,7 @@ def _split_to_parts(data, is_matrix): return parts -def _train(client, data, label, params, model_factory, weight=None, **kwargs): +def _train(client, data, label, params, model_factory, sample_weight=None, group=None, **kwargs): """Inner train routine. Parameters @@ -167,22 +176,36 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs): y : dask array of shape = [n_samples] The target values (class labels in classification, real numbers in regression). params : dict - model_factory : lightgbm.LGBMClassifier or lightgbm.LGBMRegressor class + model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class sample_weight : array-like of shape = [n_samples] or None, optional (default=None) - Weights of training data. + Weights of training data. + group : array-like or None, optional (default=None) + Group/query data. + Only used in the learning-to-rank task. + sum(group) = n_samples. + For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, + where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc. """ params = deepcopy(params) # Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality data_parts = _split_to_parts(data, is_matrix=True) label_parts = _split_to_parts(label, is_matrix=False) - if weight is None: - parts = list(map(delayed, zip(data_parts, label_parts))) + weight_parts = _split_to_parts(sample_weight, is_matrix=False) if sample_weight is not None else None + group_parts = _split_to_parts(group, is_matrix=False) if group is not None else None + + # choose between four options of (sample_weight, group) being (un)specified + if weight_parts is None and group_parts is None: + parts = zip(data_parts, label_parts) + elif weight_parts is not None and group_parts is None: + parts = zip(data_parts, label_parts, weight_parts) + elif weight_parts is None and group_parts is not None: + parts = zip(data_parts, label_parts, group_parts) else: - weight_parts = _split_to_parts(weight, is_matrix=False) - parts = list(map(delayed, zip(data_parts, label_parts, weight_parts))) + parts = zip(data_parts, label_parts, weight_parts, group_parts) # Start computation in the background + parts = list(map(delayed, parts)) parts = client.compute(parts) wait(parts) @@ -281,13 +304,13 @@ def _predict(model, data, proba=False, dtype=np.float32, **kwargs): Parameters ---------- - model : + model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class data : dask array of shape = [n_samples, n_features] Input feature matrix. proba : bool - Should method return results of predict_proba (proba == True) or predict (proba == False) + Should method return results of predict_proba (proba == True) or predict (proba == False). dtype : np.dtype - Dtype of the output + Dtype of the output. kwargs : other parameters passed to predict or predict_proba method """ if isinstance(data, dd._Frame): @@ -304,13 +327,14 @@ def _predict(model, data, proba=False, dtype=np.float32, **kwargs): class _LGBMModel: - def _fit(self, model_factory, X, y=None, sample_weight=None, client=None, **kwargs): + def _fit(self, model_factory, X, y=None, sample_weight=None, group=None, client=None, **kwargs): """Docstring is inherited from the LGBMModel.""" if client is None: client = default_client() params = self.get_params(True) - model = _train(client, X, y, params, model_factory, sample_weight, **kwargs) + model = _train(client, data=X, label=y, params=params, model_factory=model_factory, + sample_weight=sample_weight, group=group, **kwargs) self.set_params(**model.get_params()) self._copy_extra_params(model, self) @@ -335,8 +359,8 @@ class DaskLGBMClassifier(_LGBMModel, LGBMClassifier): """Distributed version of lightgbm.LGBMClassifier.""" def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): - """Docstring is inherited from the LGBMModel.""" - return self._fit(LGBMClassifier, X, y, sample_weight, client, **kwargs) + """Docstring is inherited from the lightgbm.LGBMClassifier.fit.""" + return self._fit(LGBMClassifier, X=X, y=y, sample_weight=sample_weight, client=client, **kwargs) fit.__doc__ = LGBMClassifier.fit.__doc__ def predict(self, X, **kwargs): @@ -364,7 +388,7 @@ class DaskLGBMRegressor(_LGBMModel, LGBMRegressor): def fit(self, X, y=None, sample_weight=None, client=None, **kwargs): """Docstring is inherited from the lightgbm.LGBMRegressor.fit.""" - return self._fit(LGBMRegressor, X, y, sample_weight, client, **kwargs) + return self._fit(LGBMRegressor, X=X, y=y, sample_weight=sample_weight, client=client, **kwargs) fit.__doc__ = LGBMRegressor.fit.__doc__ def predict(self, X, **kwargs): @@ -380,3 +404,29 @@ def to_local(self): model : lightgbm.LGBMRegressor """ return self._to_local(LGBMRegressor) + + +class DaskLGBMRanker(_LGBMModel, LGBMRanker): + """Docstring is inherited from the lightgbm.LGBMRanker.""" + + def fit(self, X, y=None, sample_weight=None, init_score=None, group=None, client=None, **kwargs): + """Docstring is inherited from the lightgbm.LGBMRanker.fit.""" + if init_score is not None: + raise RuntimeError('init_score is not currently supported in lightgbm.dask') + + return self._fit(LGBMRanker, X=X, y=y, sample_weight=sample_weight, group=group, client=client, **kwargs) + fit.__doc__ = LGBMRanker.fit.__doc__ + + def predict(self, X, **kwargs): + """Docstring is inherited from the lightgbm.LGBMRanker.predict.""" + return _predict(self.to_local(), X, **kwargs) + predict.__doc__ = LGBMRanker.predict.__doc__ + + def to_local(self): + """Create regular version of lightgbm.LGBMRanker from the distributed version. + + Returns + ------- + model : lightgbm.LGBMRanker + """ + return self._to_local(LGBMRanker) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index e793872ee4fb..960e1a56da63 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1,30 +1,37 @@ # coding: utf-8 +"""Tests for lightgbm.dask module""" + +import itertools import os import socket import sys import pytest -if not sys.platform.startswith("linux"): - pytest.skip("lightgbm.dask is currently supported in Linux environments", allow_module_level=True) +if not sys.platform.startswith('linux'): + pytest.skip('lightgbm.dask is currently supported in Linux environments', allow_module_level=True) import dask.array as da import dask.dataframe as dd import numpy as np import pandas as pd +from scipy.stats import spearmanr import scipy.sparse from dask.array.utils import assert_eq from dask_ml.metrics import accuracy_score, r2_score from distributed.utils_test import client, cluster_fixture, gen_cluster, loop from sklearn.datasets import make_blobs, make_regression +from sklearn.utils import check_random_state import lightgbm import lightgbm.dask as dlgbm data_output = ['array', 'scipy_csr_matrix', 'dataframe'] data_centers = [[[-4, -4], [4, 4]], [[-4, -4], [4, 4], [-4, 4]]] +group_sizes = [5, 5, 5, 10, 10, 10, 20, 20, 20, 50, 50] pytestmark = [ - pytest.mark.skipif(os.getenv("TASK", "") == "mpi", reason="Fails to run with MPI interface") + pytest.mark.skipif(os.getenv('TASK', '') == 'mpi', reason='Fails to run with MPI interface'), + pytest.mark.skipif(os.getenv('TASK', '') == 'gpu', reason='Fails to run with GPU interface') ] @@ -37,6 +44,135 @@ def listen_port(): listen_port.port = 13000 +def _make_ranking(n_samples=100, n_features=20, n_informative=5, gmax=2, + group=None, random_gs=False, avg_gs=10, random_state=0): + """Generate a learning-to-rank dataset - feature vectors grouped together with + integer-valued graded relevance scores. Replace this with a sklearn.datasets function + if ranking objective becomes supported in sklearn.datasets module. + + Parameters + ---------- + n_samples : int, optional (default=100) + Total number of documents (records) in the dataset. + n_features : int, optional (default=20) + Total number of features in the dataset. + n_informative : int, optional (default=5) + Number of features that are "informative" for ranking, as they are bias + beta * y + where bias and beta are standard normal variates. If this is greater than n_features, the dataset will have + n_features features, all will be informative. + group : array-like, optional (default=None) + 1-d array or list of group sizes. When `group` is specified, this overrides n_samples, random_gs, and + avg_gs by simply creating groups with sizes group[0], ..., group[-1]. + gmax : int, optional (default=2) + Maximum graded relevance value for creating relevance/target vector. If you set this to 2, for example, all + documents in a group will have relevance scores of either 0, 1, or 2. + random_gs : bool, optional (default=False) + True will make group sizes ~ Poisson(avg_gs), False will make group sizes == avg_gs. + avg_gs : int, optional (default=10) + Average number of documents (records) in each group. + + Returns + ------- + X : 2-d np.ndarray of shape = [n_samples (or np.sum(group)), n_features] + Input feature matrix for ranking objective. + y : 1-d np.array of shape = [n_samples (or np.sum(group))] + Integer-graded relevance scores. + group_ids : 1-d np.array of shape = [n_samples (or np.sum(group))] + Array of group ids, each value indicates to which group each record belongs. + """ + rnd_generator = check_random_state(random_state) + + y_vec, group_id_vec = np.empty((0,), dtype=int), np.empty((0,), dtype=int) + gid = 0 + + # build target, group ID vectors. + relvalues = range(gmax + 1) + + # build y/target and group-id vectors with user-specified group sizes. + if group is not None and hasattr(group, '__len__'): + n_samples = np.sum(group) + + for i, gsize in enumerate(group): + y_vec = np.concatenate((y_vec, rnd_generator.choice(relvalues, size=gsize, replace=True))) + group_id_vec = np.concatenate((group_id_vec, [i] * gsize)) + + # build y/target and group-id vectors according to n_samples, avg_gs, and random_gs. + else: + while len(y_vec) < n_samples: + gsize = avg_gs if not random_gs else rnd_generator.poisson(avg_gs) + + # groups should contain > 1 element for pairwise learning objective. + if gsize < 1: + continue + + y_vec = np.append(y_vec, rnd_generator.choice(relvalues, size=gsize, replace=True)) + group_id_vec = np.append(group_id_vec, [gid] * gsize) + gid += 1 + + y_vec, group_id_vec = y_vec[:n_samples], group_id_vec[:n_samples] + + # build feature data, X. Transform first few into informative features. + n_informative = max(min(n_features, n_informative), 0) + X = rnd_generator.uniform(size=(n_samples, n_features)) + + for j in range(n_informative): + bias, coef = rnd_generator.normal(size=2) + X[:, j] = bias + coef * y_vec + + return X, y_vec, group_id_vec + + +def _create_ranking_data(n_samples=100, output='array', chunk_size=50, **kwargs): + X, y, g = _make_ranking(n_samples=n_samples, random_state=42, **kwargs) + rnd = np.random.RandomState(42) + w = rnd.rand(X.shape[0]) * 0.01 + g_rle = np.array([len(list(grp)) for _, grp in itertools.groupby(g)]) + + if output == 'dataframe': + + # add target, weight, and group to DataFrame so that partitions abide by group boundaries. + X_df = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(X.shape[1])]) + X = X_df.copy() + X_df = X_df.assign(y=y, g=g, w=w) + + # set_index ensures partitions are based on group id. + # See https://stackoverflow.com/questions/49532824/dask-dataframe-split-partitions-based-on-a-column-or-function. + X_df.set_index('g', inplace=True) + dX = dd.from_pandas(X_df, chunksize=chunk_size) + + # separate target, weight from features. + dy = dX['y'] + dw = dX['w'] + dX = dX.drop(columns=['y', 'w']) + dg = dX.index.to_series() + + # encode group identifiers into run-length encoding, the format LightGBMRanker is expecting + # so that within each partition, sum(g) = n_samples. + dg = dg.map_partitions(lambda p: p.groupby('g', sort=False).apply(lambda z: z.shape[0])) + + elif output == 'array': + + # ranking arrays: one chunk per group. Each chunk must include all columns. + p = X.shape[1] + dX, dy, dw, dg = [], [], [], [] + for g_idx, rhs in enumerate(np.cumsum(g_rle)): + lhs = rhs - g_rle[g_idx] + dX.append(da.from_array(X[lhs:rhs, :], chunks=(rhs - lhs, p))) + dy.append(da.from_array(y[lhs:rhs])) + dw.append(da.from_array(w[lhs:rhs])) + dg.append(da.from_array(np.array([g_rle[g_idx]]))) + + dX = da.concatenate(dX, axis=0) + dy = da.concatenate(dy, axis=0) + dw = da.concatenate(dw, axis=0) + dg = da.concatenate(dg, axis=0) + + else: + raise ValueError('Ranking data creation only supported for Dask arrays and dataframes') + + return X, y, w, g_rle, dX, dy, dw, dg + + def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size=50): if objective == 'classification': X, y = make_blobs(n_samples=n_samples, centers=centers, random_state=42) @@ -96,6 +232,8 @@ def test_classifier(output, centers, client, listen_port): assert_eq(y, p2) assert_eq(p1_proba, p2_proba, atol=0.3) + client.close() + def test_training_does_not_fail_on_port_conflicts(client): _, _, _, dX, dy, dw = _create_data('classification', output='array') @@ -118,6 +256,8 @@ def test_training_does_not_fail_on_port_conflicts(client): ) assert dask_classifier.booster_ + client.close() + def test_classifier_local_predict(client, listen_port): X, y, w, dX, dy, dw = _create_data('classification', output='array') @@ -139,6 +279,8 @@ def test_classifier_local_predict(client, listen_port): assert_eq(y, p1) assert_eq(y, p2) + client.close() + @pytest.mark.parametrize('output', data_output) def test_regressor(output, client, listen_port): @@ -170,6 +312,8 @@ def test_regressor(output, client, listen_port): assert_eq(y, p1, rtol=1., atol=100.) assert_eq(y, p2, rtol=1., atol=50.) + client.close() + @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('alpha', [.1, .5, .9]) @@ -204,6 +348,8 @@ def test_regressor_quantile(output, client, listen_port, alpha): np.testing.assert_allclose(q1, alpha, atol=0.2) np.testing.assert_allclose(q2, alpha, atol=0.2) + client.close() + def test_regressor_local_predict(client, listen_port): X, y, _, dX, dy, dw = _create_data('regression', output='array') @@ -226,6 +372,54 @@ def test_regressor_local_predict(client, listen_port): assert_eq(p1, p2) assert_eq(s1, s2) + client.close() + + +@pytest.mark.parametrize('output', ['array', 'dataframe']) +@pytest.mark.parametrize('group', [None, group_sizes]) +def test_ranker(output, client, listen_port, group): + + X, y, w, g, dX, dy, dw, dg = _create_ranking_data(output=output, group=group) + + # use many trees + leaves to overfit, help ensure that dask data-parallel strategy matches that of + # serial learner. See https://github.com/microsoft/LightGBM/issues/3292#issuecomment-671288210. + dask_ranker = dlgbm.DaskLGBMRanker(time_out=5, local_listen_port=listen_port, tree_learner_type='data_parallel', + n_estimators=50, num_leaves=20, seed=42, min_child_samples=1) + dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg, client=client) + rnkvec_dask = dask_ranker.predict(dX) + rnkvec_dask = rnkvec_dask.compute() + + local_ranker = lightgbm.LGBMRanker(n_estimators=50, num_leaves=20, seed=42, min_child_samples=1) + local_ranker.fit(X, y, sample_weight=w, group=g) + rnkvec_local = local_ranker.predict(X) + + # distributed ranker should be able to rank decently well and should + # have high rank correlation with scores from serial ranker. + dcor = spearmanr(rnkvec_dask, y).correlation + assert dcor > 0.6 + assert spearmanr(rnkvec_dask, rnkvec_local).correlation > 0.9 + + client.close() + + +@pytest.mark.parametrize('output', ['array', 'dataframe']) +@pytest.mark.parametrize('group', [None, group_sizes]) +def test_ranker_local_predict(output, client, listen_port, group): + + X, y, w, g, dX, dy, dw, dg = _create_ranking_data(output=output, group=group) + + dask_ranker = dlgbm.DaskLGBMRanker(time_out=5, local_listen_port=listen_port, tree_learner='data', + n_estimators=10, num_leaves=10, seed=42, min_child_samples=1) + dask_ranker = dask_ranker.fit(dX, dy, group=dg, client=client) + rnkvec_dask = dask_ranker.predict(dX) + rnkvec_dask = rnkvec_dask.compute() + rnkvec_local = dask_ranker.to_local().predict(X) + + # distributed and to-local scores should be the same. + assert_eq(rnkvec_dask, rnkvec_local) + + client.close() + def test_find_open_port_works(): worker_ip = '127.0.0.1'