From 98e1535f0dde74c1cf6d2443b9b367cc05562348 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E7=BB=A7=E7=9B=9B?= <xuye.qin@alibaba-inc.com>
Date: Thu, 24 Mar 2022 13:52:39 +0800
Subject: [PATCH] Update tracker

---
 mars/learn/contrib/xgboost/start_tracker.py |   4 +-
 mars/learn/contrib/xgboost/tracker.py       | 382 ++++++++++++++------
 2 files changed, 265 insertions(+), 121 deletions(-)

diff --git a/mars/learn/contrib/xgboost/start_tracker.py b/mars/learn/contrib/xgboost/start_tracker.py
index 98ddf7d431..feadbcbdb3 100644
--- a/mars/learn/contrib/xgboost/start_tracker.py
+++ b/mars/learn/contrib/xgboost/start_tracker.py
@@ -47,9 +47,9 @@ def execute(cls, ctx, op):
 
         env = {"DMLC_NUM_WORKER": op.n_workers}
         rabit_context = RabitTracker(
-            hostIP=ctx.get_local_host_ip(), nslave=op.n_workers
+            host_ip=ctx.get_local_host_ip(), n_workers=op.n_workers
         )
-        env.update(rabit_context.slave_envs())
+        env.update(rabit_context.worker_envs())
 
         rabit_context.start(op.n_workers)
         thread = Thread(target=rabit_context.join)
diff --git a/mars/learn/contrib/xgboost/tracker.py b/mars/learn/contrib/xgboost/tracker.py
index cb1b77e929..1cd0afeef2 100644
--- a/mars/learn/contrib/xgboost/tracker.py
+++ b/mars/learn/contrib/xgboost/tracker.py
@@ -21,22 +21,29 @@
 # pylint: disable=too-many-branches, too-many-statements, too-many-instance-attributes
 import socket
 import struct
-import time
 import logging
 from threading import Thread
+import argparse
+import sys
+
+from typing import Dict, List, Tuple, Union, Optional, Set
+
+_RingMap = Dict[int, Tuple[int, int]]
+_TreeMap = Dict[int, List[int]]
 
 logger = logging.getLogger(__name__)
 
 
-class ExSocket(object):
+class ExSocket:
     """
     Extension of socket to handle recv and send of special data
     """
 
-    def __init__(self, sock):
+    def __init__(self, sock: socket.socket) -> None:
         self.sock = sock
 
-    def recvall(self, nbytes):
+    def recvall(self, nbytes: int) -> bytes:
+        """Receive number of bytes."""
         res = []
         nread = 0
         while nread < nbytes:
@@ -45,56 +52,82 @@ def recvall(self, nbytes):
             res.append(chunk)
         return b"".join(res)
 
-    def recvint(self):
+    def recvint(self) -> int:
+        """Receive an integer of 32 bytes"""
         return struct.unpack("@i", self.recvall(4))[0]
 
-    def sendint(self, n):
-        self.sock.sendall(struct.pack("@i", n))
+    def sendint(self, value: int) -> None:
+        """Send an integer of 32 bytes"""
+        self.sock.sendall(struct.pack("@i", value))
 
-    def sendstr(self, s):
-        self.sendint(len(s))
-        self.sock.sendall(s.encode())
+    def sendstr(self, value: str) -> None:
+        """Send a Python string"""
+        self.sendint(len(value))
+        self.sock.sendall(value.encode())
 
-    def recvstr(self):
+    def recvstr(self) -> str:
+        """Receive a Python string"""
         slen = self.recvint()
         return self.recvall(slen).decode()
 
 
 # magic number used to verify existence of data
-kMagic = 0xFF99
+MAGIC_NUM = 0xFF99
 
 
-def get_some_ip(host):
+def get_some_ip(host: str) -> str:
+    """Get ip from host"""
     return socket.getaddrinfo(host, None)[0][4][0]
 
 
-def get_family(addr):
+def get_family(addr: str) -> int:
+    """Get network family from address."""
     return socket.getaddrinfo(addr, None)[0][0]
 
 
