diff --git a/kernel_tuner/strategies/common.py b/kernel_tuner/strategies/common.py index 034fefd6f..2185cb2f7 100644 --- a/kernel_tuner/strategies/common.py +++ b/kernel_tuner/strategies/common.py @@ -203,8 +203,14 @@ def snap_to_nearest_config(x, tune_params): """Helper func that for each param selects the closest actual value.""" params = [] for i, k in enumerate(tune_params.keys()): - values = np.array(tune_params[k]) - idx = np.abs(values - x[i]).argmin() + values = tune_params[k] + + # if `x[i]` is in `values`, use that value, otherwise find the closest match + if x[i] in values: + idx = values.index(x[i]) + else: + idx = np.argmin([abs(v - x[i]) for v in values]) + params.append(values[idx]) return params diff --git a/test/test_common.py b/test/test_common.py index 132068843..7c1bd6838 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -50,9 +50,10 @@ def test_snap_to_nearest_config(): tune_params['x'] = [0, 1, 2, 3, 4, 5] tune_params['y'] = [0, 1, 2, 3, 4, 5] tune_params['z'] = [0, 1, 2, 3, 4, 5] + tune_params['w'] = ['a', 'b', 'c'] - x = [-5.7, 3.14, 1e6] - expected = [0, 3, 5] + x = [-5.7, 3.14, 1e6, 'b'] + expected = [0, 3, 5, 'b'] answer = common.snap_to_nearest_config(x, tune_params) assert answer == expected