-
Notifications
You must be signed in to change notification settings - Fork 373
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(xjx): cli in new pipeline (#160)
* Cli ditask * Import ditask in init * Add current path as default package path * Fix style * Add topology on ditask
- Loading branch information
Showing
5 changed files
with
135 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import click | ||
import os | ||
import sys | ||
import importlib | ||
import importlib.util | ||
from click.core import Context, Option | ||
|
||
from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__ | ||
from ding.framework import Parallel | ||
|
||
|
||
def print_version(ctx: Context, param: Option, value: bool) -> None: | ||
if not value or ctx.resilient_parsing: | ||
return | ||
click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__)) | ||
click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__)) | ||
ctx.exit() | ||
|
||
|
||
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) | ||
|
||
|
||
@click.command(context_settings=CONTEXT_SETTINGS) | ||
@click.option( | ||
'-v', | ||
'--version', | ||
is_flag=True, | ||
callback=print_version, | ||
expose_value=False, | ||
is_eager=True, | ||
help="Show package's version information." | ||
) | ||
@click.option('-p', '--package', type=str, help="Your code package path, could be a directory or a zip file.") | ||
@click.option('--parallel-workers', type=int, default=1, help="Parallel worker number, default: 1") | ||
@click.option( | ||
'--protocol', | ||
type=click.Choice(["tcp", "ipc"]), | ||
default="tcp", | ||
help="Network protocol in parallel mode, default: tcp" | ||
) | ||
@click.option( | ||
"--ports", | ||
type=str, | ||
default="50515", | ||
help="The port addresses that the tasks listen to, e.g. 50515,50516, default: 50515" | ||
) | ||
@click.option("--attach-to", type=str, help="The addresses to connect to.") | ||
@click.option("--address", type=str, help="The address to listen to (without port).") | ||
@click.option("--labels", type=str, help="Labels.") | ||
@click.option("--node-ids", type=str, help="Candidate node ids.") | ||
@click.option( | ||
"--topology", | ||
type=click.Choice(["alone", "mesh", "star"]), | ||
default="alone", | ||
help="Network topology, default: alone." | ||
) | ||
@click.option("-m", "--main", type=str, help="Main function of entry module.") | ||
def cli_ditask( | ||
package: str, main: str, parallel_workers: int, protocol: str, ports: str, attach_to: str, address: str, | ||
labels: str, node_ids: str, topology: str | ||
): | ||
# Parse entry point | ||
if not package: | ||
package = os.getcwd() | ||
sys.path.append(package) | ||
if main is None: | ||
mod_name = os.path.basename(package) | ||
mod_name, _ = os.path.splitext(mod_name) | ||
func_name = "main" | ||
else: | ||
mod_name, func_name = main.rsplit(".", 1) | ||
root_mod_name = mod_name.split(".", 1)[0] | ||
sys.path.append(os.path.join(package, root_mod_name)) | ||
mod = importlib.import_module(mod_name) | ||
main_func = getattr(mod, func_name) | ||
# Parse arguments | ||
ports = ports.split(",") | ||
ports = list(map(lambda i: int(i), ports)) | ||
ports = ports[0] if len(ports) == 1 else ports | ||
if attach_to: | ||
attach_to = attach_to.split(",") | ||
attach_to = list(map(lambda s: s.strip(), attach_to)) | ||
if labels: | ||
labels = labels.split(",") | ||
labels = set(map(lambda s: s.strip(), labels)) | ||
if node_ids: | ||
node_ids = node_ids.split(",") | ||
node_ids = list(map(lambda i: int(i), node_ids)) | ||
Parallel.runner( | ||
n_parallel_workers=parallel_workers, | ||
ports=ports, | ||
protocol=protocol, | ||
topology=topology, | ||
attach_to=attach_to, | ||
address=address, | ||
labels=labels, | ||
node_ids=node_ids | ||
)(main_func) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters