From 897abf1322d66304bc0a3e211a5ee57883f2fb5b Mon Sep 17 00:00:00 2001 From: Chengmin Chi Date: Mon, 4 Nov 2019 17:53:46 +0800 Subject: [PATCH] round-robin policy --- .../remote_machine/gpuScheduler.ts | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/nni_manager/training_service/remote_machine/gpuScheduler.ts b/src/nni_manager/training_service/remote_machine/gpuScheduler.ts index 5e7f065971..4244eb8967 100644 --- a/src/nni_manager/training_service/remote_machine/gpuScheduler.ts +++ b/src/nni_manager/training_service/remote_machine/gpuScheduler.ts @@ -28,6 +28,8 @@ import { parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail, ScheduleResultType, SSHClientManager } from './remoteMachineData'; +type SCHEDULE_POLICY_NAME = 'random' | 'round-robin'; + /** * A simple GPU scheduler implementation */ @@ -35,13 +37,18 @@ export class GPUScheduler { private readonly machineSSHClientMap : Map; private readonly log: Logger = getLogger(); + private readonly policyName: SCHEDULE_POLICY_NAME = 'round-robin'; + private roundRobinIndex: number = 0; + private configuredRMs: RemoteMachineMeta[] = []; /** * Constructor * @param machineSSHClientMap map from remote machine to sshClient */ constructor(machineSSHClientMap : Map) { + assert(machineSSHClientMap.size > 0); this.machineSSHClientMap = machineSSHClientMap; + this.configuredRMs = Array.from(machineSSHClientMap.keys()); } /** @@ -189,7 +196,21 @@ export class GPUScheduler { private selectMachine(rmMetas: RemoteMachineMeta[]): RemoteMachineMeta { assert(rmMetas !== undefined && rmMetas.length > 0); - return randomSelect(rmMetas); + if (this.policyName === 'random') { + return randomSelect(rmMetas); + } else if (this.policyName === 'round-robin') { + return this.roundRobinSelect(rmMetas); + } else { + throw new Error(`Unsupported schedule policy: ${this.policyName}`); + } + } + + private roundRobinSelect(rmMetas: RemoteMachineMeta[]): RemoteMachineMeta { + while (!rmMetas.includes(this.configuredRMs[this.roundRobinIndex % this.configuredRMs.length])) { + this.roundRobinIndex++; + } + + return this.configuredRMs[this.roundRobinIndex++ % this.configuredRMs.length]; } private selectGPUsForTrial(gpuInfos: GPUInfo[], requiredGPUNum: number): GPUInfo[] {