diff --git a/kernel_tuner/interface.py b/kernel_tuner/interface.py index e369d608c..42b501b61 100644 --- a/kernel_tuner/interface.py +++ b/kernel_tuner/interface.py @@ -275,6 +275,7 @@ def tune_kernel(kernel_name, kernel_string, problem_size, arguments, max_threads = dev.max_threads #move data to GPU + _check_argument_list(arguments) gpu_args = dev.ready_argument_list(arguments) #compute cartesian product of all tunable parameters @@ -519,3 +520,7 @@ def _check_kernel_correctness(dev, func, gpu_args, threads, grid, answer, instan raise Exception("Error " + instance_string + " failed correctness check") return correct +def _check_argument_list(args): + for (i, arg) in enumerate(args): + if not isinstance(arg, (numpy.ndarray, numpy.generic)): + raise TypeError("Argument at position " + str(i) + " of type: " + str(type(arg)) + " should be of type numpy.ndarray or numpy scalar") diff --git a/test/test_interface.py b/test/test_interface.py index d7624720c..296c23fe6 100644 --- a/test/test_interface.py +++ b/test/test_interface.py @@ -117,3 +117,22 @@ def test_check_kernel_correctness(dev_interface): assert args[1] == 'gpu_args' assert dev.memcpy_dtoh.called == 1 assert test + +def test_check_argument_list1(): + args = [numpy.int32(5), 'blah', numpy.array([1, 2, 3])] + try: + kernel_tuner._check_argument_list(args) + print("Expected a TypeError to be raised") + assert False + except TypeError as e: + print(str(e)) + assert "at position 1" in str(e) + except: + print("Expected a TypeError to be raised") + assert False + +def test_check_argument_list2(): + args = [numpy.int32(5), numpy.float64(4.6), numpy.array([1, 2, 3])] + kernel_tuner._check_argument_list(args) + #test that no exception is raised + assert True