-class SlaveEntry(object):
-    def __init__(self, sock, s_addr):
-        slave = ExSocket(sock)
-        self.sock = slave
+class WorkerEntry:
+    """Handler to each worker."""
+
+    def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
+        worker = ExSocket(sock)
+        self.sock = worker
         self.host = get_some_ip(s_addr[0])
-        magic = slave.recvint()
-        assert magic == kMagic, f"invalid magic number={magic} from {self.host}"
-        slave.sendint(kMagic)
-        self.rank = slave.recvint()
-        self.world_size = slave.recvint()
-        self.jobid = slave.recvstr()
-        self.cmd = slave.recvstr()
+        magic = worker.recvint()
+        assert magic == MAGIC_NUM, f"invalid magic number={magic} from {self.host}"
+        worker.sendint(MAGIC_NUM)
+        self.rank = worker.recvint()
+        self.world_size = worker.recvint()
+        self.jobid = worker.recvstr()
+        self.cmd = worker.recvstr()
         self.wait_accept = 0
-        self.port = None
+        self.port: Optional[int] = None
+
+    def print(self, use_logger: bool) -> None:
+        """Execute the print command from worker."""
+        msg = self.sock.recvstr()
+        # On dask we use print to avoid setting global verbosity.
+        if use_logger:
+            logger.info(msg.strip())
+        else:
+            print(msg.strip(), flush=True)
 
-    def decide_rank(self, job_map):
+    def decide_rank(self, job_map: Dict[str, int]) -> int:
+        """Get the rank of current entry."""
         if self.rank >= 0:
             return self.rank
         if self.jobid != "NULL" and self.jobid in job_map:
             return job_map[self.jobid]
         return -1
 
-    def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
+    def assign_rank(
+        self,
+        rank: int,
+        wait_conn: Dict[int, "WorkerEntry"],
+        tree_map: _TreeMap,
+        parent_map: Dict[int, int],
+        ring_map: _RingMap,
+    ) -> List[int]:
+        """Assign the rank for current entry."""
         self.rank = rank
         nnset = set(tree_map[rank])
         rprev, rnext = ring_map[rank]
@@ -119,6 +152,12 @@ def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
             self.sock.sendint(rnext)
         else:
             self.sock.sendint(-1)
+
+        return self._get_remote(wait_conn, nnset)
+
+    def _get_remote(
+        self, wait_conn: Dict[int, "WorkerEntry"], nnset: Set[int]
+    ) -> List[int]:
         while True:
             ngood = self.sock.recvint()
             goodset = set([])
@@ -134,7 +173,9 @@ def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
             self.sock.sendint(len(badset) - len(conset))
             for r in conset:
                 self.sock.sendstr(wait_conn[r].host)
-                self.sock.sendint(wait_conn[r].port)
+                port = wait_conn[r].port
+                assert port is not None
+                self.sock.sendint(port)
                 self.sock.sendint(r)
             nerr = self.sock.recvint()
             if nerr != 0:
@@ -152,72 +193,75 @@ def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
             return rmset
 
 
-class RabitTracker(object):
+class RabitTracker:
     """
     tracker for rabit
     """
 
-    def __init__(self, hostIP, nslave, port=9091, port_end=9999):
-        sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
-        for _port in range(port, port_end):
-            try:
-                sock.bind((hostIP, _port))
-                self.port = _port
-                break
-            except socket.error as e:
-                if e.errno in [98, 48]:
-                    continue
-                else:
-                    raise
+    def __init__(
+        self, host_ip: str, n_workers: int, port: int = 0, use_logger: bool = False
+    ) -> None:
+        """A Python implementation of RABIT tracker.
+        Parameters
+        ..........
+        use_logger:
+            Use logging.info for tracker print command.  When set to False, Python print
+            function is used instead.
+        """
+        sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
+        sock.bind((host_ip, port))
+        self.port = sock.getsockname()[1]
         sock.listen(256)
         self.sock = sock
