Skip to content

Commit

Permalink
feature(xjx): cli in new pipeline (#160)
Browse files Browse the repository at this point in the history
* Cli ditask

* Import ditask in init

* Add current path as default package path

* Fix style

* Add topology on ditask
  • Loading branch information
sailxjx authored Dec 23, 2021
1 parent 92d973c commit 954d310
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 11 deletions.
1 change: 1 addition & 0 deletions ding/entry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .cli import cli
from .cli_ditask import cli_ditask
from .serial_entry import serial_pipeline
from .serial_entry_onpolicy import serial_pipeline_onpolicy
from .serial_entry_offline import serial_pipeline_offline
Expand Down
98 changes: 98 additions & 0 deletions ding/entry/cli_ditask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import click
import os
import sys
import importlib
import importlib.util
from click.core import Context, Option

from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__
from ding.framework import Parallel


def print_version(ctx: Context, param: Option, value: bool) -> None:
if not value or ctx.resilient_parsing:
return
click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__))
click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__))
ctx.exit()


CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.option(
'-v',
'--version',
is_flag=True,
callback=print_version,
expose_value=False,
is_eager=True,
help="Show package's version information."
)
@click.option('-p', '--package', type=str, help="Your code package path, could be a directory or a zip file.")
@click.option('--parallel-workers', type=int, default=1, help="Parallel worker number, default: 1")
@click.option(
'--protocol',
type=click.Choice(["tcp", "ipc"]),
default="tcp",
help="Network protocol in parallel mode, default: tcp"
)
@click.option(
"--ports",
type=str,
default="50515",
help="The port addresses that the tasks listen to, e.g. 50515,50516, default: 50515"
)
@click.option("--attach-to", type=str, help="The addresses to connect to.")
@click.option("--address", type=str, help="The address to listen to (without port).")
@click.option("--labels", type=str, help="Labels.")
@click.option("--node-ids", type=str, help="Candidate node ids.")
@click.option(
"--topology",
type=click.Choice(["alone", "mesh", "star"]),
default="alone",
help="Network topology, default: alone."
)
@click.option("-m", "--main", type=str, help="Main function of entry module.")
def cli_ditask(
package: str, main: str, parallel_workers: int, protocol: str, ports: str, attach_to: str, address: str,
labels: str, node_ids: str, topology: str
):
# Parse entry point
if not package:
package = os.getcwd()
sys.path.append(package)
if main is None:
mod_name = os.path.basename(package)
mod_name, _ = os.path.splitext(mod_name)
func_name = "main"
else:
mod_name, func_name = main.rsplit(".", 1)
root_mod_name = mod_name.split(".", 1)[0]
sys.path.append(os.path.join(package, root_mod_name))
mod = importlib.import_module(mod_name)
main_func = getattr(mod, func_name)
# Parse arguments
ports = ports.split(",")
ports = list(map(lambda i: int(i), ports))
ports = ports[0] if len(ports) == 1 else ports
if attach_to:
attach_to = attach_to.split(",")
attach_to = list(map(lambda s: s.strip(), attach_to))
if labels:
labels = labels.split(",")
labels = set(map(lambda s: s.strip(), labels))
if node_ids:
node_ids = node_ids.split(",")
node_ids = list(map(lambda i: int(i), node_ids))
Parallel.runner(
n_parallel_workers=parallel_workers,
ports=ports,
protocol=protocol,
topology=topology,
attach_to=attach_to,
address=address,
labels=labels,
node_ids=node_ids
)(main_func)
43 changes: 33 additions & 10 deletions ding/framework/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import tempfile
import socket
from os import path
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union, Set
from threading import Thread
from pynng.nng import Bus0, Socket
from ding.utils.design_helper import SingletonMetaclass
Expand All @@ -30,10 +30,18 @@ def __init__(self) -> None:
self.attach_to = None
self.finished = False
self.node_id = None
self.labels = set()

def run(self, node_id: int, listen_to: str, attach_to: List[str] = None) -> None:
def run(
self,
node_id: int,
listen_to: str,
attach_to: Optional[List[str]] = None,
labels: Optional[Set[str]] = None
) -> None:
self.node_id = node_id
self.attach_to = attach_to = attach_to or []
self.labels = labels or set()
self._listener = Thread(
target=self.listen,
kwargs={
Expand All @@ -52,7 +60,9 @@ def runner(
protocol: str = "ipc",
address: Optional[str] = None,
ports: Optional[List[int]] = None,
topology: str = "mesh"
topology: str = "mesh",
labels: Optional[Set[str]] = None,
node_ids: Optional[List[int]] = None
) -> Callable:
"""
Overview:
Expand All @@ -66,6 +76,9 @@ def runner(
- topology (:obj:`str`): Network topology, includes:
`mesh` (default): fully connected between each other;
`star`: only connect to the first node;
`alone`: do not connect to any node, except the node attached to;
- labels (:obj:`Optional[Set[str]]`): Labels.
- node_ids (:obj:`Optional[List[int]]`): Candidate node ids.
Returns:
- _runner (:obj:`Callable`): The wrapper function for main.
"""
Expand All @@ -91,21 +104,29 @@ def cleanup_nodes():

atexit.register(cleanup_nodes)

def topology_network(node_id: int) -> List[str]:
def topology_network(i: int) -> List[str]:
if topology == "mesh":
return nodes[:node_id] + attach_to
return nodes[:i] + attach_to
elif topology == "star":
return nodes[:min(1, node_id)]
return nodes[:min(1, i)] + attach_to
elif topology == "alone":
return attach_to
else:
raise ValueError("Unknown topology: {}".format(topology))

params_group = []
for node_id in range(n_parallel_workers):
candidate_node_ids = node_ids or range(n_parallel_workers)
assert len(candidate_node_ids) == n_parallel_workers, \
"The number of workers must be the same as the number of node_ids, \
now there are {} workers and {} nodes"\
.format(n_parallel_workers, len(candidate_node_ids))
for i in range(n_parallel_workers):
runner_args = []
runner_kwargs = {
"node_id": node_id,
"listen_to": nodes[node_id],
"attach_to": topology_network(node_id) + attach_to
"node_id": candidate_node_ids[i],
"listen_to": nodes[i],
"attach_to": topology_network(i) + attach_to,
"labels": labels
}
params = [(runner_args, runner_kwargs), (main_process, args, kwargs)]
params_group.append(params)
Expand Down Expand Up @@ -151,6 +172,8 @@ def get_node_addrs(
elif protocol == "tcp":
address = address or Parallel.get_ip()
ports = ports or range(50515, 50515 + n_workers)
if isinstance(ports, int):
ports = range(ports, ports + n_workers)
assert len(ports) == n_workers, "The number of ports must be the same as the number of workers, \
now there are {} ports and {} workers".format(len(ports), n_workers)
nodes = ["tcp://{}:{}".format(address, port) for port in ports]
Expand Down
2 changes: 2 additions & 0 deletions ding/framework/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def init_labels(self):
if self.router.is_active:
self.labels.add("distributed")
self.labels.add("node.{}".format(self.router.node_id))
for label in self.router.labels:
self.labels.add(label)
else:
self.labels.add("standalone")

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
'kubernetes',
]
},
entry_points={'console_scripts': ['ding=ding.entry.cli:cli']},
entry_points={'console_scripts': ['ding=ding.entry.cli:cli', 'ditask=ding.entry.cli_ditask:cli_ditask']},
classifiers=[
'Development Status :: 5 - Production/Stable',
"Intended Audience :: Science/Research",
Expand Down

0 comments on commit 954d310

Please sign in to comment.