From a3a32d5bf718c189cc17b353a24e8af58a6f67d2 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Wed, 18 Nov 2020 07:21:40 +0800 Subject: [PATCH 01/24] support heterogeneeous --- nni/tools/nnictl/config_schema.py | 9 +- nni/tools/nnictl/launcher.py | 41 +++++++ ts/nni_manager/main.ts | 9 +- .../rest_server/restValidationSchemas.ts | 3 + .../common/trialConfigMetadataKey.ts | 1 + .../reusable/channels/amlCommandChannel.ts | 2 +- .../channels/heterogenousCommandChannel.ts | 64 +++++++++++ .../reusable/channels/webCommandChannel.ts | 2 +- .../training_service/reusable/environment.ts | 5 +- .../environments/amlEnvironmentService.ts | 60 +++++----- .../heterogenousEnvironmentService.ts | 104 ++++++++++++++++++ .../environments/remoteEnvironmentService.ts | 2 +- .../reusable/routerTrainingService.ts | 9 +- 13 files changed, 274 insertions(+), 37 deletions(-) create mode 100644 ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts create mode 100644 ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts diff --git a/nni/tools/nnictl/config_schema.py b/nni/tools/nnictl/config_schema.py index d320163595..9b0aeb6ad0 100644 --- a/nni/tools/nnictl/config_schema.py +++ b/nni/tools/nnictl/config_schema.py @@ -124,7 +124,7 @@ def validate(self, data): Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')), Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999), 'trainingServicePlatform': setChoice( - 'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'), + 'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'heterogeneous'), Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'), Optional('multiPhase'): setType('multiPhase', bool), Optional('multiThread'): setType('multiThread', bool), @@ -262,6 +262,12 @@ def validate(self, data): } } +heterogeneous_config_schema = { + 'heterogeneousConfig': { + 'trainingServicePlatforms': setType('trainingServicePlatforms', str), + } +} + kubeflow_trial_schema = { 'trial': { 'codeDir': setPathCheck('codeDir'), @@ -412,6 +418,7 @@ def validate(self, data): 'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}), 'aml': Schema({**common_schema, **aml_trial_schema, **aml_config_schema}), 'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}), + 'heterogeneous': Schema({**common_schema, **common_trial_schema, **heterogeneous_config_schema}), } diff --git a/nni/tools/nnictl/launcher.py b/nni/tools/nnictl/launcher.py index 576954e335..7ec0c77b15 100644 --- a/nni/tools/nnictl/launcher.py +++ b/nni/tools/nnictl/launcher.py @@ -297,6 +297,31 @@ def set_aml_config(experiment_config, port, config_file_name): #set trial_config return set_trial_config(experiment_config, port, config_file_name), err_message +def set_heterogeneous_config(experiment_config, port, config_file_name): + '''set heterogeneous configuration''' + heterogeneous_config_data = dict() + heterogeneous_config_data['heterogeneous_config'] = experiment_config['heterogeneousConfig'] + platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms'].split(',') + for platform in platform_list: + if platform === 'aml': + heterogeneous_config_data['aml_config'] = experiment_config['amlConfig'] + elif platform === 'remote': + heterogeneous_config_data['remote_config'] = experiment_config['remoteConfig'] + response = rest_put(cluster_metadata_url(port), json.dumps(aml_config_data), REST_TIME_OUT) + err_message = None + if not response or not response.status_code == 200: + if response is not None: + err_message = response.text + _, stderr_full_path = get_log_path(config_file_name) + with open(stderr_full_path, 'a+') as fout: + fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':'))) + return False, err_message + result, message = setNNIManagerIp(experiment_config, port, config_file_name) + if not result: + return result, message + #set trial_config + return set_trial_config(experiment_config, port, config_file_name), err_message + def set_experiment(experiment_config, mode, port, config_file_name): '''Call startExperiment (rest POST /experiment) with yaml file content''' request_data = dict() @@ -378,6 +403,20 @@ def set_experiment(experiment_config, mode, port, config_file_name): {'key': 'aml_config', 'value': experiment_config['amlConfig']}) request_data['clusterMetaData'].append( {'key': 'trial_config', 'value': experiment_config['trial']}) + elif experiment_config['trainingServicePlatform'] == 'heterogeneous': + request_data['clusterMetaData'].append( + {'key': 'heterogeneous_config', 'value': experiment_config['heterogeneousConfig']}) + platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms'].split(',') + for platform in platform_list: + if platform === 'aml': + request_data['clusterMetaData'].append( + {'key': 'aml_config', 'value': experiment_config['amlConfig']}) + elif platform === 'remote': + request_data['clusterMetaData'].append( + {'key': 'remote_config', 'value': experiment_config['remoteConfig']}) + response = rest_put(cluster_metadata_url(port), json.dumps(aml_config_data), REST_TIME_OUT) + request_data['clusterMetaData'].append( + {'key': 'trial_config', 'value': experiment_config['trial']}) response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True) if check_response(response): return response @@ -409,6 +448,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res config_result, err_msg = set_dlts_config(experiment_config, port, config_file_name) elif platform == 'aml': config_result, err_msg = set_aml_config(experiment_config, port, config_file_name) + elif platform == 'heterogeneous': + config_result, err_msg = set_heterogeneous_config(experiment_config, port, config_file_name) else: raise Exception(ERROR_INFO % 'Unsupported platform!') exit(1) diff --git a/ts/nni_manager/main.ts b/ts/nni_manager/main.ts index 86a7a2583a..ce7a5d90c7 100644 --- a/ts/nni_manager/main.ts +++ b/ts/nni_manager/main.ts @@ -25,6 +25,7 @@ import { RouterTrainingService } from './training_service/reusable/routerTrainin import { PAIYarnTrainingService } from './training_service/pai/paiYarn/paiYarnTrainingService'; import { DLTSTrainingService } from './training_service/dlts/dltsTrainingService'; + function initStartupInfo( startExpMode: string, resumeExperimentId: string, basePort: number, platform: string, logDirectory: string, experimentLogLevel: string, readonly: boolean): void { @@ -66,6 +67,10 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN Container.bind(TrainingService) .to(RouterTrainingService) .scope(Scope.Singleton); + } else if (platformMode === 'heterogeneous') { + Container.bind(TrainingService) + .to(RouterTrainingService) + .scope(Scope.Singleton); } else { throw new Error(`Error: unsupported mode: ${platformMode}`); } @@ -94,7 +99,7 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN function usage(): void { console.info('usage: node main.js --port --mode \ - --start_mode --experiment_id --foreground '); + --start_mode --experiment_id --foreground '); } const strPort: string = parseArg(['--port', '-p']); @@ -114,7 +119,7 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals const port: number = parseInt(strPort, 10); const mode: string = parseArg(['--mode', '-m']); -if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'].includes(mode)) { +if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'heterogeneous'].includes(mode)) { console.log(`FATAL: unknown mode: ${mode}`); usage(); process.exit(1); diff --git a/ts/nni_manager/rest_server/restValidationSchemas.ts b/ts/nni_manager/rest_server/restValidationSchemas.ts index 1b33925a78..cac1bc44d2 100644 --- a/ts/nni_manager/rest_server/restValidationSchemas.ts +++ b/ts/nni_manager/rest_server/restValidationSchemas.ts @@ -164,6 +164,9 @@ export namespace ValidationSchemas { maxTrialNumPerGpu: joi.number(), useActiveGpu: joi.boolean() }), + heterogeneous_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase + trainingServicePlatforms: joi.string().min(1), + }), nni_manager_ip: joi.object({ // eslint-disable-line @typescript-eslint/camelcase nniManagerIp: joi.string().min(1) }), diff --git a/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts b/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts index 395acbeca4..ebf8ec366d 100644 --- a/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts +++ b/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts @@ -17,6 +17,7 @@ export enum TrialConfigMetadataKey { PAI_YARN_CLUSTER_CONFIG = 'pai_yarn_config', PAI_CLUSTER_CONFIG = 'pai_config', KUBEFLOW_CLUSTER_CONFIG = 'kubeflow_config', + HETEROGENOUS_CLUSTER_CONFIG = 'heterogenous_config', NNI_MANAGER_IP = 'nni_manager_ip', FRAMEWORKCONTROLLER_CLUSTER_CONFIG = 'frameworkcontroller_config', DLTS_CLUSTER_CONFIG = 'dlts_config', diff --git a/ts/nni_manager/training_service/reusable/channels/amlCommandChannel.ts b/ts/nni_manager/training_service/reusable/channels/amlCommandChannel.ts index 5816a9c780..531cc40417 100644 --- a/ts/nni_manager/training_service/reusable/channels/amlCommandChannel.ts +++ b/ts/nni_manager/training_service/reusable/channels/amlCommandChannel.ts @@ -39,7 +39,7 @@ export class AMLCommandChannel extends CommandChannel { ]); } - protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise { + public async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise { this.sendQueues.push([environment, message]); } diff --git a/ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts b/ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts new file mode 100644 index 0000000000..adacec1cc3 --- /dev/null +++ b/ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +'use strict'; + +import { EventEmitter } from "events"; +import { delay } from "../../../common/utils"; +import { AMLEnvironmentInformation } from '../aml/amlConfig'; +import { CommandChannel, RunnerConnection } from "../commandChannel"; +import { Channel, EnvironmentInformation } from "../environment"; +import { AMLCommandChannel } from "./amlCommandChannel"; +import { WebCommandChannel } from "./webCommandChannel"; + +class HeterogenousRunnerConnection extends RunnerConnection { +} + +export class HeterogenousCommandChannel extends CommandChannel{ + private stopping: boolean = false; + private amlCommandChannel: AMLCommandChannel | undefined; + private webCommandChannel: WebCommandChannel | undefined; + + public get channelName(): Channel { + return "heterogenous"; + } + + public constructor(commandEmitter: EventEmitter) { + super(commandEmitter); + } + + public async config(_key: string, _value: any): Promise { + // do nothing + } + + public async start(): Promise { + this.amlCommandChannel?.start(); + this.webCommandChannel?.start(); + } + + public async stop(): Promise { + this.stopping = true; + } + + public async run(): Promise { + this.amlCommandChannel?.start(); + this.webCommandChannel?.run(); + } + + protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise { + switch (environment.platform) { + case 'aml': + this.amlCommandChannel?.sendCommandInternal(environment, message); + break; + case 'remote': + this.webCommandChannel?.sendCommandInternal(environment, message); + break; + default: + throw new Error(`Heterogenous not support platform: '${environment.platform}'`); + } + } + + protected createRunnerConnection(environment: EnvironmentInformation): RunnerConnection { + return new HeterogenousRunnerConnection(environment); + } +} diff --git a/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts b/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts index 3bd9c504aa..3fab37f491 100644 --- a/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts +++ b/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts @@ -70,7 +70,7 @@ export class WebCommandChannel extends CommandChannel { // do nothing } - protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise { + public async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise { if (this.webSocketServer === undefined) { throw new Error(`WebCommandChannel: uninitialized!`) } diff --git a/ts/nni_manager/training_service/reusable/environment.ts b/ts/nni_manager/training_service/reusable/environment.ts index b7ea4c33af..fdbb9fdb56 100644 --- a/ts/nni_manager/training_service/reusable/environment.ts +++ b/ts/nni_manager/training_service/reusable/environment.ts @@ -12,7 +12,7 @@ import { CommandChannel } from "./commandChannel"; export type EnvironmentStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED'; -export type Channel = "web" | "file" | "aml" | "ut"; +export type Channel = "web" | "file" | "aml" | "ut" | "heterogenous"; export class TrialGpuSummary { @@ -74,6 +74,9 @@ export class EnvironmentInformation { // user can specify how to use GPU resource for an environment, like local and remote. public maxTrialNumberPerGpu?: number; public useActiveGpu?: boolean; + + // the running mode for trial jobs, including local, remote, aml, pai etc. + public platform: string = ""; constructor(id: string, name: string, envId?: string) { this.log = getLogger(); diff --git a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts index 6f058f31af..05e4636552 100644 --- a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts @@ -69,37 +69,43 @@ export class AMLEnvironmentService extends EnvironmentService { this.log.debug(`AML not proccessed metadata key: '${key}', value: '${value}'`); } } + + public async refreshEnvironment(environment: EnvironmentInformation): Promise { + const amlClient = (environment as AMLEnvironmentInformation).amlClient; + if (!amlClient) { + return Promise.reject('AML client not initialized!'); + } + const newStatus = await amlClient.updateStatus(environment.status); + switch (newStatus.toUpperCase()) { + case 'WAITING': + case 'QUEUED': + environment.setStatus('WAITING'); + break; + case 'RUNNING': + environment.setStatus('RUNNING'); + break; + case 'COMPLETED': + case 'SUCCEEDED': + environment.setStatus('SUCCEEDED'); + break; + case 'FAILED': + environment.setStatus('FAILED'); + return Promise.reject(`AML: job ${environment.envId} is failed!`); + case 'STOPPED': + case 'STOPPING': + environment.setStatus('USER_CANCELED'); + break; + default: + environment.setStatus('UNKNOWN'); + } + } public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise { + const tasks: Promise[] = []; environments.forEach(async (environment) => { - const amlClient = (environment as AMLEnvironmentInformation).amlClient; - if (!amlClient) { - return Promise.reject('AML client not initialized!'); - } - const newStatus = await amlClient.updateStatus(environment.status); - switch (newStatus.toUpperCase()) { - case 'WAITING': - case 'QUEUED': - environment.setStatus('WAITING'); - break; - case 'RUNNING': - environment.setStatus('RUNNING'); - break; - case 'COMPLETED': - case 'SUCCEEDED': - environment.setStatus('SUCCEEDED'); - break; - case 'FAILED': - environment.setStatus('FAILED'); - return Promise.reject(`AML: job ${environment.envId} is failed!`); - case 'STOPPED': - case 'STOPPING': - environment.setStatus('USER_CANCELED'); - break; - default: - environment.setStatus('UNKNOWN'); - } + tasks.push(this.refreshEnvironment(environment)); }); + await Promise.all(tasks); } public async startEnvironment(environment: EnvironmentInformation): Promise { diff --git a/ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts new file mode 100644 index 0000000000..e27c74d479 --- /dev/null +++ b/ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +'use strict'; + +import { EventEmitter } from "events"; +import * as fs from 'fs'; +import * as path from 'path'; +import * as component from '../../../common/component'; +import { getLogger, Logger } from '../../../common/log'; +import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey'; +import { HeterogenousCommandChannel } from '../channels/heterogenousCommandChannel'; +import { CommandChannel } from "../commandChannel"; +import { EnvironmentInformation, EnvironmentService } from '../environment'; +import { AMLEnvironmentService } from './amlEnvironmentService'; +import { RemoteEnvironmentService } from './remoteEnvironmentService'; +import { randomSelect } from '../../../common/utils'; + + +/** + * Collector PAI jobs info from PAI cluster, and update pai job status locally + */ +@component.Singleton +export class HeteroGenousEnvironmentService extends EnvironmentService { + + private amlEnvironmentService: AMLEnvironmentService; + private remoteEnvironmentService: RemoteEnvironmentService; + + private readonly log: Logger = getLogger(); + + constructor() { + super(); + this.amlEnvironmentService = new AMLEnvironmentService(); + this.remoteEnvironmentService = new RemoteEnvironmentService(); + } + + public get hasStorageService(): boolean { + return false; + } + + public createCommandChannel(commandEmitter: EventEmitter): CommandChannel { + return new HeterogenousCommandChannel(commandEmitter); + } + + public async config(key: string, value: string): Promise { + switch (key) { + case TrialConfigMetadataKey.AML_CLUSTER_CONFIG: + this.amlEnvironmentService.config(key, value); + break; + case TrialConfigMetadataKey.MACHINE_LIST: + this.remoteEnvironmentService.config(key, value); + break; + case TrialConfigMetadataKey.TRIAL_CONFIG: + this.amlEnvironmentService.config(key, value); + this.remoteEnvironmentService.config(key, value); + default: + this.log.debug(`Heterogenous not support metadata key: '${key}', value: '${value}'`); + } + } + + public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise { + const tasks: Promise[] = []; + environments.forEach(async (environment) => { + switch (environment.platform) { + case 'aml': + tasks.push(this.amlEnvironmentService.refreshEnvironment(environment)); + break; + case 'remote': + tasks.push(this.remoteEnvironmentService.refreshEnvironment(environment)); + break; + default: + throw new Error(`Heterogenous not support platform: '${environment.platform}'`); + } + }); + await Promise.all(tasks); + } + + public async startEnvironment(environment: EnvironmentInformation): Promise { + const number = randomSelect([0, 1]); + switch (number) { + case 0: + environment.platform = 'aml'; + this.amlEnvironmentService.startEnvironment(environment); + break; + case 1: + environment.platform = 'remote'; + this.remoteEnvironmentService.startEnvironment(environment); + break; + } + } + + public async stopEnvironment(environment: EnvironmentInformation): Promise { + switch (environment.platform) { + case 'aml': + this.amlEnvironmentService.stopEnvironment(environment); + break; + case 'remote': + this.remoteEnvironmentService.stopEnvironment(environment); + break; + default: + throw new Error(`Heterogenous not support platform '${environment.platform}'`); + } + } +} diff --git a/ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts index fb5b3c789f..6526a7e7ab 100644 --- a/ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts @@ -135,7 +135,7 @@ export class RemoteEnvironmentService extends EnvironmentService { await executor.allowPermission(true, nniRootDir); } - private async refreshEnvironment(environment: EnvironmentInformation): Promise { + public async refreshEnvironment(environment: EnvironmentInformation): Promise { const executor = await this.getExecutor(environment.id); const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`; const runnerReturnCodeFilePath: string = `${environment.runnerWorkingFolder}/code`; diff --git a/ts/nni_manager/training_service/reusable/routerTrainingService.ts b/ts/nni_manager/training_service/reusable/routerTrainingService.ts index 9a079f6765..3bf47dd2ce 100644 --- a/ts/nni_manager/training_service/reusable/routerTrainingService.ts +++ b/ts/nni_manager/training_service/reusable/routerTrainingService.ts @@ -18,6 +18,7 @@ import { OpenPaiEnvironmentService } from './environments/openPaiEnvironmentServ import { AMLEnvironmentService } from './environments/amlEnvironmentService'; import { RemoteEnvironmentService } from './environments/remoteEnvironmentService'; import { MountedStorageService } from './storages/mountedStorageService'; +import { HeteroGenousEnvironmentService } from './environments/heterogenousEnvironmentService'; import { StorageService } from './storageService'; import { TrialDispatcher } from './trialDispatcher'; import { RemoteConfig } from './remote/remoteConfig'; @@ -161,9 +162,11 @@ class RouterTrainingService implements TrainingService { this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`); this.internalTrainingService = component.get(RemoteMachineTrainingService); } - } else { - this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`); - this.metaDataCache.set(key, value); + } else if (key === TrialConfigMetadataKey.HETEROGENOUS_CLUSTER_CONFIG){ + this.internalTrainingService = component.get(TrialDispatcher); + Container.bind(EnvironmentService) + .to(HeteroGenousEnvironmentService) + .scope(Scope.Singleton); } } else { await this.internalTrainingService.setClusterMetadata(key, value); From 2fc266e0b95300fcc498edbee2e91e508cbc5efb Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Wed, 25 Nov 2020 00:44:21 +0800 Subject: [PATCH 02/24] add local and pai --- nni/runtime/platform/local.py | 5 +- nni/tools/nnictl/config_schema.py | 3 +- nni/tools/nnictl/launcher.py | 16 +- ts/nni_manager/main.ts | 6 +- .../rest_server/restValidationSchemas.ts | 3 +- .../training_service/common/util.ts | 10 ++ .../local/localTrainingService.ts | 6 +- .../channels/heterogenousCommandChannel.ts | 26 +++- .../environments/amlEnvironmentService.ts | 2 +- .../heterogenousEnvironmentService.ts | 37 ++++- .../environments/localEnvironmentService.ts | 138 ++++++++++++++++++ .../reusable/routerTrainingService.ts | 21 +++ 12 files changed, 244 insertions(+), 29 deletions(-) create mode 100644 ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts diff --git a/nni/runtime/platform/local.py b/nni/runtime/platform/local.py index 5d8124d3ff..b8e6ffad62 100644 --- a/nni/runtime/platform/local.py +++ b/nni/runtime/platform/local.py @@ -21,7 +21,8 @@ os.makedirs(_outputdir) _nni_platform = trial_env_vars.NNI_PLATFORM -if _nni_platform == 'local': +_nni_trial_job_id = trial_env_vars.NNI_TRIAL_JOB_ID +if _nni_platform == 'local' and _nni_trial_job_id != 'runner': _log_file_path = os.path.join(_outputdir, 'trial.log') init_logger(_log_file_path) @@ -62,7 +63,7 @@ def get_next_parameter(): return params def send_metric(string): - if _nni_platform != 'local': + if _nni_platform != 'local' or _nni_trial_job_id == 'runner': assert len(string) < 1000000, 'Metric too long' print("NNISDK_MEb'%s'" % (string), flush=True) else: diff --git a/nni/tools/nnictl/config_schema.py b/nni/tools/nnictl/config_schema.py index 9b0aeb6ad0..7d5ed2beb6 100644 --- a/nni/tools/nnictl/config_schema.py +++ b/nni/tools/nnictl/config_schema.py @@ -141,7 +141,8 @@ def validate(self, data): Optional('localConfig'): { Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int), - Optional('useActiveGpu'): setType('useActiveGpu', bool) + Optional('useActiveGpu'): setType('useActiveGpu', bool), + Optional('reuse'): setType('reuse', bool) } } diff --git a/nni/tools/nnictl/launcher.py b/nni/tools/nnictl/launcher.py index 7ec0c77b15..b66d6e24a0 100644 --- a/nni/tools/nnictl/launcher.py +++ b/nni/tools/nnictl/launcher.py @@ -115,15 +115,9 @@ def set_trial_config(experiment_config, port, config_file_name): def set_local_config(experiment_config, port, config_file_name): '''set local configuration''' request_data = dict() + request_data['local_config'] = {'reuse': False} if experiment_config.get('localConfig'): request_data['local_config'] = experiment_config['localConfig'] - if request_data['local_config']: - if request_data['local_config'].get('gpuIndices') and isinstance(request_data['local_config'].get('gpuIndices'), int): - request_data['local_config']['gpuIndices'] = str(request_data['local_config'].get('gpuIndices')) - if request_data['local_config'].get('maxTrialNumOnEachGpu'): - request_data['local_config']['maxTrialNumOnEachGpu'] = request_data['local_config'].get('maxTrialNumOnEachGpu') - if request_data['local_config'].get('useActiveGpu'): - request_data['local_config']['useActiveGpu'] = request_data['local_config'].get('useActiveGpu') response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT) err_message = '' if not response or not check_response(response): @@ -303,9 +297,9 @@ def set_heterogeneous_config(experiment_config, port, config_file_name): heterogeneous_config_data['heterogeneous_config'] = experiment_config['heterogeneousConfig'] platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms'].split(',') for platform in platform_list: - if platform === 'aml': + if platform == 'aml': heterogeneous_config_data['aml_config'] = experiment_config['amlConfig'] - elif platform === 'remote': + elif platform == 'remote': heterogeneous_config_data['remote_config'] = experiment_config['remoteConfig'] response = rest_put(cluster_metadata_url(port), json.dumps(aml_config_data), REST_TIME_OUT) err_message = None @@ -408,10 +402,10 @@ def set_experiment(experiment_config, mode, port, config_file_name): {'key': 'heterogeneous_config', 'value': experiment_config['heterogeneousConfig']}) platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms'].split(',') for platform in platform_list: - if platform === 'aml': + if platform == 'aml': request_data['clusterMetaData'].append( {'key': 'aml_config', 'value': experiment_config['amlConfig']}) - elif platform === 'remote': + elif platform == 'remote': request_data['clusterMetaData'].append( {'key': 'remote_config', 'value': experiment_config['remoteConfig']}) response = rest_put(cluster_metadata_url(port), json.dumps(aml_config_data), REST_TIME_OUT) diff --git a/ts/nni_manager/main.ts b/ts/nni_manager/main.ts index ce7a5d90c7..eb017246db 100644 --- a/ts/nni_manager/main.ts +++ b/ts/nni_manager/main.ts @@ -37,7 +37,7 @@ function initStartupInfo( async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise { if (platformMode === 'local') { Container.bind(TrainingService) - .to(LocalTrainingService) + .to(RouterTrainingService) .scope(Scope.Singleton); } else if (platformMode === 'remote') { Container.bind(TrainingService) @@ -69,8 +69,8 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN .scope(Scope.Singleton); } else if (platformMode === 'heterogeneous') { Container.bind(TrainingService) - .to(RouterTrainingService) - .scope(Scope.Singleton); + .to(RouterTrainingService) + .scope(Scope.Singleton); } else { throw new Error(`Error: unsupported mode: ${platformMode}`); } diff --git a/ts/nni_manager/rest_server/restValidationSchemas.ts b/ts/nni_manager/rest_server/restValidationSchemas.ts index cac1bc44d2..98560fb29d 100644 --- a/ts/nni_manager/rest_server/restValidationSchemas.ts +++ b/ts/nni_manager/rest_server/restValidationSchemas.ts @@ -23,7 +23,8 @@ export namespace ValidationSchemas { local_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase gpuIndices: joi.string(), maxTrialNumPerGpu: joi.number(), - useActiveGpu: joi.boolean() + useActiveGpu: joi.boolean(), + reuse: joi.boolean() }), trial_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase image: joi.string().min(1), diff --git a/ts/nni_manager/training_service/common/util.ts b/ts/nni_manager/training_service/common/util.ts index 791d7dcebb..a1c6f7044f 100644 --- a/ts/nni_manager/training_service/common/util.ts +++ b/ts/nni_manager/training_service/common/util.ts @@ -95,6 +95,16 @@ export async function execMkdir(directory: string, share: boolean = false): Prom return Promise.resolve(); } +export async function fileExist(filePath: string): Promise { + let cmdresult: cpp.childProcessPromise.Result; + if (process.platform === 'win32') { + cmdresult = await cpp.exec(`powershell.exe Get-Content "${filePath}" -Tail 1`); + } else { + cmdresult = await cpp.exec(`test -e ${filePath} && echo True || echo False`); + } + return cmdresult.stdout !== undefined && cmdresult.stdout.trim() === 'True' +} + /** * copy files to the directory * @param source diff --git a/ts/nni_manager/training_service/local/localTrainingService.ts b/ts/nni_manager/training_service/local/localTrainingService.ts index ab9d037f99..6e29b0b8eb 100644 --- a/ts/nni_manager/training_service/local/localTrainingService.ts +++ b/ts/nni_manager/training_service/local/localTrainingService.ts @@ -78,11 +78,12 @@ class LocalTrialJobDetail implements TrialJobDetail { /** * Local training service config */ -class LocalConfig { +export class LocalConfig { public maxTrialNumPerGpu?: number; public gpuIndices?: string; public useActiveGpu?: boolean; - constructor(gpuIndices?: string, maxTrialNumPerGpu?: number, useActiveGpu?: boolean) { + public reuse?: boolean; + constructor(gpuIndices?: string, maxTrialNumPerGpu?: number, useActiveGpu?: boolean, reuse?: boolean) { if (gpuIndices !== undefined) { this.gpuIndices = gpuIndices; } @@ -92,6 +93,7 @@ class LocalConfig { if (useActiveGpu !== undefined) { this.useActiveGpu = useActiveGpu; } + this.reuse = reuse; } } diff --git a/ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts b/ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts index adacec1cc3..38e0f28256 100644 --- a/ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts +++ b/ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts @@ -32,8 +32,12 @@ export class HeterogenousCommandChannel extends CommandChannel{ } public async start(): Promise { - this.amlCommandChannel?.start(); - this.webCommandChannel?.start(); + if (this.amlCommandChannel) { + this.amlCommandChannel.start(); + } + if (this.webCommandChannel) { + this.webCommandChannel.start(); + } } public async stop(): Promise { @@ -41,17 +45,27 @@ export class HeterogenousCommandChannel extends CommandChannel{ } public async run(): Promise { - this.amlCommandChannel?.start(); - this.webCommandChannel?.run(); + if (this.amlCommandChannel) { + this.amlCommandChannel.run(); + } + if (this.webCommandChannel) { + this.webCommandChannel.run(); + } } protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise { switch (environment.platform) { case 'aml': - this.amlCommandChannel?.sendCommandInternal(environment, message); + if (this.amlCommandChannel === undefined) { + throw new Error(`amlCommandChannel not initialezed!`); + } + this.amlCommandChannel.sendCommandInternal(environment, message); break; case 'remote': - this.webCommandChannel?.sendCommandInternal(environment, message); + if (this.webCommandChannel === undefined) { + throw new Error(`webCommandChannel not initialezed!`); + } + this.webCommandChannel.sendCommandInternal(environment, message); break; default: throw new Error(`Heterogenous not support platform: '${environment.platform}'`); diff --git a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts index 05e4636552..8b0a8996d7 100644 --- a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts @@ -20,7 +20,7 @@ import { EnvironmentInformation, EnvironmentService } from '../environment'; /** - * Collector PAI jobs info from PAI cluster, and update pai job status locally + * Collector AML jobs info from AML cluster, and update pai job status locally */ @component.Singleton export class AMLEnvironmentService extends EnvironmentService { diff --git a/ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts index e27c74d479..ff3074c256 100644 --- a/ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts @@ -14,6 +14,8 @@ import { CommandChannel } from "../commandChannel"; import { EnvironmentInformation, EnvironmentService } from '../environment'; import { AMLEnvironmentService } from './amlEnvironmentService'; import { RemoteEnvironmentService } from './remoteEnvironmentService'; +import { LocalEnvironmentService } from './localEnvironmentService'; +import { OpenPaiEnvironmentService } from './openPaiEnvironmentService'; import { randomSelect } from '../../../common/utils'; @@ -25,6 +27,8 @@ export class HeteroGenousEnvironmentService extends EnvironmentService { private amlEnvironmentService: AMLEnvironmentService; private remoteEnvironmentService: RemoteEnvironmentService; + private localEnvironmentService: LocalEnvironmentService; + private paiEnvironmentService: OpenPaiEnvironmentService; private readonly log: Logger = getLogger(); @@ -32,6 +36,8 @@ export class HeteroGenousEnvironmentService extends EnvironmentService { super(); this.amlEnvironmentService = new AMLEnvironmentService(); this.remoteEnvironmentService = new RemoteEnvironmentService(); + this.localEnvironmentService = new LocalEnvironmentService(); + this.paiEnvironmentService = new OpenPaiEnvironmentService(); } public get hasStorageService(): boolean { @@ -52,7 +58,16 @@ export class HeteroGenousEnvironmentService extends EnvironmentService { break; case TrialConfigMetadataKey.TRIAL_CONFIG: this.amlEnvironmentService.config(key, value); - this.remoteEnvironmentService.config(key, value); + this.remoteEnvironmentService.config(key, value); + this.paiEnvironmentService.config(key, value); + this.localEnvironmentService.config(key, value); + break; + case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG: + this.paiEnvironmentService.config(key, value); + break; + case TrialConfigMetadataKey.LOCAL_CONFIG: + this.localEnvironmentService.config(key, value); + break; default: this.log.debug(`Heterogenous not support metadata key: '${key}', value: '${value}'`); } @@ -68,6 +83,10 @@ export class HeteroGenousEnvironmentService extends EnvironmentService { case 'remote': tasks.push(this.remoteEnvironmentService.refreshEnvironment(environment)); break; + case 'local': + tasks.push(this.localEnvironmentService.refreshEnvironment(environment)); + break; + // TODO: refresh pai default: throw new Error(`Heterogenous not support platform: '${environment.platform}'`); } @@ -76,7 +95,7 @@ export class HeteroGenousEnvironmentService extends EnvironmentService { } public async startEnvironment(environment: EnvironmentInformation): Promise { - const number = randomSelect([0, 1]); + const number = randomSelect([0, 1, 2, 3]); switch (number) { case 0: environment.platform = 'aml'; @@ -86,6 +105,14 @@ export class HeteroGenousEnvironmentService extends EnvironmentService { environment.platform = 'remote'; this.remoteEnvironmentService.startEnvironment(environment); break; + case 2: + environment.platform = 'local'; + this.localEnvironmentService.startEnvironment(environment); + break; + case 3: + environment.platform = 'pai'; + this.paiEnvironmentService.stopEnvironment(environment); + break; } } @@ -97,6 +124,12 @@ export class HeteroGenousEnvironmentService extends EnvironmentService { case 'remote': this.remoteEnvironmentService.stopEnvironment(environment); break; + case 'local': + this.localEnvironmentService.stopEnvironment(environment); + break; + case 'pai': + this.paiEnvironmentService.stopEnvironment(environment); + break; default: throw new Error(`Heterogenous not support platform '${environment.platform}'`); } diff --git a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts new file mode 100644 index 0000000000..15a78215fe --- /dev/null +++ b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +'use strict'; + +import * as cpp from 'child-process-promise'; +import * as cp from 'child_process'; +import * as fs from 'fs'; +import * as path from 'path'; +import * as yaml from 'js-yaml'; +import * as request from 'request'; +import { Deferred } from 'ts-deferred'; +import * as tkill from 'tree-kill'; +import * as component from '../../../common/component'; +import { getExperimentId } from '../../../common/experimentStartupInfo'; +import { getLogger, Logger } from '../../../common/log'; +import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey'; +import { PAIClusterConfig } from '../../pai/paiConfig'; +import { NNIPAIK8STrialConfig } from '../../pai/paiK8S/paiK8SConfig'; +import { EnvironmentInformation, EnvironmentService } from '../environment'; +import { StorageService } from '../storageService'; +import { TrialConfig } from '../../common/trialConfig'; +import { getExperimentRootDir, isAlive } from '../../../common/utils'; +import { LocalConfig } from '../../local/localTrainingService'; +import { execMkdir, validateCodeDir, runScript, fileExist, execCopydir } from '../../common/util'; + + +@component.Singleton +export class LocalEnvironmentService extends EnvironmentService { + + private readonly log: Logger = getLogger(); + private localTrialConfig: TrialConfig | undefined; + private localConfig: LocalConfig | undefined; + private experimentRootDir: string; + private experimentId: string; + + constructor() { + super(); + this.experimentId = getExperimentId(); + this.experimentRootDir = getExperimentRootDir(); + } + + public get environmentMaintenceLoopInterval(): number { + return 5000; + } + + public get hasStorageService(): boolean { + return false; + } + + public async config(key: string, value: string): Promise { + switch (key) { + case TrialConfigMetadataKey.LOCAL_CONFIG: + this.localConfig = JSON.parse(value); + break; + case TrialConfigMetadataKey.TRIAL_CONFIG: + this.localTrialConfig = JSON.parse(value); + break; + default: + this.log.debug(`OpenPAI not proccessed metadata key: '${key}', value: '${value}'`); + } + } + + public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise { + const tasks: Promise[] = []; + environments.forEach(async (environment) => { + tasks.push(this.refreshEnvironment(environment)); + }); + await Promise.all(tasks); + } + + public async refreshEnvironment(environment: EnvironmentInformation): Promise { + const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`; + const runnerReturnCodeFilePath: string = `${environment.runnerWorkingFolder}/code`; + /* eslint-disable require-atomic-updates */ + try { + // check if pid file exist + const pidExist = await fileExist(jobpidPath); + if (!pidExist) { + return; + } + const pid: string = await fs.promises.readFile(jobpidPath, 'utf8'); + const alive: boolean = await isAlive(pid); + environment.status = 'RUNNING'; + // if the process of jobpid is not alive any more + if (!alive) { + if (fs.existsSync(runnerReturnCodeFilePath)) { + const runnerReturnCode: string = await fs.promises.readFile(runnerReturnCodeFilePath, 'utf8'); + const match: RegExpMatchArray | null = runnerReturnCode.trim() + .match(/^-?(\d+)\s+(\d+)$/); + if (match !== null) { + const { 1: code } = match; + // Update trial job's status based on result code + if (parseInt(code, 10) === 0) { + environment.setStatus('SUCCEEDED'); + } else { + environment.setStatus('FAILED'); + } + } + } + } + } catch (error) { + this.log.error(`Update job status exception, error is ${error.message}`); + } + } + + public async startEnvironment(environment: EnvironmentInformation): Promise { + if (this.localTrialConfig === undefined) { + throw new Error('Local trial config is not initialized'); + } + if (this.localConfig === undefined) { + throw new Error('Local config is not initialized'); + } + const localEnvironment: LocalEnvironmentInformation = environment as LocalEnvironmentInformation; + // Need refactor, this temp folder path is not appropriate, there are two expId in this path + const localTempFolder: string = path.join(this.experimentRootDir, this.experimentId, + "environment-temp", "envs"); + const localEnvCodeFolder: string = path.join(this.experimentRootDir, "envs"); + localEnvironment.runnerWorkingFolder = path.join(localEnvCodeFolder, localEnvironment.id); + await execMkdir(localEnvironment.runnerWorkingFolder); + await execCopydir(localTempFolder, localEnvCodeFolder); + localEnvironment.command = `cd ${this.experimentRootDir} && \ +${localEnvironment.command} --job_pid_file ${localEnvironment.runnerWorkingFolder}/pid \ +1>${localEnvironment.runnerWorkingFolder}/trialrunner_stdout 2>${localEnvironment.runnerWorkingFolder}/trialrunner_stderr \ +&& echo $? \`date +%s%3N\` >${localEnvironment.runnerWorkingFolder}/code`; + await fs.promises.writeFile(path.join(localEnvCodeFolder, 'nni_run.sh'), + localEnvironment.command, { encoding: 'utf8', mode: 0o777 }), + // Execute command in local machine + runScript(path.join(localEnvCodeFolder, 'nni_run.sh')); + localEnvironment.trackingUrl = `${localEnvironment.runnerWorkingFolder}`; + } + + public async stopEnvironment(environment: EnvironmentInformation): Promise { + const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`; + const pid: string = await fs.promises.readFile(jobpidPath, 'utf8'); + tkill(Number(pid), 'SIGKILL'); + } +} diff --git a/ts/nni_manager/training_service/reusable/routerTrainingService.ts b/ts/nni_manager/training_service/reusable/routerTrainingService.ts index 3bf47dd2ce..b828bcac49 100644 --- a/ts/nni_manager/training_service/reusable/routerTrainingService.ts +++ b/ts/nni_manager/training_service/reusable/routerTrainingService.ts @@ -15,6 +15,7 @@ import { PAIK8STrainingService } from '../pai/paiK8S/paiK8STrainingService'; import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService'; import { EnvironmentService } from './environment'; import { OpenPaiEnvironmentService } from './environments/openPaiEnvironmentService'; +import { LocalEnvironmentService } from './environments/localEnvironmentService'; import { AMLEnvironmentService } from './environments/amlEnvironmentService'; import { RemoteEnvironmentService } from './environments/remoteEnvironmentService'; import { MountedStorageService } from './storages/mountedStorageService'; @@ -22,6 +23,8 @@ import { HeteroGenousEnvironmentService } from './environments/heterogenousEnvir import { StorageService } from './storageService'; import { TrialDispatcher } from './trialDispatcher'; import { RemoteConfig } from './remote/remoteConfig'; +import { LocalConfig, LocalTrainingService } from '../local/localTrainingService'; +import { TrialConfig } from 'training_service/common/trialConfig'; /** @@ -100,6 +103,24 @@ class RouterTrainingService implements TrainingService { public async setClusterMetadata(key: string, value: string): Promise { if (this.internalTrainingService === undefined) { + if (key === TrialConfigMetadataKey.LOCAL_CONFIG) { + const config = JSON.parse(value); + if (config.reuse === true) { + this.log.info(`reuse flag enabled, use EnvironmentManager.`); + this.internalTrainingService = component.get(TrialDispatcher); + + // TODO to support other serivces later. + Container.bind(EnvironmentService) + .to(LocalEnvironmentService) + .scope(Scope.Singleton); + } else { + this.internalTrainingService = component.get(LocalTrainingService); + } + if (this.internalTrainingService === undefined) { + throw new Error("TrainingService is not assigned!"); + } + await this.internalTrainingService.setClusterMetadata(key, value); + } if (key === TrialConfigMetadataKey.PAI_CLUSTER_CONFIG) { const config = JSON.parse(value); if (config.reuse === true) { From ef4f561e4902ccdc8c483ef1aec507cd7adeead7 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Fri, 27 Nov 2020 13:52:44 +0800 Subject: [PATCH 03/24] add doc and refactor code --- .../TrainingService/HeterogeneousMode.md | 57 +++++++++++++++++++ nni/runtime/platform/__init__.py | 2 +- nni/tools/nnictl/config_schema.py | 25 ++++++-- nni/tools/nnictl/launcher.py | 23 ++++++-- .../rest_server/restValidationSchemas.ts | 2 +- .../common/trialConfigMetadataKey.ts | 2 +- .../channels/heterogenousCommandChannel.ts | 8 +-- .../heterogenousEnvironmentService.ts | 22 ++++--- .../environments/localEnvironmentService.ts | 20 ++++--- .../heterogenous/heterogenousConfig.ts | 11 ++++ .../reusable/routerTrainingService.ts | 9 ++- 11 files changed, 144 insertions(+), 37 deletions(-) create mode 100644 docs/en_US/TrainingService/HeterogeneousMode.md create mode 100644 ts/nni_manager/training_service/reusable/heterogenous/heterogenousConfig.ts diff --git a/docs/en_US/TrainingService/HeterogeneousMode.md b/docs/en_US/TrainingService/HeterogeneousMode.md new file mode 100644 index 0000000000..c93c337c63 --- /dev/null +++ b/docs/en_US/TrainingService/HeterogeneousMode.md @@ -0,0 +1,57 @@ +**Run an Experiment on Heterogeneous Mode** +=== +Run NNI on heterogeneous mode means that NNI will run trials jobs in multiple kinds of training platforms. For example, NNI could submit trial jobs to remote machine and AML simultaneously。 + +## Setup environment +NNI has supported [local](./LocalMode.md), [remote](./RemoteMachineMode.md), [pai](./PaiMode.md) and [AML](./AMLMode.md) for heterogeneous training service. Before starting an experiment using these mode, users should setup the corresponding environment for the platforms. More details about the environment setup could be found in the corresponding docs. + + + +## Run an experiment +Use `examples/trials/mnist-tfv1` as an example. The NNI config YAML file's content is like: + +```yaml +authorName: default +experimentName: example_mnist +trialConcurrency: 1 +maxExecDuration: 1h +maxTrialNum: 10 +trainingServicePlatform: heterogeneous +searchSpacePath: search_space.json +#choice: true, false +useAnnotation: false +tuner: + #choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner + #SMAC (SMAC should be installed through nnictl) + builtinTunerName: TPE + classArgs: + #choice: maximize, minimize + optimize_mode: maximize +trial: + command: python3 mnist.py + codeDir: . + image: msranni/nni + gpuNum: 1 +heterogeneousConfig: + trainingServicePlatforms: + - local + - remote +localConfig: + reuse: true +remoteConfig: + reuse: true +machineList: + - ip: 10.1.1.1 + username: bob + passwd: bob123 + #port can be skip if using default ssh port 22 + #port: 22 +``` +Configurations for heterogeneous mode: + +heterogeneousConfig: +* trainingServicePlatforms. required key. This field specify the platforms used in heterogeneous mode, the values using yaml list format. NNI support setting `local`, `remote`, `aml`, `pai` in this field. + + +Note: + If setting a platform in trainingServicePlatforms mode, users should also set the corresponding configuration for the platform. For example, if set `remote` as one of the platform, should also set `machineList` and `remoteConfig` configuration. diff --git a/nni/runtime/platform/__init__.py b/nni/runtime/platform/__init__.py index 84f04a9862..1dc8aae5db 100644 --- a/nni/runtime/platform/__init__.py +++ b/nni/runtime/platform/__init__.py @@ -9,7 +9,7 @@ from .standalone import * elif trial_env_vars.NNI_PLATFORM == 'unittest': from .test import * -elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'): +elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'heterogeneous'): from .local import * else: raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM) diff --git a/nni/tools/nnictl/config_schema.py b/nni/tools/nnictl/config_schema.py index 7d5ed2beb6..effd42ecfc 100644 --- a/nni/tools/nnictl/config_schema.py +++ b/nni/tools/nnictl/config_schema.py @@ -209,7 +209,7 @@ def validate(self, data): } pai_config_schema = { - 'paiConfig': { + Optional('paiConfig'): { 'userName': setType('userName', str), Or('passWord', 'token', only_one=True): str, 'host': setType('host', str), @@ -253,7 +253,7 @@ def validate(self, data): } aml_config_schema = { - 'amlConfig': { + Optional('amlConfig'): { 'subscriptionId': setType('subscriptionId', str), 'resourceGroup': setType('resourceGroup', str), 'workspaceName': setType('workspaceName', str), @@ -265,7 +265,7 @@ def validate(self, data): heterogeneous_config_schema = { 'heterogeneousConfig': { - 'trainingServicePlatforms': setType('trainingServicePlatforms', str), + 'trainingServicePlatforms': ['local', 'remote', 'pai', 'aml'] } } @@ -386,7 +386,7 @@ def validate(self, data): } machine_list_schema = { - 'machineList': [Or( + Optional('machineList'): [Or( { 'ip': setType('ip', str), Optional('port'): setNumberRange('port', int, 1, 65535), @@ -419,7 +419,8 @@ def validate(self, data): 'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}), 'aml': Schema({**common_schema, **aml_trial_schema, **aml_config_schema}), 'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}), - 'heterogeneous': Schema({**common_schema, **common_trial_schema, **heterogeneous_config_schema}), + 'heterogeneous': Schema({**common_schema, **common_trial_schema, **heterogeneous_config_schema, **machine_list_schema, + **pai_config_schema, **aml_config_schema, **remote_config_schema}), } @@ -436,6 +437,7 @@ def validate_extras(self, experiment_config): self.validate_pai_trial_conifg(experiment_config) self.validate_kubeflow_operators(experiment_config) self.validate_eth0_device(experiment_config) + self.validate_heterogeneous_platforms(experiment_config) def validate_tuner_adivosr_assessor(self, experiment_config): if experiment_config.get('advisor'): @@ -545,3 +547,16 @@ def validate_eth0_device(self, experiment_config): and not experiment_config.get('nniManagerIp') \ and 'eth0' not in netifaces.interfaces(): raise SchemaError('This machine does not contain eth0 network device, please set nniManagerIp in config file!') + + def validate_heterogeneous_platforms(self, experiment_config): + required_config_name_map = { + 'remote': 'machineList', + 'aml': 'amlConfig', + 'pai': 'paiConfig' + } + if experiment_config.get('trainingServicePlatform') == 'heterogeneous': + for platform in experiment_config['heterogeneousConfig']['trainingServicePlatforms']: + config_name = required_config_name_map.get(platform) + if config_name and not experiment_config.get(config_name): + raise SchemaError('Need to set {0} for {1} in heterogeneous mode!'.format(config_name, platform)) + \ No newline at end of file diff --git a/nni/tools/nnictl/launcher.py b/nni/tools/nnictl/launcher.py index b66d6e24a0..cb78b1ba3a 100644 --- a/nni/tools/nnictl/launcher.py +++ b/nni/tools/nnictl/launcher.py @@ -295,13 +295,19 @@ def set_heterogeneous_config(experiment_config, port, config_file_name): '''set heterogeneous configuration''' heterogeneous_config_data = dict() heterogeneous_config_data['heterogeneous_config'] = experiment_config['heterogeneousConfig'] - platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms'].split(',') + platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms'] for platform in platform_list: if platform == 'aml': heterogeneous_config_data['aml_config'] = experiment_config['amlConfig'] elif platform == 'remote': - heterogeneous_config_data['remote_config'] = experiment_config['remoteConfig'] - response = rest_put(cluster_metadata_url(port), json.dumps(aml_config_data), REST_TIME_OUT) + if experiment_config.get('remoteConfig'): + heterogeneous_config_data['remote_config'] = experiment_config['remoteConfig'] + heterogeneous_config_data['machine_list'] = experiment_config['machineList'] + elif platform == 'local': + heterogeneous_config_data['local_config'] = experiment_config['localConfig'] + elif platform == 'pai': + heterogeneous_config_data['pai_config'] = experiment_config['paiConfig'] + response = rest_put(cluster_metadata_url(port), json.dumps(heterogeneous_config_data), REST_TIME_OUT) err_message = None if not response or not response.status_code == 200: if response is not None: @@ -400,15 +406,20 @@ def set_experiment(experiment_config, mode, port, config_file_name): elif experiment_config['trainingServicePlatform'] == 'heterogeneous': request_data['clusterMetaData'].append( {'key': 'heterogeneous_config', 'value': experiment_config['heterogeneousConfig']}) - platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms'].split(',') + platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms'] for platform in platform_list: if platform == 'aml': request_data['clusterMetaData'].append( {'key': 'aml_config', 'value': experiment_config['amlConfig']}) elif platform == 'remote': request_data['clusterMetaData'].append( - {'key': 'remote_config', 'value': experiment_config['remoteConfig']}) - response = rest_put(cluster_metadata_url(port), json.dumps(aml_config_data), REST_TIME_OUT) + {'key': 'machine_list', 'value': experiment_config['machineList']}) + elif platform == 'local': + request_data['clusterMetaData'].append( + {'key': 'local_config', 'value': experiment_config['localConfig']}) + elif platform == 'pai': + request_data['clusterMetaData'].append( + {'key': 'pai_config', 'value': experiment_config['paiConfig']}) request_data['clusterMetaData'].append( {'key': 'trial_config', 'value': experiment_config['trial']}) response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True) diff --git a/ts/nni_manager/rest_server/restValidationSchemas.ts b/ts/nni_manager/rest_server/restValidationSchemas.ts index 98560fb29d..6a37a476ff 100644 --- a/ts/nni_manager/rest_server/restValidationSchemas.ts +++ b/ts/nni_manager/rest_server/restValidationSchemas.ts @@ -166,7 +166,7 @@ export namespace ValidationSchemas { useActiveGpu: joi.boolean() }), heterogeneous_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase - trainingServicePlatforms: joi.string().min(1), + trainingServicePlatforms: joi.array(), }), nni_manager_ip: joi.object({ // eslint-disable-line @typescript-eslint/camelcase nniManagerIp: joi.string().min(1) diff --git a/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts b/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts index ebf8ec366d..b94f562917 100644 --- a/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts +++ b/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts @@ -11,13 +11,13 @@ export enum TrialConfigMetadataKey { LOCAL_CONFIG = 'local_config', TRIAL_CONFIG = 'trial_config', REMOTE_CONFIG = 'remote_config', + HETEROGENOUS_CONFIG = 'heterogenous_config', EXPERIMENT_ID = 'experimentId', MULTI_PHASE = 'multiPhase', RANDOM_SCHEDULER = 'random_scheduler', PAI_YARN_CLUSTER_CONFIG = 'pai_yarn_config', PAI_CLUSTER_CONFIG = 'pai_config', KUBEFLOW_CLUSTER_CONFIG = 'kubeflow_config', - HETEROGENOUS_CLUSTER_CONFIG = 'heterogenous_config', NNI_MANAGER_IP = 'nni_manager_ip', FRAMEWORKCONTROLLER_CLUSTER_CONFIG = 'frameworkcontroller_config', DLTS_CLUSTER_CONFIG = 'dlts_config', diff --git a/ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts b/ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts index 38e0f28256..399158484d 100644 --- a/ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts +++ b/ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts @@ -5,15 +5,11 @@ import { EventEmitter } from "events"; import { delay } from "../../../common/utils"; -import { AMLEnvironmentInformation } from '../aml/amlConfig'; import { CommandChannel, RunnerConnection } from "../commandChannel"; import { Channel, EnvironmentInformation } from "../environment"; import { AMLCommandChannel } from "./amlCommandChannel"; import { WebCommandChannel } from "./webCommandChannel"; -class HeterogenousRunnerConnection extends RunnerConnection { -} - export class HeterogenousCommandChannel extends CommandChannel{ private stopping: boolean = false; private amlCommandChannel: AMLCommandChannel | undefined; @@ -61,6 +57,8 @@ export class HeterogenousCommandChannel extends CommandChannel{ } this.amlCommandChannel.sendCommandInternal(environment, message); break; + case 'local': + case 'pai': case 'remote': if (this.webCommandChannel === undefined) { throw new Error(`webCommandChannel not initialezed!`); @@ -73,6 +71,6 @@ export class HeterogenousCommandChannel extends CommandChannel{ } protected createRunnerConnection(environment: EnvironmentInformation): RunnerConnection { - return new HeterogenousRunnerConnection(environment); + return new RunnerConnection(environment); } } diff --git a/ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts index ff3074c256..e9bfec1790 100644 --- a/ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts @@ -17,6 +17,7 @@ import { RemoteEnvironmentService } from './remoteEnvironmentService'; import { LocalEnvironmentService } from './localEnvironmentService'; import { OpenPaiEnvironmentService } from './openPaiEnvironmentService'; import { randomSelect } from '../../../common/utils'; +import { HeterogenousConfig } from '../heterogenous/heterogenousConfig'; /** @@ -29,6 +30,7 @@ export class HeteroGenousEnvironmentService extends EnvironmentService { private remoteEnvironmentService: RemoteEnvironmentService; private localEnvironmentService: LocalEnvironmentService; private paiEnvironmentService: OpenPaiEnvironmentService; + private heterogenousConfig?: HeterogenousConfig; private readonly log: Logger = getLogger(); @@ -68,6 +70,8 @@ export class HeteroGenousEnvironmentService extends EnvironmentService { case TrialConfigMetadataKey.LOCAL_CONFIG: this.localEnvironmentService.config(key, value); break; + case TrialConfigMetadataKey.HETEROGENOUS_CONFIG: + this.heterogenousConfig = JSON.parse(value); default: this.log.debug(`Heterogenous not support metadata key: '${key}', value: '${value}'`); } @@ -95,23 +99,27 @@ export class HeteroGenousEnvironmentService extends EnvironmentService { } public async startEnvironment(environment: EnvironmentInformation): Promise { - const number = randomSelect([0, 1, 2, 3]); - switch (number) { - case 0: + if (this.heterogenousConfig === undefined) { + throw new Error('heterogenousConfig not initialized!'); + } + this.heterogenousConfig.trainingServicePlatforms; + const platform = randomSelect(this.heterogenousConfig.trainingServicePlatforms); + switch (platform) { + case 'aml': environment.platform = 'aml'; this.amlEnvironmentService.startEnvironment(environment); break; - case 1: + case 'remote': environment.platform = 'remote'; this.remoteEnvironmentService.startEnvironment(environment); break; - case 2: + case 'local': environment.platform = 'local'; this.localEnvironmentService.startEnvironment(environment); break; - case 3: + case 'pai': environment.platform = 'pai'; - this.paiEnvironmentService.stopEnvironment(environment); + this.paiEnvironmentService.startEnvironment(environment); break; } } diff --git a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts index 15a78215fe..e4f026ad2b 100644 --- a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts @@ -4,6 +4,7 @@ 'use strict'; import * as cpp from 'child-process-promise'; +import { EventEmitter } from "events"; import * as cp from 'child_process'; import * as fs from 'fs'; import * as path from 'path'; @@ -23,6 +24,8 @@ import { TrialConfig } from '../../common/trialConfig'; import { getExperimentRootDir, isAlive } from '../../../common/utils'; import { LocalConfig } from '../../local/localTrainingService'; import { execMkdir, validateCodeDir, runScript, fileExist, execCopydir } from '../../common/util'; +import { FileCommandChannel } from '../channels/fileCommandChannel'; +import { CommandChannel } from "../commandChannel"; @component.Singleton @@ -111,23 +114,22 @@ export class LocalEnvironmentService extends EnvironmentService { if (this.localConfig === undefined) { throw new Error('Local config is not initialized'); } - const localEnvironment: LocalEnvironmentInformation = environment as LocalEnvironmentInformation; // Need refactor, this temp folder path is not appropriate, there are two expId in this path const localTempFolder: string = path.join(this.experimentRootDir, this.experimentId, "environment-temp", "envs"); const localEnvCodeFolder: string = path.join(this.experimentRootDir, "envs"); - localEnvironment.runnerWorkingFolder = path.join(localEnvCodeFolder, localEnvironment.id); - await execMkdir(localEnvironment.runnerWorkingFolder); + environment.runnerWorkingFolder = path.join(localEnvCodeFolder, environment.id); + await execMkdir(environment.runnerWorkingFolder); await execCopydir(localTempFolder, localEnvCodeFolder); - localEnvironment.command = `cd ${this.experimentRootDir} && \ -${localEnvironment.command} --job_pid_file ${localEnvironment.runnerWorkingFolder}/pid \ -1>${localEnvironment.runnerWorkingFolder}/trialrunner_stdout 2>${localEnvironment.runnerWorkingFolder}/trialrunner_stderr \ -&& echo $? \`date +%s%3N\` >${localEnvironment.runnerWorkingFolder}/code`; + environment.command = `cd ${this.experimentRootDir} && \ +${environment.command} --job_pid_file ${environment.runnerWorkingFolder}/pid \ +1>${environment.runnerWorkingFolder}/trialrunner_stdout 2>${environment.runnerWorkingFolder}/trialrunner_stderr \ +&& echo $? \`date +%s%3N\` >${environment.runnerWorkingFolder}/code`; await fs.promises.writeFile(path.join(localEnvCodeFolder, 'nni_run.sh'), - localEnvironment.command, { encoding: 'utf8', mode: 0o777 }), + environment.command, { encoding: 'utf8', mode: 0o777 }), // Execute command in local machine runScript(path.join(localEnvCodeFolder, 'nni_run.sh')); - localEnvironment.trackingUrl = `${localEnvironment.runnerWorkingFolder}`; + environment.trackingUrl = `${environment.runnerWorkingFolder}`; } public async stopEnvironment(environment: EnvironmentInformation): Promise { diff --git a/ts/nni_manager/training_service/reusable/heterogenous/heterogenousConfig.ts b/ts/nni_manager/training_service/reusable/heterogenous/heterogenousConfig.ts new file mode 100644 index 0000000000..2c3012bb12 --- /dev/null +++ b/ts/nni_manager/training_service/reusable/heterogenous/heterogenousConfig.ts @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + + +export class HeterogenousConfig { + public readonly trainingServicePlatforms: string[]; + + constructor(trainingServicePlatforms: string[]) { + this.trainingServicePlatforms = trainingServicePlatforms; + } +} diff --git a/ts/nni_manager/training_service/reusable/routerTrainingService.ts b/ts/nni_manager/training_service/reusable/routerTrainingService.ts index b828bcac49..20c56d79ba 100644 --- a/ts/nni_manager/training_service/reusable/routerTrainingService.ts +++ b/ts/nni_manager/training_service/reusable/routerTrainingService.ts @@ -24,7 +24,6 @@ import { StorageService } from './storageService'; import { TrialDispatcher } from './trialDispatcher'; import { RemoteConfig } from './remote/remoteConfig'; import { LocalConfig, LocalTrainingService } from '../local/localTrainingService'; -import { TrialConfig } from 'training_service/common/trialConfig'; /** @@ -183,11 +182,17 @@ class RouterTrainingService implements TrainingService { this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`); this.internalTrainingService = component.get(RemoteMachineTrainingService); } - } else if (key === TrialConfigMetadataKey.HETEROGENOUS_CLUSTER_CONFIG){ + } else if (key === TrialConfigMetadataKey.HETEROGENOUS_CONFIG){ + console.log('-------------------------------186--------------') this.internalTrainingService = component.get(TrialDispatcher); Container.bind(EnvironmentService) .to(HeteroGenousEnvironmentService) .scope(Scope.Singleton); + + if (this.internalTrainingService === undefined) { + throw new Error("TrainingService is not assigned!"); + } + await this.internalTrainingService.setClusterMetadata(key, value); } } else { await this.internalTrainingService.setClusterMetadata(key, value); From aca3e2877145e7a3c8db271cf0ffd60c571ed679 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Fri, 27 Nov 2020 13:57:23 +0800 Subject: [PATCH 04/24] remove unused console --- .../training_service/reusable/routerTrainingService.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/ts/nni_manager/training_service/reusable/routerTrainingService.ts b/ts/nni_manager/training_service/reusable/routerTrainingService.ts index 20c56d79ba..47c49a9a7f 100644 --- a/ts/nni_manager/training_service/reusable/routerTrainingService.ts +++ b/ts/nni_manager/training_service/reusable/routerTrainingService.ts @@ -183,7 +183,6 @@ class RouterTrainingService implements TrainingService { this.internalTrainingService = component.get(RemoteMachineTrainingService); } } else if (key === TrialConfigMetadataKey.HETEROGENOUS_CONFIG){ - console.log('-------------------------------186--------------') this.internalTrainingService = component.get(TrialDispatcher); Container.bind(EnvironmentService) .to(HeteroGenousEnvironmentService) From db90b8f2599fbfa0233991bd56353c23784ae948 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Mon, 30 Nov 2020 05:16:44 +0800 Subject: [PATCH 05/24] refactor code --- nni/tools/nnictl/config_schema.py | 3 +- nni/tools/nnictl/launcher.py | 5 +- nni/tools/trial_tool/trial_runner.py | 5 +- ts/nni_manager/main.ts | 2 +- .../common/trialConfigMetadataKey.ts | 2 +- .../local/localTrainingService.ts | 4 +- .../channels/heterogeneousCommandChannel.ts | 115 ++++++++++++++++++ .../channels/heterogenousCommandChannel.ts | 76 ------------ .../reusable/channels/webCommandChannel.ts | 13 +- ....ts => heterogeneousEnvironmentService.ts} | 50 ++++---- .../environments/localEnvironmentService.ts | 9 +- .../reusable/routerTrainingService.ts | 23 ++-- 12 files changed, 168 insertions(+), 139 deletions(-) create mode 100644 ts/nni_manager/training_service/reusable/channels/heterogeneousCommandChannel.ts delete mode 100644 ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts rename ts/nni_manager/training_service/reusable/environments/{heterogenousEnvironmentService.ts => heterogeneousEnvironmentService.ts} (71%) diff --git a/nni/tools/nnictl/config_schema.py b/nni/tools/nnictl/config_schema.py index effd42ecfc..b93fb34392 100644 --- a/nni/tools/nnictl/config_schema.py +++ b/nni/tools/nnictl/config_schema.py @@ -141,8 +141,7 @@ def validate(self, data): Optional('localConfig'): { Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int), - Optional('useActiveGpu'): setType('useActiveGpu', bool), - Optional('reuse'): setType('reuse', bool) + Optional('useActiveGpu'): setType('useActiveGpu', bool) } } diff --git a/nni/tools/nnictl/launcher.py b/nni/tools/nnictl/launcher.py index cb78b1ba3a..d0de8bb573 100644 --- a/nni/tools/nnictl/launcher.py +++ b/nni/tools/nnictl/launcher.py @@ -115,7 +115,6 @@ def set_trial_config(experiment_config, port, config_file_name): def set_local_config(experiment_config, port, config_file_name): '''set local configuration''' request_data = dict() - request_data['local_config'] = {'reuse': False} if experiment_config.get('localConfig'): request_data['local_config'] = experiment_config['localConfig'] response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT) @@ -303,7 +302,7 @@ def set_heterogeneous_config(experiment_config, port, config_file_name): if experiment_config.get('remoteConfig'): heterogeneous_config_data['remote_config'] = experiment_config['remoteConfig'] heterogeneous_config_data['machine_list'] = experiment_config['machineList'] - elif platform == 'local': + elif platform == 'local' and experiment_config.get('localConfig'): heterogeneous_config_data['local_config'] = experiment_config['localConfig'] elif platform == 'pai': heterogeneous_config_data['pai_config'] = experiment_config['paiConfig'] @@ -414,7 +413,7 @@ def set_experiment(experiment_config, mode, port, config_file_name): elif platform == 'remote': request_data['clusterMetaData'].append( {'key': 'machine_list', 'value': experiment_config['machineList']}) - elif platform == 'local': + elif platform == 'local' and experiment_config.get('localConfig'): request_data['clusterMetaData'].append( {'key': 'local_config', 'value': experiment_config['localConfig']}) elif platform == 'pai': diff --git a/nni/tools/trial_tool/trial_runner.py b/nni/tools/trial_tool/trial_runner.py index 8ee5c69bf7..b30a235932 100644 --- a/nni/tools/trial_tool/trial_runner.py +++ b/nni/tools/trial_tool/trial_runner.py @@ -25,7 +25,7 @@ def main_loop(args): '''main loop logic for trial runner''' idle_last_time = datetime.now() gpu_refresh_last_time = datetime.now() - timedelta(minutes=1) - + nni_log(LogType.Info, "--------------main loop-----------28-----------------") try: if args.job_pid_file: with open(args.job_pid_file, 'w') as job_file: @@ -215,7 +215,8 @@ def check_version(args): command_channel = None if args.command_channel == "file": command_channel = FileChannel(args) - elif args.command_channel == 'aml': + elif args.command_channel == 'aml' or \ + args.command_channel == 'heterogeneous' and args.platform == 'aml': from .aml_channel import AMLChannel command_channel = AMLChannel(args) else: diff --git a/ts/nni_manager/main.ts b/ts/nni_manager/main.ts index eb017246db..86400b32d7 100644 --- a/ts/nni_manager/main.ts +++ b/ts/nni_manager/main.ts @@ -37,7 +37,7 @@ function initStartupInfo( async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise { if (platformMode === 'local') { Container.bind(TrainingService) - .to(RouterTrainingService) + .to(LocalTrainingService) .scope(Scope.Singleton); } else if (platformMode === 'remote') { Container.bind(TrainingService) diff --git a/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts b/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts index b94f562917..38f547a0a9 100644 --- a/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts +++ b/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts @@ -11,7 +11,7 @@ export enum TrialConfigMetadataKey { LOCAL_CONFIG = 'local_config', TRIAL_CONFIG = 'trial_config', REMOTE_CONFIG = 'remote_config', - HETEROGENOUS_CONFIG = 'heterogenous_config', + HETEROGENEOUS_CONFIG = 'heterogeneous_config', EXPERIMENT_ID = 'experimentId', MULTI_PHASE = 'multiPhase', RANDOM_SCHEDULER = 'random_scheduler', diff --git a/ts/nni_manager/training_service/local/localTrainingService.ts b/ts/nni_manager/training_service/local/localTrainingService.ts index 6e29b0b8eb..58c5d7f764 100644 --- a/ts/nni_manager/training_service/local/localTrainingService.ts +++ b/ts/nni_manager/training_service/local/localTrainingService.ts @@ -82,8 +82,7 @@ export class LocalConfig { public maxTrialNumPerGpu?: number; public gpuIndices?: string; public useActiveGpu?: boolean; - public reuse?: boolean; - constructor(gpuIndices?: string, maxTrialNumPerGpu?: number, useActiveGpu?: boolean, reuse?: boolean) { + constructor(gpuIndices?: string, maxTrialNumPerGpu?: number, useActiveGpu?: boolean) { if (gpuIndices !== undefined) { this.gpuIndices = gpuIndices; } @@ -93,7 +92,6 @@ export class LocalConfig { if (useActiveGpu !== undefined) { this.useActiveGpu = useActiveGpu; } - this.reuse = reuse; } } diff --git a/ts/nni_manager/training_service/reusable/channels/heterogeneousCommandChannel.ts b/ts/nni_manager/training_service/reusable/channels/heterogeneousCommandChannel.ts new file mode 100644 index 0000000000..f890d48c23 --- /dev/null +++ b/ts/nni_manager/training_service/reusable/channels/heterogeneousCommandChannel.ts @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +'use strict'; + +import { EventEmitter } from "events"; +import { delay } from "../../../common/utils"; +import { CommandChannel, RunnerConnection } from "../commandChannel"; +import { Channel, EnvironmentInformation } from "../environment"; +import { AMLCommandChannel } from "./amlCommandChannel"; +import { WebCommandChannel, WebRunnerConnection } from "./webCommandChannel"; + + +export class HeterogenousCommandChannel extends CommandChannel{ + private stopping: boolean = false; + private amlCommandChannel: AMLCommandChannel | undefined; + private webCommandChannel: WebCommandChannel | undefined; + + public get channelName(): Channel { + return "web"; + } + + public constructor(commandEmitter: EventEmitter, platformsArray: string[]) { + super(commandEmitter); + console.log(platformsArray.includes('local')) + if (platformsArray.includes('local') || + platformsArray.includes('remote') || + platformsArray.includes('pai')) { + this.webCommandChannel = new WebCommandChannel(commandEmitter); + } + if (platformsArray.includes('aml')) { + this.amlCommandChannel = new AMLCommandChannel(commandEmitter); + } + } + + public async config(_key: string, _value: any): Promise { + // do nothing + } + + public async start(): Promise { + const tasks: Promise[] = []; + if (this.amlCommandChannel) { + tasks.push(this.amlCommandChannel.start()); + } + if (this.webCommandChannel) { + tasks.push(this.webCommandChannel.start()); + } + await Promise.all(tasks); + } + + public async stop(): Promise { + this.stopping = true; + } + + public async open(environment: EnvironmentInformation): Promise { + const tasks: Promise[] = []; + if (this.amlCommandChannel) { + tasks.push(this.amlCommandChannel.open(environment)); + } + if (this.webCommandChannel) { + tasks.push(this.webCommandChannel.open(environment)); + } + await Promise.all(tasks); + } + + public async close(environment: EnvironmentInformation): Promise { + const tasks: Promise[] = []; + if (this.amlCommandChannel) { + tasks.push(this.amlCommandChannel.close(environment)); + } + if (this.webCommandChannel) { + tasks.push(this.webCommandChannel.close(environment)); + } + await Promise.all(tasks); + } + + public async run(): Promise { + const tasks: Promise[] = []; + if (this.amlCommandChannel) { + tasks.push(this.amlCommandChannel.run()); + } + if (this.webCommandChannel) { + tasks.push(this.webCommandChannel.run()); + } + await Promise.all(tasks); + } + + protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise { + switch (environment.platform) { + case 'aml': + if (this.amlCommandChannel === undefined) { + throw new Error(`amlCommandChannel not initialezed!`); + } + await this.amlCommandChannel.sendCommandInternal(environment, message); + break; + case 'local': + case 'pai': + case 'remote': + if (this.webCommandChannel === undefined) { + throw new Error(`webCommandChannel not initialezed!`); + } + await this.webCommandChannel.sendCommandInternal(environment, message); + break; + default: + throw new Error(`Heterogenous not support platform: '${environment.platform}'`); + } + } + + protected createRunnerConnection(environment: EnvironmentInformation): RunnerConnection { + if (this.webCommandChannel) { + return this.webCommandChannel.createRunnerConnection(environment); + } + return new WebRunnerConnection(environment); + } +} diff --git a/ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts b/ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts deleted file mode 100644 index 399158484d..0000000000 --- a/ts/nni_manager/training_service/reusable/channels/heterogenousCommandChannel.ts +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -'use strict'; - -import { EventEmitter } from "events"; -import { delay } from "../../../common/utils"; -import { CommandChannel, RunnerConnection } from "../commandChannel"; -import { Channel, EnvironmentInformation } from "../environment"; -import { AMLCommandChannel } from "./amlCommandChannel"; -import { WebCommandChannel } from "./webCommandChannel"; - -export class HeterogenousCommandChannel extends CommandChannel{ - private stopping: boolean = false; - private amlCommandChannel: AMLCommandChannel | undefined; - private webCommandChannel: WebCommandChannel | undefined; - - public get channelName(): Channel { - return "heterogenous"; - } - - public constructor(commandEmitter: EventEmitter) { - super(commandEmitter); - } - - public async config(_key: string, _value: any): Promise { - // do nothing - } - - public async start(): Promise { - if (this.amlCommandChannel) { - this.amlCommandChannel.start(); - } - if (this.webCommandChannel) { - this.webCommandChannel.start(); - } - } - - public async stop(): Promise { - this.stopping = true; - } - - public async run(): Promise { - if (this.amlCommandChannel) { - this.amlCommandChannel.run(); - } - if (this.webCommandChannel) { - this.webCommandChannel.run(); - } - } - - protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise { - switch (environment.platform) { - case 'aml': - if (this.amlCommandChannel === undefined) { - throw new Error(`amlCommandChannel not initialezed!`); - } - this.amlCommandChannel.sendCommandInternal(environment, message); - break; - case 'local': - case 'pai': - case 'remote': - if (this.webCommandChannel === undefined) { - throw new Error(`webCommandChannel not initialezed!`); - } - this.webCommandChannel.sendCommandInternal(environment, message); - break; - default: - throw new Error(`Heterogenous not support platform: '${environment.platform}'`); - } - } - - protected createRunnerConnection(environment: EnvironmentInformation): RunnerConnection { - return new RunnerConnection(environment); - } -} diff --git a/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts b/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts index 3fab37f491..f9db475be4 100644 --- a/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts +++ b/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts @@ -9,7 +9,7 @@ import { INITIALIZED } from '../../../core/commands'; import { CommandChannel, RunnerConnection } from "../commandChannel"; import { Channel, EnvironmentInformation } from "../environment"; -class WebRunnerConnection extends RunnerConnection { +export class WebRunnerConnection extends RunnerConnection { public readonly clients: WebSocket[] = []; public async close(): Promise { @@ -46,11 +46,10 @@ export class WebCommandChannel extends CommandChannel { this.webSocketServer = new SocketServer({ port }); this.webSocketServer.on('connection', (client: WebSocket) => { - this.log.debug(`WebCommandChannel: received connection`); + this.log.info(`WebCommandChannel: received connection`); client.onerror = (event): void => { this.log.error(`error on client ${JSON.stringify(event)}`); } - this.clients.set(client, undefined); client.onmessage = (message): void => { this.receivedWebSocketMessage(client, message); @@ -84,20 +83,20 @@ export class WebCommandChannel extends CommandChannel { } } - protected createRunnerConnection(environment: EnvironmentInformation): RunnerConnection { + public createRunnerConnection(environment: EnvironmentInformation): RunnerConnection { return new WebRunnerConnection(environment); } private receivedWebSocketMessage(client: WebSocket, message: MessageEvent): void { let connection = this.clients.get(client) as WebRunnerConnection | undefined; const rawCommands = message.data.toString(); - if (connection === undefined) { // undefined means it's expecting initializing message. const commands = this.parseCommands(rawCommands); let isValid = false; - this.log.debug(`WebCommandChannel: received initialize message: ${JSON.stringify(rawCommands)}`); - + this.log.info(`WebCommandChannel: received initialize message: ${JSON.stringify(rawCommands)}`); + + if (commands.length > 0) { const commandType = commands[0][0]; const result = commands[0][1]; diff --git a/ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/heterogeneousEnvironmentService.ts similarity index 71% rename from ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts rename to ts/nni_manager/training_service/reusable/environments/heterogeneousEnvironmentService.ts index e9bfec1790..a644626db6 100644 --- a/ts/nni_manager/training_service/reusable/environments/heterogenousEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/heterogeneousEnvironmentService.ts @@ -9,7 +9,7 @@ import * as path from 'path'; import * as component from '../../../common/component'; import { getLogger, Logger } from '../../../common/log'; import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey'; -import { HeterogenousCommandChannel } from '../channels/heterogenousCommandChannel'; +import { HeterogenousCommandChannel } from '../channels/heterogeneousCommandChannel'; import { CommandChannel } from "../commandChannel"; import { EnvironmentInformation, EnvironmentService } from '../environment'; import { AMLEnvironmentService } from './amlEnvironmentService'; @@ -18,13 +18,14 @@ import { LocalEnvironmentService } from './localEnvironmentService'; import { OpenPaiEnvironmentService } from './openPaiEnvironmentService'; import { randomSelect } from '../../../common/utils'; import { HeterogenousConfig } from '../heterogenous/heterogenousConfig'; +import { WebCommandChannel } from '../channels/webCommandChannel'; /** * Collector PAI jobs info from PAI cluster, and update pai job status locally */ @component.Singleton -export class HeteroGenousEnvironmentService extends EnvironmentService { +export class HeteroGeneousEnvironmentService extends EnvironmentService { private amlEnvironmentService: AMLEnvironmentService; private remoteEnvironmentService: RemoteEnvironmentService; @@ -47,31 +48,35 @@ export class HeteroGenousEnvironmentService extends EnvironmentService { } public createCommandChannel(commandEmitter: EventEmitter): CommandChannel { - return new HeterogenousCommandChannel(commandEmitter); + if (this.heterogenousConfig === undefined) { + throw new Error('heterogenousConfig not initialized!'); + } + return new HeterogenousCommandChannel(commandEmitter, this.heterogenousConfig.trainingServicePlatforms); } public async config(key: string, value: string): Promise { switch (key) { case TrialConfigMetadataKey.AML_CLUSTER_CONFIG: - this.amlEnvironmentService.config(key, value); + await this.amlEnvironmentService.config(key, value); break; case TrialConfigMetadataKey.MACHINE_LIST: - this.remoteEnvironmentService.config(key, value); + await this.remoteEnvironmentService.config(key, value); break; - case TrialConfigMetadataKey.TRIAL_CONFIG: - this.amlEnvironmentService.config(key, value); - this.remoteEnvironmentService.config(key, value); - this.paiEnvironmentService.config(key, value); - this.localEnvironmentService.config(key, value); + case TrialConfigMetadataKey.TRIAL_CONFIG: + await this.amlEnvironmentService.config(key, value); + await this.remoteEnvironmentService.config(key, value); + await this.paiEnvironmentService.config(key, value); + await this.localEnvironmentService.config(key, value); break; case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG: - this.paiEnvironmentService.config(key, value); + await this.paiEnvironmentService.config(key, value); break; case TrialConfigMetadataKey.LOCAL_CONFIG: - this.localEnvironmentService.config(key, value); + await this.localEnvironmentService.config(key, value); break; - case TrialConfigMetadataKey.HETEROGENOUS_CONFIG: + case TrialConfigMetadataKey.HETEROGENEOUS_CONFIG: this.heterogenousConfig = JSON.parse(value); + break; default: this.log.debug(`Heterogenous not support metadata key: '${key}', value: '${value}'`); } @@ -102,24 +107,25 @@ export class HeteroGenousEnvironmentService extends EnvironmentService { if (this.heterogenousConfig === undefined) { throw new Error('heterogenousConfig not initialized!'); } - this.heterogenousConfig.trainingServicePlatforms; + console.log('traningServicePlatforms: ') + console.log(this.heterogenousConfig.trainingServicePlatforms) const platform = randomSelect(this.heterogenousConfig.trainingServicePlatforms); switch (platform) { case 'aml': environment.platform = 'aml'; - this.amlEnvironmentService.startEnvironment(environment); + await this.amlEnvironmentService.startEnvironment(environment); break; case 'remote': environment.platform = 'remote'; - this.remoteEnvironmentService.startEnvironment(environment); + await this.remoteEnvironmentService.startEnvironment(environment); break; case 'local': environment.platform = 'local'; - this.localEnvironmentService.startEnvironment(environment); + await this.localEnvironmentService.startEnvironment(environment); break; case 'pai': environment.platform = 'pai'; - this.paiEnvironmentService.startEnvironment(environment); + await this.paiEnvironmentService.startEnvironment(environment); break; } } @@ -127,16 +133,16 @@ export class HeteroGenousEnvironmentService extends EnvironmentService { public async stopEnvironment(environment: EnvironmentInformation): Promise { switch (environment.platform) { case 'aml': - this.amlEnvironmentService.stopEnvironment(environment); + await this.amlEnvironmentService.stopEnvironment(environment); break; case 'remote': - this.remoteEnvironmentService.stopEnvironment(environment); + await this.remoteEnvironmentService.stopEnvironment(environment); break; case 'local': - this.localEnvironmentService.stopEnvironment(environment); + await this.localEnvironmentService.stopEnvironment(environment); break; case 'pai': - this.paiEnvironmentService.stopEnvironment(environment); + await this.paiEnvironmentService.stopEnvironment(environment); break; default: throw new Error(`Heterogenous not support platform '${environment.platform}'`); diff --git a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts index e4f026ad2b..007c87a9cb 100644 --- a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts @@ -22,7 +22,6 @@ import { EnvironmentInformation, EnvironmentService } from '../environment'; import { StorageService } from '../storageService'; import { TrialConfig } from '../../common/trialConfig'; import { getExperimentRootDir, isAlive } from '../../../common/utils'; -import { LocalConfig } from '../../local/localTrainingService'; import { execMkdir, validateCodeDir, runScript, fileExist, execCopydir } from '../../common/util'; import { FileCommandChannel } from '../channels/fileCommandChannel'; import { CommandChannel } from "../commandChannel"; @@ -33,7 +32,6 @@ export class LocalEnvironmentService extends EnvironmentService { private readonly log: Logger = getLogger(); private localTrialConfig: TrialConfig | undefined; - private localConfig: LocalConfig | undefined; private experimentRootDir: string; private experimentId: string; @@ -53,9 +51,6 @@ export class LocalEnvironmentService extends EnvironmentService { public async config(key: string, value: string): Promise { switch (key) { - case TrialConfigMetadataKey.LOCAL_CONFIG: - this.localConfig = JSON.parse(value); - break; case TrialConfigMetadataKey.TRIAL_CONFIG: this.localTrialConfig = JSON.parse(value); break; @@ -75,6 +70,7 @@ export class LocalEnvironmentService extends EnvironmentService { public async refreshEnvironment(environment: EnvironmentInformation): Promise { const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`; const runnerReturnCodeFilePath: string = `${environment.runnerWorkingFolder}/code`; + console.log(jobpidPath) /* eslint-disable require-atomic-updates */ try { // check if pid file exist @@ -111,9 +107,6 @@ export class LocalEnvironmentService extends EnvironmentService { if (this.localTrialConfig === undefined) { throw new Error('Local trial config is not initialized'); } - if (this.localConfig === undefined) { - throw new Error('Local config is not initialized'); - } // Need refactor, this temp folder path is not appropriate, there are two expId in this path const localTempFolder: string = path.join(this.experimentRootDir, this.experimentId, "environment-temp", "envs"); diff --git a/ts/nni_manager/training_service/reusable/routerTrainingService.ts b/ts/nni_manager/training_service/reusable/routerTrainingService.ts index 47c49a9a7f..7246645dbc 100644 --- a/ts/nni_manager/training_service/reusable/routerTrainingService.ts +++ b/ts/nni_manager/training_service/reusable/routerTrainingService.ts @@ -19,7 +19,7 @@ import { LocalEnvironmentService } from './environments/localEnvironmentService' import { AMLEnvironmentService } from './environments/amlEnvironmentService'; import { RemoteEnvironmentService } from './environments/remoteEnvironmentService'; import { MountedStorageService } from './storages/mountedStorageService'; -import { HeteroGenousEnvironmentService } from './environments/heterogenousEnvironmentService'; +import { HeteroGeneousEnvironmentService } from './environments/heterogeneousEnvironmentService'; import { StorageService } from './storageService'; import { TrialDispatcher } from './trialDispatcher'; import { RemoteConfig } from './remote/remoteConfig'; @@ -103,18 +103,13 @@ class RouterTrainingService implements TrainingService { public async setClusterMetadata(key: string, value: string): Promise { if (this.internalTrainingService === undefined) { if (key === TrialConfigMetadataKey.LOCAL_CONFIG) { - const config = JSON.parse(value); - if (config.reuse === true) { - this.log.info(`reuse flag enabled, use EnvironmentManager.`); - this.internalTrainingService = component.get(TrialDispatcher); + this.log.info(`reuse flag enabled, use EnvironmentManager.`); + this.internalTrainingService = component.get(TrialDispatcher); - // TODO to support other serivces later. - Container.bind(EnvironmentService) - .to(LocalEnvironmentService) - .scope(Scope.Singleton); - } else { - this.internalTrainingService = component.get(LocalTrainingService); - } + // TODO to support other serivces later. + Container.bind(EnvironmentService) + .to(LocalEnvironmentService) + .scope(Scope.Singleton); if (this.internalTrainingService === undefined) { throw new Error("TrainingService is not assigned!"); } @@ -182,10 +177,10 @@ class RouterTrainingService implements TrainingService { this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`); this.internalTrainingService = component.get(RemoteMachineTrainingService); } - } else if (key === TrialConfigMetadataKey.HETEROGENOUS_CONFIG){ + } else if (key === TrialConfigMetadataKey.HETEROGENEOUS_CONFIG){ this.internalTrainingService = component.get(TrialDispatcher); Container.bind(EnvironmentService) - .to(HeteroGenousEnvironmentService) + .to(HeteroGeneousEnvironmentService) .scope(Scope.Singleton); if (this.internalTrainingService === undefined) { From d95d17be5a2ecf54780a1ec4211939ea15073173 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Mon, 30 Nov 2020 05:58:35 +0800 Subject: [PATCH 06/24] fix comments --- nni/tools/nnictl/launcher.py | 20 ++++++++----------- ts/nni_manager/main.ts | 19 ++++-------------- .../training_service/common/util.ts | 10 ---------- .../heterogeneousEnvironmentService.ts | 10 +++++++++- .../environments/localEnvironmentService.ts | 9 ++++----- 5 files changed, 25 insertions(+), 43 deletions(-) diff --git a/nni/tools/nnictl/launcher.py b/nni/tools/nnictl/launcher.py index d0de8bb573..65aaf1629a 100644 --- a/nni/tools/nnictl/launcher.py +++ b/nni/tools/nnictl/launcher.py @@ -406,19 +406,15 @@ def set_experiment(experiment_config, mode, port, config_file_name): request_data['clusterMetaData'].append( {'key': 'heterogeneous_config', 'value': experiment_config['heterogeneousConfig']}) platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms'] + request_dict = { + 'aml': {'key': 'aml_config', 'value': experiment_config.get('amlConfig')}, + 'remote': {'key': 'machine_list', 'value': experiment_config.get('machineList')}, + 'pai': {'key': 'pai_config', 'value': experiment_config.get('paiConfig')}, + 'local': {'key': 'local_config', 'value': experiment_config.get('localConfig')} + } for platform in platform_list: - if platform == 'aml': - request_data['clusterMetaData'].append( - {'key': 'aml_config', 'value': experiment_config['amlConfig']}) - elif platform == 'remote': - request_data['clusterMetaData'].append( - {'key': 'machine_list', 'value': experiment_config['machineList']}) - elif platform == 'local' and experiment_config.get('localConfig'): - request_data['clusterMetaData'].append( - {'key': 'local_config', 'value': experiment_config['localConfig']}) - elif platform == 'pai': - request_data['clusterMetaData'].append( - {'key': 'pai_config', 'value': experiment_config['paiConfig']}) + if request_dict.get(platform): + request_data['clusterMetaData'].append(request_dict[platform]) request_data['clusterMetaData'].append( {'key': 'trial_config', 'value': experiment_config['trial']}) response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True) diff --git a/ts/nni_manager/main.ts b/ts/nni_manager/main.ts index 86400b32d7..5b49f551ab 100644 --- a/ts/nni_manager/main.ts +++ b/ts/nni_manager/main.ts @@ -35,17 +35,14 @@ function initStartupInfo( } async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise { - if (platformMode === 'local') { - Container.bind(TrainingService) - .to(LocalTrainingService) - .scope(Scope.Singleton); - } else if (platformMode === 'remote') { + const routerPlatformMode = ['remote', 'pai', 'aml', 'heterogeneous']; + if (routerPlatformMode.includes(platformMode)) { Container.bind(TrainingService) .to(RouterTrainingService) .scope(Scope.Singleton); - } else if (platformMode === 'pai') { + } else if (platformMode === 'local') { Container.bind(TrainingService) - .to(RouterTrainingService) + .to(LocalTrainingService) .scope(Scope.Singleton); } else if (platformMode === 'paiYarn') { Container.bind(TrainingService) @@ -63,14 +60,6 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN Container.bind(TrainingService) .to(DLTSTrainingService) .scope(Scope.Singleton); - } else if (platformMode === 'aml') { - Container.bind(TrainingService) - .to(RouterTrainingService) - .scope(Scope.Singleton); - } else if (platformMode === 'heterogeneous') { - Container.bind(TrainingService) - .to(RouterTrainingService) - .scope(Scope.Singleton); } else { throw new Error(`Error: unsupported mode: ${platformMode}`); } diff --git a/ts/nni_manager/training_service/common/util.ts b/ts/nni_manager/training_service/common/util.ts index a1c6f7044f..791d7dcebb 100644 --- a/ts/nni_manager/training_service/common/util.ts +++ b/ts/nni_manager/training_service/common/util.ts @@ -95,16 +95,6 @@ export async function execMkdir(directory: string, share: boolean = false): Prom return Promise.resolve(); } -export async function fileExist(filePath: string): Promise { - let cmdresult: cpp.childProcessPromise.Result; - if (process.platform === 'win32') { - cmdresult = await cpp.exec(`powershell.exe Get-Content "${filePath}" -Tail 1`); - } else { - cmdresult = await cpp.exec(`test -e ${filePath} && echo True || echo False`); - } - return cmdresult.stdout !== undefined && cmdresult.stdout.trim() === 'True' -} - /** * copy files to the directory * @param source diff --git a/ts/nni_manager/training_service/reusable/environments/heterogeneousEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/heterogeneousEnvironmentService.ts index a644626db6..364dbb9617 100644 --- a/ts/nni_manager/training_service/reusable/environments/heterogeneousEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/heterogeneousEnvironmentService.ts @@ -84,6 +84,7 @@ export class HeteroGeneousEnvironmentService extends EnvironmentService { public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise { const tasks: Promise[] = []; + const openPaiEnvironments: EnvironmentInformation[] = []; environments.forEach(async (environment) => { switch (environment.platform) { case 'aml': @@ -95,11 +96,18 @@ export class HeteroGeneousEnvironmentService extends EnvironmentService { case 'local': tasks.push(this.localEnvironmentService.refreshEnvironment(environment)); break; - // TODO: refresh pai + case 'pai': + openPaiEnvironments.push(environment); + break; default: throw new Error(`Heterogenous not support platform: '${environment.platform}'`); } }); + // OpenPai only support refreshEnvironmentsStatus + if (openPaiEnvironments.length) { + tasks.push(this.paiEnvironmentService.refreshEnvironmentsStatus(openPaiEnvironments)); + } + await Promise.all(tasks); } diff --git a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts index 007c87a9cb..cc35e0b640 100644 --- a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts @@ -22,7 +22,7 @@ import { EnvironmentInformation, EnvironmentService } from '../environment'; import { StorageService } from '../storageService'; import { TrialConfig } from '../../common/trialConfig'; import { getExperimentRootDir, isAlive } from '../../../common/utils'; -import { execMkdir, validateCodeDir, runScript, fileExist, execCopydir } from '../../common/util'; +import { execMkdir, validateCodeDir, runScript, execCopydir } from '../../common/util'; import { FileCommandChannel } from '../channels/fileCommandChannel'; import { CommandChannel } from "../commandChannel"; @@ -55,13 +55,13 @@ export class LocalEnvironmentService extends EnvironmentService { this.localTrialConfig = JSON.parse(value); break; default: - this.log.debug(`OpenPAI not proccessed metadata key: '${key}', value: '${value}'`); + this.log.debug(`Local mode does not proccess metadata key: '${key}', value: '${value}'`); } } public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise { const tasks: Promise[] = []; - environments.forEach(async (environment) => { + environments.forEach((environment) => { tasks.push(this.refreshEnvironment(environment)); }); await Promise.all(tasks); @@ -70,11 +70,10 @@ export class LocalEnvironmentService extends EnvironmentService { public async refreshEnvironment(environment: EnvironmentInformation): Promise { const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`; const runnerReturnCodeFilePath: string = `${environment.runnerWorkingFolder}/code`; - console.log(jobpidPath) /* eslint-disable require-atomic-updates */ try { // check if pid file exist - const pidExist = await fileExist(jobpidPath); + const pidExist = await fs.existsSync(jobpidPath); if (!pidExist) { return; } From 4ead8a09181734647d4b052a156385b0a9812430 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Mon, 30 Nov 2020 06:15:19 +0800 Subject: [PATCH 07/24] fix build --- ts/nni_manager/main.ts | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/ts/nni_manager/main.ts b/ts/nni_manager/main.ts index 84e8204d13..d55a6a5b4a 100644 --- a/ts/nni_manager/main.ts +++ b/ts/nni_manager/main.ts @@ -93,11 +93,7 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN function usage(): void { console.info('usage: node main.js --port --mode \ -<<<<<<< HEAD - --start_mode --experiment_id --foreground '); -======= - --start_mode --experiment_id --foreground '); ->>>>>>> 765bc335375cb3d417d5a287a67c48fe8bef010b + --start_mode --experiment_id --foreground '); } const strPort: string = parseArg(['--port', '-p']); @@ -117,11 +113,7 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals const port: number = parseInt(strPort, 10); const mode: string = parseArg(['--mode', '-m']); -<<<<<<< HEAD -if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'heterogeneous'].includes(mode)) { -======= -if (!['adl', 'local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'].includes(mode)) { ->>>>>>> 765bc335375cb3d417d5a287a67c48fe8bef010b +if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'heterogeneous'].includes(mode)) { console.log(`FATAL: unknown mode: ${mode}`); usage(); process.exit(1); From 6b42a4d195453da45a6181642a8f7062b60a52b8 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Mon, 30 Nov 2020 06:27:21 +0800 Subject: [PATCH 08/24] fix comments --- nni/tools/trial_tool/trial_runner.py | 1 - .../reusable/environments/localEnvironmentService.ts | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/nni/tools/trial_tool/trial_runner.py b/nni/tools/trial_tool/trial_runner.py index b30a235932..d3c4ca0b22 100644 --- a/nni/tools/trial_tool/trial_runner.py +++ b/nni/tools/trial_tool/trial_runner.py @@ -25,7 +25,6 @@ def main_loop(args): '''main loop logic for trial runner''' idle_last_time = datetime.now() gpu_refresh_last_time = datetime.now() - timedelta(minutes=1) - nni_log(LogType.Info, "--------------main loop-----------28-----------------") try: if args.job_pid_file: with open(args.job_pid_file, 'w') as job_file: diff --git a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts index cc35e0b640..5851e6d7a6 100644 --- a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts @@ -42,7 +42,7 @@ export class LocalEnvironmentService extends EnvironmentService { } public get environmentMaintenceLoopInterval(): number { - return 5000; + return 1000; } public get hasStorageService(): boolean { From c764277f1e857f00f8afb8ba027d15e39b16e23e Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Thu, 3 Dec 2020 11:39:41 +0800 Subject: [PATCH 09/24] fix comments --- .../trials/mnist-tfv1/config_kubeflow.yml | 7 + nni/tools/nnictl/config_schema.py | 19 +- .../common/trialConfigMetadataKey.ts | 5 +- .../remoteMachineTrainingService.ts | 4 + .../channels/heterogeneousCommandChannel.ts | 115 ------ .../training_service/reusable/environment.ts | 23 +- .../environments/amlEnvironmentService.ts | 17 +- .../heterogeneousEnvironmentService.ts | 159 -------- .../environments/localEnvironmentService.ts | 14 +- .../environments/openPaiEnvironmentService.ts | 75 ++-- .../environments/remoteEnvironmentService.ts | 14 +- .../reusable/routerTrainingService.ts | 93 ++--- .../reusable/test/utEnvironmentService.ts | 16 +- .../reusable/trialDispatcher.ts | 363 +++++++++++------- 14 files changed, 373 insertions(+), 551 deletions(-) delete mode 100644 ts/nni_manager/training_service/reusable/channels/heterogeneousCommandChannel.ts delete mode 100644 ts/nni_manager/training_service/reusable/environments/heterogeneousEnvironmentService.ts diff --git a/examples/trials/mnist-tfv1/config_kubeflow.yml b/examples/trials/mnist-tfv1/config_kubeflow.yml index f460b37cb6..d874223a2a 100644 --- a/examples/trials/mnist-tfv1/config_kubeflow.yml +++ b/examples/trials/mnist-tfv1/config_kubeflow.yml @@ -16,6 +16,13 @@ tuner: optimize_mode: maximize trial: codeDir: . + ps: + replicas: 1 + command: python3 mnist.py + gpuNum: 0 + cpuNum: 1 + memoryMB: 8192 + image: msranni/nni:latest worker: replicas: 1 command: python3 mnist.py diff --git a/nni/tools/nnictl/config_schema.py b/nni/tools/nnictl/config_schema.py index 7401ddfdef..ae528ca58a 100644 --- a/nni/tools/nnictl/config_schema.py +++ b/nni/tools/nnictl/config_schema.py @@ -262,6 +262,23 @@ def validate(self, data): } } +heterogeneous_trial_schema = { + 'trial': { + 'codeDir': setPathCheck('codeDir'), + Optional('nniManagerNFSMountPath'): setPathCheck('nniManagerNFSMountPath'), + Optional('containerNFSMountPath'): setType('containerNFSMountPath', str), + Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'), + 'command': setType('command', str), + Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), + Optional('cpuNum'): setNumberRange('cpuNum', int, 0, 99999), + Optional('memoryMB'): setType('memoryMB', int), + Optional('image'): setType('image', str), + Optional('virtualCluster'): setType('virtualCluster', str), + Optional('paiStorageConfigName'): setType('paiStorageConfigName', str), + Optional('paiConfigPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'paiConfigPath') + } +} + heterogeneous_config_schema = { 'heterogeneousConfig': { 'trainingServicePlatforms': ['local', 'remote', 'pai', 'aml'] @@ -443,7 +460,7 @@ def validate(self, data): 'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}), 'aml': Schema({**common_schema, **aml_trial_schema, **aml_config_schema}), 'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}), - 'heterogeneous': Schema({**common_schema, **common_trial_schema, **heterogeneous_config_schema, **machine_list_schema, + 'heterogeneous': Schema({**common_schema, **heterogeneous_trial_schema, **heterogeneous_config_schema, **machine_list_schema, **pai_config_schema, **aml_config_schema, **remote_config_schema}), } diff --git a/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts b/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts index 38f547a0a9..968a1efa13 100644 --- a/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts +++ b/ts/nni_manager/training_service/common/trialConfigMetadataKey.ts @@ -23,5 +23,8 @@ export enum TrialConfigMetadataKey { DLTS_CLUSTER_CONFIG = 'dlts_config', AML_CLUSTER_CONFIG = 'aml_config', VERSION_CHECK = 'version_check', - LOG_COLLECTION = 'log_collection' + LOG_COLLECTION = 'log_collection', + // Used to set platform for heterogeneous in reuse mode, + // temproarily change and will refactor config schema in the future + PLATFORM_LIST = 'platform_list' } diff --git a/ts/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts b/ts/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts index 6dbb084ac7..174c6b26c2 100644 --- a/ts/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts +++ b/ts/nni_manager/training_service/remote_machine/remoteMachineTrainingService.ts @@ -358,6 +358,10 @@ class RemoteMachineTrainingService implements TrainingService { case TrialConfigMetadataKey.LOG_COLLECTION: this.logCollection = value; break; + case TrialConfigMetadataKey.REMOTE_CONFIG: + // Add remote_config in remoteEnvironmentService to set reuse mode, + // this config need to be catched here, otherwise will throw Unknown key exception here + break; default: //Reject for unknown keys throw new Error(`Uknown key: ${key}`); diff --git a/ts/nni_manager/training_service/reusable/channels/heterogeneousCommandChannel.ts b/ts/nni_manager/training_service/reusable/channels/heterogeneousCommandChannel.ts deleted file mode 100644 index f890d48c23..0000000000 --- a/ts/nni_manager/training_service/reusable/channels/heterogeneousCommandChannel.ts +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -'use strict'; - -import { EventEmitter } from "events"; -import { delay } from "../../../common/utils"; -import { CommandChannel, RunnerConnection } from "../commandChannel"; -import { Channel, EnvironmentInformation } from "../environment"; -import { AMLCommandChannel } from "./amlCommandChannel"; -import { WebCommandChannel, WebRunnerConnection } from "./webCommandChannel"; - - -export class HeterogenousCommandChannel extends CommandChannel{ - private stopping: boolean = false; - private amlCommandChannel: AMLCommandChannel | undefined; - private webCommandChannel: WebCommandChannel | undefined; - - public get channelName(): Channel { - return "web"; - } - - public constructor(commandEmitter: EventEmitter, platformsArray: string[]) { - super(commandEmitter); - console.log(platformsArray.includes('local')) - if (platformsArray.includes('local') || - platformsArray.includes('remote') || - platformsArray.includes('pai')) { - this.webCommandChannel = new WebCommandChannel(commandEmitter); - } - if (platformsArray.includes('aml')) { - this.amlCommandChannel = new AMLCommandChannel(commandEmitter); - } - } - - public async config(_key: string, _value: any): Promise { - // do nothing - } - - public async start(): Promise { - const tasks: Promise[] = []; - if (this.amlCommandChannel) { - tasks.push(this.amlCommandChannel.start()); - } - if (this.webCommandChannel) { - tasks.push(this.webCommandChannel.start()); - } - await Promise.all(tasks); - } - - public async stop(): Promise { - this.stopping = true; - } - - public async open(environment: EnvironmentInformation): Promise { - const tasks: Promise[] = []; - if (this.amlCommandChannel) { - tasks.push(this.amlCommandChannel.open(environment)); - } - if (this.webCommandChannel) { - tasks.push(this.webCommandChannel.open(environment)); - } - await Promise.all(tasks); - } - - public async close(environment: EnvironmentInformation): Promise { - const tasks: Promise[] = []; - if (this.amlCommandChannel) { - tasks.push(this.amlCommandChannel.close(environment)); - } - if (this.webCommandChannel) { - tasks.push(this.webCommandChannel.close(environment)); - } - await Promise.all(tasks); - } - - public async run(): Promise { - const tasks: Promise[] = []; - if (this.amlCommandChannel) { - tasks.push(this.amlCommandChannel.run()); - } - if (this.webCommandChannel) { - tasks.push(this.webCommandChannel.run()); - } - await Promise.all(tasks); - } - - protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise { - switch (environment.platform) { - case 'aml': - if (this.amlCommandChannel === undefined) { - throw new Error(`amlCommandChannel not initialezed!`); - } - await this.amlCommandChannel.sendCommandInternal(environment, message); - break; - case 'local': - case 'pai': - case 'remote': - if (this.webCommandChannel === undefined) { - throw new Error(`webCommandChannel not initialezed!`); - } - await this.webCommandChannel.sendCommandInternal(environment, message); - break; - default: - throw new Error(`Heterogenous not support platform: '${environment.platform}'`); - } - } - - protected createRunnerConnection(environment: EnvironmentInformation): RunnerConnection { - if (this.webCommandChannel) { - return this.webCommandChannel.createRunnerConnection(environment); - } - return new WebRunnerConnection(environment); - } -} diff --git a/ts/nni_manager/training_service/reusable/environment.ts b/ts/nni_manager/training_service/reusable/environment.ts index fdbb9fdb56..032a896f69 100644 --- a/ts/nni_manager/training_service/reusable/environment.ts +++ b/ts/nni_manager/training_service/reusable/environment.ts @@ -4,6 +4,7 @@ 'use strict'; import { EventEmitter } from "events"; +import { runScript } from "training_service/common/util"; import { getLogger, Logger } from "../../common/log"; import { TrialJobStatus } from "../../common/trainingService"; import { GPUInfo } from "../../training_service/common/gpuData"; @@ -12,7 +13,7 @@ import { CommandChannel } from "./commandChannel"; export type EnvironmentStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED'; -export type Channel = "web" | "file" | "aml" | "ut" | "heterogenous"; +export type Channel = "web" | "file" | "aml" | "ut"; export class TrialGpuSummary { @@ -74,9 +75,8 @@ export class EnvironmentInformation { // user can specify how to use GPU resource for an environment, like local and remote. public maxTrialNumberPerGpu?: number; public useActiveGpu?: boolean; - - // the running mode for trial jobs, including local, remote, aml, pai etc. - public platform: string = ""; + + public environmentService?: EnvironmentService; constructor(id: string, name: string, envId?: string) { this.log = getLogger(); @@ -124,10 +124,11 @@ export class EnvironmentInformation { } export abstract class EnvironmentService { - + + protected commandChannel: CommandChannel | undefined; public abstract get hasStorageService(): boolean; public abstract config(key: string, value: string): Promise; - public abstract refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise; + public abstract refreshEnvironmentStatus(environment: EnvironmentInformation): Promise; public abstract stopEnvironment(environment: EnvironmentInformation): Promise; public abstract startEnvironment(environment: EnvironmentInformation): Promise; @@ -137,6 +138,12 @@ export abstract class EnvironmentService { return 0; } + public abstract get getPlatform(): string; + + public get getCommandChanneName(): Channel { + return 'web'; + } + // It depends on environment pressure and settings // for example, OpenPAI relies on API calls, and there is an limitation for frequence, so it need to be bigger. public get environmentMaintenceLoopInterval(): number { @@ -150,10 +157,6 @@ export abstract class EnvironmentService { return true; } - public createCommandChannel(commandEmitter: EventEmitter): CommandChannel { - return new WebCommandChannel(commandEmitter); - } - public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation { return new EnvironmentInformation(envId, envName); } diff --git a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts index 8b0a8996d7..b267303ca9 100644 --- a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts @@ -42,13 +42,18 @@ export class AMLEnvironmentService extends EnvironmentService { } public createCommandChannel(commandEmitter: EventEmitter): CommandChannel { - return new AMLCommandChannel(commandEmitter); + this.commandChannel = new AMLCommandChannel(commandEmitter); + return this.commandChannel; } public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation { return new AMLEnvironmentInformation(envId, envName); } + public get getPlatform(): string { + return 'aml'; + } + public async config(key: string, value: string): Promise { switch (key) { case TrialConfigMetadataKey.AML_CLUSTER_CONFIG: @@ -70,7 +75,7 @@ export class AMLEnvironmentService extends EnvironmentService { } } - public async refreshEnvironment(environment: EnvironmentInformation): Promise { + public async refreshEnvironmentStatus(environment: EnvironmentInformation): Promise { const amlClient = (environment as AMLEnvironmentInformation).amlClient; if (!amlClient) { return Promise.reject('AML client not initialized!'); @@ -100,14 +105,6 @@ export class AMLEnvironmentService extends EnvironmentService { } } - public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise { - const tasks: Promise[] = []; - environments.forEach(async (environment) => { - tasks.push(this.refreshEnvironment(environment)); - }); - await Promise.all(tasks); - } - public async startEnvironment(environment: EnvironmentInformation): Promise { if (this.amlClusterConfig === undefined) { throw new Error('AML Cluster config is not initialized'); diff --git a/ts/nni_manager/training_service/reusable/environments/heterogeneousEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/heterogeneousEnvironmentService.ts deleted file mode 100644 index 364dbb9617..0000000000 --- a/ts/nni_manager/training_service/reusable/environments/heterogeneousEnvironmentService.ts +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -'use strict'; - -import { EventEmitter } from "events"; -import * as fs from 'fs'; -import * as path from 'path'; -import * as component from '../../../common/component'; -import { getLogger, Logger } from '../../../common/log'; -import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey'; -import { HeterogenousCommandChannel } from '../channels/heterogeneousCommandChannel'; -import { CommandChannel } from "../commandChannel"; -import { EnvironmentInformation, EnvironmentService } from '../environment'; -import { AMLEnvironmentService } from './amlEnvironmentService'; -import { RemoteEnvironmentService } from './remoteEnvironmentService'; -import { LocalEnvironmentService } from './localEnvironmentService'; -import { OpenPaiEnvironmentService } from './openPaiEnvironmentService'; -import { randomSelect } from '../../../common/utils'; -import { HeterogenousConfig } from '../heterogenous/heterogenousConfig'; -import { WebCommandChannel } from '../channels/webCommandChannel'; - - -/** - * Collector PAI jobs info from PAI cluster, and update pai job status locally - */ -@component.Singleton -export class HeteroGeneousEnvironmentService extends EnvironmentService { - - private amlEnvironmentService: AMLEnvironmentService; - private remoteEnvironmentService: RemoteEnvironmentService; - private localEnvironmentService: LocalEnvironmentService; - private paiEnvironmentService: OpenPaiEnvironmentService; - private heterogenousConfig?: HeterogenousConfig; - - private readonly log: Logger = getLogger(); - - constructor() { - super(); - this.amlEnvironmentService = new AMLEnvironmentService(); - this.remoteEnvironmentService = new RemoteEnvironmentService(); - this.localEnvironmentService = new LocalEnvironmentService(); - this.paiEnvironmentService = new OpenPaiEnvironmentService(); - } - - public get hasStorageService(): boolean { - return false; - } - - public createCommandChannel(commandEmitter: EventEmitter): CommandChannel { - if (this.heterogenousConfig === undefined) { - throw new Error('heterogenousConfig not initialized!'); - } - return new HeterogenousCommandChannel(commandEmitter, this.heterogenousConfig.trainingServicePlatforms); - } - - public async config(key: string, value: string): Promise { - switch (key) { - case TrialConfigMetadataKey.AML_CLUSTER_CONFIG: - await this.amlEnvironmentService.config(key, value); - break; - case TrialConfigMetadataKey.MACHINE_LIST: - await this.remoteEnvironmentService.config(key, value); - break; - case TrialConfigMetadataKey.TRIAL_CONFIG: - await this.amlEnvironmentService.config(key, value); - await this.remoteEnvironmentService.config(key, value); - await this.paiEnvironmentService.config(key, value); - await this.localEnvironmentService.config(key, value); - break; - case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG: - await this.paiEnvironmentService.config(key, value); - break; - case TrialConfigMetadataKey.LOCAL_CONFIG: - await this.localEnvironmentService.config(key, value); - break; - case TrialConfigMetadataKey.HETEROGENEOUS_CONFIG: - this.heterogenousConfig = JSON.parse(value); - break; - default: - this.log.debug(`Heterogenous not support metadata key: '${key}', value: '${value}'`); - } - } - - public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise { - const tasks: Promise[] = []; - const openPaiEnvironments: EnvironmentInformation[] = []; - environments.forEach(async (environment) => { - switch (environment.platform) { - case 'aml': - tasks.push(this.amlEnvironmentService.refreshEnvironment(environment)); - break; - case 'remote': - tasks.push(this.remoteEnvironmentService.refreshEnvironment(environment)); - break; - case 'local': - tasks.push(this.localEnvironmentService.refreshEnvironment(environment)); - break; - case 'pai': - openPaiEnvironments.push(environment); - break; - default: - throw new Error(`Heterogenous not support platform: '${environment.platform}'`); - } - }); - // OpenPai only support refreshEnvironmentsStatus - if (openPaiEnvironments.length) { - tasks.push(this.paiEnvironmentService.refreshEnvironmentsStatus(openPaiEnvironments)); - } - - await Promise.all(tasks); - } - - public async startEnvironment(environment: EnvironmentInformation): Promise { - if (this.heterogenousConfig === undefined) { - throw new Error('heterogenousConfig not initialized!'); - } - console.log('traningServicePlatforms: ') - console.log(this.heterogenousConfig.trainingServicePlatforms) - const platform = randomSelect(this.heterogenousConfig.trainingServicePlatforms); - switch (platform) { - case 'aml': - environment.platform = 'aml'; - await this.amlEnvironmentService.startEnvironment(environment); - break; - case 'remote': - environment.platform = 'remote'; - await this.remoteEnvironmentService.startEnvironment(environment); - break; - case 'local': - environment.platform = 'local'; - await this.localEnvironmentService.startEnvironment(environment); - break; - case 'pai': - environment.platform = 'pai'; - await this.paiEnvironmentService.startEnvironment(environment); - break; - } - } - - public async stopEnvironment(environment: EnvironmentInformation): Promise { - switch (environment.platform) { - case 'aml': - await this.amlEnvironmentService.stopEnvironment(environment); - break; - case 'remote': - await this.remoteEnvironmentService.stopEnvironment(environment); - break; - case 'local': - await this.localEnvironmentService.stopEnvironment(environment); - break; - case 'pai': - await this.paiEnvironmentService.stopEnvironment(environment); - break; - default: - throw new Error(`Heterogenous not support platform '${environment.platform}'`); - } - } -} diff --git a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts index 5851e6d7a6..40eabe229d 100644 --- a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts @@ -49,6 +49,10 @@ export class LocalEnvironmentService extends EnvironmentService { return false; } + public get getPlatform(): string { + return 'local'; + } + public async config(key: string, value: string): Promise { switch (key) { case TrialConfigMetadataKey.TRIAL_CONFIG: @@ -59,15 +63,7 @@ export class LocalEnvironmentService extends EnvironmentService { } } - public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise { - const tasks: Promise[] = []; - environments.forEach((environment) => { - tasks.push(this.refreshEnvironment(environment)); - }); - await Promise.all(tasks); - } - - public async refreshEnvironment(environment: EnvironmentInformation): Promise { + public async refreshEnvironmentStatus(environment: EnvironmentInformation): Promise { const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`; const runnerReturnCodeFilePath: string = `${environment.runnerWorkingFolder}/code`; /* eslint-disable require-atomic-updates */ diff --git a/ts/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts index 596c81dbe9..b4b071a089 100644 --- a/ts/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts @@ -45,6 +45,10 @@ export class OpenPaiEnvironmentService extends EnvironmentService { return true; } + public get getPlatform(): string { + return 'pai'; + } + public async config(key: string, value: string): Promise { switch (key) { case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG: @@ -85,7 +89,7 @@ export class OpenPaiEnvironmentService extends EnvironmentService { } } - public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise { + public async refreshEnvironmentStatus(environment: EnvironmentInformation): Promise { const deferred: Deferred = new Deferred(); if (this.paiClusterConfig === undefined) { @@ -96,7 +100,7 @@ export class OpenPaiEnvironmentService extends EnvironmentService { } const getJobInfoRequest: request.Options = { - uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs?username=${this.paiClusterConfig.userName}`, + uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs/${this.paiClusterConfig.userName}~${environment.envId}`, method: 'GET', json: true, headers: { @@ -113,47 +117,34 @@ export class OpenPaiEnvironmentService extends EnvironmentService { this.log.error(`${errorMessage}`); deferred.reject(errorMessage); } else { - const jobInfos = new Map(); - body.forEach((jobInfo: any) => { - jobInfos.set(jobInfo.name, jobInfo); - }); - - environments.forEach((environment) => { - if (jobInfos.has(environment.envId)) { - const jobResponse = jobInfos.get(environment.envId); - if (jobResponse && jobResponse.state) { - const oldEnvironmentStatus = environment.status; - switch (jobResponse.state) { - case 'RUNNING': - case 'WAITING': - case 'SUCCEEDED': - environment.setStatus(jobResponse.state); - break; - case 'FAILED': - environment.setStatus(jobResponse.state); - deferred.reject(`OpenPAI: job ${environment.envId} is failed!`); - break; - case 'STOPPED': - case 'STOPPING': - environment.setStatus('USER_CANCELED'); - break; - default: - this.log.error(`OpenPAI: job ${environment.envId} returns unknown state ${jobResponse.state}.`); - environment.setStatus('UNKNOWN'); - } - if (oldEnvironmentStatus !== environment.status) { - this.log.debug(`OpenPAI: job ${environment.envId} change status ${oldEnvironmentStatus} to ${environment.status} due to job is ${jobResponse.state}.`) - } - } else { - this.log.error(`OpenPAI: job ${environment.envId} has no state returned. body:${JSON.stringify(jobResponse)}`); - // some error happens, and mark this environment - environment.status = 'FAILED'; - } - } else { - this.log.error(`OpenPAI job ${environment.envId} is not found in job list.`); - environment.status = 'UNKNOWN'; + if (body.jobStatus && body.jobStatus.state) { + const oldEnvironmentStatus = environment.status; + switch (body.jobStatus.state) { + case 'RUNNING': + case 'WAITING': + case 'SUCCEEDED': + environment.setStatus(body.jobStatus.state); + break; + case 'FAILED': + environment.setStatus(body.jobStatus.state); + deferred.reject(`OpenPAI: job ${environment.envId} is failed!`); + break; + case 'STOPPED': + case 'STOPPING': + environment.setStatus('USER_CANCELED'); + break; + default: + this.log.error(`OpenPAI: job ${environment.envId} returns unknown state ${body.jobStatus.state}.`); + environment.setStatus('UNKNOWN'); + } + if (oldEnvironmentStatus !== environment.status) { + this.log.debug(`OpenPAI: job ${environment.envId} change status ${oldEnvironmentStatus} to ${environment.status} due to job is ${body.jobStatus.state}.`) } - }); + } else { + this.log.error(`OpenPAI: job ${environment.envId} has no state returned. body:${JSON.stringify(body)}`); + // some error happens, and mark this environment + environment.status = 'FAILED'; + } deferred.resolve(); } }); diff --git a/ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts index 6526a7e7ab..d898373fee 100644 --- a/ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts @@ -63,6 +63,10 @@ export class RemoteEnvironmentService extends EnvironmentService { return false; } + public get getPlatform(): string { + return 'remote'; + } + public async config(key: string, value: string): Promise { switch (key) { case TrialConfigMetadataKey.MACHINE_LIST: @@ -135,7 +139,7 @@ export class RemoteEnvironmentService extends EnvironmentService { await executor.allowPermission(true, nniRootDir); } - public async refreshEnvironment(environment: EnvironmentInformation): Promise { + public async refreshEnvironmentStatus(environment: EnvironmentInformation): Promise { const executor = await this.getExecutor(environment.id); const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`; const runnerReturnCodeFilePath: string = `${environment.runnerWorkingFolder}/code`; @@ -176,14 +180,6 @@ export class RemoteEnvironmentService extends EnvironmentService { } } - public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise { - const tasks: Promise[] = []; - environments.forEach(async (environment) => { - tasks.push(this.refreshEnvironment(environment)); - }); - await Promise.all(tasks); - } - /** * If a environment is finished, release the connection resource * @param environment remote machine environment job detail diff --git a/ts/nni_manager/training_service/reusable/routerTrainingService.ts b/ts/nni_manager/training_service/reusable/routerTrainingService.ts index 7246645dbc..e4320faaea 100644 --- a/ts/nni_manager/training_service/reusable/routerTrainingService.ts +++ b/ts/nni_manager/training_service/reusable/routerTrainingService.ts @@ -19,10 +19,10 @@ import { LocalEnvironmentService } from './environments/localEnvironmentService' import { AMLEnvironmentService } from './environments/amlEnvironmentService'; import { RemoteEnvironmentService } from './environments/remoteEnvironmentService'; import { MountedStorageService } from './storages/mountedStorageService'; -import { HeteroGeneousEnvironmentService } from './environments/heterogeneousEnvironmentService'; import { StorageService } from './storageService'; import { TrialDispatcher } from './trialDispatcher'; import { RemoteConfig } from './remote/remoteConfig'; +import { HeterogenousConfig } from './heterogenous/heterogenousConfig'; import { LocalConfig, LocalTrainingService } from '../local/localTrainingService'; @@ -34,7 +34,6 @@ import { LocalConfig, LocalTrainingService } from '../local/localTrainingService class RouterTrainingService implements TrainingService { protected readonly log!: Logger; private internalTrainingService: TrainingService | undefined; - private metaDataCache: Map = new Map(); constructor() { this.log = getLogger(); @@ -102,95 +101,69 @@ class RouterTrainingService implements TrainingService { public async setClusterMetadata(key: string, value: string): Promise { if (this.internalTrainingService === undefined) { - if (key === TrialConfigMetadataKey.LOCAL_CONFIG) { - this.log.info(`reuse flag enabled, use EnvironmentManager.`); + if (key === TrialConfigMetadataKey.HETEROGENEOUS_CONFIG){ this.internalTrainingService = component.get(TrialDispatcher); - - // TODO to support other serivces later. - Container.bind(EnvironmentService) - .to(LocalEnvironmentService) + const heterogenousConfig: HeterogenousConfig = JSON.parse(value); + if (this.internalTrainingService === undefined) { + throw new Error("internalTrainingService not initialized!"); + } + // Initialize storageService for pai + if (heterogenousConfig.trainingServicePlatforms.includes('pai')) { + Container.bind(StorageService) + .to(MountedStorageService) .scope(Scope.Singleton); + } + await this.internalTrainingService.setClusterMetadata('platform_list', + heterogenousConfig.trainingServicePlatforms.join(',')); + } else if (key === TrialConfigMetadataKey.LOCAL_CONFIG) { + this.internalTrainingService = component.get(TrialDispatcher); if (this.internalTrainingService === undefined) { - throw new Error("TrainingService is not assigned!"); + throw new Error("internalTrainingService not initialized!"); } - await this.internalTrainingService.setClusterMetadata(key, value); - } - if (key === TrialConfigMetadataKey.PAI_CLUSTER_CONFIG) { + await this.internalTrainingService.setClusterMetadata('platform_list', 'local'); + } else if (key === TrialConfigMetadataKey.PAI_CLUSTER_CONFIG) { const config = JSON.parse(value); if (config.reuse === true) { this.log.info(`reuse flag enabled, use EnvironmentManager.`); this.internalTrainingService = component.get(TrialDispatcher); - - // TODO to support other serivces later. - Container.bind(EnvironmentService) - .to(OpenPaiEnvironmentService) - .scope(Scope.Singleton); // TODO to support other storages later. Container.bind(StorageService) .to(MountedStorageService) .scope(Scope.Singleton); + if (this.internalTrainingService === undefined) { + throw new Error("internalTrainingService not initialized!"); + } + await this.internalTrainingService.setClusterMetadata('platform_list', 'pai'); } else { this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`); this.internalTrainingService = component.get(PAIK8STrainingService); } - for (const [key, value] of this.metaDataCache) { - if (this.internalTrainingService === undefined) { - throw new Error("TrainingService is not assigned!"); - } - await this.internalTrainingService.setClusterMetadata(key, value); - } - - if (this.internalTrainingService === undefined) { - throw new Error("TrainingService is not assigned!"); - } - await this.internalTrainingService.setClusterMetadata(key, value); - - this.metaDataCache.clear(); } else if (key === TrialConfigMetadataKey.AML_CLUSTER_CONFIG) { this.internalTrainingService = component.get(TrialDispatcher); - - Container.bind(EnvironmentService) - .to(AMLEnvironmentService) - .scope(Scope.Singleton); - for (const [key, value] of this.metaDataCache) { - if (this.internalTrainingService === undefined) { - throw new Error("TrainingService is not assigned!"); - } - await this.internalTrainingService.setClusterMetadata(key, value); - } - if (this.internalTrainingService === undefined) { - throw new Error("TrainingService is not assigned!"); + throw new Error("internalTrainingService not initialized!"); } - await this.internalTrainingService.setClusterMetadata(key, value); - - this.metaDataCache.clear(); + await this.internalTrainingService.setClusterMetadata('platform_list', 'aml'); } else if (key === TrialConfigMetadataKey.REMOTE_CONFIG) { const config = JSON.parse(value); if (config.reuse === true) { this.log.info(`reuse flag enabled, use EnvironmentManager.`); this.internalTrainingService = component.get(TrialDispatcher); - Container.bind(EnvironmentService) - .to(RemoteEnvironmentService) - .scope(Scope.Singleton); + if (this.internalTrainingService === undefined) { + throw new Error("internalTrainingService not initialized!"); + } + await this.internalTrainingService.setClusterMetadata('platform_list', 'remote'); } else { this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`); this.internalTrainingService = component.get(RemoteMachineTrainingService); } - } else if (key === TrialConfigMetadataKey.HETEROGENEOUS_CONFIG){ - this.internalTrainingService = component.get(TrialDispatcher); - Container.bind(EnvironmentService) - .to(HeteroGeneousEnvironmentService) - .scope(Scope.Singleton); - - if (this.internalTrainingService === undefined) { - throw new Error("TrainingService is not assigned!"); - } - await this.internalTrainingService.setClusterMetadata(key, value); } - } else { - await this.internalTrainingService.setClusterMetadata(key, value); } + if (this.internalTrainingService === undefined) { + throw new Error("internalTrainingService not initialized!"); + } + await this.internalTrainingService.setClusterMetadata(key, value); + } public async getClusterMetadata(key: string): Promise { diff --git a/ts/nni_manager/training_service/reusable/test/utEnvironmentService.ts b/ts/nni_manager/training_service/reusable/test/utEnvironmentService.ts index d43bca5cfe..f4983c19ab 100644 --- a/ts/nni_manager/training_service/reusable/test/utEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/test/utEnvironmentService.ts @@ -7,7 +7,7 @@ import { CommandChannel } from "../commandChannel"; import { UtCommandChannel } from "./utCommandChannel"; export class UtEnvironmentService extends EnvironmentService { - private commandChannel: UtCommandChannel | undefined; + private utCommandChannel: UtCommandChannel | undefined; private allEnvironments = new Map(); private hasMoreEnvironmentsInternal = true; @@ -23,6 +23,10 @@ export class UtEnvironmentService extends EnvironmentService { return 1; } + public get getPlatform(): string { + return 'ut'; + } + public testSetEnvironmentStatus(environment: EnvironmentInformation, newStatus: EnvironmentStatus): void { environment.status = newStatus; } @@ -36,10 +40,10 @@ export class UtEnvironmentService extends EnvironmentService { } public testGetCommandChannel(): UtCommandChannel { - if (this.commandChannel === undefined) { + if (this.utCommandChannel === undefined) { throw new Error(`command channel shouldn't be undefined.`); } - return this.commandChannel; + return this.utCommandChannel; } public testSetNoMoreEnvironment(hasMore: boolean): void { @@ -51,15 +55,15 @@ export class UtEnvironmentService extends EnvironmentService { } public createCommandChannel(commandEmitter: EventEmitter): CommandChannel { - this.commandChannel = new UtCommandChannel(commandEmitter) - return this.commandChannel; + this.utCommandChannel = new UtCommandChannel(commandEmitter) + return this.utCommandChannel; } public async config(_key: string, _value: string): Promise { // do nothing } - public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise { + public async refreshEnvironmentStatus(environment: EnvironmentInformation): Promise { // do nothing } diff --git a/ts/nni_manager/training_service/reusable/trialDispatcher.ts b/ts/nni_manager/training_service/reusable/trialDispatcher.ts index 79cac7c10e..ae8a1cb62b 100644 --- a/ts/nni_manager/training_service/reusable/trialDispatcher.ts +++ b/ts/nni_manager/training_service/reusable/trialDispatcher.ts @@ -13,7 +13,7 @@ import { NNIError, NNIErrorNames, MethodNotImplementedError } from '../../common import { getBasePort, getExperimentId, getPlatform } from '../../common/experimentStartupInfo'; import { getLogger, Logger } from '../../common/log'; import { NNIManagerIpConfig, TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus, LogType } from '../../common/trainingService'; -import { delay, getExperimentRootDir, getIPV4Address, getLogLevel, getVersion, mkDirPSync, uniqueString } from '../../common/utils'; +import { delay, getExperimentRootDir, getIPV4Address, getLogLevel, getVersion, mkDirPSync, randomSelect, uniqueString } from '../../common/utils'; import { GPU_INFO, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, REPORT_METRIC_DATA, SEND_TRIAL_JOB_PARAMETER, STDOUT, TRIAL_END, VERSION_CHECK } from '../../core/commands'; import { ScheduleResultType } from '../../training_service/common/gpuData'; import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; @@ -21,11 +21,19 @@ import { TrialConfig } from '../common/trialConfig'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { validateCodeDir } from '../common/util'; import { Command, CommandChannel } from './commandChannel'; -import { EnvironmentInformation, EnvironmentService, NodeInformation, RunnerSettings, TrialGpuSummary } from './environment'; +import { EnvironmentInformation, EnvironmentService, NodeInformation, RunnerSettings, TrialGpuSummary, Channel } from './environment'; import { GpuScheduler } from './gpuScheduler'; import { MountedStorageService } from './storages/mountedStorageService'; import { StorageService } from './storageService'; import { TrialDetail } from './trial'; +import { AMLEnvironmentService } from './environments/amlEnvironmentService'; +import { OpenPaiEnvironmentService } from './environments/openPaiEnvironmentService'; +import { LocalTrainingService } from 'training_service/local/localTrainingService'; +import { LocalEnvironmentService } from './environments/localEnvironmentService'; +import { RemoteEnvironmentService } from './environments/remoteEnvironmentService'; +import { AMLCommandChannel } from './channels/amlCommandChannel'; +import { WebCommandChannel } from './channels/webCommandChannel'; +import { FileCommandChannel } from './channels/fileCommandChannel'; /** @@ -45,14 +53,15 @@ class TrialDispatcher implements TrainingService { private enableVersionCheck: boolean = true; private trialConfig: TrialConfig | undefined; - private runnerSettings: RunnerSettings; - - private commandEmitter: EventEmitter | undefined; - private commandChannel: CommandChannel | undefined; private readonly trials: Map; private readonly environments: Map; + private environmentServiceList: EnvironmentService[] = []; + private commandChannelDict: Map; + private commandEmitter: EventEmitter; + private nniManagerIp: string; + // uses to accelerate trial manager loop // true means there is updates, and trial loop should run a cycle immediately. private shouldUpdateTrials: boolean = true; @@ -62,6 +71,8 @@ class TrialDispatcher implements TrainingService { private enableGpuScheduler: boolean = false; // uses to save if user like to reuse environment private reuseEnvironment: boolean = true; + private logCollection: string = ''; + private environmentMaintenceLoopInterval: number = 5000; private gpuScheduler: GpuScheduler; @@ -73,13 +84,11 @@ class TrialDispatcher implements TrainingService { this.log = getLogger(); this.trials = new Map(); this.environments = new Map(); + this.commandChannelDict = new Map(); this.metricsEmitter = new EventEmitter(); this.experimentId = getExperimentId(); this.experimentRootDir = getExperimentRootDir(); - - this.runnerSettings = new RunnerSettings(); - this.runnerSettings.experimentId = this.experimentId; - this.runnerSettings.platform = getPlatform(); + this.nniManagerIp = getIPV4Address(); const logLevel = getLogLevel(); this.log.debug(`current folder ${__dirname}`); @@ -89,6 +98,8 @@ class TrialDispatcher implements TrainingService { this.isDeveloping = true; } + this.commandEmitter = new EventEmitter(); + this.gpuScheduler = new GpuScheduler(); } @@ -122,12 +133,12 @@ class TrialDispatcher implements TrainingService { const trialId: string = uniqueString(5); - const environmentService = component.get(EnvironmentService); + // const environmentService = component.get(EnvironmentService); let trialWorkingFolder: string = ""; - if (environmentService.hasStorageService) { - const storageService = component.get(StorageService); - trialWorkingFolder = storageService.joinPath('trials', trialId); - } + // if (environmentService.hasStorageService) { + // const storageService = component.get(StorageService); + // trialWorkingFolder = storageService.joinPath('trials', trialId); + // } const trialJobDetail: TrialDetail = new TrialDetail(trialId, "WAITING", Date.now(), trialWorkingFolder, form); this.trials.set(trialId, trialJobDetail); @@ -142,23 +153,25 @@ class TrialDispatcher implements TrainingService { if (environment === undefined) { throw new Error(`TrialDispatcher: trial ${trialJobId}'s env shouldn't be undefined in updateTrialJob.`); } - if (this.commandChannel === undefined) { - throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in updateTrialJob.`); + if (environment.environmentService === undefined) { + throw new Error(`Environment ${environment.id} does not assigned environment service.`); } const message = { "trialId": trialJobId, "parameters": form.hyperParameters, } - await this.commandChannel.sendCommand(environment, SEND_TRIAL_JOB_PARAMETER, message); + const commandChannel: CommandChannel | undefined = + this.commandChannelDict.get(environment.environmentService.getCommandChanneName); + if (commandChannel === undefined) { + throw new Error(`${environment.environmentService.getCommandChanneName} command channel not initialized!`); + } + await commandChannel.sendCommand(environment, SEND_TRIAL_JOB_PARAMETER, message); return trialDetail; } public async cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean | undefined): Promise { - if (this.commandChannel === undefined) { - throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in cancelTrialJob.`); - } const trial = await this.getTrialJob(trialJobId); switch (trial.status) { case "RUNNING": @@ -166,8 +179,13 @@ class TrialDispatcher implements TrainingService { case "UNKNOWN": { const environment = trial.environment; - if (environment) { - await this.commandChannel.sendCommand(environment, KILL_TRIAL_JOB, trial.id); + if (environment && environment.environmentService) { + const commandChannel: CommandChannel | undefined = + this.commandChannelDict.get(environment.environmentService.getCommandChanneName); + if (commandChannel === undefined) { + throw new Error(`${environment.environmentService.getCommandChanneName} command channel not initialized!`); + } + await commandChannel.sendCommand(environment, KILL_TRIAL_JOB, trial.id); trial.isEarlyStopped = isEarlyStopped; trial.status = trial.isEarlyStopped === true ? 'EARLY_STOPPED' : 'USER_CANCELED'; @@ -179,70 +197,76 @@ class TrialDispatcher implements TrainingService { } public async run(): Promise { - const environmentService = component.get(EnvironmentService); - - this.commandEmitter = new EventEmitter(); - this.commandChannel = environmentService.createCommandChannel(this.commandEmitter); - - // TODO it's a hard code of web channel, it needs to be improved. - if (this.runnerSettings.nniManagerIP === "" || this.runnerSettings.nniManagerIP === null) { - this.runnerSettings.nniManagerIP = getIPV4Address(); + if (this.trialConfig === undefined) { + throw new Error(`trial config shouldn't be undefined in run()`); + } + for(let environmentService of this.environmentServiceList) { + + const runnerSettings: RunnerSettings = new RunnerSettings(); + runnerSettings.nniManagerIP = this.nniManagerIp; + runnerSettings.nniManagerPort = getBasePort() + 1; + runnerSettings.commandChannel = environmentService.getCommandChanneName; + runnerSettings.enableGpuCollector = this.enableGpuScheduler; + runnerSettings.command = this.trialConfig.command; + runnerSettings.nniManagerVersion = this.enableVersionCheck ? await getVersion() : ''; + runnerSettings.logCollection = this.logCollection; + runnerSettings.platform = environmentService.getPlatform; + runnerSettings.experimentId = this.experimentId; + + const commandChannel: CommandChannel | undefined = + this.commandChannelDict.get(environmentService.getCommandChanneName); + if (commandChannel === undefined) { + throw new Error(`${environmentService.getCommandChanneName} command channel not initialized!`); + } + await commandChannel.start(); + this.log.info(`TrialDispatcher: started channel: ${commandChannel.constructor.name}`); + + this.log.info(`TrialDispatcher: copying code and settings.`); + let storageService: StorageService; + if (environmentService.hasStorageService) { + this.log.debug(`TrialDispatcher: use existing storage service.`); + storageService = component.get(StorageService); + } else { + this.log.debug(`TrialDispatcher: create temp storage service to temp folder.`); + storageService = new MountedStorageService(); + const environmentLocalTempFolder = path.join(this.experimentRootDir, this.experimentId, "environment-temp"); + storageService.initialize(this.trialConfig.codeDir, environmentLocalTempFolder); + } + // Copy the compressed file to remoteDirectory and delete it + const codeDir = path.resolve(this.trialConfig.codeDir); + const envDir = storageService.joinPath("envs"); + const codeFileName = await storageService.copyDirectory(codeDir, envDir, true); + storageService.rename(codeFileName, "nni-code.tar.gz"); + + const installFileName = storageService.joinPath(envDir, 'install_nni.sh'); + await storageService.save(CONTAINER_INSTALL_NNI_SHELL_FORMAT, installFileName); + + const runnerSettingsConfig = storageService.joinPath(envDir, "settings.json"); + await storageService.save(JSON.stringify(runnerSettings), runnerSettingsConfig); + + if (this.isDeveloping) { + let trialToolsPath = path.join(__dirname, "../../../../../tools/nni_trial_tool"); + if (false === fs.existsSync(trialToolsPath)) { + trialToolsPath = path.join(__dirname, "..\\..\\..\\..\\..\\tools\\nni_trial_tool"); + } + await storageService.copyDirectory(trialToolsPath, envDir, true); + } } - this.runnerSettings.nniManagerPort = getBasePort() + 1; - this.runnerSettings.commandChannel = this.commandChannel.channelName; - // start channel this.commandEmitter.on("command", (command: Command): void => { this.handleCommand(command).catch((err: Error) => { this.log.error(`TrialDispatcher: error on handle env ${command.environment.id} command: ${command.command}, data: ${command.data}, error: ${err}`); }) }); - await this.commandChannel.start(); - this.log.info(`TrialDispatcher: started channel: ${this.commandChannel.constructor.name}`); - - if (this.trialConfig === undefined) { - throw new Error(`trial config shouldn't be undefined in run()`); - } - - this.log.info(`TrialDispatcher: copying code and settings.`); - let storageService: StorageService; - if (environmentService.hasStorageService) { - this.log.debug(`TrialDispatcher: use existing storage service.`); - storageService = component.get(StorageService); - } else { - this.log.debug(`TrialDispatcher: create temp storage service to temp folder.`); - storageService = new MountedStorageService(); - const environmentLocalTempFolder = path.join(this.experimentRootDir, this.experimentId, "environment-temp"); - storageService.initialize(this.trialConfig.codeDir, environmentLocalTempFolder); - } - - // Copy the compressed file to remoteDirectory and delete it - const codeDir = path.resolve(this.trialConfig.codeDir); - const envDir = storageService.joinPath("envs"); - const codeFileName = await storageService.copyDirectory(codeDir, envDir, true); - storageService.rename(codeFileName, "nni-code.tar.gz"); - - const installFileName = storageService.joinPath(envDir, 'install_nni.sh'); - await storageService.save(CONTAINER_INSTALL_NNI_SHELL_FORMAT, installFileName); - - const runnerSettings = storageService.joinPath(envDir, "settings.json"); - await storageService.save(JSON.stringify(this.runnerSettings), runnerSettings); - - // FIXME: what the hell is this? - if (this.isDeveloping) { - let trialToolsPath = path.join(__dirname, "../../../../../tools/nni_trial_tool"); - if (false === fs.existsSync(trialToolsPath)) { - trialToolsPath = path.join(__dirname, "..\\..\\..\\..\\..\\tools\\nni_trial_tool"); - } - await storageService.copyDirectory(trialToolsPath, envDir, true); - } await this.prefetchEnvironments(); this.log.info(`TrialDispatcher: run loop started.`); - await Promise.all([ - this.environmentMaintenanceLoop(), - this.trialManagementLoop(), - this.commandChannel.run(), - ]); + const promiseList: Promise[] = []; + for(let commandChannel of this.commandChannelDict.values()) { + promiseList.push(commandChannel.run()); + } + promiseList.push(this.environmentMaintenanceLoop()); + promiseList.push(this.trialManagementLoop()); + await Promise.all(promiseList); } public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void { @@ -260,14 +284,14 @@ class TrialDispatcher implements TrainingService { public async setClusterMetadata(key: string, value: string): Promise { switch (key) { case TrialConfigMetadataKey.NNI_MANAGER_IP: - this.runnerSettings.nniManagerIP = (JSON.parse(value)).nniManagerIp; + this.nniManagerIp = (JSON.parse(value)).nniManagerIp; break; case TrialConfigMetadataKey.VERSION_CHECK: this.enableVersionCheck = (value === 'true' || value === 'True'); - this.runnerSettings.nniManagerVersion = this.enableVersionCheck ? await getVersion() : ''; + break; case TrialConfigMetadataKey.LOG_COLLECTION: - this.runnerSettings.logCollection = value; + this.logCollection = value; break; case TrialConfigMetadataKey.TRIAL_CONFIG: this.trialConfig = JSON.parse(value); @@ -279,15 +303,52 @@ class TrialDispatcher implements TrainingService { this.log.info(`TrialDispatcher: GPU scheduler is enabled.`) this.enableGpuScheduler = true; } - this.runnerSettings.enableGpuCollector = this.enableGpuScheduler; - this.runnerSettings.command = this.trialConfig.command; // Validate to make sure codeDir doesn't have too many files await validateCodeDir(this.trialConfig.codeDir); break; + case TrialConfigMetadataKey.PLATFORM_LIST: + const platforms: string[] = value.split(","); + for(let platform of platforms) { + let environmentService: EnvironmentService; + switch(platform) { + case 'local': + environmentService = new LocalEnvironmentService(); + break; + case 'remote': + environmentService = new RemoteEnvironmentService(); + break; + case 'aml': + environmentService = new AMLEnvironmentService(); + break; + case 'pai': + environmentService = new OpenPaiEnvironmentService(); + break; + default: + throw new Error(`${platform} not supported!`); + } + if (!this.commandChannelDict.has(environmentService.getCommandChanneName)) { + switch(environmentService.getCommandChanneName) { + case 'aml': + this.commandChannelDict.set('aml', new AMLCommandChannel(this.commandEmitter)); + break; + case 'web': + this.commandChannelDict.set('web', new WebCommandChannel(this.commandEmitter)); + break; + case 'file': + this.commandChannelDict.set('file', new FileCommandChannel(this.commandEmitter)); + break; + default: + throw new Error(`Unsupported channel ${environmentService.getCommandChanneName}`); + } + } + this.environmentServiceList.push(environmentService); + } + + } + for(let environmentService of this.environmentServiceList) { + await environmentService.config(key, value); } - const environmentService = component.get(EnvironmentService); - await environmentService.config(key, value); } public getClusterMetadata(_key: string): Promise { @@ -295,48 +356,60 @@ class TrialDispatcher implements TrainingService { } public async cleanUp(): Promise { - if (this.commandChannel === undefined) { - throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in cleanUp.`); - } if (this.commandEmitter === undefined) { throw new Error(`TrialDispatcher: commandEmitter shouldn't be undefined in cleanUp.`); } this.stopping = true; this.shouldUpdateTrials = true; - const environmentService = component.get(EnvironmentService); const environments = [...this.environments.values()]; for (let index = 0; index < environments.length; index++) { const environment = environments[index]; if (environment.isAlive === true) { this.log.info(`stopping environment ${environment.id}...`); - await environmentService.stopEnvironment(environment); - await this.commandChannel.close(environment); + if (environment.environmentService === undefined) { + throw new Error(`${environment.id} do not have environmentService!`); + } + await environment.environmentService.stopEnvironment(environment); + const commandChannel: CommandChannel | undefined = + this.commandChannelDict.get(environment.environmentService.getCommandChanneName); + if (commandChannel === undefined) { + throw new Error(`${environment.environmentService.getCommandChanneName} command channel not initialized!`); + } this.log.info(`stopped environment ${environment.id}.`); } } this.commandEmitter.off("command", this.handleCommand); - await this.commandChannel.stop(); + for(let commandChannel of this.commandChannelDict.values()) { + commandChannel.stop(); + } } private async environmentMaintenanceLoop(): Promise { - if (this.commandChannel === undefined) { - throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in environmentMaintenanceLoop.`); - } - const environmentService = component.get(EnvironmentService); while (!this.stopping) { const environments: EnvironmentInformation[] = []; for (const environment of this.environments.values()) { if (environment.isAlive === true) { environments.push(environment); } else { - await this.commandChannel.close(environment); + if (environment.environmentService === undefined) { + throw new Error(`${environment.id} do not have environment service!`); + } + const commandChannel: CommandChannel | undefined = + this.commandChannelDict.get(environment.environmentService.getCommandChanneName); + if (commandChannel === undefined) { + throw new Error(`${environment.environmentService.getCommandChanneName} command channel not initialized!`); + } + await commandChannel.close(environment); } } - await environmentService.refreshEnvironmentsStatus(environments); - - environments.forEach((environment) => { + + for (let environment of environments) { + if (environment.environmentService === undefined) { + throw new Error(`${environment.id} do not have environment service!`); + } + await environment.environmentService.refreshEnvironmentStatus(environment); const oldIsAlive = environment.isAlive; switch (environment.status) { case 'WAITING': @@ -351,16 +424,13 @@ class TrialDispatcher implements TrainingService { if (oldIsAlive !== environment.isAlive) { this.log.debug(`set environment ${environment.id} isAlive from ${oldIsAlive} to ${environment.isAlive} due to status is ${environment.status}.`); } - }); + } this.shouldUpdateTrials = true; - await delay(environmentService.environmentMaintenceLoopInterval); + await delay(this.environmentMaintenceLoopInterval); } } private async trialManagementLoop(): Promise { - if (this.commandChannel === undefined) { - throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in trialManagementLoop.`); - } const interval = 1; while (!this.stopping) { @@ -400,6 +470,11 @@ class TrialDispatcher implements TrainingService { liveTrialsCount++; continue; } + + if (environment.environmentService === undefined) { + throw new Error(`${environment.id} does not has environment service!`); + } + trial.url = environment.trackingUrl; const environmentStatus = environment.status; @@ -414,7 +489,12 @@ class TrialDispatcher implements TrainingService { // for example, in horovod, it's just sleep command, has no impact on trial result. if (environment.nodeCount > completedCount) { this.log.info(`stop partial completed trial ${trial.id}`); - await this.commandChannel.sendCommand(environment, KILL_TRIAL_JOB, trial.id); + const commandChannel: CommandChannel | undefined = + this.commandChannelDict.get(environment.environmentService.getCommandChanneName); + if (commandChannel === undefined) { + throw new Error(`${environment.environmentService.getCommandChanneName} command channel not initialized!`); + } + await commandChannel.sendCommand(environment, KILL_TRIAL_JOB, trial.id); } for (const node of trial.nodes.values()) { if (node.status === "FAILED") { @@ -463,8 +543,10 @@ class TrialDispatcher implements TrainingService { false === this.reuseEnvironment && environment.assignedTrialCount > 0 ) { - const environmentService = component.get(EnvironmentService); - await environmentService.stopEnvironment(environment); + if (environment.environmentService === undefined) { + throw new Error(`${environment.id} does not has environment service!`); + } + await environment.environmentService.stopEnvironment(environment); continue; } @@ -556,11 +638,13 @@ class TrialDispatcher implements TrainingService { } if (neededEnvironmentCount > 0) { - const environmentService = component.get(EnvironmentService); let requestedCount = 0; + let hasMoreEnvironments = false; for (let index = 0; index < neededEnvironmentCount; index++) { - if (true === environmentService.hasMoreEnvironments) { - await this.requestEnvironment(); + let environmentService: EnvironmentService | undefined = this.selectEnvironmentService(); + if (environmentService !== undefined) { + hasMoreEnvironments = true; + await this.requestEnvironment(environmentService); requestedCount++; this.isLoggedNoMoreEnvironment = false; } else { @@ -570,7 +654,7 @@ class TrialDispatcher implements TrainingService { } } } - if (environmentService.hasMoreEnvironments === true || requestedCount > 0) { + if (hasMoreEnvironments === true || requestedCount > 0) { this.log.info(`requested new environment, live trials: ${liveTrialsCount}, ` + `live environments: ${liveEnvironmentsCount}, neededEnvironmentCount: ${neededEnvironmentCount}, ` + `requestedCount: ${requestedCount}`); @@ -580,24 +664,36 @@ class TrialDispatcher implements TrainingService { } } - private async prefetchEnvironments (): Promise { - const environmentService = component.get(EnvironmentService); - const number = environmentService.prefetchedEnvironmentCount; - this.log.info(`Initialize environments total number: ${number}`); - for (let index = 0; index < number; index++) { - await this.requestEnvironment(); + // Schedule a environment platform for environment + private selectEnvironmentService(): EnvironmentService | undefined { + const validEnvironmentServiceList = []; + for(let environmentService of this.environmentServiceList){ + if (environmentService.hasMoreEnvironments) { + validEnvironmentServiceList.push(environmentService); + } } + if (validEnvironmentServiceList.length === 0) { + return undefined; + } + // Random scheduler + return randomSelect(validEnvironmentServiceList); } - - private async requestEnvironment(): Promise { - if (this.commandChannel === undefined) { - throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in requestEnvironment.`); + + private async prefetchEnvironments (): Promise { + for (let environmentService of this.environmentServiceList) { + const number = environmentService.prefetchedEnvironmentCount; + this.log.info(`Initialize environments total number: ${number}`); + for (let index = 0; index < number; index++) { + await this.requestEnvironment(environmentService); + } } + } - const environmentService = component.get(EnvironmentService); + private async requestEnvironment(environmentService: EnvironmentService): Promise { const envId = uniqueString(5); const envName = `nni_exp_${this.experimentId}_env_${envId}`; const environment = environmentService.createEnvironmentInformation(envId, envName); + environment.environmentService = environmentService; environment.command = `sh ../install_nni.sh && python3 -m nni.tools.trial_tool.trial_runner`; @@ -616,15 +712,16 @@ class TrialDispatcher implements TrainingService { } else { environment.isAlive = true; } - - await this.commandChannel.open(environment); + const commandChannel: CommandChannel | undefined = + this.commandChannelDict.get(environment.environmentService.getCommandChanneName); + if (commandChannel === undefined) { + throw new Error(`${environment.environmentService.getCommandChanneName} command channel not initialized!`); + } + await commandChannel.open(environment); this.log.info(`requested environment ${environment.id} and job id is ${environment.envId}.`); } private async allocateEnvironment(trial: TrialDetail, environment: EnvironmentInformation): Promise { - if (this.commandChannel === undefined) { - throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in allocateEnvironment.`); - } if (this.trialConfig === undefined) { throw new Error(`TrialDispatcher: trialConfig shouldn't be undefined in allocateEnvironment.`); } @@ -661,7 +758,15 @@ class TrialDispatcher implements TrainingService { } trial.startTime = Date.now(); trial.status = "RUNNING"; - await this.commandChannel.sendCommand(trial.environment, NEW_TRIAL_JOB, trial.settings); + if (environment.environmentService === undefined) { + throw new Error(`${environment.id} does not have environment service!`); + } + const commandChannel: CommandChannel | undefined = + this.commandChannelDict.get(environment.environmentService.getCommandChanneName); + if (commandChannel === undefined) { + throw new Error(`${environment.environmentService.getCommandChanneName} command channel not initialized!`); + } + await commandChannel.sendCommand(trial.environment, NEW_TRIAL_JOB, trial.settings); } /** From eb4802c7d2ea27227002fc74cf4fe3c61358564c Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Thu, 3 Dec 2020 12:09:02 +0800 Subject: [PATCH 10/24] remove unused change --- examples/trials/mnist-tfv1/config_kubeflow.yml | 7 ------- nni/tools/trial_tool/trial_runner.py | 3 +-- .../reusable/channels/amlCommandChannel.ts | 2 +- .../reusable/channels/webCommandChannel.ts | 10 ++++------ .../reusable/environments/amlEnvironmentService.ts | 3 +-- .../training_service/reusable/trialDispatcher.ts | 5 ----- 6 files changed, 7 insertions(+), 23 deletions(-) diff --git a/examples/trials/mnist-tfv1/config_kubeflow.yml b/examples/trials/mnist-tfv1/config_kubeflow.yml index d874223a2a..f460b37cb6 100644 --- a/examples/trials/mnist-tfv1/config_kubeflow.yml +++ b/examples/trials/mnist-tfv1/config_kubeflow.yml @@ -16,13 +16,6 @@ tuner: optimize_mode: maximize trial: codeDir: . - ps: - replicas: 1 - command: python3 mnist.py - gpuNum: 0 - cpuNum: 1 - memoryMB: 8192 - image: msranni/nni:latest worker: replicas: 1 command: python3 mnist.py diff --git a/nni/tools/trial_tool/trial_runner.py b/nni/tools/trial_tool/trial_runner.py index d3c4ca0b22..f506cef0db 100644 --- a/nni/tools/trial_tool/trial_runner.py +++ b/nni/tools/trial_tool/trial_runner.py @@ -214,8 +214,7 @@ def check_version(args): command_channel = None if args.command_channel == "file": command_channel = FileChannel(args) - elif args.command_channel == 'aml' or \ - args.command_channel == 'heterogeneous' and args.platform == 'aml': + elif args.command_channel == 'aml': from .aml_channel import AMLChannel command_channel = AMLChannel(args) else: diff --git a/ts/nni_manager/training_service/reusable/channels/amlCommandChannel.ts b/ts/nni_manager/training_service/reusable/channels/amlCommandChannel.ts index 531cc40417..5816a9c780 100644 --- a/ts/nni_manager/training_service/reusable/channels/amlCommandChannel.ts +++ b/ts/nni_manager/training_service/reusable/channels/amlCommandChannel.ts @@ -39,7 +39,7 @@ export class AMLCommandChannel extends CommandChannel { ]); } - public async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise { + protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise { this.sendQueues.push([environment, message]); } diff --git a/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts b/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts index f9db475be4..0a30b853ec 100644 --- a/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts +++ b/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts @@ -46,7 +46,7 @@ export class WebCommandChannel extends CommandChannel { this.webSocketServer = new SocketServer({ port }); this.webSocketServer.on('connection', (client: WebSocket) => { - this.log.info(`WebCommandChannel: received connection`); + this.log.debug(`WebCommandChannel: received connection`); client.onerror = (event): void => { this.log.error(`error on client ${JSON.stringify(event)}`); } @@ -69,7 +69,7 @@ export class WebCommandChannel extends CommandChannel { // do nothing } - public async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise { + protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise { if (this.webSocketServer === undefined) { throw new Error(`WebCommandChannel: uninitialized!`) } @@ -83,7 +83,7 @@ export class WebCommandChannel extends CommandChannel { } } - public createRunnerConnection(environment: EnvironmentInformation): RunnerConnection { + protected createRunnerConnection(environment: EnvironmentInformation): RunnerConnection { return new WebRunnerConnection(environment); } @@ -94,9 +94,7 @@ export class WebCommandChannel extends CommandChannel { // undefined means it's expecting initializing message. const commands = this.parseCommands(rawCommands); let isValid = false; - this.log.info(`WebCommandChannel: received initialize message: ${JSON.stringify(rawCommands)}`); - - + this.log.debug(`WebCommandChannel: received initialize message: ${JSON.stringify(rawCommands)}`); if (commands.length > 0) { const commandType = commands[0][0]; const result = commands[0][1]; diff --git a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts index b267303ca9..4098b28dba 100644 --- a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts @@ -42,8 +42,7 @@ export class AMLEnvironmentService extends EnvironmentService { } public createCommandChannel(commandEmitter: EventEmitter): CommandChannel { - this.commandChannel = new AMLCommandChannel(commandEmitter); - return this.commandChannel; + return new AMLCommandChannel(commandEmitter); } public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation { diff --git a/ts/nni_manager/training_service/reusable/trialDispatcher.ts b/ts/nni_manager/training_service/reusable/trialDispatcher.ts index ae8a1cb62b..f9181d3240 100644 --- a/ts/nni_manager/training_service/reusable/trialDispatcher.ts +++ b/ts/nni_manager/training_service/reusable/trialDispatcher.ts @@ -133,12 +133,7 @@ class TrialDispatcher implements TrainingService { const trialId: string = uniqueString(5); - // const environmentService = component.get(EnvironmentService); let trialWorkingFolder: string = ""; - // if (environmentService.hasStorageService) { - // const storageService = component.get(StorageService); - // trialWorkingFolder = storageService.joinPath('trials', trialId); - // } const trialJobDetail: TrialDetail = new TrialDetail(trialId, "WAITING", Date.now(), trialWorkingFolder, form); this.trials.set(trialId, trialJobDetail); From c245b6366988f4deb98880145ab9d4a56b015fec Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Thu, 3 Dec 2020 15:39:01 +0800 Subject: [PATCH 11/24] refactor code --- nni/runtime/env_vars.py | 3 ++- nni/runtime/platform/local.py | 5 +++-- nni/tools/trial_tool/trial_runner.py | 1 + .../training_service/reusable/routerTrainingService.ts | 3 ++- ts/nni_manager/training_service/reusable/trialDispatcher.ts | 2 +- 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/nni/runtime/env_vars.py b/nni/runtime/env_vars.py index 5227956012..810ab2f4f6 100644 --- a/nni/runtime/env_vars.py +++ b/nni/runtime/env_vars.py @@ -12,7 +12,8 @@ 'NNI_SYS_DIR', 'NNI_OUTPUT_DIR', 'NNI_TRIAL_SEQ_ID', - 'MULTI_PHASE' + 'MULTI_PHASE', + 'REUSE_MODE' ] _dispatcher_env_var_names = [ diff --git a/nni/runtime/platform/local.py b/nni/runtime/platform/local.py index b8e6ffad62..b19bea927c 100644 --- a/nni/runtime/platform/local.py +++ b/nni/runtime/platform/local.py @@ -20,9 +20,10 @@ if not os.path.exists(_outputdir): os.makedirs(_outputdir) +_reuse_mode = trial_env_vars.REUSE_MODE _nni_platform = trial_env_vars.NNI_PLATFORM _nni_trial_job_id = trial_env_vars.NNI_TRIAL_JOB_ID -if _nni_platform == 'local' and _nni_trial_job_id != 'runner': +if _nni_platform == 'local' and _reuse_mode not in ('true', 'True'): _log_file_path = os.path.join(_outputdir, 'trial.log') init_logger(_log_file_path) @@ -63,7 +64,7 @@ def get_next_parameter(): return params def send_metric(string): - if _nni_platform != 'local' or _nni_trial_job_id == 'runner': + if _nni_platform != 'local' or _reuse_mode in ('true', 'True'): assert len(string) < 1000000, 'Metric too long' print("NNISDK_MEb'%s'" % (string), flush=True) else: diff --git a/nni/tools/trial_tool/trial_runner.py b/nni/tools/trial_tool/trial_runner.py index f506cef0db..9f39fe74b5 100644 --- a/nni/tools/trial_tool/trial_runner.py +++ b/nni/tools/trial_tool/trial_runner.py @@ -187,6 +187,7 @@ def check_version(args): os.environ['NNI_EXP_ID'] = args.exp_id os.environ['MULTI_PHASE'] = "true" os.environ['NNI_TRIAL_JOB_ID'] = "runner" + os.environ['REUSE_MODE'] = "true" from .log_utils import LogType, RemoteLogger, StdOutputType, nni_log from .trial import Trial diff --git a/ts/nni_manager/training_service/reusable/routerTrainingService.ts b/ts/nni_manager/training_service/reusable/routerTrainingService.ts index e4320faaea..546ed7f041 100644 --- a/ts/nni_manager/training_service/reusable/routerTrainingService.ts +++ b/ts/nni_manager/training_service/reusable/routerTrainingService.ts @@ -101,13 +101,14 @@ class RouterTrainingService implements TrainingService { public async setClusterMetadata(key: string, value: string): Promise { if (this.internalTrainingService === undefined) { + // Need to refactor configuration, remove heterogeneous_config field in the future if (key === TrialConfigMetadataKey.HETEROGENEOUS_CONFIG){ this.internalTrainingService = component.get(TrialDispatcher); const heterogenousConfig: HeterogenousConfig = JSON.parse(value); if (this.internalTrainingService === undefined) { throw new Error("internalTrainingService not initialized!"); } - // Initialize storageService for pai + // Initialize storageService for pai, only support singleton for now, need refactor if (heterogenousConfig.trainingServicePlatforms.includes('pai')) { Container.bind(StorageService) .to(MountedStorageService) diff --git a/ts/nni_manager/training_service/reusable/trialDispatcher.ts b/ts/nni_manager/training_service/reusable/trialDispatcher.ts index f9181d3240..da1c999340 100644 --- a/ts/nni_manager/training_service/reusable/trialDispatcher.ts +++ b/ts/nni_manager/training_service/reusable/trialDispatcher.ts @@ -689,7 +689,7 @@ class TrialDispatcher implements TrainingService { const envName = `nni_exp_${this.experimentId}_env_${envId}`; const environment = environmentService.createEnvironmentInformation(envId, envName); environment.environmentService = environmentService; - + this.log.info(`Assign environment service ${environmentService.getPlatform} to environment ${envId}`); environment.command = `sh ../install_nni.sh && python3 -m nni.tools.trial_tool.trial_runner`; if (this.isDeveloping) { From 92dd6f83b21d42f0d778cf5bb32a8d5a0d2a67aa Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Thu, 3 Dec 2020 16:03:18 +0800 Subject: [PATCH 12/24] add example --- .../TrainingService/HeterogeneousMode.md | 5 +-- .../mnist-tfv1/config_heterogeneous.yml | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) create mode 100644 examples/trials/mnist-tfv1/config_heterogeneous.yml diff --git a/docs/en_US/TrainingService/HeterogeneousMode.md b/docs/en_US/TrainingService/HeterogeneousMode.md index c93c337c63..948a239279 100644 --- a/docs/en_US/TrainingService/HeterogeneousMode.md +++ b/docs/en_US/TrainingService/HeterogeneousMode.md @@ -13,7 +13,7 @@ Use `examples/trials/mnist-tfv1` as an example. The NNI config YAML file's conte ```yaml authorName: default experimentName: example_mnist -trialConcurrency: 1 +trialConcurrency: 2 maxExecDuration: 1h maxTrialNum: 10 trainingServicePlatform: heterogeneous @@ -30,14 +30,11 @@ tuner: trial: command: python3 mnist.py codeDir: . - image: msranni/nni gpuNum: 1 heterogeneousConfig: trainingServicePlatforms: - local - remote -localConfig: - reuse: true remoteConfig: reuse: true machineList: diff --git a/examples/trials/mnist-tfv1/config_heterogeneous.yml b/examples/trials/mnist-tfv1/config_heterogeneous.yml new file mode 100644 index 0000000000..6d04896148 --- /dev/null +++ b/examples/trials/mnist-tfv1/config_heterogeneous.yml @@ -0,0 +1,32 @@ +authorName: default +experimentName: example_mnist +trialConcurrency: 3 +maxExecDuration: 1h +maxTrialNum: 10 +trainingServicePlatform: heterogeneous +searchSpacePath: search_space.json +#choice: true, false +useAnnotation: false +tuner: + #choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner + #SMAC (SMAC should be installed through nnictl) + builtinTunerName: TPE + classArgs: + #choice: maximize, minimize + optimize_mode: maximize +trial: + command: python3 mnist.py + codeDir: . + gpuNum: 0 +heterogeneousConfig: + trainingServicePlatforms: + - local + - remote +remoteConfig: + reuse: true +machineList: + - ip: 10.1.1.1 + username: bob + passwd: bob123 + #port can be skip if using default ssh port 22 + #port: 22 \ No newline at end of file From a2392e8e80d99305d8a7c8de6049715e39a01194 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Thu, 3 Dec 2020 16:35:19 +0800 Subject: [PATCH 13/24] fix eslint --- .../training_service/reusable/environment.ts | 3 -- .../environments/localEnvironmentService.ts | 14 +--------- .../reusable/routerTrainingService.ts | 6 ---- .../reusable/trialDispatcher.ts | 28 +++++++++---------- 4 files changed, 14 insertions(+), 37 deletions(-) diff --git a/ts/nni_manager/training_service/reusable/environment.ts b/ts/nni_manager/training_service/reusable/environment.ts index 032a896f69..e76f0ef09a 100644 --- a/ts/nni_manager/training_service/reusable/environment.ts +++ b/ts/nni_manager/training_service/reusable/environment.ts @@ -3,12 +3,9 @@ 'use strict'; -import { EventEmitter } from "events"; -import { runScript } from "training_service/common/util"; import { getLogger, Logger } from "../../common/log"; import { TrialJobStatus } from "../../common/trainingService"; import { GPUInfo } from "../../training_service/common/gpuData"; -import { WebCommandChannel } from "./channels/webCommandChannel"; import { CommandChannel } from "./commandChannel"; diff --git a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts index 40eabe229d..e5f40709ce 100644 --- a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts @@ -3,29 +3,17 @@ 'use strict'; -import * as cpp from 'child-process-promise'; -import { EventEmitter } from "events"; -import * as cp from 'child_process'; import * as fs from 'fs'; import * as path from 'path'; -import * as yaml from 'js-yaml'; -import * as request from 'request'; -import { Deferred } from 'ts-deferred'; import * as tkill from 'tree-kill'; import * as component from '../../../common/component'; import { getExperimentId } from '../../../common/experimentStartupInfo'; import { getLogger, Logger } from '../../../common/log'; import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey'; -import { PAIClusterConfig } from '../../pai/paiConfig'; -import { NNIPAIK8STrialConfig } from '../../pai/paiK8S/paiK8SConfig'; import { EnvironmentInformation, EnvironmentService } from '../environment'; -import { StorageService } from '../storageService'; import { TrialConfig } from '../../common/trialConfig'; import { getExperimentRootDir, isAlive } from '../../../common/utils'; -import { execMkdir, validateCodeDir, runScript, execCopydir } from '../../common/util'; -import { FileCommandChannel } from '../channels/fileCommandChannel'; -import { CommandChannel } from "../commandChannel"; - +import { execMkdir, runScript, execCopydir } from '../../common/util'; @component.Singleton export class LocalEnvironmentService extends EnvironmentService { diff --git a/ts/nni_manager/training_service/reusable/routerTrainingService.ts b/ts/nni_manager/training_service/reusable/routerTrainingService.ts index 546ed7f041..d5c747d9c5 100644 --- a/ts/nni_manager/training_service/reusable/routerTrainingService.ts +++ b/ts/nni_manager/training_service/reusable/routerTrainingService.ts @@ -13,17 +13,11 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { PAIClusterConfig } from '../pai/paiConfig'; import { PAIK8STrainingService } from '../pai/paiK8S/paiK8STrainingService'; import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService'; -import { EnvironmentService } from './environment'; -import { OpenPaiEnvironmentService } from './environments/openPaiEnvironmentService'; -import { LocalEnvironmentService } from './environments/localEnvironmentService'; -import { AMLEnvironmentService } from './environments/amlEnvironmentService'; -import { RemoteEnvironmentService } from './environments/remoteEnvironmentService'; import { MountedStorageService } from './storages/mountedStorageService'; import { StorageService } from './storageService'; import { TrialDispatcher } from './trialDispatcher'; import { RemoteConfig } from './remote/remoteConfig'; import { HeterogenousConfig } from './heterogenous/heterogenousConfig'; -import { LocalConfig, LocalTrainingService } from '../local/localTrainingService'; /** diff --git a/ts/nni_manager/training_service/reusable/trialDispatcher.ts b/ts/nni_manager/training_service/reusable/trialDispatcher.ts index da1c999340..865ae10088 100644 --- a/ts/nni_manager/training_service/reusable/trialDispatcher.ts +++ b/ts/nni_manager/training_service/reusable/trialDispatcher.ts @@ -10,7 +10,7 @@ import { Writable } from 'stream'; import { String } from 'typescript-string-operations'; import * as component from '../../common/component'; import { NNIError, NNIErrorNames, MethodNotImplementedError } from '../../common/errors'; -import { getBasePort, getExperimentId, getPlatform } from '../../common/experimentStartupInfo'; +import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo'; import { getLogger, Logger } from '../../common/log'; import { NNIManagerIpConfig, TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus, LogType } from '../../common/trainingService'; import { delay, getExperimentRootDir, getIPV4Address, getLogLevel, getVersion, mkDirPSync, randomSelect, uniqueString } from '../../common/utils'; @@ -28,7 +28,6 @@ import { StorageService } from './storageService'; import { TrialDetail } from './trial'; import { AMLEnvironmentService } from './environments/amlEnvironmentService'; import { OpenPaiEnvironmentService } from './environments/openPaiEnvironmentService'; -import { LocalTrainingService } from 'training_service/local/localTrainingService'; import { LocalEnvironmentService } from './environments/localEnvironmentService'; import { RemoteEnvironmentService } from './environments/remoteEnvironmentService'; import { AMLCommandChannel } from './channels/amlCommandChannel'; @@ -133,8 +132,7 @@ class TrialDispatcher implements TrainingService { const trialId: string = uniqueString(5); - let trialWorkingFolder: string = ""; - const trialJobDetail: TrialDetail = new TrialDetail(trialId, "WAITING", Date.now(), trialWorkingFolder, form); + const trialJobDetail: TrialDetail = new TrialDetail(trialId, "WAITING", Date.now(), "", form); this.trials.set(trialId, trialJobDetail); @@ -195,7 +193,7 @@ class TrialDispatcher implements TrainingService { if (this.trialConfig === undefined) { throw new Error(`trial config shouldn't be undefined in run()`); } - for(let environmentService of this.environmentServiceList) { + for(const environmentService of this.environmentServiceList) { const runnerSettings: RunnerSettings = new RunnerSettings(); runnerSettings.nniManagerIP = this.nniManagerIp; @@ -256,7 +254,7 @@ class TrialDispatcher implements TrainingService { await this.prefetchEnvironments(); this.log.info(`TrialDispatcher: run loop started.`); const promiseList: Promise[] = []; - for(let commandChannel of this.commandChannelDict.values()) { + for(const commandChannel of this.commandChannelDict.values()) { promiseList.push(commandChannel.run()); } promiseList.push(this.environmentMaintenanceLoop()); @@ -302,9 +300,9 @@ class TrialDispatcher implements TrainingService { // Validate to make sure codeDir doesn't have too many files await validateCodeDir(this.trialConfig.codeDir); break; - case TrialConfigMetadataKey.PLATFORM_LIST: + case TrialConfigMetadataKey.PLATFORM_LIST: { const platforms: string[] = value.split(","); - for(let platform of platforms) { + for(const platform of platforms) { let environmentService: EnvironmentService; switch(platform) { case 'local': @@ -339,9 +337,9 @@ class TrialDispatcher implements TrainingService { } this.environmentServiceList.push(environmentService); } - + } } - for(let environmentService of this.environmentServiceList) { + for(const environmentService of this.environmentServiceList) { await environmentService.config(key, value); } } @@ -376,7 +374,7 @@ class TrialDispatcher implements TrainingService { } this.commandEmitter.off("command", this.handleCommand); - for(let commandChannel of this.commandChannelDict.values()) { + for(const commandChannel of this.commandChannelDict.values()) { commandChannel.stop(); } } @@ -400,7 +398,7 @@ class TrialDispatcher implements TrainingService { } } - for (let environment of environments) { + for (const environment of environments) { if (environment.environmentService === undefined) { throw new Error(`${environment.id} do not have environment service!`); } @@ -636,7 +634,7 @@ class TrialDispatcher implements TrainingService { let requestedCount = 0; let hasMoreEnvironments = false; for (let index = 0; index < neededEnvironmentCount; index++) { - let environmentService: EnvironmentService | undefined = this.selectEnvironmentService(); + const environmentService: EnvironmentService | undefined = this.selectEnvironmentService(); if (environmentService !== undefined) { hasMoreEnvironments = true; await this.requestEnvironment(environmentService); @@ -662,7 +660,7 @@ class TrialDispatcher implements TrainingService { // Schedule a environment platform for environment private selectEnvironmentService(): EnvironmentService | undefined { const validEnvironmentServiceList = []; - for(let environmentService of this.environmentServiceList){ + for(const environmentService of this.environmentServiceList){ if (environmentService.hasMoreEnvironments) { validEnvironmentServiceList.push(environmentService); } @@ -675,7 +673,7 @@ class TrialDispatcher implements TrainingService { } private async prefetchEnvironments (): Promise { - for (let environmentService of this.environmentServiceList) { + for (const environmentService of this.environmentServiceList) { const number = environmentService.prefetchedEnvironmentCount; this.log.info(`Initialize environments total number: ${number}`); for (let index = 0; index < number; index++) { From 401378d857036b396624eda7f87d5079993da4b3 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Fri, 4 Dec 2020 12:23:19 +0800 Subject: [PATCH 14/24] fix ut --- docs/en_US/training_services.rst | 1 + .../reusable/channels/webCommandChannel.ts | 5 ++++- .../training_service/reusable/environment.ts | 4 +--- .../environments/amlEnvironmentService.ts | 8 +++---- .../reusable/test/trialDispatcher.test.ts | 4 +++- .../reusable/test/utEnvironmentService.ts | 22 +++++-------------- .../reusable/trialDispatcher.ts | 5 ++--- 7 files changed, 19 insertions(+), 30 deletions(-) diff --git a/docs/en_US/training_services.rst b/docs/en_US/training_services.rst index da286b9d50..bb3081fd07 100644 --- a/docs/en_US/training_services.rst +++ b/docs/en_US/training_services.rst @@ -12,3 +12,4 @@ Introduction to NNI Training Services FrameworkController<./TrainingService/FrameworkControllerMode> DLTS<./TrainingService/DLTSMode> AML<./TrainingService/AMLMode> + Heterogeneous<./TrainingService/HeterogeneousMode> diff --git a/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts b/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts index 0a30b853ec..3bd9c504aa 100644 --- a/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts +++ b/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts @@ -9,7 +9,7 @@ import { INITIALIZED } from '../../../core/commands'; import { CommandChannel, RunnerConnection } from "../commandChannel"; import { Channel, EnvironmentInformation } from "../environment"; -export class WebRunnerConnection extends RunnerConnection { +class WebRunnerConnection extends RunnerConnection { public readonly clients: WebSocket[] = []; public async close(): Promise { @@ -50,6 +50,7 @@ export class WebCommandChannel extends CommandChannel { client.onerror = (event): void => { this.log.error(`error on client ${JSON.stringify(event)}`); } + this.clients.set(client, undefined); client.onmessage = (message): void => { this.receivedWebSocketMessage(client, message); @@ -90,11 +91,13 @@ export class WebCommandChannel extends CommandChannel { private receivedWebSocketMessage(client: WebSocket, message: MessageEvent): void { let connection = this.clients.get(client) as WebRunnerConnection | undefined; const rawCommands = message.data.toString(); + if (connection === undefined) { // undefined means it's expecting initializing message. const commands = this.parseCommands(rawCommands); let isValid = false; this.log.debug(`WebCommandChannel: received initialize message: ${JSON.stringify(rawCommands)}`); + if (commands.length > 0) { const commandType = commands[0][0]; const result = commands[0][1]; diff --git a/ts/nni_manager/training_service/reusable/environment.ts b/ts/nni_manager/training_service/reusable/environment.ts index e76f0ef09a..2f29fdf164 100644 --- a/ts/nni_manager/training_service/reusable/environment.ts +++ b/ts/nni_manager/training_service/reusable/environment.ts @@ -6,7 +6,6 @@ import { getLogger, Logger } from "../../common/log"; import { TrialJobStatus } from "../../common/trainingService"; import { GPUInfo } from "../../training_service/common/gpuData"; -import { CommandChannel } from "./commandChannel"; export type EnvironmentStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED'; @@ -121,8 +120,7 @@ export class EnvironmentInformation { } export abstract class EnvironmentService { - - protected commandChannel: CommandChannel | undefined; + public abstract get hasStorageService(): boolean; public abstract config(key: string, value: string): Promise; public abstract refreshEnvironmentStatus(environment: EnvironmentInformation): Promise; diff --git a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts index 4098b28dba..a242f18cb2 100644 --- a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts @@ -3,7 +3,6 @@ 'use strict'; -import { EventEmitter } from "events"; import * as fs from 'fs'; import * as path from 'path'; import * as component from '../../../common/component'; @@ -14,8 +13,7 @@ import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey'; import { validateCodeDir } from '../../common/util'; import { AMLClient } from '../aml/amlClient'; import { AMLClusterConfig, AMLEnvironmentInformation, AMLTrialConfig } from '../aml/amlConfig'; -import { AMLCommandChannel } from '../channels/amlCommandChannel'; -import { CommandChannel } from "../commandChannel"; +import { Channel } from '../environment'; import { EnvironmentInformation, EnvironmentService } from '../environment'; @@ -41,8 +39,8 @@ export class AMLEnvironmentService extends EnvironmentService { return false; } - public createCommandChannel(commandEmitter: EventEmitter): CommandChannel { - return new AMLCommandChannel(commandEmitter); + public get getCommandChanneName(): Channel { + return 'aml'; } public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation { diff --git a/ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts b/ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts index 958738de54..d3570428b8 100644 --- a/ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts +++ b/ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts @@ -4,6 +4,7 @@ import * as chai from 'chai'; import * as path from 'path'; import { Scope } from "typescript-ioc"; +import { EventEmitter } from 'events'; import * as component from '../../../common/component'; import { getLogger, Logger } from "../../../common/log"; import { TrialJobApplicationForm, TrialJobStatus } from '../../../common/trainingService'; @@ -210,7 +211,8 @@ describe('Unit Test for TrialDispatcher', () => { trialRunPromise = trialDispatcher.run(); environmentService = component.get(EnvironmentService) as UtEnvironmentService; - commandChannel = environmentService.testGetCommandChannel(); + + commandChannel = new UtCommandChannel(new EventEmitter()); }); afterEach(async () => { diff --git a/ts/nni_manager/training_service/reusable/test/utEnvironmentService.ts b/ts/nni_manager/training_service/reusable/test/utEnvironmentService.ts index f4983c19ab..8364e849fd 100644 --- a/ts/nni_manager/training_service/reusable/test/utEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/test/utEnvironmentService.ts @@ -1,13 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -import { EnvironmentInformation, EnvironmentService, EnvironmentStatus } from "../environment"; -import { EventEmitter } from "events"; -import { CommandChannel } from "../commandChannel"; -import { UtCommandChannel } from "./utCommandChannel"; +import { Channel, EnvironmentInformation, EnvironmentService, EnvironmentStatus } from "../environment"; export class UtEnvironmentService extends EnvironmentService { - private utCommandChannel: UtCommandChannel | undefined; private allEnvironments = new Map(); private hasMoreEnvironmentsInternal = true; @@ -27,6 +23,10 @@ export class UtEnvironmentService extends EnvironmentService { return 'ut'; } + public get getCommandChanneName(): Channel { + return 'ut'; + } + public testSetEnvironmentStatus(environment: EnvironmentInformation, newStatus: EnvironmentStatus): void { environment.status = newStatus; } @@ -39,13 +39,6 @@ export class UtEnvironmentService extends EnvironmentService { return this.allEnvironments; } - public testGetCommandChannel(): UtCommandChannel { - if (this.utCommandChannel === undefined) { - throw new Error(`command channel shouldn't be undefined.`); - } - return this.utCommandChannel; - } - public testSetNoMoreEnvironment(hasMore: boolean): void { this.hasMoreEnvironmentsInternal = hasMore; } @@ -54,11 +47,6 @@ export class UtEnvironmentService extends EnvironmentService { return this.hasMoreEnvironmentsInternal; } - public createCommandChannel(commandEmitter: EventEmitter): CommandChannel { - this.utCommandChannel = new UtCommandChannel(commandEmitter) - return this.utCommandChannel; - } - public async config(_key: string, _value: string): Promise { // do nothing } diff --git a/ts/nni_manager/training_service/reusable/trialDispatcher.ts b/ts/nni_manager/training_service/reusable/trialDispatcher.ts index 865ae10088..2bd65606c3 100644 --- a/ts/nni_manager/training_service/reusable/trialDispatcher.ts +++ b/ts/nni_manager/training_service/reusable/trialDispatcher.ts @@ -59,7 +59,7 @@ class TrialDispatcher implements TrainingService { private environmentServiceList: EnvironmentService[] = []; private commandChannelDict: Map; private commandEmitter: EventEmitter; - private nniManagerIp: string; + private nniManagerIp: string | undefined; // uses to accelerate trial manager loop // true means there is updates, and trial loop should run a cycle immediately. @@ -87,7 +87,6 @@ class TrialDispatcher implements TrainingService { this.metricsEmitter = new EventEmitter(); this.experimentId = getExperimentId(); this.experimentRootDir = getExperimentRootDir(); - this.nniManagerIp = getIPV4Address(); const logLevel = getLogLevel(); this.log.debug(`current folder ${__dirname}`); @@ -196,7 +195,7 @@ class TrialDispatcher implements TrainingService { for(const environmentService of this.environmentServiceList) { const runnerSettings: RunnerSettings = new RunnerSettings(); - runnerSettings.nniManagerIP = this.nniManagerIp; + runnerSettings.nniManagerIP = this.nniManagerIp === undefined? getIPV4Address() : this.nniManagerIp; runnerSettings.nniManagerPort = getBasePort() + 1; runnerSettings.commandChannel = environmentService.getCommandChanneName; runnerSettings.enableGpuCollector = this.enableGpuScheduler; From 5bd5c38ea78bd0869318532bf8823a3add8ae9ba Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Fri, 4 Dec 2020 21:30:26 +0800 Subject: [PATCH 15/24] fix ut --- .../reusable/test/trialDispatcher.test.ts | 29 +++++++++---------- .../reusable/trialDispatcher.ts | 7 +++-- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts b/ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts index d3570428b8..ff5cbbcd46 100644 --- a/ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts +++ b/ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts @@ -11,8 +11,8 @@ import { TrialJobApplicationForm, TrialJobStatus } from '../../../common/trainin import { cleanupUnitTest, delay, prepareUnitTest, uniqueString } from '../../../common/utils'; import { INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, SEND_TRIAL_JOB_PARAMETER, TRIAL_END, GPU_INFO } from '../../../core/commands'; import { TrialConfigMetadataKey } from '../../../training_service/common/trialConfigMetadataKey'; -import { Command } from '../commandChannel'; -import { EnvironmentInformation, EnvironmentService } from "../environment"; +import { Command, CommandChannel } from '../commandChannel'; +import { Channel, EnvironmentInformation, EnvironmentService } from "../environment"; import { TrialDetail } from '../trial'; import { TrialDispatcher } from "../trialDispatcher"; import { UtCommandChannel } from './utCommandChannel'; @@ -55,7 +55,7 @@ async function waitResult(callback: () => Promise, return undefined; } -async function waitResultMust(callback: () => Promise, waitMs: number = 1000, interval: number = 1): Promise { +async function waitResultMust(callback: () => Promise, waitMs: number = 10000, interval: number = 1): Promise { const result = await waitResult(callback, waitMs, interval, true); // this error should be thrown in waitResult already. if (result === undefined) { @@ -202,17 +202,21 @@ describe('Unit Test for TrialDispatcher', () => { nniManagerIp: "127.0.0.1", } trialDispatcher = new TrialDispatcher(); - component.Container.bind(EnvironmentService) - .to(UtEnvironmentService) - .scope(Scope.Singleton); await trialDispatcher.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, JSON.stringify(trialConfig)); await trialDispatcher.setClusterMetadata(TrialConfigMetadataKey.NNI_MANAGER_IP, JSON.stringify(nniManagerIpConfig)); - trialRunPromise = trialDispatcher.run(); - - environmentService = component.get(EnvironmentService) as UtEnvironmentService; + // set ut environment + let environmentServiceList: EnvironmentService[] = []; + environmentService = new UtEnvironmentService(); + environmentServiceList.push(environmentService); + trialDispatcher.environmentServiceList = environmentServiceList; + // set ut command channel + commandChannel = new UtCommandChannel(trialDispatcher.commandEmitter); + let commandChannelDict: Map = new Map(); + commandChannelDict.set('ut', commandChannel); + trialDispatcher.commandChannelDict = commandChannelDict; - commandChannel = new UtCommandChannel(new EventEmitter()); + trialRunPromise = trialDispatcher.run(); }); afterEach(async () => { @@ -260,9 +264,6 @@ describe('Unit Test for TrialDispatcher', () => { await waitEnvironment(2, previousEnvironments, environmentService, commandChannel); await verifyTrialRunning(commandChannel, trialDetail); await verifyTrialResult(commandChannel, trialDetail, -1); - await waitResultMust(async () => { - return environment.status === 'USER_CANCELED' ? true : undefined; - }); chai.assert.equal(environmentService.testGetEnvironments().size, 2, "as env not reused, so only 2 envs should be here."); const trials = await trialDispatcher.listTrialJobs(); @@ -435,12 +436,10 @@ describe('Unit Test for TrialDispatcher', () => { let environment = await waitEnvironment(1, previousEnvironments, environmentService, commandChannel); await verifyTrialRunning(commandChannel, trialDetail); await verifyTrialResult(commandChannel, trialDetail, 0); - environmentService.testSetEnvironmentStatus(environment, 'SUCCEEDED'); await waitResultMust(async () => { return environment.status === 'SUCCEEDED' ? true : undefined; }); - trialDetail = await newTrial(trialDispatcher); await waitEnvironment(2, previousEnvironments, environmentService, commandChannel); await verifyTrialRunning(commandChannel, trialDetail); diff --git a/ts/nni_manager/training_service/reusable/trialDispatcher.ts b/ts/nni_manager/training_service/reusable/trialDispatcher.ts index 2bd65606c3..cf79133f48 100644 --- a/ts/nni_manager/training_service/reusable/trialDispatcher.ts +++ b/ts/nni_manager/training_service/reusable/trialDispatcher.ts @@ -55,10 +55,11 @@ class TrialDispatcher implements TrainingService { private readonly trials: Map; private readonly environments: Map; + // make public for ut + public environmentServiceList: EnvironmentService[] = []; + public commandChannelDict: Map; + public commandEmitter: EventEmitter; - private environmentServiceList: EnvironmentService[] = []; - private commandChannelDict: Map; - private commandEmitter: EventEmitter; private nniManagerIp: string | undefined; // uses to accelerate trial manager loop From ab34de67ed85f3effeb3b399942af35f4d019fa5 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Tue, 8 Dec 2020 20:25:16 +0800 Subject: [PATCH 16/24] fix comments --- .../training_service/reusable/environment.ts | 4 +- .../environments/amlEnvironmentService.ts | 2 +- .../environments/localEnvironmentService.ts | 2 +- .../environments/openPaiEnvironmentService.ts | 2 +- .../environments/remoteEnvironmentService.ts | 2 +- .../reusable/trialDispatcher.ts | 44 +++++++++---------- 6 files changed, 28 insertions(+), 28 deletions(-) diff --git a/ts/nni_manager/training_service/reusable/environment.ts b/ts/nni_manager/training_service/reusable/environment.ts index 2f29fdf164..3eb74643dd 100644 --- a/ts/nni_manager/training_service/reusable/environment.ts +++ b/ts/nni_manager/training_service/reusable/environment.ts @@ -133,9 +133,9 @@ export abstract class EnvironmentService { return 0; } - public abstract get getPlatform(): string; + public abstract get getName(): string; - public get getCommandChanneName(): Channel { + public get getCommandChannelName(): Channel { return 'web'; } diff --git a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts index a242f18cb2..99e2c9f74c 100644 --- a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts @@ -47,7 +47,7 @@ export class AMLEnvironmentService extends EnvironmentService { return new AMLEnvironmentInformation(envId, envName); } - public get getPlatform(): string { + public get getName(): string { return 'aml'; } diff --git a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts index e5f40709ce..572f052df2 100644 --- a/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/localEnvironmentService.ts @@ -37,7 +37,7 @@ export class LocalEnvironmentService extends EnvironmentService { return false; } - public get getPlatform(): string { + public get getName(): string { return 'local'; } diff --git a/ts/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts index b4b071a089..bdb1903ddd 100644 --- a/ts/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/openPaiEnvironmentService.ts @@ -45,7 +45,7 @@ export class OpenPaiEnvironmentService extends EnvironmentService { return true; } - public get getPlatform(): string { + public get getName(): string { return 'pai'; } diff --git a/ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts index d898373fee..67ee6ff0b4 100644 --- a/ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/remoteEnvironmentService.ts @@ -63,7 +63,7 @@ export class RemoteEnvironmentService extends EnvironmentService { return false; } - public get getPlatform(): string { + public get getName(): string { return 'remote'; } diff --git a/ts/nni_manager/training_service/reusable/trialDispatcher.ts b/ts/nni_manager/training_service/reusable/trialDispatcher.ts index cf79133f48..f9f1c37722 100644 --- a/ts/nni_manager/training_service/reusable/trialDispatcher.ts +++ b/ts/nni_manager/training_service/reusable/trialDispatcher.ts @@ -155,9 +155,9 @@ class TrialDispatcher implements TrainingService { "parameters": form.hyperParameters, } const commandChannel: CommandChannel | undefined = - this.commandChannelDict.get(environment.environmentService.getCommandChanneName); + this.commandChannelDict.get(environment.environmentService.getCommandChannelName); if (commandChannel === undefined) { - throw new Error(`${environment.environmentService.getCommandChanneName} command channel not initialized!`); + throw new Error(`${environment.environmentService.getCommandChannelName} command channel not initialized!`); } await commandChannel.sendCommand(environment, SEND_TRIAL_JOB_PARAMETER, message); @@ -174,9 +174,9 @@ class TrialDispatcher implements TrainingService { const environment = trial.environment; if (environment && environment.environmentService) { const commandChannel: CommandChannel | undefined = - this.commandChannelDict.get(environment.environmentService.getCommandChanneName); + this.commandChannelDict.get(environment.environmentService.getCommandChannelName); if (commandChannel === undefined) { - throw new Error(`${environment.environmentService.getCommandChanneName} command channel not initialized!`); + throw new Error(`${environment.environmentService.getCommandChannelName} command channel not initialized!`); } await commandChannel.sendCommand(environment, KILL_TRIAL_JOB, trial.id); trial.isEarlyStopped = isEarlyStopped; @@ -198,18 +198,18 @@ class TrialDispatcher implements TrainingService { const runnerSettings: RunnerSettings = new RunnerSettings(); runnerSettings.nniManagerIP = this.nniManagerIp === undefined? getIPV4Address() : this.nniManagerIp; runnerSettings.nniManagerPort = getBasePort() + 1; - runnerSettings.commandChannel = environmentService.getCommandChanneName; + runnerSettings.commandChannel = environmentService.getCommandChannelName; runnerSettings.enableGpuCollector = this.enableGpuScheduler; runnerSettings.command = this.trialConfig.command; runnerSettings.nniManagerVersion = this.enableVersionCheck ? await getVersion() : ''; runnerSettings.logCollection = this.logCollection; - runnerSettings.platform = environmentService.getPlatform; + runnerSettings.platform = environmentService.getName; runnerSettings.experimentId = this.experimentId; const commandChannel: CommandChannel | undefined = - this.commandChannelDict.get(environmentService.getCommandChanneName); + this.commandChannelDict.get(environmentService.getCommandChannelName); if (commandChannel === undefined) { - throw new Error(`${environmentService.getCommandChanneName} command channel not initialized!`); + throw new Error(`${environmentService.getCommandChannelName} command channel not initialized!`); } await commandChannel.start(); this.log.info(`TrialDispatcher: started channel: ${commandChannel.constructor.name}`); @@ -320,8 +320,8 @@ class TrialDispatcher implements TrainingService { default: throw new Error(`${platform} not supported!`); } - if (!this.commandChannelDict.has(environmentService.getCommandChanneName)) { - switch(environmentService.getCommandChanneName) { + if (!this.commandChannelDict.has(environmentService.getCommandChannelName)) { + switch(environmentService.getCommandChannelName) { case 'aml': this.commandChannelDict.set('aml', new AMLCommandChannel(this.commandEmitter)); break; @@ -332,7 +332,7 @@ class TrialDispatcher implements TrainingService { this.commandChannelDict.set('file', new FileCommandChannel(this.commandEmitter)); break; default: - throw new Error(`Unsupported channel ${environmentService.getCommandChanneName}`); + throw new Error(`Unsupported channel ${environmentService.getCommandChannelName}`); } } this.environmentServiceList.push(environmentService); @@ -365,9 +365,9 @@ class TrialDispatcher implements TrainingService { } await environment.environmentService.stopEnvironment(environment); const commandChannel: CommandChannel | undefined = - this.commandChannelDict.get(environment.environmentService.getCommandChanneName); + this.commandChannelDict.get(environment.environmentService.getCommandChannelName); if (commandChannel === undefined) { - throw new Error(`${environment.environmentService.getCommandChanneName} command channel not initialized!`); + throw new Error(`${environment.environmentService.getCommandChannelName} command channel not initialized!`); } this.log.info(`stopped environment ${environment.id}.`); } @@ -390,9 +390,9 @@ class TrialDispatcher implements TrainingService { throw new Error(`${environment.id} do not have environment service!`); } const commandChannel: CommandChannel | undefined = - this.commandChannelDict.get(environment.environmentService.getCommandChanneName); + this.commandChannelDict.get(environment.environmentService.getCommandChannelName); if (commandChannel === undefined) { - throw new Error(`${environment.environmentService.getCommandChanneName} command channel not initialized!`); + throw new Error(`${environment.environmentService.getCommandChannelName} command channel not initialized!`); } await commandChannel.close(environment); } @@ -483,9 +483,9 @@ class TrialDispatcher implements TrainingService { if (environment.nodeCount > completedCount) { this.log.info(`stop partial completed trial ${trial.id}`); const commandChannel: CommandChannel | undefined = - this.commandChannelDict.get(environment.environmentService.getCommandChanneName); + this.commandChannelDict.get(environment.environmentService.getCommandChannelName); if (commandChannel === undefined) { - throw new Error(`${environment.environmentService.getCommandChanneName} command channel not initialized!`); + throw new Error(`${environment.environmentService.getCommandChannelName} command channel not initialized!`); } await commandChannel.sendCommand(environment, KILL_TRIAL_JOB, trial.id); } @@ -687,7 +687,7 @@ class TrialDispatcher implements TrainingService { const envName = `nni_exp_${this.experimentId}_env_${envId}`; const environment = environmentService.createEnvironmentInformation(envId, envName); environment.environmentService = environmentService; - this.log.info(`Assign environment service ${environmentService.getPlatform} to environment ${envId}`); + this.log.info(`Assign environment service ${environmentService.getName} to environment ${envId}`); environment.command = `sh ../install_nni.sh && python3 -m nni.tools.trial_tool.trial_runner`; if (this.isDeveloping) { @@ -706,9 +706,9 @@ class TrialDispatcher implements TrainingService { environment.isAlive = true; } const commandChannel: CommandChannel | undefined = - this.commandChannelDict.get(environment.environmentService.getCommandChanneName); + this.commandChannelDict.get(environment.environmentService.getCommandChannelName); if (commandChannel === undefined) { - throw new Error(`${environment.environmentService.getCommandChanneName} command channel not initialized!`); + throw new Error(`${environment.environmentService.getCommandChannelName} command channel not initialized!`); } await commandChannel.open(environment); this.log.info(`requested environment ${environment.id} and job id is ${environment.envId}.`); @@ -755,9 +755,9 @@ class TrialDispatcher implements TrainingService { throw new Error(`${environment.id} does not have environment service!`); } const commandChannel: CommandChannel | undefined = - this.commandChannelDict.get(environment.environmentService.getCommandChanneName); + this.commandChannelDict.get(environment.environmentService.getCommandChannelName); if (commandChannel === undefined) { - throw new Error(`${environment.environmentService.getCommandChanneName} command channel not initialized!`); + throw new Error(`${environment.environmentService.getCommandChannelName} command channel not initialized!`); } await commandChannel.sendCommand(trial.environment, NEW_TRIAL_JOB, trial.settings); } From e03f063c131f3e13ed70f4f508513e11d4fe54ac Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Fri, 11 Dec 2020 11:36:30 +0800 Subject: [PATCH 17/24] fix tslint --- ts/nni_manager/training_service/reusable/trialDispatcher.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ts/nni_manager/training_service/reusable/trialDispatcher.ts b/ts/nni_manager/training_service/reusable/trialDispatcher.ts index caf83387c6..9dd82aa5b8 100644 --- a/ts/nni_manager/training_service/reusable/trialDispatcher.ts +++ b/ts/nni_manager/training_service/reusable/trialDispatcher.ts @@ -727,8 +727,7 @@ class TrialDispatcher implements TrainingService { } if (environment.environmentService.hasStorageService) { const storageService = component.get(StorageService); - let trialWorkingFolder = storageService.joinPath('trials', trial.id); - trial.workingDirectory = trialWorkingFolder; + trial.workingDirectory = storageService.joinPath('trials', trial.id); } trial.settings = { trialId: trial.id, From 3ce49c02d451bd33eb894ff060e8756ac93b76ae Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Fri, 11 Dec 2020 11:50:56 +0800 Subject: [PATCH 18/24] fix trial.log cause metrics error in reuse mode --- nni/runtime/log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nni/runtime/log.py b/nni/runtime/log.py index da96138e09..7a188e8a9d 100644 --- a/nni/runtime/log.py +++ b/nni/runtime/log.py @@ -31,7 +31,7 @@ def init_logger() -> None: if trial_platform == 'unittest': return - if trial_platform: + if trial_platform and not trial_env_vars.REUSE_MODE: _init_logger_trial() return From 17626fbd6c1de8ec5ea9b40d28aa6eb09034c5ee Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Fri, 11 Dec 2020 15:06:27 +0800 Subject: [PATCH 19/24] add environmentServiceFactory --- .../training_service/reusable/environment.ts | 21 ++++++++++++++++ .../environments/amlEnvironmentService.ts | 6 ++--- .../reusable/trialDispatcher.ts | 24 ++----------------- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/ts/nni_manager/training_service/reusable/environment.ts b/ts/nni_manager/training_service/reusable/environment.ts index 098ee9be31..f72cc37e16 100644 --- a/ts/nni_manager/training_service/reusable/environment.ts +++ b/ts/nni_manager/training_service/reusable/environment.ts @@ -9,6 +9,10 @@ import { GPUInfo } from "../../training_service/common/gpuData"; import { CommandChannel } from "./commandChannel"; import { WebCommandChannel } from './channels/webCommandChannel'; import { EventEmitter } from "events"; +import { AMLEnvironmentService } from './environments/amlEnvironmentService'; +import { OpenPaiEnvironmentService } from './environments/openPaiEnvironmentService'; +import { LocalEnvironmentService } from './environments/localEnvironmentService'; +import { RemoteEnvironmentService } from './environments/remoteEnvironmentService'; export type EnvironmentStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED'; @@ -170,6 +174,23 @@ export abstract class EnvironmentService { } } +export class EnvironmentServiceFactory { + public static createEnvironmentService(name: string): EnvironmentService { + switch(name) { + case 'local': + return new LocalEnvironmentService(); + case 'remote': + return new RemoteEnvironmentService(); + case 'aml': + return new AMLEnvironmentService(); + case 'pai': + return new OpenPaiEnvironmentService(); + default: + throw new Error(`${name} not supported!`); + } + } +} + export class NodeInformation { public id: string; public status: TrialJobStatus = "UNKNOWN"; diff --git a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts index d8b5668950..a9b5ee58fd 100644 --- a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts @@ -20,7 +20,7 @@ import { AMLCommandChannel } from '../channels/amlCommandChannel'; /** - * Collector AML jobs info from AML cluster, and update pai job status locally + * Collector AML jobs info from AML cluster, and update aml job status locally */ @component.Singleton export class AMLEnvironmentService extends EnvironmentService { @@ -40,7 +40,7 @@ export class AMLEnvironmentService extends EnvironmentService { public get hasStorageService(): boolean { return false; } - + public initCommandChannel(eventEmitter: EventEmitter): void { this.commandChannel = new AMLCommandChannel(eventEmitter); } @@ -73,7 +73,7 @@ export class AMLEnvironmentService extends EnvironmentService { this.log.debug(`AML not proccessed metadata key: '${key}', value: '${value}'`); } } - + public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise { environments.forEach(async (environment) => { const amlClient = (environment as AMLEnvironmentInformation).amlClient; diff --git a/ts/nni_manager/training_service/reusable/trialDispatcher.ts b/ts/nni_manager/training_service/reusable/trialDispatcher.ts index 9dd82aa5b8..b3ea1612e6 100644 --- a/ts/nni_manager/training_service/reusable/trialDispatcher.ts +++ b/ts/nni_manager/training_service/reusable/trialDispatcher.ts @@ -21,15 +21,11 @@ import { TrialConfig } from '../common/trialConfig'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { validateCodeDir } from '../common/util'; import { Command, CommandChannel } from './commandChannel'; -import { EnvironmentInformation, EnvironmentService, NodeInformation, RunnerSettings, TrialGpuSummary, Channel } from './environment'; +import { EnvironmentInformation, EnvironmentService, NodeInformation, RunnerSettings, TrialGpuSummary, EnvironmentServiceFactory } from './environment'; import { GpuScheduler } from './gpuScheduler'; import { MountedStorageService } from './storages/mountedStorageService'; import { StorageService } from './storageService'; import { TrialDetail } from './trial'; -import { AMLEnvironmentService } from './environments/amlEnvironmentService'; -import { OpenPaiEnvironmentService } from './environments/openPaiEnvironmentService'; -import { LocalEnvironmentService } from './environments/localEnvironmentService'; -import { RemoteEnvironmentService } from './environments/remoteEnvironmentService'; /** @@ -285,23 +281,7 @@ class TrialDispatcher implements TrainingService { case TrialConfigMetadataKey.PLATFORM_LIST: { const platforms: string[] = value.split(","); for(const platform of platforms) { - let environmentService: EnvironmentService; - switch(platform) { - case 'local': - environmentService = new LocalEnvironmentService(); - break; - case 'remote': - environmentService = new RemoteEnvironmentService(); - break; - case 'aml': - environmentService = new AMLEnvironmentService(); - break; - case 'pai': - environmentService = new OpenPaiEnvironmentService(); - break; - default: - throw new Error(`${platform} not supported!`); - } + let environmentService: EnvironmentService = EnvironmentServiceFactory.createEnvironmentService(platform); environmentService.initCommandChannel(this.commandEmitter); this.environmentMaintenceLoopInterval = Math.max(environmentService.environmentMaintenceLoopInterval, this.environmentMaintenceLoopInterval); From 6b880486775f82dc6881bd2e8b402c968b577ec9 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Fri, 11 Dec 2020 15:09:43 +0800 Subject: [PATCH 20/24] refactor ut --- ts/nni_manager/training_service/reusable/environment.ts | 2 +- .../training_service/reusable/test/trialDispatcher.test.ts | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ts/nni_manager/training_service/reusable/environment.ts b/ts/nni_manager/training_service/reusable/environment.ts index f72cc37e16..d1bc747df2 100644 --- a/ts/nni_manager/training_service/reusable/environment.ts +++ b/ts/nni_manager/training_service/reusable/environment.ts @@ -134,7 +134,7 @@ export abstract class EnvironmentService { public abstract stopEnvironment(environment: EnvironmentInformation): Promise; public abstract startEnvironment(environment: EnvironmentInformation): Promise; // Make public for ut - public commandChannel: CommandChannel | undefined; + protected commandChannel: CommandChannel | undefined; // It is used to set prefetched environment count, default value is 0 for OpenPAI and AML mode, // in remote mode, this value is set to the length of machine list. diff --git a/ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts b/ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts index 1fb721067f..ba835cedb3 100644 --- a/ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts +++ b/ts/nni_manager/training_service/reusable/test/trialDispatcher.test.ts @@ -208,8 +208,8 @@ describe('Unit Test for TrialDispatcher', () => { environmentServiceList.push(environmentService); trialDispatcher.environmentServiceList = environmentServiceList; // set ut command channel - commandChannel = new UtCommandChannel(trialDispatcher.commandEmitter); - environmentService.commandChannel = commandChannel; + environmentService.initCommandChannel(trialDispatcher.commandEmitter); + commandChannel = environmentService.getCommandChannel as UtCommandChannel; trialDispatcher.commandChannelSet = new Set().add(environmentService.getCommandChannel); trialDispatcher.environmentMaintenceLoopInterval = 1000; From 09c2131677d3b88db925206734a8a07cc0274ec5 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Fri, 11 Dec 2020 19:58:03 +0800 Subject: [PATCH 21/24] fix eslint --- .../training_service/reusable/channels/webCommandChannel.ts | 2 +- .../reusable/environments/amlEnvironmentService.ts | 1 - ts/nni_manager/training_service/reusable/trialDispatcher.ts | 6 +++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts b/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts index 773180ad81..f292b5451a 100644 --- a/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts +++ b/ts/nni_manager/training_service/reusable/channels/webCommandChannel.ts @@ -47,7 +47,7 @@ export class WebCommandChannel extends CommandChannel { super(commandEmitter); } - public static getInstance(commandEmitter: EventEmitter) { + public static getInstance(commandEmitter: EventEmitter): CommandChannel { if (!this.commandChannel) { this.commandChannel = new WebCommandChannel(commandEmitter); } diff --git a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts index a9b5ee58fd..6a59b81c0e 100644 --- a/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/amlEnvironmentService.ts @@ -13,7 +13,6 @@ import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey'; import { validateCodeDir } from '../../common/util'; import { AMLClient } from '../aml/amlClient'; import { AMLClusterConfig, AMLEnvironmentInformation, AMLTrialConfig } from '../aml/amlConfig'; -import { Channel } from '../environment'; import { EnvironmentInformation, EnvironmentService } from '../environment'; import { EventEmitter } from "events"; import { AMLCommandChannel } from '../channels/amlCommandChannel'; diff --git a/ts/nni_manager/training_service/reusable/trialDispatcher.ts b/ts/nni_manager/training_service/reusable/trialDispatcher.ts index b3ea1612e6..941b080328 100644 --- a/ts/nni_manager/training_service/reusable/trialDispatcher.ts +++ b/ts/nni_manager/training_service/reusable/trialDispatcher.ts @@ -281,7 +281,7 @@ class TrialDispatcher implements TrainingService { case TrialConfigMetadataKey.PLATFORM_LIST: { const platforms: string[] = value.split(","); for(const platform of platforms) { - let environmentService: EnvironmentService = EnvironmentServiceFactory.createEnvironmentService(platform); + const environmentService: EnvironmentService = EnvironmentServiceFactory.createEnvironmentService(platform); environmentService.initCommandChannel(this.commandEmitter); this.environmentMaintenceLoopInterval = Math.max(environmentService.environmentMaintenceLoopInterval, this.environmentMaintenceLoopInterval); @@ -347,7 +347,7 @@ class TrialDispatcher implements TrainingService { if (!environmentServiceDict.has(environment.environmentService)) { environmentServiceDict.set(environment.environmentService, [environment]); } else { - let environmentsList: EnvironmentInformation[] | undefined = environmentServiceDict.get(environment.environmentService); + const environmentsList: EnvironmentInformation[] | undefined = environmentServiceDict.get(environment.environmentService); if (environmentsList === undefined) { throw new Error(`Environment list not initialized!`); } @@ -356,7 +356,7 @@ class TrialDispatcher implements TrainingService { } } // Refresh all environments - let taskList: Promise[] = []; + const taskList: Promise[] = []; for (const environmentService of environmentServiceDict.keys()) { const environmentsList: EnvironmentInformation[] | undefined = environmentServiceDict.get(environmentService); if (environmentsList) { From f383650e43f23aa30d120b38845fab0ff06dd71c Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Mon, 14 Dec 2020 10:34:40 +0800 Subject: [PATCH 22/24] fix sphinx --- .../{HeterogeneousMode.md => HeterogeneousMode.rst} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/en_US/TrainingService/{HeterogeneousMode.md => HeterogeneousMode.rst} (100%) diff --git a/docs/en_US/TrainingService/HeterogeneousMode.md b/docs/en_US/TrainingService/HeterogeneousMode.rst similarity index 100% rename from docs/en_US/TrainingService/HeterogeneousMode.md rename to docs/en_US/TrainingService/HeterogeneousMode.rst From 6fcaa3df7eba87dba4f1263f457df263aec796cc Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Mon, 14 Dec 2020 22:32:48 +0800 Subject: [PATCH 23/24] fix ut --- .../TrainingService/HeterogeneousMode.rst | 67 +++++++++---------- .../training_service/reusable/environment.ts | 21 ------ .../environments/environmentServiceFactory.ts | 22 ++++++ .../reusable/trialDispatcher.ts | 3 +- 4 files changed, 56 insertions(+), 57 deletions(-) create mode 100644 ts/nni_manager/training_service/reusable/environments/environmentServiceFactory.ts diff --git a/docs/en_US/TrainingService/HeterogeneousMode.rst b/docs/en_US/TrainingService/HeterogeneousMode.rst index 948a239279..150e9b7bde 100644 --- a/docs/en_US/TrainingService/HeterogeneousMode.rst +++ b/docs/en_US/TrainingService/HeterogeneousMode.rst @@ -1,5 +1,6 @@ **Run an Experiment on Heterogeneous Mode** -=== +================================= + Run NNI on heterogeneous mode means that NNI will run trials jobs in multiple kinds of training platforms. For example, NNI could submit trial jobs to remote machine and AML simultaneously。 ## Setup environment @@ -10,40 +11,36 @@ NNI has supported [local](./LocalMode.md), [remote](./RemoteMachineMode.md), [pa ## Run an experiment Use `examples/trials/mnist-tfv1` as an example. The NNI config YAML file's content is like: -```yaml -authorName: default -experimentName: example_mnist -trialConcurrency: 2 -maxExecDuration: 1h -maxTrialNum: 10 -trainingServicePlatform: heterogeneous -searchSpacePath: search_space.json -#choice: true, false -useAnnotation: false -tuner: - #choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner - #SMAC (SMAC should be installed through nnictl) - builtinTunerName: TPE - classArgs: - #choice: maximize, minimize - optimize_mode: maximize -trial: - command: python3 mnist.py - codeDir: . - gpuNum: 1 -heterogeneousConfig: - trainingServicePlatforms: - - local - - remote -remoteConfig: - reuse: true -machineList: - - ip: 10.1.1.1 - username: bob - passwd: bob123 - #port can be skip if using default ssh port 22 - #port: 22 -``` +.. code-block:: yaml + authorName: default + experimentName: example_mnist + trialConcurrency: 2 + maxExecDuration: 1h + maxTrialNum: 10 + trainingServicePlatform: heterogeneous + searchSpacePath: search_space.json + #choice: true, false + useAnnotation: false + tuner: + builtinTunerName: TPE + classArgs: + #choice: maximize, minimize + optimize_mode: maximize + trial: + command: python3 mnist.py + codeDir: . + gpuNum: 1 + heterogeneousConfig: + trainingServicePlatforms: + - local + - remote + remoteConfig: + reuse: true + machineList: + - ip: 10.1.1.1 + username: bob + passwd: bob123 + Configurations for heterogeneous mode: heterogeneousConfig: diff --git a/ts/nni_manager/training_service/reusable/environment.ts b/ts/nni_manager/training_service/reusable/environment.ts index d1bc747df2..3f021676db 100644 --- a/ts/nni_manager/training_service/reusable/environment.ts +++ b/ts/nni_manager/training_service/reusable/environment.ts @@ -9,10 +9,6 @@ import { GPUInfo } from "../../training_service/common/gpuData"; import { CommandChannel } from "./commandChannel"; import { WebCommandChannel } from './channels/webCommandChannel'; import { EventEmitter } from "events"; -import { AMLEnvironmentService } from './environments/amlEnvironmentService'; -import { OpenPaiEnvironmentService } from './environments/openPaiEnvironmentService'; -import { LocalEnvironmentService } from './environments/localEnvironmentService'; -import { RemoteEnvironmentService } from './environments/remoteEnvironmentService'; export type EnvironmentStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED'; @@ -174,23 +170,6 @@ export abstract class EnvironmentService { } } -export class EnvironmentServiceFactory { - public static createEnvironmentService(name: string): EnvironmentService { - switch(name) { - case 'local': - return new LocalEnvironmentService(); - case 'remote': - return new RemoteEnvironmentService(); - case 'aml': - return new AMLEnvironmentService(); - case 'pai': - return new OpenPaiEnvironmentService(); - default: - throw new Error(`${name} not supported!`); - } - } -} - export class NodeInformation { public id: string; public status: TrialJobStatus = "UNKNOWN"; diff --git a/ts/nni_manager/training_service/reusable/environments/environmentServiceFactory.ts b/ts/nni_manager/training_service/reusable/environments/environmentServiceFactory.ts new file mode 100644 index 0000000000..2a94e0c993 --- /dev/null +++ b/ts/nni_manager/training_service/reusable/environments/environmentServiceFactory.ts @@ -0,0 +1,22 @@ +import { AMLEnvironmentService } from './amlEnvironmentService'; +import { OpenPaiEnvironmentService } from './openPaiEnvironmentService'; +import { LocalEnvironmentService } from './localEnvironmentService'; +import { RemoteEnvironmentService } from './remoteEnvironmentService'; +import { EnvironmentService } from '../environment'; + +export class EnvironmentServiceFactory { + public static createEnvironmentService(name: string): EnvironmentService { + switch(name) { + case 'local': + return new LocalEnvironmentService(); + case 'remote': + return new RemoteEnvironmentService(); + case 'aml': + return new AMLEnvironmentService(); + case 'pai': + return new OpenPaiEnvironmentService(); + default: + throw new Error(`${name} not supported!`); + } + } +} diff --git a/ts/nni_manager/training_service/reusable/trialDispatcher.ts b/ts/nni_manager/training_service/reusable/trialDispatcher.ts index 941b080328..1316fad2a5 100644 --- a/ts/nni_manager/training_service/reusable/trialDispatcher.ts +++ b/ts/nni_manager/training_service/reusable/trialDispatcher.ts @@ -21,7 +21,8 @@ import { TrialConfig } from '../common/trialConfig'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { validateCodeDir } from '../common/util'; import { Command, CommandChannel } from './commandChannel'; -import { EnvironmentInformation, EnvironmentService, NodeInformation, RunnerSettings, TrialGpuSummary, EnvironmentServiceFactory } from './environment'; +import { EnvironmentInformation, EnvironmentService, NodeInformation, RunnerSettings, TrialGpuSummary } from './environment'; +import { EnvironmentServiceFactory } from './environments/environmentServiceFactory'; import { GpuScheduler } from './gpuScheduler'; import { MountedStorageService } from './storages/mountedStorageService'; import { StorageService } from './storageService'; From 396682760c7402edfc30c5d06a0c8d37d1549f73 Mon Sep 17 00:00:00 2001 From: Shinai Yang Date: Mon, 14 Dec 2020 23:00:30 +0800 Subject: [PATCH 24/24] fix Sphinx doc --- docs/en_US/TrainingService/HeterogeneousMode.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/en_US/TrainingService/HeterogeneousMode.rst b/docs/en_US/TrainingService/HeterogeneousMode.rst index 150e9b7bde..56b762f748 100644 --- a/docs/en_US/TrainingService/HeterogeneousMode.rst +++ b/docs/en_US/TrainingService/HeterogeneousMode.rst @@ -1,5 +1,5 @@ **Run an Experiment on Heterogeneous Mode** -================================= +=========================================== Run NNI on heterogeneous mode means that NNI will run trials jobs in multiple kinds of training platforms. For example, NNI could submit trial jobs to remote machine and AML simultaneously。 @@ -12,6 +12,7 @@ NNI has supported [local](./LocalMode.md), [remote](./RemoteMachineMode.md), [pa Use `examples/trials/mnist-tfv1` as an example. The NNI config YAML file's content is like: .. code-block:: yaml + authorName: default experimentName: example_mnist trialConcurrency: 2