Skip to content

Commit

Permalink
[dask] use random ports in network setup (#3823)
Browse files Browse the repository at this point in the history
* use socket.bind with port 0 and client.run to find random open ports

* include test for found ports

* find random open ports as default

* parametrize local_listen_port. type hint to _find_random_open_port. fid open ports only on workers with data.

* make indentation consistent and pass list of workers to client.run

* remove socket import

* change random port implementation

* fix test
  • Loading branch information
jmoralez authored Feb 24, 2021
1 parent 7777852 commit 0e57657
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 109 deletions.
86 changes: 10 additions & 76 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,83 +45,18 @@ def _get_dask_client(client: Optional[Client]) -> Client:
return client


def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
"""Find an open port.
This function tries to find a free port on the machine it's run on. It is intended to
be run once on each Dask worker, sequentially.
Parameters
----------
worker_ip : str
IP address for the Dask worker.
local_listen_port : int
First port to try when searching for open ports.
ports_to_skip: Iterable[int]
An iterable of integers referring to ports that should be skipped. Since multiple Dask
workers can run on the same physical machine, this method may be called multiple times
on the same machine. ``ports_to_skip`` is used to ensure that LightGBM doesn't try to use
the same port for two worker processes running on the same machine.
def _find_random_open_port() -> int:
"""Find a random open port on localhost.
Returns
-------
port : int
A free port on the machine referenced by ``worker_ip``.
"""
max_tries = 1000
found_port = False
for i in range(max_tries):
out_port = local_listen_port + i
if out_port in ports_to_skip:
continue
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((worker_ip, out_port))
found_port = True
break
# if unavailable, you'll get OSError: Address already in use
except OSError:
continue
if not found_port:
msg = "LightGBM tried %s:%d-%d and could not create a connection. Try setting local_listen_port to a different value."
raise RuntimeError(msg % (worker_ip, local_listen_port, out_port))
return out_port


def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], local_listen_port: int) -> Dict[str, int]:
"""Find an open port on each worker.
LightGBM distributed training uses TCP sockets by default, and this method is used to
identify open ports on each worker so LightGBM can reliable create those sockets.
Parameters
----------
client : dask.distributed.Client
Dask client.
worker_addresses : Iterable[str]
An iterable of addresses for workers in the cluster. These are strings of the form ``<protocol>://<host>:port``.
local_listen_port : int
First port to try when searching for open ports.
Returns
-------
result : Dict[str, int]
Dictionary where keys are worker addresses and values are an open port for LightGBM to use.
A free port on localhost
"""
lightgbm_ports: Set[int] = set()
worker_ip_to_port = {}
for worker_address in worker_addresses:
port = client.submit(
func=_find_open_port,
workers=[worker_address],
worker_ip=urlparse(worker_address).hostname,
local_listen_port=local_listen_port,
ports_to_skip=lightgbm_ports
).result()
lightgbm_ports.add(port)
worker_ip_to_port[worker_address] = port

return worker_ip_to_port
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
port = s.getsockname()[1]
return port


def _concat(seq: List[_DaskPart]) -> _DaskPart:
Expand Down Expand Up @@ -415,10 +350,9 @@ def _train(
}
else:
_log_info("Finding random open ports for workers")
worker_address_to_port = _find_ports_for_workers(
client=client,
worker_addresses=worker_addresses,
local_listen_port=local_listen_port
worker_address_to_port = client.run(
_find_random_open_port,
workers=list(worker_addresses)
)
machines = ','.join([
'%s:%d' % (urlparse(worker_address).hostname, port)
Expand Down
48 changes: 15 additions & 33 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,6 @@ def _accuracy_score(dy_true, dy_pred):
return da.average(dy_true == dy_pred).compute()


def _find_random_open_port() -> int:
"""Find a random open port on localhost"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
port = s.getsockname()[1]
return port


def _pickle(obj, filepath, serializer):
if serializer == 'pickle':
with open(filepath, 'wb') as f:
Expand Down Expand Up @@ -343,6 +335,19 @@ def test_classifier_pred_contrib(output, centers, client):
client.close(timeout=CLIENT_CLOSE_TIMEOUT)


def test_find_random_open_port(client):
for _ in range(5):
worker_address_to_port = client.run(lgb.dask._find_random_open_port)
found_ports = worker_address_to_port.values()
# check that found ports are different for same address (LocalCluster)
assert len(set(found_ports)) == len(found_ports)
# check that the ports are indeed open
for port in found_ports:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', port))
client.close(timeout=CLIENT_CLOSE_TIMEOUT)


def test_training_does_not_fail_on_port_conflicts(client):
_, _, _, dX, dy, dw = _create_data('classification', output='array')

Expand Down Expand Up @@ -885,29 +890,6 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
assert_eq(preds_orig_local, preds_loaded_model_local)


def test_find_open_port_works(listen_port):
worker_ip = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((worker_ip, listen_port))
new_port = lgb.dask._find_open_port(
worker_ip=worker_ip,
local_listen_port=listen_port,
ports_to_skip=set()
)
assert listen_port < new_port < listen_port + 1000

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s_1:
s_1.bind((worker_ip, listen_port))
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s_2:
s_2.bind((worker_ip, listen_port + 1))
new_port = lgb.dask._find_open_port(
worker_ip=worker_ip,
local_listen_port=listen_port,
ports_to_skip=set()
)
assert listen_port + 1 < new_port < listen_port + 1000


def test_warns_and_continues_on_unrecognized_tree_learner(client):
X = da.random.random((1e3, 10))
y = da.random.random((1e3, 1))
Expand Down Expand Up @@ -1075,7 +1057,7 @@ def test_network_params_not_required_but_respected_if_given(client, task, output

# model 2 - machines given
n_workers = len(client.scheduler_info()['workers'])
open_ports = [_find_random_open_port() for _ in range(n_workers)]
open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
dask_model2 = dask_model_factory(
n_estimators=5,
num_leaves=5,
Expand Down Expand Up @@ -1143,7 +1125,7 @@ def test_machines_should_be_used_if_provided(task, output):
client.rebalance()

n_workers = len(client.scheduler_info()['workers'])
open_ports = [_find_random_open_port() for _ in range(n_workers)]
open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
dask_model = dask_model_factory(
n_estimators=5,
num_leaves=5,
Expand Down

0 comments on commit 0e57657

Please sign in to comment.