Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Suport prefix url for nnimanager #3643

Merged
merged 5 commits into from
May 24, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nni/tools/nnictl/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, home_dir=NNI_HOME_DIR):
self.experiments = self.read_file()

def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED',
tag=[], pid=None, webuiUrl=[], logDir=''):
tag=[], pid=None, webuiUrl=[], logDir='', prefixUrl=None):
'''set {key:value} pairs to self.experiment'''
with self.lock:
self.experiments = self.read_file()
Expand All @@ -124,6 +124,7 @@ def add_experiment(self, expId, port, startTime, platform, experiment_name, endT
self.experiments[expId]['pid'] = pid
self.experiments[expId]['webuiUrl'] = webuiUrl
self.experiments[expId]['logDir'] = str(logDir)
self.experiments[expId]['prefixUrl'] = prefixUrl
self.write_file()

def update_experiment(self, expId, key, value):
Expand Down
90 changes: 48 additions & 42 deletions nni/tools/nnictl/launcher.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions nni/tools/nnictl/nnictl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def parse_args():
parser_start.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file')
parser_start.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', type=int, help='the port of restful server')
parser_start.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_start.add_argument('--url_prefix', '-u', dest='url_prefix', help=' set prefix url')
parser_start.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal')
parser_start.set_defaults(func=create_experiment)

Expand Down
49 changes: 28 additions & 21 deletions nni/tools/nnictl/nnictl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@
from .command_utils import check_output_command, kill_command
from .ssh_utils import create_ssh_sftp_client, remove_remote_directory

def get_experiment_time(port):
def get_experiment_time(port, prefixUrl):
'''get the startTime and endTime of an experiment'''
response = rest_get(experiment_url(port), REST_TIME_OUT)
response = rest_get(experiment_url(port, prefixUrl), REST_TIME_OUT)
if response and check_response(response):
content = json.loads(response.text)
return content.get('startTime'), content.get('endTime')
return None, None

def get_experiment_status(port):
def get_experiment_status(port, prefixUrl):
'''get the status of an experiment'''
result, response = check_rest_server_quick(port)
result, response = check_rest_server_quick(port, prefixUrl)
if result:
return json.loads(response.text).get('status')
return None
Expand Down Expand Up @@ -202,7 +202,8 @@ def check_rest(args):
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
rest_port = experiments_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port)
prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl')
running, _ = check_rest_server_quick(rest_port, prefix_url)
if running:
print_normal('Restful server is running...')
else:
Expand Down Expand Up @@ -245,13 +246,14 @@ def final_metric_data_cmp(lhs, rhs):
experiments_dict = experiments_config.get_all_experiments()
experiment_id = get_config_filename(args)
rest_port = experiments_dict.get(experiment_id).get('port')
prefix_url = experiments_dict.get(experiment_id).get('prefixUrl')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
running, response = check_rest_server_quick(rest_port)
running, response = check_rest_server_quick(rest_port, prefix_url)
if running:
response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT)
response = rest_get(trial_jobs_url(rest_port, prefix_url), REST_TIME_OUT)
if response and check_response(response):
content = json.loads(response.text)
if args.head:
Expand All @@ -278,13 +280,14 @@ def trial_kill(args):
experiments_dict = experiments_config.get_all_experiments()
experiment_id = get_config_filename(args)
rest_port = experiments_dict.get(experiment_id).get('port')
prefix_url = experiments_dict.get(experiment_id).get('prefixUrl')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
running, _ = check_rest_server_quick(rest_port)
running, _ = check_rest_server_quick(rest_port, prefix_url)
if running:
response = rest_delete(trial_job_id_url(rest_port, args.trial_id), REST_TIME_OUT)
response = rest_delete(trial_job_id_url(rest_port, args.trial_id, prefix_url), REST_TIME_OUT)
if response and check_response(response):
print(response.text)
return True
Expand All @@ -311,13 +314,14 @@ def list_experiment(args):
experiments_dict = experiments_config.get_all_experiments()
experiment_id = get_config_filename(args)
rest_port = experiments_dict.get(experiment_id).get('port')
prefix_url = experiments_dict.get(experiment_id).get('prefixUrl')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
running, _ = check_rest_server_quick(rest_port)
running, _ = check_rest_server_quick(rest_port, prefix_url)
if running:
response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
response = rest_get(experiment_url(rest_port, prefix_url), REST_TIME_OUT)
if response and check_response(response):
content = convert_time_stamp_to_date(json.loads(response.text))
print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':')))
Expand All @@ -333,7 +337,8 @@ def experiment_status(args):
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
rest_port = experiments_dict.get(get_config_filename(args)).get('port')
result, response = check_rest_server_quick(rest_port)
prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl')
result, response = check_rest_server_quick(rest_port, prefix_url)
if not result:
print_normal('Restful server is not running...')
else:
Expand Down Expand Up @@ -399,14 +404,15 @@ def log_trial(args):
experiments_dict = experiments_config.get_all_experiments()
experiment_id = get_config_filename(args)
rest_port = experiments_dict.get(experiment_id).get('port')
prefix_url = experiments_dict.get(experiment_id).get('prefixUrl')
rest_pid = experiments_dict.get(experiment_id).get('pid')
experiment_config = Config(experiment_id, experiments_dict.get(experiment_id).get('logDir')).get_config()
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
running, response = check_rest_server_quick(rest_port)
running, response = check_rest_server_quick(rest_port, prefix_url)
if running:
response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT)
response = rest_get(trial_jobs_url(rest_port, prefix_url), REST_TIME_OUT)
if response and check_response(response):
content = json.loads(response.text)
for trial in content:
Expand Down Expand Up @@ -661,9 +667,9 @@ def show_experiment_info():
experiments_dict[key].get('platform'), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['startTime'] / 1000)) if isinstance(experiments_dict[key]['startTime'], int) else experiments_dict[key]['startTime'], \
get_time_interval(experiments_dict[key]['startTime'], experiments_dict[key]['endTime'])))
print(TRIAL_MONITOR_HEAD)
running, response = check_rest_server_quick(experiments_dict[key]['port'])
running, response = check_rest_server_quick(experiments_dict[key]['port'], experiments_dict[key]['prefixUrl'])
if running:
response = rest_get(trial_jobs_url(experiments_dict[key]['port']), REST_TIME_OUT)
response = rest_get(trial_jobs_url(experiments_dict[key]['port'], experiments_dict[key]['prefixUrl']), REST_TIME_OUT)
if response and check_response(response):
content = json.loads(response.text)
for index, value in enumerate(content):
Expand All @@ -672,7 +678,7 @@ def show_experiment_info():
content[index].get('endTime'), content[index].get('status')))
print(TRIAL_MONITOR_TAIL)

