diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py index cf9cd809aa8d..afbfb4c03988 100644 --- a/python/tvm/autotvm/task/space.py +++ b/python/tvm/autotvm/task/space.py @@ -19,7 +19,7 @@ """ Template configuration space. -Each template function can be parametrized by a ConfigSpace. +Each template function can be parameterized by a ConfigSpace. The space is declared when we invoke the template function with ConfigSpace. During evaluation, we pass in a ConfigEntity, which contains a specific entity in the space. This entity contains deterministic parameters. @@ -63,7 +63,7 @@ class TransformSpace(object): Each operator has some tunable parameters (e.g. the split factor). Then the tuning process is just to find good parameters of these op. - So the all the combinations of the parameters of these op forms our search space. + So all the combinations of the parameters of these op form our search space. Naming convention: We call the set of all possible values as XXXSpace. (XXX can be Split, Reorder, Config ...) @@ -797,7 +797,7 @@ def add_flop(self, flop): def raise_error(self, msg): """register error in config - Using this to actively detect error when scheudling. + Using this to actively detect error when scheduling. Otherwise these error will occur during runtime, which will cost more time. @@ -848,6 +848,8 @@ def get(self, index): index: int index in the space """ + if index < 0 or index >= len(self): + raise IndexError("Index out of range: size {}, got index {}".format(len(self), index)) entities = OrderedDict() t = index for name, space in self.space_map.items(): diff --git a/tests/python/unittest/test_autotvm_common.py b/tests/python/unittest/test_autotvm_common.py index 917036fc24a1..60f7d8bafb1b 100644 --- a/tests/python/unittest/test_autotvm_common.py +++ b/tests/python/unittest/test_autotvm_common.py @@ -101,6 +101,6 @@ def get_sample_records(n): inps, ress = [], [] for i in range(n): - inps.append(MeasureInput(target, tsk, tsk.config_space.get(i))) + inps.append(MeasureInput(target, tsk, tsk.config_space.get(i % len(tsk.config_space)))) ress.append(MeasureResult((i + 1,), 0, i, time.time())) return list(zip(inps, ress))