Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small improvements to searchspaces and simulation mode #251

Merged
merged 7 commits into from
May 23, 2024
2 changes: 2 additions & 0 deletions kernel_tuner/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,8 @@ def tune_kernel(

# create search space
searchspace = Searchspace(tune_params, restrictions, runner.dev.max_threads)
restrictions = searchspace._modified_restrictions
tuning_options.restrictions = restrictions
if verbose:
print(f"Searchspace has {searchspace.size} configurations after restrictions.")

Expand Down
6 changes: 4 additions & 2 deletions kernel_tuner/runners/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def run(self, parameter_space, tuning_options):
continue

# if the element is not in the cache, raise an error
logging.debug(f"kernel configuration {element} not in cache")
raise ValueError(f"Kernel configuration {element} not in cache - in simulation mode, all configurations must be present in the cache")
check = util.check_restrictions(tuning_options.restrictions, dict(zip(tuning_options['tune_params'].keys(), element)), True)
err_string = f"kernel configuration {element} not in cache, does {'' if check else 'not '}pass extra restriction check ({check})"
logging.debug(err_string)
raise ValueError(f"{err_string} - in simulation mode, all configurations must be present in the cache")

return results
15 changes: 15 additions & 0 deletions kernel_tuner/searchspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(
restrictions = restrictions if restrictions is not None else []
self.tune_params = tune_params
self.restrictions = restrictions
# the searchspace can add commonly used constraints (e.g. maxprod(blocks) <= maxthreads)
self._modified_restrictions = restrictions
self.param_names = list(self.tune_params.keys())
self.params_values = tuple(tuple(param_vals) for param_vals in self.tune_params.values())
self.params_values_indices = None
Expand Down Expand Up @@ -166,6 +168,10 @@ def __build_searchspace_bruteforce(self, block_size_names: list, max_threads: in
block_size_restriction_unspaced = f"{'*'.join(used_block_size_names)} <= {max_threads}"
if block_size_restriction_spaced not in restrictions and block_size_restriction_unspaced not in restrictions:
restrictions.append(block_size_restriction_spaced)
if isinstance(self._modified_restrictions, list) and block_size_restriction_spaced not in self._modified_restrictions:
self._modified_restrictions.append(block_size_restriction_spaced)
if isinstance(self.restrictions, list):
self.restrictions.append(block_size_restriction_spaced)

# check for search space restrictions
if restrictions is not None:
Expand Down Expand Up @@ -293,6 +299,11 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
)
if len(valid_block_size_names) > 0:
parameter_space.addConstraint(MaxProdConstraint(max_threads), valid_block_size_names)
max_block_size_product = f"{' * '.join(valid_block_size_names)} <= {max_threads}"
if isinstance(self._modified_restrictions, list) and max_block_size_product not in self._modified_restrictions:
self._modified_restrictions.append(max_block_size_product)
if isinstance(self.restrictions, list):
self.restrictions.append((MaxProdConstraint(max_threads), valid_block_size_names))

# construct the parameter space with the constraints applied
return parameter_space.getSolutionsAsListDict(order=self.param_names)
Expand All @@ -302,10 +313,14 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
if isinstance(self.restrictions, list):
for restriction in self.restrictions:
required_params = self.param_names

# convert to a Constraint type if necessary
if isinstance(restriction, tuple):
restriction, required_params = restriction
if callable(restriction) and not isinstance(restriction, Constraint):
restriction = FunctionConstraint(restriction)

# add the Constraint
if isinstance(restriction, FunctionConstraint):
parameter_space.addConstraint(restriction, required_params)
elif isinstance(restriction, Constraint):
Expand Down
2 changes: 1 addition & 1 deletion kernel_tuner/strategies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __call__(self, x, check_restrictions=True):
# check if max_fevals is reached or time limit is exceeded
util.check_stop_criterion(self.tuning_options)

# snap values in x to nearest actual value for each parameter unscale x if needed
# snap values in x to nearest actual value for each parameter, unscale x if needed
if self.snap:
if self.scaling:
params = unscale_and_snap_to_nearest(x, self.searchspace.tune_params, self.tuning_options.eps)
Expand Down
87 changes: 49 additions & 38 deletions kernel_tuner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ class StopCriterionReached(Exception):
"block_size_x",
"block_size_y",
"block_size_z",
"ngangs",
"nworkers",
"vlength",
]


Expand Down Expand Up @@ -248,9 +245,37 @@ def check_block_size_params_names_list(block_size_names, tune_params):
UserWarning,
)

