diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 66ab83591cdb..fb8b06077e70 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -5,7 +5,9 @@ It is based on dask-xgboost package. """ import logging +import socket from collections import defaultdict +from typing import Dict, Iterable from urllib.parse import urlparse import numpy as np @@ -13,7 +15,7 @@ from dask import array as da from dask import dataframe as dd from dask import delayed -from dask.distributed import default_client, get_worker, wait +from dask.distributed import Client, default_client, get_worker, wait from .basic import _LIB, _safe_call from .sklearn import LGBMClassifier, LGBMRegressor @@ -23,33 +25,84 @@ logger = logging.getLogger(__name__) -def _parse_host_port(address): - parsed = urlparse(address) - return parsed.hostname, parsed.port +def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int: + """Find an open port. + This function tries to find a free port on the machine it's run on. It is intended to + be run once on each Dask worker, sequentially. -def _build_network_params(worker_addresses, local_worker_ip, local_listen_port, time_out): - """Build network parameters suitable for LightGBM C backend. + Parameters + ---------- + worker_ip : str + IP address for the Dask worker. + local_listen_port : int + First port to try when searching for open ports. + ports_to_skip: Iterable[int] + An iterable of integers referring to ports that should be skipped. Since multiple Dask + workers can run on the same physical machine, this method may be called multiple times + on the same machine. ``ports_to_skip`` is used to ensure that LightGBM doesn't try to use + the same port for two worker processes running on the same machine. + + Returns + ------- + result : int + A free port on the machine referenced by ``worker_ip``. + """ + max_tries = 1000 + out_port = None + found_port = False + for i in range(max_tries): + out_port = local_listen_port + i + if out_port in ports_to_skip: + continue + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind((worker_ip, out_port)) + found_port = True + break + # if unavailable, you'll get OSError: Address already in use + except OSError: + continue + if not found_port: + msg = "LightGBM tried %s:%d-%d and could not create a connection. Try setting local_listen_port to a different value." + raise RuntimeError(msg % (worker_ip, local_listen_port, out_port)) + return out_port + + +def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], local_listen_port: int) -> Dict[str, int]: + """Find an open port on each worker. + + LightGBM distributed training uses TCP sockets by default, and this method is used to + identify open ports on each worker so LightGBM can reliable create those sockets. Parameters ---------- - worker_addresses : iterable of str - collection of worker addresses in `://:port` format - local_worker_ip : str + client : dask.distributed.Client + Dask client. + worker_addresses : Iterable[str] + An iterable of addresses for workers in the cluster. These are strings of the form ``://:port`` local_listen_port : int - time_out : int + First port to try when searching for open ports. Returns ------- - params: dict + result : Dict[str, int] + Dictionary where keys are worker addresses and values are an open port for LightGBM to use. """ - addr_port_map = {addr: (local_listen_port + i) for i, addr in enumerate(worker_addresses)} - params = { - 'machines': ','.join('%s:%d' % (_parse_host_port(addr)[0], port) for addr, port in addr_port_map.items()), - 'local_listen_port': addr_port_map[local_worker_ip], - 'time_out': time_out, - 'num_machines': len(addr_port_map) - } - return params + lightgbm_ports = set() + worker_ip_to_port = {} + for worker_address in worker_addresses: + port = client.submit( + func=_find_open_port, + workers=[worker_address], + worker_ip=urlparse(worker_address).hostname, + local_listen_port=local_listen_port, + ports_to_skip=lightgbm_ports + ).result() + lightgbm_ports.add(port) + worker_ip_to_port[worker_address] = port + + return worker_ip_to_port def _concat(seq): @@ -63,9 +116,20 @@ def _concat(seq): raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0]))) -def _train_part(params, model_factory, list_of_parts, worker_addresses, return_model, local_listen_port=12400, +def _train_part(params, model_factory, list_of_parts, worker_address_to_port, return_model, time_out=120, **kwargs): - network_params = _build_network_params(worker_addresses, get_worker().address, local_listen_port, time_out) + local_worker_address = get_worker().address + machine_list = ','.join([ + '%s:%d' % (urlparse(worker_address).hostname, port) + for worker_address, port + in worker_address_to_port.items() + ]) + network_params = { + 'machines': machine_list, + 'local_listen_port': worker_address_to_port[local_worker_address], + 'time_out': time_out, + 'num_machines': len(worker_address_to_port) + } params.update(network_params) # Concatenate many parts into one @@ -138,13 +202,22 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs): '(%s), using "data" as default', params.get("tree_learner", None)) params['tree_learner'] = 'data' + # find an open port on each worker. note that multiple workers can run + # on the same machine, so this needs to ensure that each one gets its + # own port + local_listen_port = params.get('local_listen_port', 12400) + worker_address_to_port = _find_ports_for_workers( + client=client, + worker_addresses=worker_map.keys(), + local_listen_port=local_listen_port + ) + # Tell each worker to train on the parts that it has locally futures_classifiers = [client.submit(_train_part, model_factory=model_factory, params={**params, 'num_threads': worker_ncores[worker]}, list_of_parts=list_of_parts, - worker_addresses=list(worker_map.keys()), - local_listen_port=params.get('local_listen_port', 12400), + worker_address_to_port=worker_address_to_port, time_out=params.get('time_out', 120), return_model=(worker == master_worker), **kwargs) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index d512cfcbda63..e8498314e1fc 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -1,5 +1,6 @@ # coding: utf-8 import os +import socket import sys import pytest @@ -89,6 +90,26 @@ def test_classifier(output, centers, client, listen_port): assert_eq(y, p2) +def test_training_does_not_fail_on_port_conflicts(client): + _, _, _, dX, dy, dw = _create_data('classification', output='array') + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('127.0.0.1', 12400)) + + dask_classifier = dlgbm.DaskLGBMClassifier( + time_out=5, + local_listen_port=12400 + ) + for i in range(5): + dask_classifier.fit( + X=dX, + y=dy, + sample_weight=dw, + client=client + ) + assert dask_classifier.booster_ + + @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('centers', data_centers) def test_classifier_proba(output, centers, client, listen_port): @@ -183,21 +204,27 @@ def test_regressor_local_predict(client, listen_port): assert_eq(s1, s2) -def test_build_network_params(): - workers_ips = [ - 'tcp://192.168.0.1:34545', - 'tcp://192.168.0.2:34346', - 'tcp://192.168.0.3:34347' - ] - - params = dlgbm._build_network_params(workers_ips, 'tcp://192.168.0.2:34346', 12400, 120) - exp_params = { - 'machines': '192.168.0.1:12400,192.168.0.2:12401,192.168.0.3:12402', - 'local_listen_port': 12401, - 'num_machines': len(workers_ips), - 'time_out': 120 - } - assert exp_params == params +def test_find_open_port_works(): + worker_ip = '127.0.0.1' + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind((worker_ip, 12400)) + new_port = dlgbm._find_open_port( + worker_ip=worker_ip, + local_listen_port=12400, + ports_to_skip=set() + ) + assert new_port == 12401 + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s_1: + s_1.bind((worker_ip, 12400)) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s_2: + s_2.bind((worker_ip, 12401)) + new_port = dlgbm._find_open_port( + worker_ip=worker_ip, + local_listen_port=12400, + ports_to_skip=set() + ) + assert new_port == 12402 @gen_cluster(client=True, timeout=None)