diff --git a/src/nni_manager/training_service/dlts/dltsTrainingService.ts b/src/nni_manager/training_service/dlts/dltsTrainingService.ts index 6414f5236e..e2e5868c46 100644 --- a/src/nni_manager/training_service/dlts/dltsTrainingService.ts +++ b/src/nni_manager/training_service/dlts/dltsTrainingService.ts @@ -38,7 +38,9 @@ class DLTSTrainingService implements TrainingService { private versionCheck: boolean = true; private logCollection: string = 'none'; private isMultiPhase: boolean = false; + private dltsRestServerHost: string; private dltsRestServerPort?: number; + private jobMode: boolean; private readonly trialJobsMap: Map; private nniManagerIpConfig?: NNIManagerIpConfig; @@ -51,7 +53,9 @@ class DLTSTrainingService implements TrainingService { this.trialJobsMap = new Map(); this.jobQueue = []; this.experimentId = getExperimentId(); - this.log.info('Construct DLTS training service.'); + this.dltsRestServerHost = getIPV4Address(); + this.jobMode = 'DLTS_JOB_ID' in process.env; + this.log.info(`Construct DLTS training service in ${this.jobMode ? 'job mode' : 'local mode'}.`); } public async run(): Promise { @@ -60,12 +64,70 @@ class DLTSTrainingService implements TrainingService { await restServer.start(); restServer.setEnableVersionCheck = this.versionCheck; this.log.info(`DLTS Training service rest server listening on: ${restServer.endPoint}`); + if (this.jobMode) { + await this.exposeRestServerPort(restServer.clusterRestServerPort); + } else { + this.dltsRestServerPort = restServer.clusterRestServerPort + } await Promise.all([ this.statusCheckingLoop(), this.submitJobLoop()]); this.log.info('DLTS training service exit.'); } + private async exposeRestServerPort(port: number): Promise { + if (this.dltsClusterConfig == null) { + throw Error('Cluster config is not set'); + } + const { dashboard, cluster, email, password } = this.dltsClusterConfig; + const jobId = process.env['DLTS_JOB_ID'] + ''; + const uri = `${dashboard}api/clusters/${cluster}/jobs/${jobId}/endpoints`; + const qs = { email, password }; + + do { + this.log.debug('Checking endpoints'); + const endpoints = await new Promise((resolve, reject) => { + request.get(uri, { qs, json: true }, function (error, response, body) { + if (error) { + reject(error); + } else { + resolve(body); + } + }); + }); + this.log.debug('Endpoints: %o', endpoints); + if (Array.isArray(endpoints)) { + const restServerEndpoint = endpoints.find(({ podPort }) => podPort === port); + if (restServerEndpoint == null) { + this.log.debug('Exposing %d', port); + await new Promise((resolve, reject) => { + request.post(uri, { + qs, + json: true, + body: { + endpoints: [{ + name: "nni-rest-server", + podPort: port + }] + } + }, function (error) { + if (error) { + reject(error); + } else { + resolve(); + } + }); + }); + } else if (restServerEndpoint['status'] === 'running') { + // We get an exposed restserver port + this.dltsRestServerHost = restServerEndpoint['nodeName']; + this.dltsRestServerPort = restServerEndpoint['port']; + break; + } + } + } while (await new Promise(resolve => setTimeout(resolve, 1000, true))); + } + private async statusCheckingLoop(): Promise { while (!this.stopping) { const updateDLTSTrialJobs: Promise[] = []; @@ -400,7 +462,7 @@ class DLTSTrainingService implements TrainingService { ); } // tslint:disable-next-line: strict-boolean-expressions - const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address(); + const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : this.dltsRestServerHost; const version: string = this.versionCheck ? await getVersion() : ''; const nniDLTSTrialCommand: string = String.Format( DLTS_TRIAL_COMMAND_FORMAT,