diff --git a/src/nni_manager/training_service/remote_machine/gpuScheduler.ts b/src/nni_manager/training_service/remote_machine/gpuScheduler.ts index 5e7f065971..4244eb8967 100644 --- a/src/nni_manager/training_service/remote_machine/gpuScheduler.ts +++ b/src/nni_manager/training_service/remote_machine/gpuScheduler.ts @@ -28,6 +28,8 @@ import { parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail, ScheduleResultType, SSHClientManager } from './remoteMachineData'; +type SCHEDULE_POLICY_NAME = 'random' | 'round-robin'; + /** * A simple GPU scheduler implementation */ @@ -35,13 +37,18 @@ export class GPUScheduler { private readonly machineSSHClientMap : Map; private readonly log: Logger = getLogger(); + private readonly policyName: SCHEDULE_POLICY_NAME = 'round-robin'; + private roundRobinIndex: number = 0; + private configuredRMs: RemoteMachineMeta[] = []; /** * Constructor * @param machineSSHClientMap map from remote machine to sshClient */ constructor(machineSSHClientMap : Map) { + assert(machineSSHClientMap.size > 0); this.machineSSHClientMap = machineSSHClientMap; + this.configuredRMs = Array.from(machineSSHClientMap.keys()); } /** @@ -189,7 +196,21 @@ export class GPUScheduler { private selectMachine(rmMetas: RemoteMachineMeta[]): RemoteMachineMeta { assert(rmMetas !== undefined && rmMetas.length > 0); - return randomSelect(rmMetas); + if (this.policyName === 'random') { + return randomSelect(rmMetas); + } else if (this.policyName === 'round-robin') { + return this.roundRobinSelect(rmMetas); + } else { + throw new Error(`Unsupported schedule policy: ${this.policyName}`); + } + } + + private roundRobinSelect(rmMetas: RemoteMachineMeta[]): RemoteMachineMeta { + while (!rmMetas.includes(this.configuredRMs[this.roundRobinIndex % this.configuredRMs.length])) { + this.roundRobinIndex++; + } + + return this.configuredRMs[this.roundRobinIndex++ % this.configuredRMs.length]; } private selectGPUsForTrial(gpuInfos: GPUInfo[], requiredGPUNum: number): GPUInfo[] {