Skip to content

Commit

Permalink
[dask] include multiclass-classification task in tests (#4048)
Browse files Browse the repository at this point in the history
* include multiclass-classification task and task_to_model_factory dicts

* define centers coordinates. flatten init_scores within each partition for multiclass-classification

* include issue comment and fix linting error
  • Loading branch information
jmoralez authored Mar 10, 2021
1 parent 13680d8 commit 1d7b54d
Showing 1 changed file with 54 additions and 60 deletions.
114 changes: 54 additions & 60 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,21 @@
# see https://distributed.dask.org/en/latest/api.html#distributed.Client.close
CLIENT_CLOSE_TIMEOUT = 120

tasks = ['classification', 'regression', 'ranking']
tasks = ['binary-classification', 'multiclass-classification', 'regression', 'ranking']
data_output = ['array', 'scipy_csr_matrix', 'dataframe', 'dataframe-with-categorical']
data_centers = [[[-4, -4], [4, 4]], [[-4, -4], [4, 4], [-4, 4]]]
group_sizes = [5, 5, 5, 10, 10, 10, 20, 20, 20, 50, 50]
task_to_dask_factory = {
'regression': lgb.DaskLGBMRegressor,
'binary-classification': lgb.DaskLGBMClassifier,
'multiclass-classification': lgb.DaskLGBMClassifier,
'ranking': lgb.DaskLGBMRanker
}
task_to_local_factory = {
'regression': lgb.LGBMRegressor,
'binary-classification': lgb.LGBMClassifier,
'multiclass-classification': lgb.LGBMClassifier,
'ranking': lgb.LGBMRanker
}

pytestmark = [
pytest.mark.skipif(getenv('TASK', '') == 'mpi', reason='Fails to run with MPI interface'),
Expand Down Expand Up @@ -120,8 +131,14 @@ def _create_ranking_data(n_samples=100, output='array', chunk_size=50, **kwargs)
return X, y, w, g_rle, dX, dy, dw, dg


def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size=50):
if objective == 'classification':
def _create_data(objective, n_samples=100, output='array', chunk_size=50):
if objective.endswith('classification'):
if objective == 'binary-classification':
centers = [[-4, -4], [4, 4]]
elif objective == 'multiclass-classification':
centers = [[-4, -4], [4, 4], [-4, 4]]
else:
raise ValueError(f"Unknown classification task '{objective}'")
X, y = make_blobs(n_samples=n_samples, centers=centers, random_state=42)
elif objective == 'regression':
X, y = make_regression(n_samples=n_samples, random_state=42)
Expand Down Expand Up @@ -206,12 +223,11 @@ def _unpickle(filepath, serializer):


@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers)
def test_classifier(output, centers, client):
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification'])
def test_classifier(output, task, client):
X, y, w, dX, dy, dw = _create_data(
objective='classification',
output=output,
centers=centers
objective=task,
output=output
)

params = {
Expand Down Expand Up @@ -273,12 +289,11 @@ def test_classifier(output, centers, client):


@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers)
def test_classifier_pred_contrib(output, centers, client):
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification'])
def test_classifier_pred_contrib(output, task, client):
X, y, w, dX, dy, dw = _create_data(
objective='classification',
output=output,
centers=centers
objective=task,
output=output
)

params = {
Expand Down Expand Up @@ -354,7 +369,7 @@ def test_find_random_open_port(client):


def test_training_does_not_fail_on_port_conflicts(client):
_, _, _, dX, dy, dw = _create_data('classification', output='array')
_, _, _, dX, dy, dw = _create_data('binary-classification', output='array')

lightgbm_default_port = 12400
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
Expand Down Expand Up @@ -640,17 +655,13 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, c
output='array',
group=None
)
model_factory = lgb.DaskLGBMRanker
else:
_, _, _, dX, dy, _ = _create_data(
objective=task,
output='array',
)
dg = None
if task == 'classification':
model_factory = lgb.DaskLGBMClassifier
elif task == 'regression':
model_factory = lgb.DaskLGBMRegressor
model_factory = task_to_dask_factory[task]

params = {
"time_out": 5,
Expand Down Expand Up @@ -744,12 +755,7 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
)
dg_2 = None

if task == 'ranking':
model_factory = lgb.DaskLGBMRanker
elif task == 'classification':
model_factory = lgb.DaskLGBMClassifier
elif task == 'regression':
model_factory = lgb.DaskLGBMRegressor
model_factory = task_to_dask_factory[task]