def set_monitor(auto_exit, time_interval, port=None, pid=None):
def set_monitor(auto_exit, time_interval, port=None, pid=None, prefixUrl=None):
'''set the experiment monitor engine'''
while True:
try:
Expand All @@ -683,7 +689,7 @@ def set_monitor(auto_exit, time_interval, port=None, pid=None):
update_experiment()
show_experiment_info()
if auto_exit:
status = get_experiment_status(port)
status = get_experiment_status(port, prefixUrl)
if status in ['DONE', 'ERROR', 'STOPPED']:
print_normal('Experiment status is {0}.'.format(status))
print_normal('Stopping experiment...')
Expand Down Expand Up @@ -724,20 +730,21 @@ def groupby_trial_id(intermediate_results):
experiments_dict = experiments_config.get_all_experiments()
experiment_id = get_config_filename(args)
rest_port = experiments_dict.get(experiment_id).get('port')
prefix_url = experiments_dict.get(experiment_id).get('prefixUrl')
rest_pid = experiments_dict.get(experiment_id).get('pid')

if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
running, response = check_rest_server_quick(rest_port)
running, response = check_rest_server_quick(rest_port, prefix_url)
if not running:
print_error('Restful server is not running')
return
response = rest_get(export_data_url(rest_port), 20)
response = rest_get(export_data_url(rest_port, prefix_url), 20)
if response is not None and check_response(response):
content = json.loads(response.text)
if args.intermediate:
intermediate_results_response = rest_get(metric_data_url(rest_port), REST_TIME_OUT)
intermediate_results_response = rest_get(metric_data_url(rest_port, prefix_url), REST_TIME_OUT)
if not intermediate_results_response or not check_response(intermediate_results_response):
print_error('Error getting intermediate results.')
return
Expand Down
8 changes: 4 additions & 4 deletions nni/tools/nnictl/rest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def rest_delete(url, timeout, show_error=False):
print_error(exception)
return None

def check_rest_server(rest_port):
def check_rest_server(rest_port, prefixUrl):
'''Check if restful server is ready'''
retry_count = 20
for _ in range(retry_count):
response = rest_get(check_status_url(rest_port), REST_TIME_OUT)
response = rest_get(check_status_url(rest_port, prefixUrl), REST_TIME_OUT)
if response:
if response.status_code == 200:
return True, response
Expand All @@ -63,9 +63,9 @@ def check_rest_server(rest_port):
time.sleep(1)
return False, response

