Skip to content

Commit

Permalink
Merge pull request #355 from ecmwf-ifs/naml-fix-get-local-arrays
Browse files Browse the repository at this point in the history
Transformations: Test and fix corner case in get_local_arrays
  • Loading branch information
mlange05 authored Aug 2, 2024
2 parents 7d56565 + 5045915 commit 117c5d1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
23 changes: 19 additions & 4 deletions loki/transformations/tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,13 +459,21 @@ def test_transform_utilites_find_driver_loops(frontend):


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_utilites_get_local_arrays(frontend):
def test_transform_utilites_get_local_arrays(frontend, tmp_path):
""" Test :any:`get_local_arrays` utility. """

fcode = """
subroutine test_get_local_arrays(n, start, end, arr)
module test_get_local_arrays_mod
implicit none
type my_dim
integer :: a(2)
end type my_dim
contains
subroutine test_get_local_arrays(n, dims, start, end, arr)
integer, intent(in) :: n, start, end
real, intent(inout) :: arr(n)
type(my_dim), intent(in) :: dims
real, intent(inout) :: arr(dims%a(2))
real :: local(n), tmp
integer :: i
Expand All @@ -479,8 +487,10 @@ def test_transform_utilites_get_local_arrays(frontend):
ARR(ji) = tmp * local(i)
end do
end subroutine test_get_local_arrays
end module test_get_local_arrays_mod
"""
routine = Subroutine.from_source(fcode, frontend=frontend)
module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = module['test_get_local_arrays']

locals = get_local_arrays(routine, routine.body, unique=True)
assert len(locals) == 1
Expand All @@ -494,6 +504,11 @@ def test_transform_utilites_get_local_arrays(frontend):
assert len(locals) == 1
assert locals[0] == 'local(i)'

# Test for component arrays on arguments in spec
locals = get_local_arrays(routine, routine.spec, unique=True)
assert len(locals) == 1
assert locals[0] == 'local(1:n)' if frontend == OMNI else 'local(n)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_utilites_check_routine_pragmas(frontend, tmp_path):
Expand Down
2 changes: 1 addition & 1 deletion loki/transformations/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def get_local_arrays(routine, section, unique=True):
variables = FindVariables(unique=unique).visit(section)

# Filter all variables by argument name to get local arrays
arrays = [v for v in variables if isinstance(v, sym.Array)]
arrays = [v for v in variables if isinstance(v, sym.Array) and not v.parent]
arrays = [v for v in arrays if str(v.name).lower() not in arg_names]

return arrays
Expand Down

0 comments on commit 117c5d1

Please sign in to comment.