Skip to content

Commit

Permalink
Fix unit tests for variable transformations. Reorganization a bit to …
Browse files Browse the repository at this point in the history
…be consistent with how the DEBUG checking is implemented in PR NCAR#512
  • Loading branch information
dustinswales committed Dec 4, 2023
1 parent b5cc909 commit e6ffbd5
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 79 deletions.
90 changes: 61 additions & 29 deletions scripts/suite_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,31 +1210,8 @@ def analyze(self, phase, group, scheme_library, suite_vars, level):
# end if
# Are there any forward/reverse transforms for this variable?
if compat_obj is not None and (compat_obj.has_vert_transforms or compat_obj.has_unit_transforms):
# Add local variable (<var>_local) needed for transformation.
tmp_var = var.clone(var.get_prop_value('local_name')+'_local')
self.__group.manage_variable(tmp_var)

# Create indices, flipping if necessary.
indices = [':']*var.get_rank()
if compat_obj.has_vert_transforms:
dim = find_vertical_dimension(var.get_dimensions())
vdim_name = vert_dim.split(':')[-1]
group_vvar = self.__group.call_list.find_variable(vdim_name)
vname = group_vvar.get_prop_value('local_name')
indices[dim[1]] = vname+':1:-1'

# Add any forward transforms.
if (var.get_prop_value('intent') != 'in'):
self.__forward_transforms.append(
compat_obj.forward_transform(lvar_lname=var.get_prop_value('local_name'),
rvar_lname=tmp_var.get_prop_value('local_name'),
indices=indices))
# Add any reverse transforms.
if (var.get_prop_value('intent') != 'out'):
self.__reverse_transforms.append(
compat_obj.reverse_transform(lvar_lname=tmp_var.get_prop_value('local_name'),
rvar_lname=var.get_prop_value('local_name'),
indices=indices))
self.add_var_transform(var, compat_obj, vert_dim)

# end for
if self.needs_vertical is not None:
self.parent.add_part(self, replace=True) # Should add a vloop
Expand All @@ -1246,6 +1223,58 @@ def analyze(self, phase, group, scheme_library, suite_vars, level):
# end if
return scheme_mods

def add_var_transform(self, var, compat_obj, vert_dim):
"""Add variable transformation before/after call to Scheme in <outfile>"""
# Add dummy variable (<var>_local) needed for transformation.
dummy = var.clone(var.get_prop_value('local_name')+'_local')
self.__group.manage_variable(dummy)

# Create indices for transform.
lindices = [':']*var.get_rank()
rindices = [':']*var.get_rank()

# If needed, modify vertical dimension for vertical orientation flipping
dim = find_vertical_dimension(var.get_dimensions())
vdim_name = vert_dim.split(':')[-1]
group_vvar = self.__group.call_list.find_variable(vdim_name)
vname = group_vvar.get_prop_value('local_name')
lindices[dim[1]] = '1:'+vname
rindices[dim[1]] = '1:'+vname
if compat_obj.has_vert_transforms:
rindices[dim[1]] = vname+':1:-1'

# If needed, modify horizontal dimension for loop substitution.
dim = find_horizontal_dimension(var.get_dimensions())
if compat_obj.has_dim_transforms:
print("SWALES: ",dim)

# Add any forward transforms.
if (var.get_prop_value('intent') != 'in'):
self.__forward_transforms.append([var.get_prop_value('local_name'),
dummy.get_prop_value('local_name'),
lindices, rindices, compat_obj])

# Add any reverse transforms.
if (var.get_prop_value('intent') != 'out'):
self.__reverse_transforms.append([dummy.get_prop_value('local_name'),
var.get_prop_value('local_name'),
rindices, lindices, compat_obj])

def write_var_transform(self, var, dummy, rindices, lindices, compat_obj,
outfile, indent, forward):
"""Write variable transformation needed to call this Scheme <outfile>"""
if forward:
stmt = compat_obj.forward_transform(lvar_lname=dummy,
rvar_lname=var,
lvar_indices=lindices,
rvar_indices=rindices)
else:
stmt = compat_obj.reverse_transform(lvar_lname=var,
rvar_lname=dummy,
lvar_indices=rindices,
rvar_indices=lindices)
outfile.write(stmt, indent+1)