params = {
"time_out": 5,
Expand Down Expand Up @@ -970,21 +976,16 @@ def collection_to_single_partition(collection):
output=output,
group=None
)
dask_model_factory = lgb.DaskLGBMRanker
local_model_factory = lgb.LGBMRanker
else:
X, y, w, dX, dy, dw = _create_data(
objective=task,
output=output
)
g = None
dg = None
if task == 'classification':
dask_model_factory = lgb.DaskLGBMClassifier
local_model_factory = lgb.LGBMClassifier
elif task == 'regression':
dask_model_factory = lgb.DaskLGBMRegressor
local_model_factory = lgb.LGBMRegressor

dask_model_factory = task_to_dask_factory[task]
local_model_factory = task_to_local_factory[task]

dX = collection_to_single_partition(dX)
dy = collection_to_single_partition(dy)
Expand Down Expand Up @@ -1029,18 +1030,15 @@ def test_network_params_not_required_but_respected_if_given(client, task, output
group=None,
chunk_size=10,
)
dask_model_factory = lgb.DaskLGBMRanker
else:
_, _, _, dX, dy, _ = _create_data(
objective=task,
output=output,
chunk_size=10,
)
dg = None
if task == 'classification':
dask_model_factory = lgb.DaskLGBMClassifier
elif task == 'regression':
dask_model_factory = lgb.DaskLGBMRegressor

dask_model_factory = task_to_dask_factory[task]

# rebalance data to be sure that each worker has a piece of the data
if output == 'array':
Expand Down Expand Up @@ -1103,18 +1101,15 @@ def test_machines_should_be_used_if_provided(task, output):
group=None,
chunk_size=10,
)
dask_model_factory = lgb.DaskLGBMRanker
else:
_, _, _, dX, dy, _ = _create_data(
objective=task,
output=output,
chunk_size=10,
)
dg = None
if task == 'classification':
dask_model_factory = lgb.DaskLGBMClassifier
elif task == 'regression':
dask_model_factory = lgb.DaskLGBMRegressor

dask_model_factory = task_to_dask_factory[task]

# rebalance data to be sure that each worker has a piece of the data
if output == 'array':
Expand Down Expand Up @@ -1201,17 +1196,15 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(
output='dataframe',
group=None
)
model_factory = lgb.DaskLGBMRanker
else:
_, _, _, dX, dy, dw = _create_data(
objective=task,
output='dataframe',
)
dg = None
if task == 'classification':
model_factory = lgb.DaskLGBMClassifier
elif task == 'regression':
model_factory = lgb.DaskLGBMRegressor

model_factory = task_to_dask_factory[task]

dy = dy.to_dask_array(lengths=True)
dy_col_array = dy.reshape(-1, 1)
assert len(dy_col_array.shape) == 2 and dy_col_array.shape[1] == 1
Expand All @@ -1231,10 +1224,7 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(

@pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output)
def test_init_score(
task,
output,
client):
def test_init_score(task, output, client):
if task == 'ranking' and output == 'scipy_csr_matrix':
pytest.skip('LGBMRanker is not currently tested on sparse matrices')

Expand All @@ -1243,28 +1233,32 @@ def test_init_score(
output=output,
group=None
)
model_factory = lgb.DaskLGBMRanker
else:
_, _, _, dX, dy, dw = _create_data(
objective=task,
output=output,
)
dg = None
if task == 'classification':
model_factory = lgb.DaskLGBMClassifier
elif task == 'regression':
model_factory = lgb.DaskLGBMRegressor

model_factory = task_to_dask_factory[task]

params = {
'n_estimators': 1,
'num_leaves': 2,
'time_out': 5
}
init_score = random.random()
# init_scores must be a 1D array, even for multiclass classification
# where you need to provide 1 score per class for each row in X
# https://github.com/microsoft/LightGBM/issues/4046
size_factor = 1
if task == 'multiclass-classification':
size_factor = 3 # number of classes

if output.startswith('dataframe'):
init_scores = dy.map_partitions(lambda x: pd.Series([init_score] * x.size))
init_scores = dy.map_partitions(lambda x: pd.Series([init_score] * x.size * size_factor))
else:
init_scores = da.full_like(dy, fill_value=init_score, dtype=np.float64)
init_scores = dy.map_blocks(lambda x: np.repeat(init_score, x.size * size_factor))
model = model_factory(client=client, **params)
model.fit(dX, dy, sample_weight=dw, init_score=init_scores, group=dg)
# value of the root node is 0 when init_score is set
Expand Down

0 comments on commit 1d7b54d

Please sign in to comment.