Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Fix 3rd-party training service bug (#3726)
Browse files Browse the repository at this point in the history
Co-authored-by: liuzhe <zhe.liu@microsoft.com>
  • Loading branch information
liuzhe-lz and liuzhe authored Jun 7, 2021
1 parent d9dd29f commit 6b52fb1
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 134 deletions.
4 changes: 2 additions & 2 deletions nni/tools/nnictl/ts_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def register(args):

try:
service_config = {
'node_module_path': info.node_module_path,
'node_class_name': info.node_class_name,
'nodeModulePath': str(info.node_module_path),
'nodeClassName': info.node_class_name,
}
json.dumps(service_config)
except Exception:
Expand Down
168 changes: 72 additions & 96 deletions ts/nni_manager/common/experimentStartupInfo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,43 @@
import * as assert from 'assert';
import * as os from 'os';
import * as path from 'path';
import * as component from '../common/component';

@component.Singleton
class ExperimentStartupInfo {
private readonly API_ROOT_URL: string = '/api/v1/nni';

private experimentId: string = '';
private newExperiment: boolean = true;
private basePort: number = -1;
private initialized: boolean = false;
private logDir: string = '';
private logLevel: string = '';
private readonly: boolean = false;
private dispatcherPipe: string | null = null;
private platform: string = '';
private urlprefix: string = '';

public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string, urlprefix?: string): void {
assert(!this.initialized);
assert(experimentId.trim().length > 0);

const API_ROOT_URL: string = '/api/v1/nni';

let singleton: ExperimentStartupInfo | null = null;

export class ExperimentStartupInfo {

public experimentId: string = '';
public newExperiment: boolean = true;
public basePort: number = -1;
public initialized: boolean = false;
public logDir: string = '';
public logLevel: string = '';
public readonly: boolean = false;
public dispatcherPipe: string | null = null;
public platform: string = '';
public urlprefix: string = '';

constructor(
newExperiment: boolean,
experimentId: string,
basePort: number,
platform: string,
logDir?: string,
logLevel?: string,
readonly?: boolean,
dispatcherPipe?: string,
urlprefix?: string) {
this.newExperiment = newExperiment;
this.experimentId = experimentId;
this.basePort = basePort;
this.initialized = true;
this.platform = platform;

if (logDir !== undefined && logDir.length > 0) {
this.logDir = path.join(path.normalize(logDir), this.getExperimentId());
this.logDir = path.join(path.normalize(logDir), experimentId);
} else {
this.logDir = path.join(os.homedir(), 'nni-experiments', this.getExperimentId());
this.logDir = path.join(os.homedir(), 'nni-experiments', experimentId);
}

if (logLevel !== undefined && logLevel.length > 1) {
Expand All @@ -55,98 +62,67 @@ class ExperimentStartupInfo {
}
}

public getExperimentId(): string {
assert(this.initialized);

return this.experimentId;
}

public getBasePort(): number {
assert(this.initialized);

return this.basePort;
}

public isNewExperiment(): boolean {
assert(this.initialized);

return this.newExperiment;
}

public getPlatform(): string {
assert(this.initialized);

return this.platform;
public get apiRootUrl(): string {
return this.urlprefix === '' ? API_ROOT_URL : `/${this.urlprefix}${API_ROOT_URL}`;
}

public getLogDir(): string {
assert(this.initialized);

return this.logDir;
}

public getLogLevel(): string {
assert(this.initialized);

return this.logLevel;
}

public isReadonly(): boolean {
assert(this.initialized);

return this.readonly;
}

public getDispatcherPipe(): string | null {
assert(this.initialized);
return this.dispatcherPipe;
}

public getAPIRootUrl(): string {
assert(this.initialized);
return this.urlprefix==''?this.API_ROOT_URL:`/${this.urlprefix}${this.API_ROOT_URL}`;
public static getInstance(): ExperimentStartupInfo {
assert(singleton !== null);
return singleton!;
}
}

function getExperimentId(): string {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getExperimentId();
export function getExperimentStartupInfo(): ExperimentStartupInfo {
return ExperimentStartupInfo.getInstance();
}

function getBasePort(): number {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getBasePort();
export function setExperimentStartupInfo(
newExperiment: boolean,
experimentId: string,
basePort: number,
platform: string,
logDir?: string,
logLevel?: string,
readonly?: boolean,
dispatcherPipe?: string,
urlprefix?: string): void {
singleton = new ExperimentStartupInfo(
newExperiment,
experimentId,
basePort,
platform,
logDir,
logLevel,
readonly,
dispatcherPipe,
urlprefix
);
}

function isNewExperiment(): boolean {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isNewExperiment();
export function getExperimentId(): string {
return getExperimentStartupInfo().experimentId;
}

function getPlatform(): string {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getPlatform();
export function getBasePort(): number {
return getExperimentStartupInfo().basePort;
}

function getExperimentStartupInfo(): ExperimentStartupInfo {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo);
export function isNewExperiment(): boolean {
return getExperimentStartupInfo().newExperiment;
}

function setExperimentStartupInfo(
newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string, urlprefix?: string): void {
component.get<ExperimentStartupInfo>(ExperimentStartupInfo)
.setStartupInfo(newExperiment, experimentId, basePort, platform, logDir, logLevel, readonly, dispatcherPipe, urlprefix);
export function getPlatform(): string {
return getExperimentStartupInfo().platform;
}

function isReadonly(): boolean {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isReadonly();
export function isReadonly(): boolean {
return getExperimentStartupInfo().readonly;
}

function getDispatcherPipe(): string | null {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getDispatcherPipe();
export function getDispatcherPipe(): string | null {
return getExperimentStartupInfo().dispatcherPipe;
}

function getAPIRootUrl(): string {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getAPIRootUrl();
export function getAPIRootUrl(): string {
return getExperimentStartupInfo().apiRootUrl;
}

export {
ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment, getPlatform, getExperimentStartupInfo,
setExperimentStartupInfo, isReadonly, getDispatcherPipe, getAPIRootUrl
};
25 changes: 15 additions & 10 deletions ts/nni_manager/common/pythonScript.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,32 @@
import { spawn } from 'child_process';
import { Logger, getLogger } from './log';

const python = process.platform === 'win32' ? 'python.exe' : 'python3';
const logger: Logger = getLogger('pythonScript');

export async function runPythonScript(script: string, logger?: Logger): Promise<string> {
const python: string = process.platform === 'win32' ? 'python.exe' : 'python3';

export async function runPythonScript(script: string, logTag?: string): Promise<string> {
const proc = spawn(python, [ '-c', script ]);

let stdout: string = '';
let stderr: string = '';
proc.stdout.on('data', (data: string) => { stdout += data; });
proc.stderr.on('data', (data: string) => { stderr += data; });

const procPromise = new Promise<void>((resolve, reject) => {
proc.on('error', (err: Error) => { reject(err); });
proc.on('exit', () => { resolve(); });
});
await procPromise;

const stdout = proc.stdout.read().toString();
const stderr = proc.stderr.read().toString();

if (stderr) {
if (logger === undefined) {
logger = getLogger('pythonScript');
if (logTag) {
logger.warning(`Python script [${logTag}] has stderr:`, stderr);
} else {
logger.warning('Python script has stderr.');
logger.warning(' script:', script);
logger.warning(' stderr:', stderr);
}
logger.warning('python script has stderr.');
logger.warning('script:', script);
logger.warning('stderr:', stderr);
}

return stdout;
Expand Down
10 changes: 5 additions & 5 deletions ts/nni_manager/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@ import * as util from 'util';
import * as glob from 'glob';

import { Database, DataStore } from './datastore';
import { ExperimentStartupInfo, getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo';
import { getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo';
import { ExperimentConfig, Manager } from './manager';
import { ExperimentManager } from './experimentManager';
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';

function getExperimentRootDir(): string {
return getExperimentStartupInfo().getLogDir();
return getExperimentStartupInfo().logDir;
}

function getLogDir(): string {
return path.join(getExperimentRootDir(), 'log');
}

function getLogLevel(): string {
return getExperimentStartupInfo().getLogLevel();
return getExperimentStartupInfo().logLevel;
}

function getDefaultDatabaseDir(): string {
Expand Down Expand Up @@ -184,7 +184,6 @@ function generateParamFileName(hyperParameters: HyperParameters): string {
* Must be paired with `cleanupUnitTest()`.
*/
function prepareUnitTest(): void {
Container.snapshot(ExperimentStartupInfo);
Container.snapshot(Database);
Container.snapshot(DataStore);
Container.snapshot(TrainingService);
Expand Down Expand Up @@ -213,8 +212,9 @@ function cleanupUnitTest(): void {
Container.restore(TrainingService);
Container.restore(DataStore);
Container.restore(Database);
Container.restore(ExperimentStartupInfo);
Container.restore(ExperimentManager);
const logLevel: string = parseArg(['--log_level', '-ll']);
setExperimentStartupInfo(true, 'unittest', 8080, 'unittest', undefined, logLevel);
}

let cachedipv4Address: string = '';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import * as path from 'path';
import * as component from '../../../common/component';
import { getLogger, Logger } from '../../../common/log';
import { ExperimentConfig, AmlConfig, flattenConfig } from '../../../common/experimentConfig';
import { ExperimentStartupInfo } from '../../../common/experimentStartupInfo';
import { validateCodeDir } from '../../common/util';
import { AMLClient } from '../aml/amlClient';
import { AMLEnvironmentInformation } from '../aml/amlConfig';
Expand All @@ -29,10 +30,10 @@ export class AMLEnvironmentService extends EnvironmentService {
private experimentRootDir: string;
private config: FlattenAmlConfig;

constructor(experimentRootDir: string, experimentId: string, config: ExperimentConfig) {
constructor(config: ExperimentConfig, info: ExperimentStartupInfo) {
super();
this.experimentId = experimentId;
this.experimentRootDir = experimentRootDir;
this.experimentId = info.experimentId;
this.experimentRootDir = info.logDir;
this.config = flattenConfig(config, 'aml');
validateCodeDir(this.config.trialCodeDirectory);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,22 @@ import { LocalEnvironmentService } from './localEnvironmentService';
import { RemoteEnvironmentService } from './remoteEnvironmentService';
import { EnvironmentService } from '../environment';
import { ExperimentConfig } from '../../../common/experimentConfig';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { ExperimentStartupInfo } from '../../../common/experimentStartupInfo';
import { getCustomEnvironmentServiceConfig } from '../../../common/nniConfig';
import { getExperimentRootDir, importModule } from '../../../common/utils';

import { importModule } from '../../../common/utils';

export async function createEnvironmentService(name: string, config: ExperimentConfig): Promise<EnvironmentService> {
const expId = getExperimentId();
const rootDir = getExperimentRootDir();
const info = ExperimentStartupInfo.getInstance();

switch(name) {
case 'local':
return new LocalEnvironmentService(rootDir, expId, config);
return new LocalEnvironmentService(config, info);
case 'remote':
return new RemoteEnvironmentService(rootDir, expId, config);
return new RemoteEnvironmentService(config, info);
case 'aml':
return new AMLEnvironmentService(rootDir, expId, config);
return new AMLEnvironmentService(config, info);
case 'openpai':
return new OpenPaiEnvironmentService(rootDir, expId, config);
return new OpenPaiEnvironmentService(config, info);
}

const esConfig = await getCustomEnvironmentServiceConfig(name);
Expand All @@ -30,5 +28,5 @@ export async function createEnvironmentService(name: string, config: ExperimentC
}
const esModule = importModule(esConfig.nodeModulePath);
const esClass = esModule[esConfig.nodeClassName] as any;
return new esClass(rootDir, expId, config);
return new esClass(config, info);
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import * as tkill from 'tree-kill';
import * as component from '../../../common/component';
import { getLogger, Logger } from '../../../common/log';
import { ExperimentConfig } from '../../../common/experimentConfig';
import { ExperimentStartupInfo } from '../../../common/experimentStartupInfo';
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { isAlive, getNewLine } from '../../../common/utils';
import { execMkdir, runScript, getScriptName, execCopydir } from '../../common/util';
Expand All @@ -21,10 +22,10 @@ export class LocalEnvironmentService extends EnvironmentService {
private experimentRootDir: string;
private experimentId: string;

constructor(experimentRootDir: string, experimentId: string, _config: ExperimentConfig) {
constructor(_config: ExperimentConfig, info: ExperimentStartupInfo) {
super();
this.experimentId = experimentId;
this.experimentRootDir = experimentRootDir;
this.experimentId = info.experimentId;
this.experimentRootDir = info.logDir;
}

public get environmentMaintenceLoopInterval(): number {
Expand Down
Loading

0 comments on commit 6b52fb1

Please sign in to comment.