Skip to content

Commit

Permalink
requested changes: docstrings, dask_ml, tuples for list_of_parts
Browse files Browse the repository at this point in the history
  • Loading branch information
ffineis committed Jan 16, 2021
1 parent 78358e2 commit d6e5209
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ if [[ $TASK == "lint" ]]; then
"r-lintr>=2.0"
pip install --user cpplint
echo "Linting Python code"
pycodestyle --ignore=E402,E501,W503 --exclude=./compute,./eigen,./.nuget,./external_libs . || exit -1
pycodestyle --ignore=E501,W503 --exclude=./compute,./eigen,./.nuget,./external_libs . || exit -1
pydocstyle --convention=numpy --add-ignore=D105 --match-dir="^(?!^compute|^eigen|external_libs|test|example).*" --match="(?!^test_|setup).*\.py" . || exit -1
echo "Linting R code"
Rscript ${BUILD_DIRECTORY}/.ci/lint_r_code.R ${BUILD_DIRECTORY} || exit -1
Expand Down
40 changes: 22 additions & 18 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,19 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re
is_ranker = model_factory.__qualname__ == 'LGBMRanker'

# Concatenate many parts into one
data = _concat([d['X'] for d in list_of_parts])
label = _concat([d['y'] for d in list_of_parts])
weight = _concat([d['weight'] for d in list_of_parts]) if 'weight' in list_of_parts[0] else None
parts = tuple(zip(*list_of_parts))
data = _concat(parts[0])
label = _concat(parts[1])

try:
model = model_factory(**params)

if is_ranker:
group = _concat([d['group'] for d in list_of_parts])
group = _concat(parts[-1])
weight = _concat(parts[2]) if len(parts) == 4 else None
model.fit(data, y=label, sample_weight=weight, group=group, **kwargs)
else:
weight = _concat(parts[2]) if len(parts) == 3 else None
model.fit(data, y=label, sample_weight=weight, **kwargs)

finally:
Expand Down Expand Up @@ -176,24 +178,26 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
sample_weight : array-like of shape = [n_samples] or None, optional (default=None)
Weights of training data.
group : array-like
Group/query data, used for ranking task. sum(group) = n_samples.
group : array-like where sum(group) = [n_samples] or None for non-ranking objectives (default=None)
Group/query data, only used for ranking task. sum(group) = n_samples. For example,
if you have a 100-record dataset with `group = [10, 20, 40, 10, 10]`, that means that you have
5 groups, where the first 10 records are in the first group, records 11-30 are the second group, etc.
"""
# Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality
data_parts = _split_to_parts(data, is_matrix=True)
label_parts = _split_to_parts(label, is_matrix=False)
parts = [{'X': x, 'y': y} for (x, y) in zip(data_parts, label_parts)]

# append weight, group vectors to part dicts when needed.
if sample_weight is not None:
weight_parts = _split_to_parts(sample_weight, is_matrix=False)
for i, d in enumerate(parts):
parts[i] = {**d, 'weight': weight_parts[i]}

if group is not None:
group_parts = _split_to_parts(group, is_matrix=False)
for i, d in enumerate(parts):
parts[i] = {**d, 'group': group_parts[i]}
weight_parts = _split_to_parts(sample_weight, is_matrix=False) if sample_weight is not None else None
group_parts = _split_to_parts(group, is_matrix=False) if group is not None else None

# choose between four options of (sample_weight, group) being (un)specified
if weight_parts is None and group_parts is None:
parts = zip(data_parts, label_parts)
elif weight_parts is not None and group_parts is None:
parts = zip(data_parts, label_parts, weight_parts)
elif weight_parts is None and group_parts is not None:
parts = zip(data_parts, label_parts, group_parts)
else:
parts = zip(data_parts, label_parts, weight_parts, group_parts)

# Start computation in the background
parts = list(map(delayed, parts))
Expand Down
2 changes: 1 addition & 1 deletion tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def test_ranker(output, client, listen_port, group):

# difference between distributed ranker and local ranker spearman corr should be small.
lcor = spearmanr(rnkvec_local, y).correlation
assert np.abs(dcor - lcor) < 0.003
assert np.abs(dcor - lcor) < 0.03


@pytest.mark.parametrize('output', ['array', 'dataframe'])
Expand Down

0 comments on commit d6e5209

Please sign in to comment.