-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Improve docstring of base tuner and assessor #1669
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,45 +18,113 @@ | |
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | ||
# ================================================================================================== | ||
|
||
""" | ||
Assessor analyzes trial's intermediate results (e.g., periodically evaluated accuracy on test dataset) | ||
to tell whether this trial can be early stopped or not. | ||
|
||
See :class:`Assessor`' specification and ``docs/en_US/assessors.rst`` for details. | ||
""" | ||
|
||
import logging | ||
from enum import Enum | ||
import logging | ||
|
||
from .recoverable import Recoverable | ||
|
||
__all__ = ['AssessResult', 'Assessor'] | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
class AssessResult(Enum): | ||
""" | ||
Enum class for :meth:`Assessor.assess_trial` return value. | ||
""" | ||
|
||
Good = True | ||
"""The trial works well.""" | ||
|
||
Bad = False | ||
"""The trial works poorly and should be early stopped.""" | ||
|
||
|
||
class Assessor(Recoverable): | ||
""" | ||
Assessor analyzes trial's intermediate results (e.g., periodically evaluated accuracy on test dataset) | ||
to tell whether this trial can be early stopped or not. | ||
|
||
This is the abstract base class for all assessors. | ||
Early stopping algorithms should derive this class and override :meth:`assess_trial` method, | ||
which receives intermediate results from trials and give an assessing result. | ||
|
||
If :meth:`assess_trial` returns :obj:`AssessResult.Bad` for a trial, | ||
it hints NNI framework that the trial is likely to result in a poor final accuracy, | ||
and therefore should be killed to save resource. | ||
|
||
If an accessor want's to get notified when a trial ends, it can also override :meth:`trial_end`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wants to be notified |
||
|
||
To write a new assessor, you can reference :class:`~nni.medianstop_assessor.MedianstopAssessor`'s code as an example. | ||
|
||
See Also | ||
-------- | ||
Builtin assessors: | ||
:class:`~nni.medianstop_assessor.MedianstopAssessor` | ||
:class:`~nni.curvefitting_assessor.CurvefittingAssessor` | ||
""" | ||
|
||
def assess_trial(self, trial_job_id, trial_history): | ||
"""Determines whether a trial should be killed. Must override. | ||
trial_job_id: identifier of the trial (str). | ||
trial_history: a list of intermediate result objects. | ||
Returns AssessResult.Good or AssessResult.Bad. | ||
""" | ||
Abstract method for determining whether a trial should be killed. Must override. | ||
|
||
The NNI framework has little guarantee on ``trial_history``. | ||
This method is not guaranteed to be invoked for each time ``trial_history`` get updated. | ||
It is also possible that a trial's history keeps updateing after receiving a bad result. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updating |
||
And if the trial failed and retried, ``trial_history`` may be inconsistent with its previous value. | ||
|
||
The only guarantee is that ``trial_history`` is always growing. | ||
It will not be empty and will always be longer than previous value. | ||
|
||
This is an example of how :meth:`assess_trial` get invoked sequentially: | ||
|
||
:: | ||
|
||
trial_job_id | trial_history | return value | ||
------------ | --------------- | ------------ | ||
Trial_A | [1.0, 2.0] | Good | ||
Trial_B | [1.5, 1.3] | Bad | ||
Trial_B | [1.5, 1.3, 1.9] | Good | ||
Trial_A | [0.9, 1.8, 2.3] | Good | ||
|
||
Parameters | ||
---------- | ||
trial_job_id: str | ||
Unique identifier of the trial. | ||
trial_history: list | ||
Intermediate results of this trial. The element type is decided by trial code. | ||
|
||
Returns | ||
------- | ||
AssessResult | ||
:obj:`AssessResult.Good` or :obj:`AssessResult.Bad`. | ||
""" | ||
raise NotImplementedError('Assessor: assess_trial not implemented') | ||
|
||
def trial_end(self, trial_job_id, success): | ||
"""Invoked when a trial is completed or terminated. Do nothing by default. | ||
trial_job_id: identifier of the trial (str). | ||
success: True if the trial successfully completed; False if failed or terminated. | ||
""" | ||
Abstract method invoked when a trial is completed or terminated. Do nothing by default. | ||
|
||
def load_checkpoint(self): | ||
"""Load the checkpoint of assessr. | ||
path: checkpoint directory for assessor | ||
Parameters | ||
---------- | ||
trial_job_id: str | ||
Unique identifier of the trial. | ||
success: bool | ||
True if the trial successfully completed; False if failed or terminated. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I don't know about Python. Shouldn't there be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pylint with defaut configuration will complain if there is a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If Pylint didn't complain, I'm fine with it. |
||
""" | ||
|
||
def load_checkpoint(self): | ||
checkpoin_path = self.get_checkpoint_path() | ||
_logger.info('Load checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path) | ||
|
||
def save_checkpoint(self): | ||
"""Save the checkpoint of assessor. | ||
path: checkpoint directory for assessor | ||
""" | ||
checkpoin_path = self.get_checkpoint_path() | ||
_logger.info('Save checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .curvefitting_assessor import CurvefittingAssessor |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .medianstop_assessor import MedianstopAssessor |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,31 +17,119 @@ | |
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT | ||
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | ||
# ================================================================================================== | ||
|
||
""" | ||
Tuner is an AutoML algorithm, which generates a new configuration for the next try. | ||
A new trial will run with this configuration. | ||
|
||
See :class:`Tuner`' specification and ``docs/en_US/tuners.rst`` for details. | ||
""" | ||
|
||
import logging | ||
|
||
import nni | ||
|
||
from .recoverable import Recoverable | ||
|
||
__all__ = ['Tuner'] | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
class Tuner(Recoverable): | ||
""" | ||
Tuner is an AutoML algorithm, which generates a new configuration for the next try. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. generates configurations to run with. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sentence is copied from overview. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure. "For the next try" looks really Chi-english to me. |
||
A new trial will run with this configuration. | ||
|
||
This is the abstract base class for all tuners. | ||
Tuning algorithms should derive this class and override :meth:`update_search_space`, :meth:`receive_trial_result`, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should inherit / be derived from this class There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ultmaster merged here, please create another pr for your comments |
||
as well as :meth:`generate_parameters` or :meth:`generate_multiple_parameters`. | ||
|
||
After initializing, NNI will first call :meth:`update_search_space` to tell tuner the feasible region, | ||
and then call :meth:`generate_parameters` one or more times to request for hyper-parameter configurations. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better mention There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean |
||
|
||
The framework will train several models with given configuration. | ||
When one of them is finished, the final accuracy will be reported to :meth:`receive_trial_result`. | ||
And then another configuration will be reqeusted and trained, util the whole experiment finish. | ||
|
||
If a tuner want's to know when a trial ends, it can also override :meth:`trial_end`. | ||
|
||
The type/format of search space and hyper-parameters are not limited, | ||
as long as they are JSON-serializable and in sync with trial code. | ||
For HPO tuners, however, there is a widely shared common interface, | ||
which supports ``choice``, ``randint``, ``uniform``, and so on. | ||
See ``docs/en_US/Tutorial/SearchSpaceSpec.md`` for details of this interface. | ||
|
||
[WIP] For advanced tuners which take advantage of trials' intermediate results, | ||
an ``Advisor`` interface is under development. | ||
|
||
See Also | ||
-------- | ||
Builtin tuners: | ||
:class:`~nni.hyperopt_tuner.hyperopt_tuner.HyperoptTuner` | ||
:class:`~nni.evolution_tuner.evolution_tuner.EvolutionTuner` | ||
:class:`~nni.smac_tuner.smac_tuner.SMACTuner` | ||
:class:`~nni.gridsearch_tuner.gridsearch_tuner.GridSearchTuner` | ||
:class:`~nni.networkmorphism_tuner.networkmorphism_tuner.NetworkMorphismTuner` | ||
:class:`~nni.metis_tuner.mets_tuner.MetisTuner` | ||
""" | ||
|
||
def generate_parameters(self, parameter_id, **kwargs): | ||
"""Returns a set of trial (hyper-)parameters, as a serializable object. | ||
User code must override either this function or 'generate_multiple_parameters()'. | ||
""" | ||
Abstract method which provides one set of hyper-parameters. | ||
|
||
This method will get called when the framework is about to launch a new trial, | ||
if user does not override :meth:`generate_multiple_parameters`. | ||
|
||
The return value will be received by trials via :func:`nni.get_next_parameter`. | ||
It should fit in the search space, though the framework will not verify this. | ||
|
||
User code must override either this method or :meth:`generate_multiple_parameters`. | ||
|
||
Parameters | ||
---------- | ||
parameter_id: int | ||
Unique identifier for requested hyper-parameters. This will later be used in :meth:`receive_trial_result`. | ||
**kwargs: | ||
Unstable parameters which should be ignored by normal users. | ||
|
||
Returns | ||
------- | ||
any | ||
The hyper-parameters, a dict in most cases, but could be any JSON-serializable type when needed. | ||
|
||
Raises | ||
------ | ||
nni.NoMoreTrialError | ||
If the search space is fully explored, tuner can raise this exception. | ||
[FIXME] Currently some tuners also raise this exception when they are waiting more trial results. | ||
""" | ||
raise NotImplementedError('Tuner: generate_parameters not implemented') | ||
|
||
def generate_multiple_parameters(self, parameter_id_list, **kwargs): | ||
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects. | ||
Call 'generate_parameters()' by 'count' times by default. | ||
User code must override either this function or 'generate_parameters()'. | ||
If there's no more trial, user should raise nni.NoMoreTrialError exception in generate_parameters(). | ||
If so, this function will only return sets of trial (hyper-)parameters that have already been collected. | ||
""" | ||
Callback method which provides multiple sets of hyper-parameters. | ||
|
||
This method will get called when the framework is about to launch one or more new trials. | ||
|
||
If user does not override this method, it will invoke :meth:`generate_parameters` on each parameter ID. | ||
|
||
See :meth:`generate_parameters` for details. | ||
|
||
User code must override either this method or :meth:`generate_parameters`. | ||
|
||
Parameters | ||
---------- | ||
parameter_id_list: list of int | ||
Unique identifiers for each set of requested hyper-parameters. | ||
These will later be used in :meth:`receive_trial_result`. | ||
**kwargs: | ||
Unstable parameters which should be ignored by normal users. | ||
|
||
Returns | ||
------- | ||
list | ||
List of hyper-parameters. An empty list indicates there are no more trials. | ||
""" | ||
result = [] | ||
for parameter_id in parameter_id_list: | ||
|
@@ -54,56 +142,74 @@ def generate_multiple_parameters(self, parameter_id_list, **kwargs): | |
return result | ||
|
||
def receive_trial_result(self, parameter_id, parameters, value, **kwargs): | ||
"""Invoked when a trial reports its final result. Must override. | ||
By default this only reports results of algorithm-generated hyper-parameters. | ||
Use `accept_customized_trials()` to receive results from user-added parameters. | ||
""" | ||
Abstract method invoked when a trial reports its final result. Must override. | ||
|
||
This method only reports results of algorithm-generated hyper-parameters. | ||
Currently customized trials added from web UI will not report result to this method. | ||
|
||
Parameters | ||
---------- | ||
parameter_id: int | ||
parameters: object created by 'generate_parameters()' | ||
value: object reported by trial | ||
customized: bool, true if the trial is created from web UI, false if generated by algorithm | ||
trial_job_id: str, only available in multiphase mode. | ||
Unique identifier of used hyper-parameters, same with :meth:`generate_parameters`. | ||
parameters | ||
Hyper-parameters generated by :meth:`generate_parameters`. | ||
value | ||
Result from trial (the return value of :func:`nni.report_final_result`). | ||
**kwargs: | ||
Unstable parameters which should be ignored by normal users. | ||
""" | ||
raise NotImplementedError('Tuner: receive_trial_result not implemented') | ||
|
||
def accept_customized_trials(self, accept=True): | ||
"""Enable or disable receiving results of user-added hyper-parameters. | ||
By default `receive_trial_result()` will only receive results of algorithm-generated hyper-parameters. | ||
If tuners want to receive those of customized parameters as well, they can call this function in `__init__()`. | ||
""" | ||
def _accept_customized_trials(self, accept=True): | ||
# FIXME: because Tuner is designed as interface, this API should not be here | ||
|
||
# Enable or disable receiving results of user-added hyper-parameters. | ||
# By default `receive_trial_result()` will only receive results of algorithm-generated hyper-parameters. | ||
# If tuners want to receive those of customized parameters as well, they can call this function in `__init__()`. | ||
|
||
# pylint: disable=attribute-defined-outside-init | ||
# FIXME: because tuner is designed as interface, this API should not be here | ||
self._accept_customized = accept | ||
|
||
def trial_end(self, parameter_id, success, **kwargs): | ||
"""Invoked when a trial is completed or terminated. Do nothing by default. | ||
parameter_id: int | ||
success: True if the trial successfully completed; False if failed or terminated | ||
""" | ||
Abstract method invoked when a trial is completed or terminated. Do nothing by default. | ||
|
||
Parameters | ||
---------- | ||
trial_job_id: str | ||
Unique identifier of the trial. | ||
success: bool | ||
True if the trial successfully completed; False if failed or terminated. | ||
""" | ||
|
||
def update_search_space(self, search_space): | ||
"""Update the search space of tuner. Must override. | ||
search_space: JSON object | ||
""" | ||
Abstract method for updating the search space. Must override. | ||
|
||
Tuners are advised to support updating search space at run-time. | ||
If a tuner can only set search space once before generating first hyper-parameters, | ||
it should explicitly document this behaviour. | ||
|
||
Parameters | ||
---------- | ||
search_space | ||
JSON object defined by experiment owner. | ||
""" | ||
raise NotImplementedError('Tuner: update_search_space not implemented') | ||
|
||
def load_checkpoint(self): | ||
"""Load the checkpoint of tuner. | ||
path: checkpoint directory for tuner | ||
""" | ||
checkpoin_path = self.get_checkpoint_path() | ||
_logger.info('Load checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path) | ||
|
||
def save_checkpoint(self): | ||
"""Save the checkpoint of tuner. | ||
path: checkpoint directory for tuner | ||
""" | ||
checkpoin_path = self.get_checkpoint_path() | ||
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path) | ||
|
||
def import_data(self, data): | ||
"""Import additional data for tuning | ||
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value' | ||
""" | ||
# Import additional data for tuning | ||
# data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value' | ||
pass | ||
|
||
def _on_exit(self): | ||
pass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.