Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] [python-package] use keyword args for internal function calls #3755

Merged
merged 8 commits into from
Jan 22, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 101 additions & 28 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# coding: utf-8
"""Distributed training with LightGBM and Dask.distributed.

This module enables you to perform distributed training with LightGBM on Dask.Array and Dask.DataFrame collections.
It is based on dask-xgboost package.
This module enables you to perform distributed training with LightGBM on
Dask.Array and Dask.DataFrame collections.

It is based on dask-lightgbm, which was based on dask-xgboost.
"""
import logging
from collections import defaultdict
Expand Down Expand Up @@ -42,9 +44,18 @@ def _build_network_params(worker_addresses, local_worker_ip, local_listen_port,
-------
params: dict
"""
addr_port_map = {addr: (local_listen_port + i) for i, addr in enumerate(worker_addresses)}
addr_port_map = {
addr: (local_listen_port + i)
for i, addr
in enumerate(worker_addresses)
}
machine_list = [
'%s:%d' % (_parse_host_port(addr)[0], port)
for addr, port
in addr_port_map.items()
]
params = {
'machines': ','.join('%s:%d' % (_parse_host_port(addr)[0], port) for addr, port in addr_port_map.items()),
'machines': ','.join(machine_list),
'local_listen_port': addr_port_map[local_worker_ip],
'time_out': time_out,
'num_machines': len(addr_port_map)
Expand All @@ -65,7 +76,12 @@ def _concat(seq):

def _train_part(params, model_factory, list_of_parts, worker_addresses, return_model, local_listen_port=12400,
time_out=120, **kwargs):
network_params = _build_network_params(worker_addresses, get_worker().address, local_listen_port, time_out)
network_params = _build_network_params(
worker_addresses=worker_addresses,
local_worker_ip=get_worker().address,
local_listen_port=local_listen_port,
time_out=time_out
)
params.update(network_params)

# Concatenate many parts into one
Expand All @@ -76,7 +92,12 @@ def _train_part(params, model_factory, list_of_parts, worker_addresses, return_m

try:
model = model_factory(**params)
model.fit(data, label, sample_weight=weight, **kwargs)
model.fit(
X=data,
y=label,
sample_weight=weight,
**kwargs
)
finally:
_safe_call(_LIB.LGBM_NetworkFree())

Expand All @@ -86,7 +107,10 @@ def _train_part(params, model_factory, list_of_parts, worker_addresses, return_m
def _split_to_parts(data, is_matrix):
parts = data.to_delayed()
if isinstance(parts, np.ndarray):
assert (parts.shape[1] == 1) if is_matrix else (parts.ndim == 1 or parts.shape[1] == 1)
if is_matrix:
assert parts.shape[1] == 1
else:
assert parts.ndim == 1 or parts.shape[1] == 1
parts = parts.flatten().tolist()
return parts

Expand All @@ -107,12 +131,12 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
Weights of training data.
"""
# Split arrays/dataframes into parts. Arrange parts into tuples to enforce co-locality
data_parts = _split_to_parts(data, is_matrix=True)
label_parts = _split_to_parts(label, is_matrix=False)
data_parts = _split_to_parts(data=data, is_matrix=True)
label_parts = _split_to_parts(data=label, is_matrix=False)
if weight is None:
parts = list(map(delayed, zip(data_parts, label_parts)))
else:
weight_parts = _split_to_parts(weight, is_matrix=False)
weight_parts = _split_to_parts(data=weight, is_matrix=False)
parts = list(map(delayed, zip(data_parts, label_parts, weight_parts)))

# Start computation in the background
Expand All @@ -134,21 +158,28 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
worker_ncores = client.ncores()

if 'tree_learner' not in params or params['tree_learner'].lower() not in {'data', 'feature', 'voting'}:
logger.warning('Parameter tree_learner not set or set to incorrect value '
'(%s), using "data" as default', params.get("tree_learner", None))
logger.warning(
'Parameter tree_learner not set or set to incorrect value '
'(%s), using "data" as default',
params.get("tree_learner", None)
)
params['tree_learner'] = 'data'

# Tell each worker to train on the parts that it has locally
futures_classifiers = [client.submit(_train_part,
model_factory=model_factory,
params={**params, 'num_threads': worker_ncores[worker]},
list_of_parts=list_of_parts,
worker_addresses=list(worker_map.keys()),
local_listen_port=params.get('local_listen_port', 12400),
time_out=params.get('time_out', 120),
return_model=(worker == master_worker),
**kwargs)
for worker, list_of_parts in worker_map.items()]
futures_classifiers = [
client.submit(
_train_part,
model_factory=model_factory,
params={**params, 'num_threads': worker_ncores[worker]},
list_of_parts=list_of_parts,
worker_addresses=list(worker_map.keys()),
local_listen_port=params.get('local_listen_port', 12400),
time_out=params.get('time_out', 120),
return_model=(worker == master_worker),
**kwargs
)
for worker, list_of_parts in worker_map.items()
]

results = client.gather(futures_classifiers)
results = [v for v in results if v]
Expand Down Expand Up @@ -208,7 +239,16 @@ def _fit(self, model_factory, X, y=None, sample_weight=None, client=None, **kwar
client = default_client()

params = self.get_params(True)
model = _train(client, X, y, params, model_factory, sample_weight, **kwargs)

model = _train(
client=client,
data=X,
label=y,
params=params,
model_factory=model_factory,
weight=sample_weight,
**kwargs
)

self.set_params(**model.get_params())
self._copy_extra_params(model, self)
Expand All @@ -234,17 +274,37 @@ class DaskLGBMClassifier(_LGBMModel, LGBMClassifier):

def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the LGBMModel."""
return self._fit(LGBMClassifier, X, y, sample_weight, client, **kwargs)
return self._fit(
model_factory=LGBMClassifier,
X=X,
y=y,
sample_weight=sample_weight,
client=client,
**kwargs
)

fit.__doc__ = LGBMClassifier.fit.__doc__

def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
return _predict(self.to_local(), X, dtype=self.classes_.dtype, **kwargs)
return _predict(
model=self.to_local(),
data=X,
dtype=self.classes_.dtype,
**kwargs
)

predict.__doc__ = LGBMClassifier.predict.__doc__

def predict_proba(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba."""
return _predict(self.to_local(), X, proba=True, **kwargs)
return _predict(
model=self.to_local(),
data=X,
proba=True,
**kwargs
)

predict_proba.__doc__ = LGBMClassifier.predict_proba.__doc__

def to_local(self):
Expand All @@ -262,12 +322,25 @@ class DaskLGBMRegressor(_LGBMModel, LGBMRegressor):

def fit(self, X, y=None, sample_weight=None, client=None, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return self._fit(LGBMRegressor, X, y, sample_weight, client, **kwargs)
return self._fit(
model_factoory=LGBMRegressor,
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
X=X,
y=y,
sample_weight=sample_weight,
client=client,
**kwargs
)

fit.__doc__ = LGBMRegressor.fit.__doc__

def predict(self, X, **kwargs):
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
return _predict(self.to_local(), X, **kwargs)
return _predict(
model=self.to_local(),
data=X,
**kwargs
)

predict.__doc__ = LGBMRegressor.predict.__doc__

def to_local(self):
Expand Down
99 changes: 77 additions & 22 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,16 @@ def _create_data(objective, n_samples=100, centers=2, output='array', chunk_size
@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers)
def test_classifier(output, centers, client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers)

dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port)
X, y, w, dX, dy, dw = _create_data(
objective='classification',
output=output,
centers=centers
)

dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=listen_port
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_classifier.predict(dX)
s1 = accuracy_score(dy, p1)
Expand All @@ -92,9 +99,16 @@ def test_classifier(output, centers, client, listen_port):
@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers)
def test_classifier_proba(output, centers, client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output=output, centers=centers)

dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port)
X, y, w, dX, dy, dw = _create_data(
objective='classification',
output=output,
centers=centers
)

dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=listen_port
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_classifier.predict_proba(dX)
p1 = p1.compute()
Expand All @@ -107,9 +121,15 @@ def test_classifier_proba(output, centers, client, listen_port):


def test_classifier_local_predict(client, listen_port):
X, y, w, dX, dy, dw = _create_data('classification', output='array')

dask_classifier = dlgbm.DaskLGBMClassifier(time_out=5, local_listen_port=listen_port)
X, y, w, dX, dy, dw = _create_data(
objective='classification',
output='array'
)

dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=listen_port
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_classifier.to_local().predict(dX)

Expand All @@ -124,9 +144,16 @@ def test_classifier_local_predict(client, listen_port):

@pytest.mark.parametrize('output', data_output)
def test_regressor(output, client, listen_port):
X, y, w, dX, dy, dw = _create_data('regression', output=output)

dask_regressor = dlgbm.DaskLGBMRegressor(time_out=5, local_listen_port=listen_port, seed=42)
X, y, w, dX, dy, dw = _create_data(
objective='regression',
output=output
)

dask_regressor = dlgbm.DaskLGBMRegressor(
time_out=5,
local_listen_port=listen_port,
seed=42
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX)
if output != 'dataframe':
Expand All @@ -150,14 +177,26 @@ def test_regressor(output, client, listen_port):
@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('alpha', [.1, .5, .9])
def test_regressor_quantile(output, client, listen_port, alpha):
X, y, w, dX, dy, dw = _create_data('regression', output=output)

