From 98e1535f0dde74c1cf6d2443b9b367cc05562348 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BB=A7=E7=9B=9B?= <xuye.qin@alibaba-inc.com> Date: Thu, 24 Mar 2022 13:52:39 +0800 Subject: [PATCH] Update tracker --- mars/learn/contrib/xgboost/start_tracker.py | 4 +- mars/learn/contrib/xgboost/tracker.py | 382 ++++++++++++++------ 2 files changed, 265 insertions(+), 121 deletions(-) diff --git a/mars/learn/contrib/xgboost/start_tracker.py b/mars/learn/contrib/xgboost/start_tracker.py index 98ddf7d431..feadbcbdb3 100644 --- a/mars/learn/contrib/xgboost/start_tracker.py +++ b/mars/learn/contrib/xgboost/start_tracker.py @@ -47,9 +47,9 @@ def execute(cls, ctx, op): env = {"DMLC_NUM_WORKER": op.n_workers} rabit_context = RabitTracker( - hostIP=ctx.get_local_host_ip(), nslave=op.n_workers + host_ip=ctx.get_local_host_ip(), n_workers=op.n_workers ) - env.update(rabit_context.slave_envs()) + env.update(rabit_context.worker_envs()) rabit_context.start(op.n_workers) thread = Thread(target=rabit_context.join) diff --git a/mars/learn/contrib/xgboost/tracker.py b/mars/learn/contrib/xgboost/tracker.py index cb1b77e929..1cd0afeef2 100644 --- a/mars/learn/contrib/xgboost/tracker.py +++ b/mars/learn/contrib/xgboost/tracker.py @@ -21,22 +21,29 @@ # pylint: disable=too-many-branches, too-many-statements, too-many-instance-attributes import socket import struct -import time import logging from threading import Thread +import argparse +import sys + +from typing import Dict, List, Tuple, Union, Optional, Set + +_RingMap = Dict[int, Tuple[int, int]] +_TreeMap = Dict[int, List[int]] logger = logging.getLogger(__name__) -class ExSocket(object): +class ExSocket: """ Extension of socket to handle recv and send of special data """ - def __init__(self, sock): + def __init__(self, sock: socket.socket) -> None: self.sock = sock - def recvall(self, nbytes): + def recvall(self, nbytes: int) -> bytes: + """Receive number of bytes.""" res = [] nread = 0 while nread < nbytes: @@ -45,56 +52,82 @@ def recvall(self, nbytes): res.append(chunk) return b"".join(res) - def recvint(self): + def recvint(self) -> int: + """Receive an integer of 32 bytes""" return struct.unpack("@i", self.recvall(4))[0] - def sendint(self, n): - self.sock.sendall(struct.pack("@i", n)) + def sendint(self, value: int) -> None: + """Send an integer of 32 bytes""" + self.sock.sendall(struct.pack("@i", value)) - def sendstr(self, s): - self.sendint(len(s)) - self.sock.sendall(s.encode()) + def sendstr(self, value: str) -> None: + """Send a Python string""" + self.sendint(len(value)) + self.sock.sendall(value.encode()) - def recvstr(self): + def recvstr(self) -> str: + """Receive a Python string""" slen = self.recvint() return self.recvall(slen).decode() # magic number used to verify existence of data -kMagic = 0xFF99 +MAGIC_NUM = 0xFF99 -def get_some_ip(host): +def get_some_ip(host: str) -> str: + """Get ip from host""" return socket.getaddrinfo(host, None)[0][4][0] -def get_family(addr): +def get_family(addr: str) -> int: + """Get network family from address.""" return socket.getaddrinfo(addr, None)[0][0] -class SlaveEntry(object): - def __init__(self, sock, s_addr): - slave = ExSocket(sock) - self.sock = slave +class WorkerEntry: + """Handler to each worker.""" + + def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]): + worker = ExSocket(sock) + self.sock = worker self.host = get_some_ip(s_addr[0]) - magic = slave.recvint() - assert magic == kMagic, f"invalid magic number={magic} from {self.host}" - slave.sendint(kMagic) - self.rank = slave.recvint() - self.world_size = slave.recvint() - self.jobid = slave.recvstr() - self.cmd = slave.recvstr() + magic = worker.recvint() + assert magic == MAGIC_NUM, f"invalid magic number={magic} from {self.host}" + worker.sendint(MAGIC_NUM) + self.rank = worker.recvint() + self.world_size = worker.recvint() + self.jobid = worker.recvstr() + self.cmd = worker.recvstr() self.wait_accept = 0 - self.port = None + self.port: Optional[int] = None + + def print(self, use_logger: bool) -> None: + """Execute the print command from worker.""" + msg = self.sock.recvstr() + # On dask we use print to avoid setting global verbosity. + if use_logger: + logger.info(msg.strip()) + else: + print(msg.strip(), flush=True) - def decide_rank(self, job_map): + def decide_rank(self, job_map: Dict[str, int]) -> int: + """Get the rank of current entry.""" if self.rank >= 0: return self.rank if self.jobid != "NULL" and self.jobid in job_map: return job_map[self.jobid] return -1 - def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map): + def assign_rank( + self, + rank: int, + wait_conn: Dict[int, "WorkerEntry"], + tree_map: _TreeMap, + parent_map: Dict[int, int], + ring_map: _RingMap, + ) -> List[int]: + """Assign the rank for current entry.""" self.rank = rank nnset = set(tree_map[rank]) rprev, rnext = ring_map[rank] @@ -119,6 +152,12 @@ def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map): self.sock.sendint(rnext) else: self.sock.sendint(-1) + + return self._get_remote(wait_conn, nnset) + + def _get_remote( + self, wait_conn: Dict[int, "WorkerEntry"], nnset: Set[int] + ) -> List[int]: while True: ngood = self.sock.recvint() goodset = set([]) @@ -134,7 +173,9 @@ def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map): self.sock.sendint(len(badset) - len(conset)) for r in conset: self.sock.sendstr(wait_conn[r].host) - self.sock.sendint(wait_conn[r].port) + port = wait_conn[r].port + assert port is not None + self.sock.sendint(port) self.sock.sendint(r) nerr = self.sock.recvint() if nerr != 0: @@ -152,72 +193,75 @@ def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map): return rmset -class RabitTracker(object): +class RabitTracker: """ tracker for rabit """ - def __init__(self, hostIP, nslave, port=9091, port_end=9999): - sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM) - for _port in range(port, port_end): - try: - sock.bind((hostIP, _port)) - self.port = _port - break - except socket.error as e: - if e.errno in [98, 48]: - continue - else: - raise + def __init__( + self, host_ip: str, n_workers: int, port: int = 0, use_logger: bool = False + ) -> None: + """A Python implementation of RABIT tracker. + Parameters + .......... + use_logger: + Use logging.info for tracker print command. When set to False, Python print + function is used instead. + """ + sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM) + sock.bind((host_ip, port)) + self.port = sock.getsockname()[1] sock.listen(256) self.sock = sock - self.hostIP = hostIP - self.thread = None - self.start_time = None - self.end_time = None - self.nslave = nslave - logger.info("start listen on %s:%d", hostIP, self.port) + self.host_ip = host_ip + self.thread: Optional[Thread] = None + self.n_workers = n_workers + self._use_logger = use_logger + logger.info("start listen on %s:%d", host_ip, self.port) - def __del__(self): - self.sock.close() + def __del__(self) -> None: + if hasattr(self, "sock"): + self.sock.close() @staticmethod - def get_neighbor(rank, nslave): + def _get_neighbor(rank: int, n_workers: int) -> List[int]: rank = rank + 1 ret = [] if rank > 1: ret.append(rank // 2 - 1) - if rank * 2 - 1 < nslave: + if rank * 2 - 1 < n_workers: ret.append(rank * 2 - 1) - if rank * 2 < nslave: + if rank * 2 < n_workers: ret.append(rank * 2) return ret - def slave_envs(self): + def worker_envs(self) -> Dict[str, Union[str, int]]: """ - get environment variables for slaves + get environment variables for workers can be passed in as args or envs """ - return {"DMLC_TRACKER_URI": self.hostIP, "DMLC_TRACKER_PORT": self.port} + return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port} - def get_tree(self, nslave): - tree_map = {} - parent_map = {} - for r in range(nslave): - tree_map[r] = self.get_neighbor(r, nslave) + def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]: + tree_map: _TreeMap = {} + parent_map: Dict[int, int] = {} + for r in range(n_workers): + tree_map[r] = self._get_neighbor(r, n_workers) parent_map[r] = (r + 1) // 2 - 1 return tree_map, parent_map - def find_share_ring(self, tree_map, parent_map, r): + def find_share_ring( + self, tree_map: _TreeMap, parent_map: Dict[int, int], rank: int + ) -> List[int]: """ get a ring structure that tends to share nodes with the tree - return a list starting from r + return a list starting from rank """ - nset = set(tree_map[r]) - cset = nset - set([parent_map[r]]) + nset = set(tree_map[rank]) + cset = nset - set([parent_map[rank]]) if not cset: - return [r] - rlst = [r] + return [rank] + rlst = [rank] cnt = 0 for v in cset: vlst = self.find_share_ring(tree_map, parent_map, v) @@ -227,84 +271,84 @@ def find_share_ring(self, tree_map, parent_map, r): rlst += vlst return rlst - def get_ring(self, tree_map, parent_map): + def get_ring(self, tree_map: _TreeMap, parent_map: Dict[int, int]) -> _RingMap: """ get a ring connection used to recover local data """ assert parent_map[0] == -1 rlst = self.find_share_ring(tree_map, parent_map, 0) assert len(rlst) == len(tree_map) - ring_map = {} - nslave = len(tree_map) - for r in range(nslave): - rprev = (r + nslave - 1) % nslave - rnext = (r + 1) % nslave + ring_map: _RingMap = {} + n_workers = len(tree_map) + for r in range(n_workers): + rprev = (r + n_workers - 1) % n_workers + rnext = (r + 1) % n_workers ring_map[rlst[r]] = (rlst[rprev], rlst[rnext]) return ring_map - def get_link_map(self, nslave): + def get_link_map(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int], _RingMap]: """ get the link map, this is a bit hacky, call for better algorithm to place similar nodes together """ - tree_map, parent_map = self.get_tree(nslave) + tree_map, parent_map = self._get_tree(n_workers) ring_map = self.get_ring(tree_map, parent_map) rmap = {0: 0} k = 0 - for i in range(nslave - 1): + for i in range(n_workers - 1): k = ring_map[k][1] rmap[k] = i + 1 - ring_map_ = {} - tree_map_ = {} - parent_map_ = {} + ring_map_: _RingMap = {} + tree_map_: _TreeMap = {} + parent_map_: Dict[int, int] = {} for k, v in ring_map.items(): ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]]) - for k, v in tree_map.items(): - tree_map_[rmap[k]] = [rmap[x] for x in v] - for k, v in parent_map.items(): + for k, tree_nodes in tree_map.items(): + tree_map_[rmap[k]] = [rmap[x] for x in tree_nodes] + for k, parent in parent_map.items(): if k != 0: - parent_map_[rmap[k]] = rmap[v] + parent_map_[rmap[k]] = rmap[parent] else: parent_map_[rmap[k]] = -1 return tree_map_, parent_map_, ring_map_ - def accept_slaves(self, nslave): + def accept_workers(self, n_workers: int) -> None: + """Wait for all workers to connect to the tracker.""" # set of nodes that finishes the job - shutdown = {} + shutdown: Dict[int, WorkerEntry] = {} # set of nodes that is waiting for connections - wait_conn = {} + wait_conn: Dict[int, WorkerEntry] = {} # maps job id to rank - job_map = {} + job_map: Dict[str, int] = {} # list of workers that is pending to be assigned rank - pending = [] + pending: List[WorkerEntry] = [] # lazy initialize tree_map tree_map = None - while len(shutdown) != nslave: + while len(shutdown) != n_workers: fd, s_addr = self.sock.accept() - s = SlaveEntry(fd, s_addr) + s = WorkerEntry(fd, s_addr) if s.cmd == "print": - msg = s.sock.recvstr() - logger.info(msg.strip()) + s.print(self._use_logger) continue if s.cmd == "shutdown": assert s.rank >= 0 and s.rank not in shutdown assert s.rank not in wait_conn shutdown[s.rank] = s - logger.debug("Receive %s signal from %d", s.cmd, s.rank) + logger.debug("Received %s signal from %d", s.cmd, s.rank) continue - assert s.cmd == "start" or s.cmd == "recover" - # lazily initialize the slaves + assert s.cmd in ("start", "recover") + # lazily initialize the workers if tree_map is None: assert s.cmd == "start" if s.world_size > 0: - nslave = s.world_size - tree_map, parent_map, ring_map = self.get_link_map(nslave) + n_workers = s.world_size + tree_map, parent_map, ring_map = self.get_link_map(n_workers) # set of nodes that is pending for getting up - todo_nodes = list(range(nslave)) + todo_nodes = list(range(n_workers)) else: - assert s.world_size == -1 or s.world_size == nslave + assert s.world_size in (-1, n_workers) if s.cmd == "recover": assert s.rank >= 0 @@ -323,37 +367,137 @@ def accept_slaves(self, nslave): if s.wait_accept > 0: wait_conn[rank] = s logger.debug( - "Receive %s signal from %s; assign rank %d", + "Received %s signal from %s; assign rank %d", s.cmd, s.host, s.rank, ) if not todo_nodes: - logger.info("@tracker All of %d nodes getting started", nslave) - self.start_time = time.time() + logger.info("@tracker All of %d nodes getting started", n_workers) else: s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) - logger.debug("Receive %s signal from %d", s.cmd, s.rank) + logger.debug("Received %s signal from %d", s.cmd, s.rank) if s.wait_accept > 0: wait_conn[rank] = s logger.info("@tracker All nodes finishes job") - self.end_time = time.time() - logger.info( - "@tracker %s secs between node start and job finish", - str(self.end_time - self.start_time), - ) - - def start(self, nslave): - def run(): - self.accept_slaves(nslave) - - self.thread = Thread(target=run, args=()) - self.thread.setDaemon(True) + + def start(self, n_workers: int) -> None: + """Start the tracker, it will wait for `n_workers` to connect.""" + + def run() -> None: + self.accept_workers(n_workers) + + self.thread = Thread(target=run, args=(), daemon=True) self.thread.start() - def join(self): - while self.thread.isAlive(): + def join(self) -> None: + """Wait for the tracker to finish.""" + while self.thread is not None and self.thread.is_alive(): self.thread.join(100) - def alive(self): - return self.thread.isAlive() + def alive(self) -> bool: + """Whether the tracker thread is alive""" + return self.thread is not None and self.thread.is_alive() + + +def get_host_ip(host_ip: Optional[str] = None) -> str: + """Get the IP address of current host. If `host_ip` is not none then it will be + returned as it's + """ + if host_ip is None or host_ip == "auto": + host_ip = "ip" + + if host_ip == "dns": + host_ip = socket.getfqdn() + elif host_ip == "ip": + from socket import gaierror + + try: + host_ip = socket.gethostbyname(socket.getfqdn()) + except gaierror: + logger.debug( + "gethostbyname(socket.getfqdn()) failed... trying on hostname()" + ) + host_ip = socket.gethostbyname(socket.gethostname()) + if host_ip.startswith("127."): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + # doesn't have to be reachable + s.connect(("10.255.255.255", 1)) + host_ip = s.getsockname()[0] + + assert host_ip is not None + return host_ip + + +def start_rabit_tracker(args: argparse.Namespace) -> None: + """Standalone function to start rabit tracker. + Parameters + ---------- + args: arguments to start the rabit tracker. + """ + envs = {"DMLC_NUM_WORKER": args.num_workers, "DMLC_NUM_SERVER": args.num_servers} + rabit = RabitTracker( + host_ip=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True + ) + envs.update(rabit.worker_envs()) + rabit.start(args.num_workers) + sys.stdout.write("DMLC_TRACKER_ENV_START\n") + # simply write configuration to stdout + for k, v in envs.items(): + sys.stdout.write(f"{k}={v}\n") + sys.stdout.write("DMLC_TRACKER_ENV_END\n") + sys.stdout.flush() + rabit.join() + + +def main() -> None: + """Main function if tracker is executed in standalone mode.""" + parser = argparse.ArgumentParser(description="Rabit Tracker start.") + parser.add_argument( + "--num-workers", + required=True, + type=int, + help="Number of worker process to be launched.", + ) + parser.add_argument( + "--num-servers", + default=0, + type=int, + help="Number of server process to be launched. Only used in PS jobs.", + ) + parser.add_argument( + "--host-ip", + default=None, + type=str, + help=( + "Host IP addressed, this is only needed " + + "if the host IP cannot be automatically guessed." + ), + ) + parser.add_argument( + "--log-level", + default="INFO", + type=str, + choices=["INFO", "DEBUG"], + help="Logging level of the logger.", + ) + args = parser.parse_args() + + fmt = "%(asctime)s %(levelname)s %(message)s" + if args.log_level == "INFO": + level = logging.INFO + elif args.log_level == "DEBUG": + level = logging.DEBUG + else: + raise RuntimeError(f"Unknown logging level {args.log_level}") + + logging.basicConfig(format=fmt, level=level) + + if args.num_servers == 0: + start_rabit_tracker(args) + else: + raise RuntimeError("Do not yet support start ps tracker in standalone mode.") + + +if __name__ == "__main__": + main()