Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] [python-package] Search for available ports when setting up network (fixes #3753) #3766

Merged
merged 16 commits into from
Jan 15, 2021
Merged
122 changes: 100 additions & 22 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
It is based on dask-xgboost package.
"""
import logging
import socket
from collections import defaultdict
from typing import Dict, Iterable
from urllib.parse import urlparse

import numpy as np
import pandas as pd
from dask import array as da
from dask import dataframe as dd
from dask import delayed
from dask.distributed import default_client, get_worker, wait
from dask.distributed import Client, default_client, get_worker, wait

from .basic import _LIB, _safe_call
from .sklearn import LGBMClassifier, LGBMRegressor
Expand All @@ -23,33 +25,89 @@
logger = logging.getLogger(__name__)


def _parse_host_port(address):
parsed = urlparse(address)
return parsed.hostname, parsed.port
def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
"""Find an open port

This function tries to find a free port on the machine it's run on. It is intended to
be run once on each Dask worker, sequentially.

def _build_network_params(worker_addresses, local_worker_ip, local_listen_port, time_out):
"""Build network parameters suitable for LightGBM C backend.
Parameters
----------
worker_ip : str
IP address for the Dask worker.
local_listen_port : int
First port to try when searching for open ports.
ports_to_skip: Iterable[int]
An iterable of integers referring to ports that should be skipped. Since multiple Dask
workers can run on the same physical machine, this method may be called multiple times
on the same machine. ``ports_to_skip`` is used to ensure that LightGBM doesn't try to use
the same port for two worker processes running on the same machine.

Returns
-------
result : int
A free port on the machine referenced by ``worker_ip``.
"""
max_tries = 1000
out_port = None
found_port = False
for i in range(max_tries):
out_port = local_listen_port + i
try:
out_port = local_listen_port + i
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
if out_port in ports_to_skip:
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
continue
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((worker_ip, out_port))
s.listen()
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
found_port = True
break
# if unavailable, you'll get OSError: Address already in use
except OSError:
continue
if not found_port:
msg = "LightGBM tried %s:%d-%d and could not create a connection. Try setting local_listen_port to a different value."
raise RuntimeError(msg % (worker_ip, local_listen_port, out_port))
return out_port


def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], local_listen_port: int) -> Dict[str, int]:
"""Find an open port on each worker

LightGBM distributed training uses TCP sockets by default, and this method is used to
identify open ports on each worker so LightGBM can reliable create those sockets.

Returns a dictionary where keys are
worker addresses and values are an open port for LightGBM to use.
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
worker_addresses : iterable of str - collection of worker addresses in `<protocol>://<host>:port` format
local_worker_ip : str
client : dask.distributed.Client
Dask client.
worker_addresses : Iterable[str]
An iterable of addresses for workers in the cluster. These are strings of the form ```<protocol>://<host>:port```
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
local_listen_port : int
time_out : int
First port to try when searching for open ports.

Returns
-------
params: dict
result : Dict[str, int]
Dictionary where keys are worker addresses and values are an open port for LightGBM to use.
"""
addr_port_map = {addr: (local_listen_port + i) for i, addr in enumerate(worker_addresses)}
params = {
'machines': ','.join('%s:%d' % (_parse_host_port(addr)[0], port) for addr, port in addr_port_map.items()),
'local_listen_port': addr_port_map[local_worker_ip],
'time_out': time_out,
'num_machines': len(addr_port_map)
}
return params
lightgbm_ports = set()
worker_ip_to_port = {}
for worker_address in worker_addresses:
port = client.submit(
func=_find_open_port,
workers=[worker_address],
worker_ip=urlparse(worker_address).hostname,
local_listen_port=local_listen_port,
ports_to_skip=lightgbm_ports
).result()
lightgbm_ports.add(port)
worker_ip_to_port[worker_address] = port
Comment on lines +92 to +103
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to draw reviewers' attention to this block. This setup checks for ports to use sequentially. So it will first run _find_open_port() on worker 1's machine, then worker 2's machine, then work 3's, etc.

That means there is a small startup time introduced. I saw that this process took about 5 seconds per worker when I was running a cluster on AWS Fargate. So that means you might expect around 1 minute of startup time for a cluster with 20 workers.

I don't think this is too bad, and I think it's a fairly clean way to ensure that if you have multiple work processes on the same machine, they're assigned different ports in the network parameters

Copy link
Contributor

@ffineis ffineis Jan 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably overkill, but if you wanted to speed it up, maybe could group by IP address and apply the search in parallel across separate nodes (IP addresses)? So each the search for ports gathers as many ports as there are workers on that server.

def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], local_listen_port: int) -> Dict[str, int]:
    worker_ip_to_port = {}

    # group workers by ip addr
    worker_hostname_map = defaultdict(list())
    for worker_address in worker_addresses:
    	hostname = urlparse(worker_address).hostname
    	worker_hostname_map[hostname].append(worker_address)

    # run search for ports on groups of workers
   	lgbm_machines = list()
   	for hostname in worker_hostname_map:
   		machines_on_node = client.submit(
            func=_port_search_for_ip,
            worker_ip=hostname,
            desired_ports=len(worker_hostname_map[hostname]),
            local_listen_port=local_listen_port
        )

        # add more futures
        lgbm_machines.append(machines_on_node)

    # wait for search across nodes to complete.
    _ = wait(lgbm_machines)
    lgbm_machines = lgbm_machines.results()


