Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoTVM] Add index boundary check in ConfigSpace.get() #7234

Merged
merged 2 commits into from
Jan 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions python/tvm/autotvm/task/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""
Template configuration space.

Each template function can be parametrized by a ConfigSpace.
Each template function can be parameterized by a ConfigSpace.
The space is declared when we invoke the template function with ConfigSpace.
During evaluation, we pass in a ConfigEntity, which contains a specific
entity in the space. This entity contains deterministic parameters.
Expand Down Expand Up @@ -63,7 +63,7 @@ class TransformSpace(object):
Each operator has some tunable parameters (e.g. the split factor).
Then the tuning process is just to find good parameters of these op.

So the all the combinations of the parameters of these op forms our search space.
So all the combinations of the parameters of these op form our search space.

Naming convention:
We call the set of all possible values as XXXSpace. (XXX can be Split, Reorder, Config ...)
Expand Down Expand Up @@ -797,7 +797,7 @@ def add_flop(self, flop):

def raise_error(self, msg):
"""register error in config
Using this to actively detect error when scheudling.
Using this to actively detect error when scheduling.
Otherwise these error will occur during runtime, which
will cost more time.

Expand Down Expand Up @@ -850,6 +850,8 @@ def get(self, index):
index: int
index in the space
"""
if index < 0 or index >= len(self):
raise IndexError("Index out of range: size {}, got index {}".format(len(self), index))
entities = OrderedDict()
t = index
for name, space in self.space_map.items():
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_autotvm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,6 @@ def get_sample_records(n):

inps, ress = [], []
for i in range(n):
inps.append(MeasureInput(target, tsk, tsk.config_space.get(i)))
inps.append(MeasureInput(target, tsk, tsk.config_space.get(i % len(tsk.config_space))))
ress.append(MeasureResult((i + 1,), 0, i, time.time()))
return list(zip(inps, ress))