diff --git a/kernel_tuner/utils/directives.py b/kernel_tuner/utils/directives.py index d06b0d2c..7d9bb4e7 100644 --- a/kernel_tuner/utils/directives.py +++ b/kernel_tuner/utils/directives.py @@ -32,6 +32,9 @@ class Cxx(Language): def get(self) -> str: return "cxx" + def end_string(self) -> str: + return "#pragma tuner stop" + class Fortran(Language): """Class to represent Fortran code""" @@ -39,6 +42,9 @@ class Fortran(Language): def get(self) -> str: return "fortran" + def end_string(self) -> str: + return "!$tuner stop" + class Code(object): """Class to represent the directive and host code of the application""" @@ -356,24 +362,34 @@ def extract_directive_code(code: str, langs: Code, kernel_name: str = None) -> d """Extract explicitly marked directive sections from code""" if is_cxx(langs.language): start_string = "#pragma tuner start" - end_string = "#pragma tuner stop" elif is_fortran(langs.language): start_string = "!$tuner start" - end_string = "!$tuner stop" - return extract_code(start_string, end_string, code, langs, kernel_name) + return extract_code(start_string, langs.language.end_string(), code, langs, kernel_name) def extract_initialization_code(code: str, langs: Code) -> str: """Extract the initialization section from code""" if is_cxx(langs.language): start_string = "#pragma tuner initialize" - end_string = "#pragma tuner stop" elif is_fortran(langs.language): start_string = "!$tuner initialize" - end_string = "!$tuner stop" - init_code = extract_code(start_string, end_string, code, langs) + init_code = extract_code(start_string, langs.language.end_string(), code, langs) + if len(init_code) >= 1: + return "\n".join(init_code.values()) + "\n" + else: + return "" + + +def extract_deinitialization_code(code: str, langs: Code) -> str: + """Extract the deinitialization section from code""" + if is_cxx(langs.language): + start_string = "#pragma tuner deinitialize" + elif is_fortran(langs.language): + start_string = "!$tuner deinitialize" + + init_code = extract_code(start_string, langs.language.end_string(), code, langs) if len(init_code) >= 1: return "\n".join(init_code.values()) + "\n" else: @@ -508,6 +524,7 @@ def generate_directive_function( langs: Code, data: dict = None, initialization: str = "", + deinitialization: str = "", user_dimensions: dict = None, ) -> str: """Generate tunable function for one directive""" @@ -535,13 +552,17 @@ def generate_directive_function( else: code += body code = end_timing_cxx(code) + if len(deinitialization) > 1: + code += deinitialization + "\n" code += "\n}" elif is_fortran(langs.language): body = wrap_timing(body, langs.language) if data is not None: code += wrap_data(body + "\n", langs, data, preprocessor, user_dimensions) else: - code += body + code += body + "\n" + if len(deinitialization) > 1: + code += deinitialization + "\n" name = signature.split(" ")[1].split("(")[0] code += f"\nend function {name}\nend module kt\n" diff --git a/pyproject.toml b/pyproject.toml index 3175ed34..13d1cb64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ generate-setup-file = false # ATTENTION: if anything is changed here, run `poetry update` [tool.poetry.dependencies] python = ">=3.9,<3.13" # NOTE when changing the supported Python versions, also change the test versions in the noxfile -numpy = ">=1.26.0" # Python 3.12 requires numpy at least 1.26 +numpy = "^1.26.0" # Python 3.12 requires numpy at least 1.26 scipy = ">=1.11.0" packaging = "*" # required by file_utils jsonschema = "*" diff --git a/test/utils/test_directives.py b/test/utils/test_directives.py index 3542cdcf..bed2d871 100644 --- a/test/utils/test_directives.py +++ b/test/utils/test_directives.py @@ -326,6 +326,13 @@ def test_extract_initialization_code(): assert extract_initialization_code(code_f90, Code(OpenACC(), Fortran())) == "integer :: value\n" +def test_extract_deinitialization_code(): + code_cpp = "#pragma tuner deinitialize\nconst int value = 42;\n#pragma tuner stop\n" + code_f90 = "!$tuner deinitialize\ninteger :: value\n!$tuner stop\n" + assert extract_deinitialization_code(code_cpp, Code(OpenACC(), Cxx())) == "const int value = 42;\n" + assert extract_deinitialization_code(code_f90, Code(OpenACC(), Fortran())) == "integer :: value\n" + + def test_add_present_openacc(): acc_cxx = Code(OpenACC(), Cxx()) acc_f90 = Code(OpenACC(), Fortran())