dask_regressor = dlgbm.DaskLGBMRegressor(local_listen_port=listen_port, seed=42, objective='quantile', alpha=alpha)
X, y, w, dX, dy, dw = _create_data(
objective='regression',
output=output
)

dask_regressor = dlgbm.DaskLGBMRegressor(
local_listen_port=listen_port,
seed=42,
objective='quantile',
alpha=alpha
)
dask_regressor = dask_regressor.fit(dX, dy, client=client, sample_weight=dw)
p1 = dask_regressor.predict(dX).compute()
q1 = np.count_nonzero(y < p1) / y.shape[0]

local_regressor = lightgbm.LGBMRegressor(seed=42, objective='quantile', alpha=alpha)
local_regressor = lightgbm.LGBMRegressor(
seed=42,
objective='quantile',
alpha=alpha
)
local_regressor.fit(X, y, sample_weight=w)
p2 = local_regressor.predict(X)
q2 = np.count_nonzero(y < p2) / y.shape[0]
Expand All @@ -168,9 +207,15 @@ def test_regressor_quantile(output, client, listen_port, alpha):


def test_regressor_local_predict(client, listen_port):
X, y, w, dX, dy, dw = _create_data('regression', output='array')

dask_regressor = dlgbm.DaskLGBMRegressor(local_listen_port=listen_port, seed=42)
X, y, w, dX, dy, dw = _create_data(
objective='regression',
output='array'
)

dask_regressor = dlgbm.DaskLGBMRegressor(
local_listen_port=listen_port,
seed=42
)
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw, client=client)
p1 = dask_regressor.predict(dX)
p2 = dask_regressor.to_local().predict(X)
Expand All @@ -189,8 +234,12 @@ def test_build_network_params():
'tcp://192.168.0.2:34346',
'tcp://192.168.0.3:34347'
]

params = dlgbm._build_network_params(workers_ips, 'tcp://192.168.0.2:34346', 12400, 120)
params = dlgbm._build_network_params(
workers_addresses=workers_ips,
local_worker_ip='tcp://192.168.0.2:34346',
local_listen_port=12400,
time_out=120
)
exp_params = {
'machines': '192.168.0.1:12400,192.168.0.2:12401,192.168.0.3:12402',
'local_listen_port': 12401,
Expand All @@ -208,5 +257,11 @@ def f(part):
df = dd.demo.make_timeseries()
df = df.map_partitions(f, meta=df._meta)
with pytest.raises(Exception) as info:
yield dlgbm._train(c, df, df.x, params={}, model_factory=lightgbm.LGBMClassifier)
yield dlgbm._train(
client=c,
data=df,
label=df.x,
params={},
model_factory=lightgbm.LGBMClassifier
)
assert 'foo' in str(info.value)