def write(self, outfile, errcode, indent):
# Unused arguments are for consistent write interface
# pylint: disable=unused-argument
Expand All @@ -1257,14 +1286,17 @@ def write(self, outfile, errcode, indent):
my_args = self.call_list.call_string(cldicts=cldicts,
is_func_call=True,
subname=self.subroutine_name)
stmt = 'call {}({})'
# Write the scheme call.

outfile.write('if ({} == 0) then'.format(errcode), indent)
# Write any reverse transforms.
for reverse_transform in self.__reverse_transforms: outfile.write(reverse_transform, indent+1)
for (dummy, var, rindices, lindices, compat_obj) in self.__reverse_transforms:
tstmt = self.write_var_transform(dummy, var, rindices, lindices, compat_obj, outfile, indent, False)
# Write the scheme call.
stmt = 'call {}({})'
outfile.write(stmt.format(self.subroutine_name, my_args), indent+1)
# Write any forward transforms.
for forward_transform in self.__forward_transforms: outfile.write(forward_transform, indent+1)
for (var, dummy, lindices, rindices, compat_obj) in self.__forward_transforms:
tstmt = self.write_var_transform(dummy, var, rindices, lindices, compat_obj, outfile, indent, True)
outfile.write('end if', indent)

def schemes(self):
Expand Down
22 changes: 12 additions & 10 deletions scripts/var_props.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,13 +953,15 @@ def __init__(self, var1_stdname, var1_type, var1_kind, var1_units,
# end if
self.__incompat_reason = " and ".join([x for x in incompat_reason if x])

def forward_transform(self, lvar_lname, rvar_lname, indices,
def forward_transform(self, lvar_lname, rvar_lname, rvar_indices, lvar_indices,
adjust_hdim=None, flip_vdim=None):
"""Compute and return the the forward transform from "var1" to "var2".
<lvar_lname> is the local name of "var2".
<rvar_lname> is the local name of "var1".
<indices> is a tuple of the loop indices for "var1" (i.e., "var1"
will show up in the RHS of the transform as "var1(indices)".
<rvar_indices> is a tuple of the loop indices for "var1" (i.e., "var1"
will show up in the RHS of the transform as "var1(rvar_indices)".
<lvar_indices> is a tuple of the loop indices for "var1" (i.e., "var2"
will show up in the LHS of the transform as "var2(lvar_indices)".
If <adjust_hdim> is not None, it should be a string containing the
local name of the "horizontal_loop_begin" variable. This is used to
compute the offset in the horizontal axis index between one and
Expand All @@ -972,8 +974,8 @@ def forward_transform(self, lvar_lname, rvar_lname, indices,
"vertical_interface_dimension").
"""
# Dimension transform (Indices handled externally)
rhs_term = f"{rvar_lname}({','.join(indices)})"
lhs_term = f"{lvar_lname}"
rhs_term = f"{rvar_lname}({','.join(rvar_indices)})"
lhs_term = f"{lvar_lname}({','.join(lvar_indices)})"

if self.has_kind_transforms:
kind = self.__kind_transforms[1]
Expand All @@ -991,13 +993,13 @@ def forward_transform(self, lvar_lname, rvar_lname, indices,
# end if
return f"{lhs_term} = {rhs_term}"

def reverse_transform(self, lvar_lname, rvar_lname, indices,
def reverse_transform(self, lvar_lname, rvar_lname, rvar_indices, lvar_indices,
adjust_hdim=None, flip_vdim=None):
"""Compute and return the the reverse transform from "var2" to "var1".
<lvar_lname> is the local name of "var1".
<rvar_lname> is the local name of "var2".
<indices> is a tuple of the loop indices for "var2" (i.e., "var2"
will show up in the RHS of the transform as "var2(indices)".
<rvar_indices> is a tuple of the loop indices for "var1" (i.e., "var1"
will show up in the RHS of the transform as "var1(rvar_indices)".
If <adjust_hdim> is not None, it should be a string containing the
local name of the "horizontal_loop_begin" variable. This is used to
compute the offset in the horizontal axis index between one and
Expand All @@ -1010,8 +1012,8 @@ def reverse_transform(self, lvar_lname, rvar_lname, indices,
"vertical_interface_dimension").
"""
# Dimension transforms (Indices handled exrernally)
lhs_term = f"{lvar_lname}({','.join(indices)})"
rhs_term = f"{rvar_lname}"
lhs_term = f"{lvar_lname}({','.join(lvar_indices)})"
rhs_term = f"{rvar_lname}({','.join(rvar_indices)})"

if self.has_kind_transforms:
kind = self.__kind_transforms[0]
Expand Down
55 changes: 15 additions & 40 deletions test/unit_tests/test_var_transforms.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,12 @@ def test_valid_dim_transforms(self):
self.assertIsInstance(compat, VarCompatObj,
msg=self.__inst_emsg.format(type(compat)))
rindices = ("hind", "vind")
fwd_stmt = compat.forward_transform(v2_lname, v1_lname, rindices,
adjust_hdim=None, flip_vdim=None)
lindices = rindices
fwd_stmt = compat.forward_transform(v2_lname, v1_lname, rindices, lindices)
ind_str = ','.join(rindices)
expected = f"{v2_lname}({ind_str}) = {v1_lname}({ind_str})"
self.assertEqual(fwd_stmt, expected)
rev_stmt = compat.reverse_transform(v1_lname, v2_lname, rindices,
adjust_hdim=None, flip_vdim=None)
rev_stmt = compat.reverse_transform(v1_lname, v2_lname, rindices, lindices)
expected = f"{v1_lname}({ind_str}) = {v2_lname}({ind_str})"
self.assertEqual(rev_stmt, expected)

Expand All @@ -298,17 +297,13 @@ def test_valid_dim_transforms(self):
msg=self.__inst_emsg.format(type(compat)))
rindices = ("hind", "vind")
lindices = ("hind-col_start+1", "vind")
fwd_stmt = compat.forward_transform(v2_lname, v1_lname, rindices,
adjust_hdim='col_start',
flip_vdim=None)
fwd_stmt = compat.forward_transform(v2_lname, v1_lname, rindices, lindices)
lind_str = ','.join(lindices)
rind_str = ','.join(rindices)
expected = f"{v2_lname}({lind_str}) = {v1_lname}({rind_str})"
self.assertEqual(fwd_stmt, expected)
rev_stmt = compat.reverse_transform(v1_lname, v2_lname, rindices,
adjust_hdim='col_start',
flip_vdim=None)
lindices = ("hind+col_start-1", "vind")
rev_stmt = compat.reverse_transform(v1_lname, v2_lname, rindices, lindices)
lind_str = ','.join(lindices)
expected = f"{v1_lname}({lind_str}) = {v2_lname}({rind_str})"
self.assertEqual(rev_stmt, expected)
Expand All @@ -320,17 +315,13 @@ def test_valid_dim_transforms(self):
msg=self.__inst_emsg.format(type(compat)))
rindices = ("hind", "vind")
lindices = ("hind-col_start+1", "pver-vind+1")
fwd_stmt = compat.forward_transform(v2_lname, v1_lname, rindices,
adjust_hdim='col_start',
flip_vdim='pver')
fwd_stmt = compat.forward_transform(v2_lname, v1_lname, rindices, lindices)
lind_str = ','.join(lindices)
rind_str = ','.join(rindices)
expected = f"{v2_lname}({lind_str}) = {v1_lname}({rind_str})"
self.assertEqual(fwd_stmt, expected)
rev_stmt = compat.reverse_transform(v1_lname, v2_lname, rindices,
adjust_hdim='col_start',
flip_vdim='pver')
lindices = ("hind+col_start-1", "pver-vind+1")
rev_stmt = compat.reverse_transform(v1_lname, v2_lname, rindices, lindices)
lind_str = ','.join(lindices)
expected = f"{v1_lname}({lind_str}) = {v2_lname}({rind_str})"
self.assertEqual(rev_stmt, expected)
Expand All @@ -342,17 +333,13 @@ def test_valid_dim_transforms(self):
rindices = ("hind", "vind")
lindices = ("hind-col_start+1", "vind")
conv = f"273.15_{real_array1.get_prop_value('kind')}"
fwd_stmt = compat.forward_transform(v3_lname, v1_lname, rindices,
adjust_hdim='col_start',
flip_vdim=None)
fwd_stmt = compat.forward_transform(v3_lname, v1_lname, rindices, lindices)
lind_str = ','.join(lindices)
rind_str = ','.join(rindices)
expected = f"{v3_lname}({lind_str}) = {v1_lname}({rind_str})+{conv}"
self.assertEqual(fwd_stmt, expected)
rev_stmt = compat.reverse_transform(v1_lname, v3_lname, rindices,
adjust_hdim='col_start',
flip_vdim=None)
lindices = ("hind+col_start-1", "vind")
rev_stmt = compat.reverse_transform(v1_lname, v3_lname, rindices, lindices)
lind_str = ','.join(lindices)
conv = f"273.15_{real_array2.get_prop_value('kind')}"
expected = f"{v1_lname}({lind_str}) = {v3_lname}({rind_str})-{conv}"
Expand All @@ -364,18 +351,14 @@ def test_valid_dim_transforms(self):
msg=self.__inst_emsg.format(type(compat)))
rindices = ("hind", "vind")
lindices = ("hind", "vind")
fwd_stmt = compat.forward_transform(v4_lname, v3_lname, rindices,
adjust_hdim='col_start',
flip_vdim=None)
fwd_stmt = compat.forward_transform(v4_lname, v3_lname, rindices, lindices)
lind_str = ','.join(lindices)
rind_str = ','.join(rindices)
rkind = real_array3.get_prop_value('kind')
expected = f"{v4_lname}({lind_str}) = real({v3_lname}({rind_str}), {rkind})"
self.assertEqual(fwd_stmt, expected)
rev_stmt = compat.reverse_transform(v3_lname, v4_lname, rindices,
adjust_hdim='col_start',
flip_vdim=None)
lindices = ("hind", "vind")
rev_stmt = compat.reverse_transform(v3_lname, v4_lname, rindices, lindices)
lind_str = ','.join(lindices)
rkind = real_array4.get_prop_value('kind')
expected = f"{v3_lname}({lind_str}) = real({v4_lname}({rind_str}), {rkind})"
Expand All @@ -389,17 +372,13 @@ def test_valid_dim_transforms(self):
lindices = ("hind-col_start+1", "vind")
rkind = real_array4.get_prop_value('kind')
conv = f"273.15_{rkind}"
fwd_stmt = compat.forward_transform(v2_lname, v1_lname, rindices,
adjust_hdim='col_start',
flip_vdim=None)
fwd_stmt = compat.forward_transform(v2_lname, v1_lname, rindices, lindices)
lind_str = ','.join(lindices)
rind_str = ','.join(rindices)
expected = f"{v2_lname}({lind_str}) = real({v1_lname}({rind_str}), {rkind})+{conv}"
self.assertEqual(fwd_stmt, expected)
rev_stmt = compat.reverse_transform(v1_lname, v2_lname, rindices,
adjust_hdim='col_start',
flip_vdim=None)
lindices = ("hind+col_start-1", "vind")
rev_stmt = compat.reverse_transform(v1_lname, v2_lname, rindices, lindices)
lind_str = ','.join(lindices)
rkind = real_array1.get_prop_value('kind')
conv = f"273.15_{rkind}"
Expand All @@ -413,20 +392,16 @@ def test_valid_dim_transforms(self):
msg=self.__inst_emsg.format(type(compat)))
rindices = ("hind", "vind")
lindices = ("pver-vind+1", "hind-col_start+1")
fwd_stmt = compat.forward_transform(v4_lname, v5_lname, rindices,
adjust_hdim='col_start',
flip_vdim='pver')
fwd_stmt = compat.forward_transform(v4_lname, v5_lname, rindices, lindices)
lind_str = ','.join(lindices)
rind_str = ','.join(rindices)
rkind = real_array3.get_prop_value('kind')
expected = f"{v4_lname}({lind_str}) = {v5_lname}({rind_str})"
self.assertEqual(fwd_stmt, expected)
rindices = ("vind", "hind")
rind_str = ','.join(rindices)
rev_stmt = compat.reverse_transform(v5_lname, v4_lname, rindices,
adjust_hdim='col_start',
flip_vdim='pver')
lindices = ("hind+col_start-1", "pver-vind+1")
rev_stmt = compat.reverse_transform(v5_lname, v4_lname, rindices, lindices)
lind_str = ','.join(lindices)
rkind = real_array4.get_prop_value('kind')
expected = f"{v5_lname}({lind_str}) = {v4_lname}({rind_str})"
Expand Down

0 comments on commit e6ffbd5

Please sign in to comment.