def check_restriction(restrict, params: dict) -> bool:
"""Check whether a configuration meets a search space restriction."""
# if it's a python-constraint, convert to function and execute
if isinstance(restrict, Constraint):
restrict = convert_constraint_restriction(restrict)
return restrict(list(params.values()))
# if it's a string, fill in the parameters and evaluate
elif isinstance(restrict, str):
return eval(replace_param_occurrences(restrict, params))
# if it's a function, call it
elif callable(restrict):
return restrict(**params)
# if it's a tuple, use only the parameters in the second argument to call the restriction
elif (isinstance(restrict, tuple) and len(restrict) == 2
and callable(restrict[0]) and isinstance(restrict[1], (list, tuple))):
# unpack the tuple
restrict, selected_params = restrict
# look up the selected parameters and their value
selected_params = dict((key, params[key]) for key in selected_params)
# call the restriction
if isinstance(restrict, Constraint):
restrict = convert_constraint_restriction(restrict)
return restrict(list(selected_params.values()))
else:
return restrict(**selected_params)
# otherwise, raise an error
else:
raise ValueError(f"Unkown restriction type {type(restrict)} ({restrict})")

def check_restrictions(restrictions, params: dict, verbose: bool) -> bool:
"""Check whether a specific configuration meets the search space restrictions."""
"""Check whether a configuration meets the search space restrictions."""
if callable(restrictions):
valid = restrictions(params)
if not valid and verbose is True:
Expand All @@ -260,40 +285,13 @@ def check_restrictions(restrictions, params: dict, verbose: bool) -> bool:
for restrict in restrictions:
# Check the type of each restriction and validate accordingly. Re-implement as a switch when Python >= 3.10.
try:
# if it's a python-constraint, convert to function and execute
if isinstance(restrict, Constraint):
restrict = convert_constraint_restriction(restrict)
if not restrict(params.values()):
valid = False
break
# if it's a string, fill in the parameters and evaluate
elif isinstance(restrict, str):
if not eval(replace_param_occurrences(restrict, params)):
valid = False
break
# if it's a function, call it
elif callable(restrict):
if not restrict(**params):
valid = False
break
# if it's a tuple, use only the parameters in the second argument to call the restriction
elif (isinstance(restrict, tuple) and len(restrict) == 2
and callable(restrict[0]) and isinstance(restrict[1], (list, tuple))):
# unpack the tuple
restrict, selected_params = restrict
# look up the selected parameters and their value
selected_params = dict((key, params[key]) for key in selected_params)
# call the restriction
if not restrict(**selected_params):
valid = False
break
# otherwise, raise an error
else:
raise ValueError(f"Unkown restriction type {type(restrict)} ({restrict})")
valid = check_restriction(restrict, params)
if not valid:
break
except ZeroDivisionError:
logging.debug(f"Restriction {restrict} with configuration {get_instance_string(params)} divides by zero.")
if not valid and verbose is True:
print(f"skipping config {get_instance_string(params)}, reason: config fails restriction")
print(f"skipping config {get_instance_string(params)}, reason: config fails restriction {restrict}")
return valid


Expand All @@ -311,6 +309,9 @@ def f_restrict(p):
elif isinstance(restrict, MaxProdConstraint):
def f_restrict(p):
return np.prod(p) <= restrict._maxprod
elif isinstance(restrict, MinProdConstraint):
def f_restrict(p):
return np.prod(p) >= restrict._minprod
elif isinstance(restrict, MaxSumConstraint):
def f_restrict(p):
return sum(p) <= restrict._maxsum
Expand Down Expand Up @@ -1005,6 +1006,9 @@ def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Unio
params_used = list(params_used)
finalized_constraint = None
if try_to_constraint and " or " not in res and " and " not in res:
# if applicable, strip the outermost round brackets
while parsed_restriction[0] == '(' and parsed_restriction[-1] == ')' and '(' not in parsed_restriction[1:] and ')' not in parsed_restriction[:1]:
parsed_restriction = parsed_restriction[1:-1]
# check if we can turn this into the built-in numeric comparison constraint
finalized_constraint = to_numeric_constraint(parsed_restriction, params_used)
if finalized_constraint is None:
Expand Down Expand Up @@ -1059,8 +1063,15 @@ def compile_restrictions(restrictions: list, tune_params: dict, monolithic = Fal
# return the restrictions and used parameters
if len(restrictions_ignore) == 0:
return compiled_restrictions
restrictions_ignore = list(zip(restrictions_ignore, (() for _ in restrictions_ignore)))
return restrictions_ignore + compiled_restrictions

# use the required parameters or add an empty tuple for unknown parameters of ignored restrictions
noncompiled_restrictions = []
for r in restrictions_ignore:
if isinstance(r, tuple) and len(r) == 2 and isinstance(r[1], (list, tuple)):
noncompiled_restrictions.append(r)
else:
noncompiled_restrictions.append((r, ()))
return noncompiled_restrictions + compiled_restrictions


def process_cache(cache, kernel_options, tuning_options, runner):
Expand Down Expand Up @@ -1181,7 +1192,7 @@ def correct_open_cache(cache, open_cache=True):
filestr = cachefile.read().strip()

# if file was not properly closed, pretend it was properly closed
if len(filestr) > 0 and not filestr[-3:] == "}\n}":
if len(filestr) > 0 and not filestr[-3:] in ["}\n}", "}}}"]:
# remove the trailing comma if any, and append closing brackets
if filestr[-1] == ",":
filestr = filestr[:-1]
Expand Down
Loading