Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add native support for v2 config #3466

Merged
merged 28 commits into from
Apr 9, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions nni/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,26 +78,26 @@ def _run_advisor(exp_params):


def _create_tuner(exp_params):
if exp_params.get('tuner').get('builtinTunerName'):
if exp_params['tuner'].get('name'):
tuner = create_builtin_class_instance(
exp_params.get('tuner').get('builtinTunerName'),
exp_params.get('tuner').get('classArgs'),
exp_params['tuner']['name'],
exp_params['tuner'].get('classArgs'),
'tuners')
else:
tuner = create_customized_class_instance(exp_params.get('tuner'))
tuner = create_customized_class_instance(exp_params['tuner'])
if tuner is None:
raise AssertionError('Failed to create Tuner instance')
return tuner


def _create_assessor(exp_params):
if exp_params.get('assessor').get('builtinAssessorName'):
if exp_params['assessor'].get('name'):
assessor = create_builtin_class_instance(
exp_params.get('assessor').get('builtinAssessorName'),
exp_params.get('assessor').get('classArgs'),
exp_params['assessor']['name'],
exp_params['assessor'].get('classArgs'),
'assessors')
else:
assessor = create_customized_class_instance(exp_params.get('assessor'))
assessor = create_customized_class_instance(exp_params['assessor'])
if assessor is None:
raise AssertionError('Failed to create Assessor instance')
return assessor
Expand Down
1 change: 1 addition & 0 deletions nni/experiment/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .kubeflow import *
from .frameworkcontroller import *
from .adl import *
from .shared_storage import *
20 changes: 17 additions & 3 deletions nni/experiment/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from ruamel.yaml import YAML

from .base import ConfigBase, PathLike
from . import util

Expand All @@ -27,13 +29,11 @@ def validate(self):
super().validate()
_validate_algo(self)


@dataclass(init=False)
class AlgorithmConfig(_AlgorithmConfig):
name: str
class_args: Optional[Dict[str, Any]] = None


@dataclass(init=False)
class CustomAlgorithmConfig(_AlgorithmConfig):
class_name: str
Expand All @@ -44,6 +44,12 @@ class CustomAlgorithmConfig(_AlgorithmConfig):
class TrainingServiceConfig(ConfigBase):
platform: str

class SharedStorageConfig(ConfigBase):
storage_type: str
local_mount_point: str
remote_mount_point: str
local_mounted: str


@dataclass(init=False)
class ExperimentConfig(ConfigBase):
Expand All @@ -53,7 +59,7 @@ class ExperimentConfig(ConfigBase):
trial_command: str
trial_code_directory: PathLike = '.'
trial_concurrency: int
trial_gpu_number: Optional[int] = None
trial_gpu_number: Optional[int] = None # TODO: in openpai cannot be None
max_experiment_duration: Optional[str] = None
max_trial_number: Optional[int] = None
nni_manager_ip: Optional[str] = None
Expand All @@ -66,6 +72,8 @@ class ExperimentConfig(ConfigBase):
assessor: Optional[_AlgorithmConfig] = None
advisor: Optional[_AlgorithmConfig] = None
training_service: Union[TrainingServiceConfig, List[TrainingServiceConfig]]
shared_storage: Optional[SharedStorageConfig] = None
_deprecated: Optional[Dict[str, Any]] = None

def __init__(self, training_service_platform: Optional[Union[str, List[str]]] = None, **kwargs):
base_path = kwargs.pop('_base_path', None)
Expand Down Expand Up @@ -100,6 +108,12 @@ def validate(self, initialized_tuner: bool = False) -> None:
if self.training_service.use_active_gpu is None:
raise ValueError('Please set "use_active_gpu"')

def json(self) -> Dict[str, Any]:
obj = super().json()
if obj.get('searchSpaceFile'):
obj['searchSpace'] = YAML().load(open(obj.pop('searchSpaceFile')))
return obj

## End of public API ##

@property
Expand Down
Loading