def _port_search_for_ip(worker_ip: str, local_listen_port: int, n_desired_ports: int) -> Iterable[str]:
    max_tries = 1000
    out_ports = list()
    found_all_ports = False
    for i in range(max_tries):
        out_port = local_listen_port + i
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                s.bind((worker_ip, out_port))
                s.listen()

            out_ports.append(out_port)
            if len(out_ports) == n_desired_ports:
            	found_all_ports = True
            	break

        except OSError:
            continue

    if not found_all_ports:
        raise RuntimeError()

    return [hostname + ':' + str(x) for x in out_ports]

Otherwise, the contraints of lightgbm + dask are that: 1) lightgbm requires that each worker knows about all of the other workers while at the same time 2) prior to distributing work to each worker, we know which workers can listen for lightgbm traffic on which ports. Seems like your fix nails this!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably overkill, but if you wanted to speed it up, maybe could group by IP address and apply the search in parallel across separate nodes (IP addresses)? So each the search for ports gathers as many ports as there are workers on that server

That's a great idea, I like it! I think I'll push it to a later PR though, since it would take me some time to test and I really want to unblock you for #3708

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GH needs more emojis but if I could, I'd give you the prayer "bless up" hands

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha thank you

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just documented this in #3768, thanks for taking the time to type out that possible implementation. Super helpful.


return worker_ip_to_port


def _concat(seq):
Expand All @@ -63,9 +121,20 @@ def _concat(seq):
raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0])))


def _train_part(params, model_factory, list_of_parts, worker_addresses, return_model, local_listen_port=12400,
def _train_part(params, model_factory, list_of_parts, worker_address_to_port, return_model,
time_out=120, **kwargs):
network_params = _build_network_params(worker_addresses, get_worker().address, local_listen_port, time_out)
local_worker_address = get_worker().address
machine_list = ','.join([
'%s:%d' % (urlparse(worker_address).hostname, port)
for worker_address, port
in worker_address_to_port.items()
])
network_params = {
'machines': machine_list,
'local_listen_port': worker_address_to_port[local_worker_address],
'time_out': time_out,
'num_machines': len(worker_address_to_port)
}
params.update(network_params)

# Concatenate many parts into one
Expand Down Expand Up @@ -138,13 +207,22 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
'(%s), using "data" as default', params.get("tree_learner", None))
params['tree_learner'] = 'data'

# find an open port on each worker. note that multiple workers can run
# on the same machine, so this needs to ensure that each one gets its
# own port
local_listen_port = params.get('local_listen_port', 12400)
worker_address_to_port = _find_ports_for_workers(
client=client,
worker_addresses=list(worker_map.keys()),
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
local_listen_port=local_listen_port
)

# Tell each worker to train on the parts that it has locally
futures_classifiers = [client.submit(_train_part,
model_factory=model_factory,
params={**params, 'num_threads': worker_ncores[worker]},
list_of_parts=list_of_parts,
worker_addresses=list(worker_map.keys()),
local_listen_port=params.get('local_listen_port', 12400),
worker_address_to_port=worker_address_to_port,
time_out=params.get('time_out', 120),
return_model=(worker == master_worker),
**kwargs)
Expand Down
61 changes: 46 additions & 15 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding: utf-8
import os
import socket
import sys

import pytest
Expand Down Expand Up @@ -89,6 +90,27 @@ def test_classifier(output, centers, client, listen_port):
assert_eq(y, p2)


def test_training_does_not_fail_on_port_conflicts(client):
X, y, w, dX, dy, dw = _create_data('classification', output='array')
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 12400))
s.listen()

dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=12400
)
for i in range(5):
dask_classifier.fit(
X=dX,
y=dy,
sample_weight=dw,
client=client
)
assert dask_classifier.booster_


@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers)
def test_classifier_proba(output, centers, client, listen_port):
Expand Down Expand Up @@ -183,21 +205,30 @@ def test_regressor_local_predict(client, listen_port):
assert_eq(s1, s2)


def test_build_network_params():
workers_ips = [
'tcp://192.168.0.1:34545',
'tcp://192.168.0.2:34346',
'tcp://192.168.0.3:34347'
]

params = dlgbm._build_network_params(workers_ips, 'tcp://192.168.0.2:34346', 12400, 120)
exp_params = {
'machines': '192.168.0.1:12400,192.168.0.2:12401,192.168.0.3:12402',
'local_listen_port': 12401,
'num_machines': len(workers_ips),
'time_out': 120
}
assert exp_params == params
def test_find_open_port_works():
worker_ip = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((worker_ip, 12400))
s.listen()
new_port = dlgbm._find_open_port(
worker_ip=worker_ip,
local_listen_port=12400,
ports_to_skip=set()
)
assert new_port == 12401

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((worker_ip, 12400))
s.listen()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((worker_ip, 12401))
s.listen()
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
new_port = dlgbm._find_open_port(
worker_ip=worker_ip,
local_listen_port=12400,
ports_to_skip=set()
)
assert new_port == 12402


@gen_cluster(client=True, timeout=None)
Expand Down