Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix bugs and lints in nnictl #3712

Merged
merged 2 commits into from
Jun 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion nni/tools/nnictl/algo_management.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import importlib
import json
from nni.tools.package_utils import read_registerd_algo_meta, get_registered_algo_meta, \
Expand Down
6 changes: 3 additions & 3 deletions nni/tools/nnictl/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import socket
import string
import random
import yaml
import psutil
import filelock
import glob
from colorama import Fore
import filelock
import psutil
import yaml

from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO

Expand Down
4 changes: 2 additions & 2 deletions nni/tools/nnictl/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import os

import netifaces
from schema import And, Optional, Or, Regex, Schema, SchemaError
from nni.tools.package_utils import (
create_validator_instance,
get_all_builtin_names,
get_registered_algo_meta,
)
from schema import And, Optional, Or, Regex, Schema, SchemaError

from .common_utils import get_yml_content, print_warning
from .constants import SCHEMA_PATH_ERROR, SCHEMA_RANGE_ERROR, SCHEMA_TYPE_ERROR
Expand Down Expand Up @@ -625,7 +625,7 @@ def validate_frameworkcontroller_trial_config(self, experiment_config):
raise SchemaError("""If no taskRoles are specified a valid custom frameworkcontroller config should
be set using the configPath attribute in frameworkcontrollerConfig!""")
config_content = get_yml_content(experiment_config.get('frameworkcontrollerConfig').get('configPath'))
if not config_content.get('spec').get('taskRoles') or not len(config_content.get('spec').get('taskRoles')):
if not config_content.get('spec').get('taskRoles') or not config_content.get('spec').get('taskRoles'):
raise SchemaError('Invalid frameworkcontroller config! No taskRoles were specified!')
if not config_content.get('spec').get('taskRoles')[0].get('task'):
raise SchemaError('Invalid frameworkcontroller config! No task was specified for taskRole!')
Expand Down
5 changes: 1 addition & 4 deletions nni/tools/nnictl/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
# Licensed under the MIT license.

import os
import json_tricks
import shutil
import sqlite3
import time
import json_tricks
from .constants import NNI_HOME_DIR
from .command_utils import print_error
from .common_utils import get_file_lock

def config_v0_to_v1(config: dict) -> dict:
Expand Down
19 changes: 11 additions & 8 deletions nni/tools/nnictl/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
import nni_node # pylint: disable=import-error
from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_response
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, setPrefixUrl, formatURLPath
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, set_prefix_url
from .config_utils import Config, Experiments
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \
detect_port, get_user
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, detect_port, get_user

from .constants import NNI_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER
from .command_utils import check_output_command, kill_command
Expand Down Expand Up @@ -84,7 +83,7 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log
cmds += ['--foreground', 'true']
if url_prefix:
_validate_prefix_path(url_prefix)
setPrefixUrl(url_prefix)
set_prefix_url(url_prefix)
cmds += ['--url_prefix', url_prefix]

stdout_full_path, stderr_full_path = get_log_path(experiment_id)
Expand Down Expand Up @@ -167,7 +166,8 @@ def set_V1_common_config(experiment_config, port, config_file_name):
response = rest_put(cluster_metadata_url(port), json.dumps({'version_check': version_check}), REST_TIME_OUT)
validate_response(response, config_file_name)
if experiment_config.get('logCollection'):
response = rest_put(cluster_metadata_url(port), json.dumps({'log_collection': experiment_config.get('logCollection')}), REST_TIME_OUT)
data = json.dumps({'log_collection': experiment_config.get('logCollection')})
response = rest_put(cluster_metadata_url(port), data, REST_TIME_OUT)
validate_response(response, config_file_name)

def setNNIManagerIp(experiment_config, port, config_file_name):
Expand Down Expand Up @@ -229,7 +229,8 @@ def set_frameworkcontroller_config(experiment_config, port, config_file_name):

def set_shared_storage(experiment_config, port, config_file_name):
if 'sharedStorage' in experiment_config:
response = rest_put(cluster_metadata_url(port), json.dumps({'shared_storage_config': experiment_config['sharedStorage']}), REST_TIME_OUT)
data = json.dumps({'shared_storage_config': experiment_config['sharedStorage']})
response = rest_put(cluster_metadata_url(port), data, REST_TIME_OUT)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
Expand Down Expand Up @@ -485,7 +486,10 @@ def _validate_v2(config, path):
print_error(f'Config V2 validation failed: {repr(e)}')

def _validate_prefix_path(path):
assert re.match("^[A-Za-z0-9_-]*$", path), "prefix url is invalid."
assert not path.startswith('/'), 'URL prefix should not start with "/".'
parts = path.split('/')
valid = all(re.match('^[A-Za-z0-9_-]*$', part) for part in parts)
assert valid, 'URL prefix should only contain letter, number, underscore, and hyphen.'

def create_experiment(args):
'''start a new experiment'''
Expand All @@ -504,7 +508,6 @@ def create_experiment(args):
config_v1 = config_yml
else:
schema = 2
from nni.experiment.config import convert
config_v2 = convert.to_v2(config_yml).json()
else:
config_v2 = _validate_v2(config_yml, config_path)
Expand Down
3 changes: 1 addition & 2 deletions nni/tools/nnictl/nnictl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import sys
import json
import time
import re
import shutil
import subprocess
from functools import cmp_to_key
Expand Down Expand Up @@ -528,7 +527,7 @@ def experiment_clean(args):
print_warning('platform {0} clean up not supported yet.'.format(platform))
exit(0)
# clean local data
local_base_dir = experiments_config[experiment_id]['logDir']
local_base_dir = experiments_config.experiments[experiment_id]['logDir']
if not local_base_dir:
local_base_dir = NNI_HOME_DIR
local_experiment_dir = os.path.join(local_base_dir, experiment_id)
Expand Down
8 changes: 4 additions & 4 deletions nni/tools/nnictl/url_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@

METRIC_DATA_API = '/metric-data'

def formatURLPath(path):
return API_ROOT_URL if path is None else '/{0}{1}'.format(path, API_ROOT_URL)
def format_url_path(path):
return API_ROOT_URL if path is None else f'/{path}{API_ROOT_URL}'

def setPrefixUrl(prefix_path):
def set_prefix_url(prefix_path):
global API_ROOT_URL
API_ROOT_URL = formatURLPath(prefix_path)
API_ROOT_URL = format_url_path(prefix_path)

def metric_data_url(port):
'''get metric_data url'''
Expand Down