def check_rest_server_quick(rest_port):
def check_rest_server_quick(rest_port, prefixUrl):
'''Check if restful server is ready, only check once'''
response = rest_get(check_status_url(rest_port), 5)
response = rest_get(check_status_url(rest_port, prefixUrl), 5)
if response and response.status_code == 200:
return True, response
return False, None
Expand Down
15 changes: 9 additions & 6 deletions nni/tools/nnictl/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,14 @@ def update_experiment_profile(args, key, value):
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
rest_port = experiments_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port)
prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl')
running, _ = check_rest_server_quick(rest_port, prefix_url)
if running:
response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
response = rest_get(experiment_url(rest_port, prefix_url), REST_TIME_OUT)
if response and check_response(response):
experiment_profile = json.loads(response.text)
experiment_profile['params'][key] = value
response = rest_put(experiment_url(rest_port)+get_query_type(key), json.dumps(experiment_profile), REST_TIME_OUT)
response = rest_put(experiment_url(rest_port, prefix_url)+get_query_type(key), json.dumps(experiment_profile), REST_TIME_OUT)
if response and check_response(response):
return response
else:
Expand Down Expand Up @@ -121,11 +122,12 @@ def import_data(args):
experiments_dict = Experiments().get_all_experiments()
experiment_id = get_config_filename(args)
rest_port = experiments_dict.get(experiment_id).get('port')
prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
running, _ = check_rest_server_quick(rest_port)
running, _ = check_rest_server_quick(rest_port, prefix_url)
if not running:
print_error('Restful server is not running')
return
Expand All @@ -141,9 +143,10 @@ def import_data_to_restful_server(args, content):
'''call restful server to import data to the experiment'''
experiments_dict = Experiments().get_all_experiments()
rest_port = experiments_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port)
prefix_url = experiments_dict.get(get_config_filename(args)).get('prefixUrl')
running, _ = check_rest_server_quick(rest_port, prefix_url)
if running:
response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT)
response = rest_post(import_data_url(rest_port, prefix_url), content, REST_TIME_OUT)
if response and check_response(response):
return response
else:
Expand Down
47 changes: 27 additions & 20 deletions nni/tools/nnictl/url_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import socket
import psutil
import re

BASE_URL = 'http://localhost'

Expand All @@ -24,55 +25,61 @@

METRIC_DATA_API = '/metric-data'

def metric_data_url(port):
def path_validation(path):
assert re.match("^[A-Za-z0-9_-]*$", path), "prefix url is invalid."

def formatURLPath(path):
return '' if path is None else '/{0}'.format(path)

def metric_data_url(port,prefix):
'''get metric_data url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, METRIC_DATA_API)
return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), METRIC_DATA_API)

def check_status_url(port):
def check_status_url(port,prefix):
'''get check_status url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CHECK_STATUS_API)
return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), CHECK_STATUS_API)


def cluster_metadata_url(port):
def cluster_metadata_url(port,prefix):
'''get cluster_metadata_url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CLUSTER_METADATA_API)
return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), CLUSTER_METADATA_API)


def import_data_url(port):
def import_data_url(port,prefix):
'''get import_data_url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, IMPORT_DATA_API)
return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), IMPORT_DATA_API)


def experiment_url(port):
def experiment_url(port,prefix):
'''get experiment_url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, EXPERIMENT_API)
return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), EXPERIMENT_API)


def trial_jobs_url(port):
def trial_jobs_url(port,prefix):
'''get trial_jobs url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, TRIAL_JOBS_API)
return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), TRIAL_JOBS_API)


def trial_job_id_url(port, job_id):
def trial_job_id_url(port, job_id,prefix):
'''get trial_jobs with id url'''
return '{0}:{1}{2}{3}/{4}'.format(BASE_URL, port, API_ROOT_URL, TRIAL_JOBS_API, job_id)
return '{0}:{1}{2}{3}{4}/{5}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), TRIAL_JOBS_API, job_id)


def export_data_url(port):
def export_data_url(port,prefix):
'''get export_data url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, EXPORT_DATA_API)
return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), EXPORT_DATA_API)


def tensorboard_url(port):
def tensorboard_url(port,prefix):
'''get tensorboard url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, TENSORBOARD_API)
return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, formatURLPath(prefix), TENSORBOARD_API)


def get_local_urls(port):
def get_local_urls(port,prefix):
'''get urls of local machine'''
url_list = []
for _, info in psutil.net_if_addrs().items():
for addr in info:
if socket.AddressFamily.AF_INET == addr.family:
url_list.append('http://{}:{}'.format(addr.address, port))
url_list.append('http://{0}:{1}{2}'.format(addr.address, port, formatURLPath(prefix)))
return url_list
Loading