-        self.hostIP = hostIP
-        self.thread = None
-        self.start_time = None
-        self.end_time = None
-        self.nslave = nslave
-        logger.info("start listen on %s:%d", hostIP, self.port)
+        self.host_ip = host_ip
+        self.thread: Optional[Thread] = None
+        self.n_workers = n_workers
+        self._use_logger = use_logger
+        logger.info("start listen on %s:%d", host_ip, self.port)
 
-    def __del__(self):
-        self.sock.close()
+    def __del__(self) -> None:
+        if hasattr(self, "sock"):
+            self.sock.close()
 
     @staticmethod
-    def get_neighbor(rank, nslave):
+    def _get_neighbor(rank: int, n_workers: int) -> List[int]:
         rank = rank + 1
         ret = []
         if rank > 1:
             ret.append(rank // 2 - 1)
-        if rank * 2 - 1 < nslave:
+        if rank * 2 - 1 < n_workers:
             ret.append(rank * 2 - 1)
-        if rank * 2 < nslave:
+        if rank * 2 < n_workers:
             ret.append(rank * 2)
         return ret
 
-    def slave_envs(self):
+    def worker_envs(self) -> Dict[str, Union[str, int]]:
         """
-        get environment variables for slaves
+        get environment variables for workers
         can be passed in as args or envs
         """
-        return {"DMLC_TRACKER_URI": self.hostIP, "DMLC_TRACKER_PORT": self.port}
+        return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port}
 
-    def get_tree(self, nslave):
-        tree_map = {}
-        parent_map = {}
-        for r in range(nslave):
-            tree_map[r] = self.get_neighbor(r, nslave)
+    def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
+        tree_map: _TreeMap = {}
+        parent_map: Dict[int, int] = {}
+        for r in range(n_workers):
+            tree_map[r] = self._get_neighbor(r, n_workers)
             parent_map[r] = (r + 1) // 2 - 1
         return tree_map, parent_map
 
-    def find_share_ring(self, tree_map, parent_map, r):
+    def find_share_ring(
+        self, tree_map: _TreeMap, parent_map: Dict[int, int], rank: int
+    ) -> List[int]:
         """
         get a ring structure that tends to share nodes with the tree
-        return a list starting from r
+        return a list starting from rank
         """
-        nset = set(tree_map[r])
-        cset = nset - set([parent_map[r]])
+        nset = set(tree_map[rank])
+        cset = nset - set([parent_map[rank]])
         if not cset:
-            return [r]
-        rlst = [r]
+            return [rank]
+        rlst = [rank]
         cnt = 0
         for v in cset:
             vlst = self.find_share_ring(tree_map, parent_map, v)
@@ -227,84 +271,84 @@ def find_share_ring(self, tree_map, parent_map, r):
             rlst += vlst
         return rlst
 
-    def get_ring(self, tree_map, parent_map):
+    def get_ring(self, tree_map: _TreeMap, parent_map: Dict[int, int]) -> _RingMap:
         """
         get a ring connection used to recover local data
         """
         assert parent_map[0] == -1
         rlst = self.find_share_ring(tree_map, parent_map, 0)
         assert len(rlst) == len(tree_map)
-        ring_map = {}
-        nslave = len(tree_map)
-        for r in range(nslave):
-            rprev = (r + nslave - 1) % nslave
-            rnext = (r + 1) % nslave
+        ring_map: _RingMap = {}
+        n_workers = len(tree_map)
+        for r in range(n_workers):
+            rprev = (r + n_workers - 1) % n_workers
+            rnext = (r + 1) % n_workers
             ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
         return ring_map
 
-    def get_link_map(self, nslave):
+    def get_link_map(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int], _RingMap]:
         """
         get the link map, this is a bit hacky, call for better algorithm
         to place similar nodes together
         """
-        tree_map, parent_map = self.get_tree(nslave)
+        tree_map, parent_map = self._get_tree(n_workers)
         ring_map = self.get_ring(tree_map, parent_map)
         rmap = {0: 0}
         k = 0
-        for i in range(nslave - 1):
+        for i in range(n_workers - 1):
             k = ring_map[k][1]
             rmap[k] = i + 1
 
-        ring_map_ = {}
-        tree_map_ = {}
-        parent_map_ = {}
+        ring_map_: _RingMap = {}
+        tree_map_: _TreeMap = {}
+        parent_map_: Dict[int, int] = {}
         for k, v in ring_map.items():
             ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
-        for k, v in tree_map.items():
-            tree_map_[rmap[k]] = [rmap[x] for x in v]
-        for k, v in parent_map.items():
+        for k, tree_nodes in tree_map.items():
+            tree_map_[rmap[k]] = [rmap[x] for x in tree_nodes]
+        for k, parent in parent_map.items():
             if k != 0:
-                parent_map_[rmap[k]] = rmap[v]
+                parent_map_[rmap[k]] = rmap[parent]
             else:
                 parent_map_[rmap[k]] = -1
         return tree_map_, parent_map_, ring_map_
 
-    def accept_slaves(self, nslave):
+    def accept_workers(self, n_workers: int) -> None:
+        """Wait for all workers to connect to the tracker."""
         # set of nodes that finishes the job
-        shutdown = {}
+        shutdown: Dict[int, WorkerEntry] = {}
         # set of nodes that is waiting for connections
-        wait_conn = {}
+        wait_conn: Dict[int, WorkerEntry] = {}
         # maps job id to rank
-        job_map = {}
+        job_map: Dict[str, int] = {}
         # list of workers that is pending to be assigned rank
-        pending = []
+        pending: List[WorkerEntry] = []
         # lazy initialize tree_map
         tree_map = None
 
-        while len(shutdown) != nslave:
+        while len(shutdown) != n_workers:
             fd, s_addr = self.sock.accept()
-            s = SlaveEntry(fd, s_addr)
+            s = WorkerEntry(fd, s_addr)
             if s.cmd == "print":
-                msg = s.sock.recvstr()
-                logger.info(msg.strip())
+                s.print(self._use_logger)
                 continue
             if s.cmd == "shutdown":
                 assert s.rank >= 0 and s.rank not in shutdown
                 assert s.rank not in wait_conn
                 shutdown[s.rank] = s
-                logger.debug("Receive %s signal from %d", s.cmd, s.rank)
+                logger.debug("Received %s signal from %d", s.cmd, s.rank)
                 continue
-            assert s.cmd == "start" or s.cmd == "recover"
-            # lazily initialize the slaves
+            assert s.cmd in ("start", "recover")
+            # lazily initialize the workers
             if tree_map is None:
                 assert s.cmd == "start"
                 if s.world_size > 0:
-                    nslave = s.world_size
-                tree_map, parent_map, ring_map = self.get_link_map(nslave)
+                    n_workers = s.world_size
+                tree_map, parent_map, ring_map = self.get_link_map(n_workers)
                 # set of nodes that is pending for getting up
-                todo_nodes = list(range(nslave))
+                todo_nodes = list(range(n_workers))
             else:
-                assert s.world_size == -1 or s.world_size == nslave
+                assert s.world_size in (-1, n_workers)
             if s.cmd == "recover":
                 assert s.rank >= 0
 
@@ -323,37 +367,137 @@ def accept_slaves(self, nslave):
                         if s.wait_accept > 0:
                             wait_conn[rank] = s
                         logger.debug(
-                            "Receive %s signal from %s; assign rank %d",
+                            "Received %s signal from %s; assign rank %d",
                             s.cmd,
                             s.host,
                             s.rank,
                         )
                 if not todo_nodes:
-                    logger.info("@tracker All of %d nodes getting started", nslave)
-                    self.start_time = time.time()
+                    logger.info("@tracker All of %d nodes getting started", n_workers)
             else:
                 s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
-                logger.debug("Receive %s signal from %d", s.cmd, s.rank)
+                logger.debug("Received %s signal from %d", s.cmd, s.rank)
                 if s.wait_accept > 0:
                     wait_conn[rank] = s
         logger.info("@tracker All nodes finishes job")
