Skip to content

Commit

Permalink
Merge pull request #267 from KernelTuner/directives
Browse files Browse the repository at this point in the history
ESiWACE3 hackathon
  • Loading branch information
isazi authored Jun 28, 2024
2 parents abd8de0 + 9b88aaa commit e046cfd
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
35 changes: 28 additions & 7 deletions kernel_tuner/utils/directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,19 @@ class Cxx(Language):
def get(self) -> str:
return "cxx"

def end_string(self) -> str:
return "#pragma tuner stop"


class Fortran(Language):
"""Class to represent Fortran code"""

def get(self) -> str:
return "fortran"

def end_string(self) -> str:
return "!$tuner stop"


class Code(object):
"""Class to represent the directive and host code of the application"""
Expand Down Expand Up @@ -356,24 +362,34 @@ def extract_directive_code(code: str, langs: Code, kernel_name: str = None) -> d
"""Extract explicitly marked directive sections from code"""
if is_cxx(langs.language):
start_string = "#pragma tuner start"
end_string = "#pragma tuner stop"
elif is_fortran(langs.language):
start_string = "!$tuner start"
end_string = "!$tuner stop"

return extract_code(start_string, end_string, code, langs, kernel_name)
return extract_code(start_string, langs.language.end_string(), code, langs, kernel_name)


def extract_initialization_code(code: str, langs: Code) -> str:
"""Extract the initialization section from code"""
if is_cxx(langs.language):
start_string = "#pragma tuner initialize"
end_string = "#pragma tuner stop"
elif is_fortran(langs.language):
start_string = "!$tuner initialize"
end_string = "!$tuner stop"

init_code = extract_code(start_string, end_string, code, langs)
init_code = extract_code(start_string, langs.language.end_string(), code, langs)
if len(init_code) >= 1:
return "\n".join(init_code.values()) + "\n"
else:
return ""


def extract_deinitialization_code(code: str, langs: Code) -> str:
"""Extract the deinitialization section from code"""
if is_cxx(langs.language):
start_string = "#pragma tuner deinitialize"
elif is_fortran(langs.language):
start_string = "!$tuner deinitialize"

init_code = extract_code(start_string, langs.language.end_string(), code, langs)
if len(init_code) >= 1:
return "\n".join(init_code.values()) + "\n"
else:
Expand Down Expand Up @@ -508,6 +524,7 @@ def generate_directive_function(
langs: Code,
data: dict = None,
initialization: str = "",
deinitialization: str = "",
user_dimensions: dict = None,
) -> str:
"""Generate tunable function for one directive"""
Expand Down Expand Up @@ -535,13 +552,17 @@ def generate_directive_function(
else:
code += body
code = end_timing_cxx(code)
if len(deinitialization) > 1:
code += deinitialization + "\n"
code += "\n}"
elif is_fortran(langs.language):
body = wrap_timing(body, langs.language)
if data is not None:
code += wrap_data(body + "\n", langs, data, preprocessor, user_dimensions)
else:
code += body
code += body + "\n"
if len(deinitialization) > 1:
code += deinitialization + "\n"
name = signature.split(" ")[1].split("(")[0]
code += f"\nend function {name}\nend module kt\n"

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ generate-setup-file = false
# ATTENTION: if anything is changed here, run `poetry update`
[tool.poetry.dependencies]
python = ">=3.9,<3.13" # NOTE when changing the supported Python versions, also change the test versions in the noxfile
numpy = ">=1.26.0" # Python 3.12 requires numpy at least 1.26
numpy = "^1.26.0" # Python 3.12 requires numpy at least 1.26
scipy = ">=1.11.0"
packaging = "*" # required by file_utils
jsonschema = "*"
Expand Down
7 changes: 7 additions & 0 deletions test/utils/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,13 @@ def test_extract_initialization_code():
assert extract_initialization_code(code_f90, Code(OpenACC(), Fortran())) == "integer :: value\n"


def test_extract_deinitialization_code():
code_cpp = "#pragma tuner deinitialize\nconst int value = 42;\n#pragma tuner stop\n"
code_f90 = "!$tuner deinitialize\ninteger :: value\n!$tuner stop\n"
assert extract_deinitialization_code(code_cpp, Code(OpenACC(), Cxx())) == "const int value = 42;\n"
assert extract_deinitialization_code(code_f90, Code(OpenACC(), Fortran())) == "integer :: value\n"


def test_add_present_openacc():
acc_cxx = Code(OpenACC(), Cxx())
acc_f90 = Code(OpenACC(), Fortran())
Expand Down

0 comments on commit e046cfd

Please sign in to comment.