-        self.end_time = time.time()
-        logger.info(
-            "@tracker %s secs between node start and job finish",
-            str(self.end_time - self.start_time),
-        )
-
-    def start(self, nslave):
-        def run():
-            self.accept_slaves(nslave)
-
-        self.thread = Thread(target=run, args=())
-        self.thread.setDaemon(True)
+
+    def start(self, n_workers: int) -> None:
+        """Start the tracker, it will wait for `n_workers` to connect."""
+
+        def run() -> None:
+            self.accept_workers(n_workers)
+
+        self.thread = Thread(target=run, args=(), daemon=True)
         self.thread.start()
 
-    def join(self):
-        while self.thread.isAlive():
+    def join(self) -> None:
+        """Wait for the tracker to finish."""
+        while self.thread is not None and self.thread.is_alive():
             self.thread.join(100)
 
-    def alive(self):
-        return self.thread.isAlive()
+    def alive(self) -> bool:
+        """Whether the tracker thread is alive"""
+        return self.thread is not None and self.thread.is_alive()
+
+
+def get_host_ip(host_ip: Optional[str] = None) -> str:
+    """Get the IP address of current host.  If `host_ip` is not none then it will be
+    returned as it's
+    """
+    if host_ip is None or host_ip == "auto":
+        host_ip = "ip"
+
+    if host_ip == "dns":
+        host_ip = socket.getfqdn()
+    elif host_ip == "ip":
+        from socket import gaierror
+
+        try:
+            host_ip = socket.gethostbyname(socket.getfqdn())
+        except gaierror:
+            logger.debug(
+                "gethostbyname(socket.getfqdn()) failed... trying on hostname()"
+            )
+            host_ip = socket.gethostbyname(socket.gethostname())
+        if host_ip.startswith("127."):
+            s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+            # doesn't have to be reachable
+            s.connect(("10.255.255.255", 1))
+            host_ip = s.getsockname()[0]
+
+    assert host_ip is not None
+    return host_ip
+
+
+def start_rabit_tracker(args: argparse.Namespace) -> None:
+    """Standalone function to start rabit tracker.
+    Parameters
+    ----------
+    args: arguments to start the rabit tracker.
+    """
+    envs = {"DMLC_NUM_WORKER": args.num_workers, "DMLC_NUM_SERVER": args.num_servers}
+    rabit = RabitTracker(
+        host_ip=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True
+    )
+    envs.update(rabit.worker_envs())
+    rabit.start(args.num_workers)
+    sys.stdout.write("DMLC_TRACKER_ENV_START\n")
+    # simply write configuration to stdout
+    for k, v in envs.items():
+        sys.stdout.write(f"{k}={v}\n")
+    sys.stdout.write("DMLC_TRACKER_ENV_END\n")
+    sys.stdout.flush()
+    rabit.join()
+
+
+def main() -> None:
+    """Main function if tracker is executed in standalone mode."""
+    parser = argparse.ArgumentParser(description="Rabit Tracker start.")
+    parser.add_argument(
+        "--num-workers",
+        required=True,
+        type=int,
+        help="Number of worker process to be launched.",
+    )
+    parser.add_argument(
+        "--num-servers",
+        default=0,
+        type=int,
+        help="Number of server process to be launched. Only used in PS jobs.",
+    )
+    parser.add_argument(
+        "--host-ip",
+        default=None,
+        type=str,
+        help=(
+            "Host IP addressed, this is only needed "
+            + "if the host IP cannot be automatically guessed."
+        ),
+    )
+    parser.add_argument(
+        "--log-level",
+        default="INFO",
+        type=str,
+        choices=["INFO", "DEBUG"],
+        help="Logging level of the logger.",
+    )
+    args = parser.parse_args()
+
+    fmt = "%(asctime)s %(levelname)s %(message)s"
+    if args.log_level == "INFO":
+        level = logging.INFO
+    elif args.log_level == "DEBUG":
+        level = logging.DEBUG
+    else:
+        raise RuntimeError(f"Unknown logging level {args.log_level}")
+
+    logging.basicConfig(format=fmt, level=level)
+
+    if args.num_servers == 0:
+        start_rabit_tracker(args)
+    else:
+        raise RuntimeError("Do not yet support start ps tracker in standalone mode.")
+
+
+if __name__ == "__main__":
+    main()