Skip to content

Commit

Permalink
Merge pull request #129 from Autostronomy/printparams
Browse files Browse the repository at this point in the history
more helpful str and repr for users
  • Loading branch information
ConnorStoneAstro authored Oct 11, 2023
2 parents fd815e3 + 14550ba commit 9a72f08
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 22 deletions.
4 changes: 4 additions & 0 deletions astrophot/models/core_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,10 @@ def parameter_order(self):

def __str__(self):
"""String representation for the model."""
return self.parameters.__str__()

def __repr__(self):
"""Detailed string representation for the model."""
return yaml.dump(self.get_state(), indent=2)

def get_state(self):
Expand Down
2 changes: 1 addition & 1 deletion astrophot/param/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __getitem__(self, key):
for node in self.nodes.values():
if key == node.identity:
return node
raise ValueError(f"Unrecognized key for '{self}': {key}")
raise ValueError(f"Unrecognized key for '{self.name}': {key}")

def __contains__(self, key):
"""Check if a node has a link directly to another node. A check like
Expand Down
23 changes: 19 additions & 4 deletions astrophot/param/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,9 +621,24 @@ def __len__(self):
"""
return self.size


def print_params(self, include_locked=True, include_prof=True, include_id=True):
if self.leaf:
return f"{self.name}" + (f" (id-{self.identity})" if include_id else "") + f": {self.value.detach().cpu().tolist()}" + ("" if self.uncertainty is None else f" +- {self.uncertainty.detach().cpu().tolist()}") + f" [{self.units}]" + ("" if self.limits[0] is None and self.limits[1] is None else f", limits: ({None if self.limits[0] is None else self.limits[0].detach().cpu().tolist()}, {None if self.limits[1] is None else self.limits[1].detach().cpu().tolist()})") + (", cyclic" if self.cyclic else "") + (", locked" if self.locked else "") + (f", prof: {self.prof.detach().cpu().tolist()}" if include_prof and self.prof is not None else "")
elif isinstance(self._value, Parameter_Node):
return self.name + (f" (id-{self.identity})" if include_id else "") + " points to: " + self._value.print_params(include_locked=include_locked, include_prof=include_prof, include_id=include_id)
return self.name + (f" (id-{self.identity}, {('function node, '+self._value.__name__) if isinstance(self._value, FunctionType) else 'branch node'})" if include_id else "") + ":\n"

def __str__(self):
return super().__str__() + " " + ("branch" if self.value is None else str(self.value.detach().cpu().tolist()))
def __repr__(self):
return super().__repr__() + "\nValue: " + ("branch" if self.value is None else str(self.value.detach().cpu().tolist()))
reply = self.print_params(include_locked=True, include_prof=False, include_id=False)
if self.leaf or isinstance(self._value, Parameter_Node):
return reply
reply += "\n".join(node.print_params(include_locked=True, include_prof=False, include_id=False) for node in self.flat(include_locked=True, include_links=False).values())
return reply

def __repr__(self, level = 0, indent = ' '):
reply = indent*level + self.print_params(include_locked=True, include_prof=False, include_id=True)
if self.leaf or isinstance(self._value, Parameter_Node):
return reply
reply += "\n".join(node.__repr__(level = level+1, indent=indent) for node in self.nodes.values())
return reply
37 changes: 22 additions & 15 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,24 @@ def test_AstroPhot_Model(self):

state = model.get_state()

def test_initialize_does_not_recurse(self):
"Test case for error where missing parameter name triggered print that triggered missing parameter name ..."
target = make_basic_sersic()
model = ap.models.AstroPhot_Model(
name="test model",
model_type="sersic galaxy model",
target=target,
)
# Define a function that accesses a parameter that doesn't exist
def calc(params):
return params["A"].value

model["center"].value = calc

with self.assertRaises(ValueError) as context:
model.initialize()
self.assertTrue(str(context.exception) == "Unrecognized key for 'center': A")

def test_basic_model_methods(self):

target = make_basic_sersic()
Expand Down Expand Up @@ -135,10 +153,11 @@ def test_mask(self):


class TestAllModelBasics(unittest.TestCase):
def test_all_model_init(self):
def test_all_model_sample(self):

target = make_basic_sersic()
for model_type in ap.models.Component_Model.List_Model_Names(useable=True):
print(model_type)
MODEL = ap.models.AstroPhot_Model(
name="test model",
model_type=model_type,
Expand All @@ -150,25 +169,13 @@ def test_all_model_init(self):
MODEL[P].value,
f"Model type {model_type} parameter {P} should not be None after initialization",
)
# perhaps add check that uncertainty is not none

def test_all_model_sample(self):

target = make_basic_sersic()
for model_type in ap.models.Component_Model.List_Model_Names(useable=True):
print(model_type)
MODEL = ap.models.AstroPhot_Model(
name="test model",
model_type=model_type,
target=target,
)
MODEL.initialize()
img = MODEL()
self.assertTrue(
torch.all(torch.isfinite(img.data)),
"Model should evaluate a real number for the full image",
)

self.assertIsInstance(str(MODEL), str, "String representation should return string")
self.assertIsInstance(repr(MODEL), str, "Repr should return string")

class TestSersic(unittest.TestCase):
def test_sersic_creation(self):
Expand Down
48 changes: 46 additions & 2 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,52 @@ def test_vector_representation(self):
self.assertTrue(torch.all(vec == PG.vector_values()), "representation should be reversible")
self.assertEqual(PG.vector_values().numel(), 5, "masked values shouldn't be shown")

S = str(PG)
R = repr(PG)

def test_printing(self):

def node_func_sqr(P):
return P["test1"].value**2
P1 = Parameter_Node("test1", value = 0.5, uncertainty = 0.3, limits = (-1, 1), locked = False, prof = 1.)
P2 = Parameter_Node("test2", value = 2., uncertainty = 1., locked = False)
P3 = Parameter_Node("test3", value = [4.,5.], uncertainty = [5.,3.], limits = ((0., 1.), None), locked = False)
P4 = Parameter_Node("test4", value = P2)
P5 = Parameter_Node("test5", value = node_func_sqr, link = (P1,))
P6 = Parameter_Node("test6", value = ((5,6),(7,8)), uncertainty = 0.1 * np.zeros((2,2)), limits = (None, 10*np.ones((2,2))))
PG = Parameter_Node("testgroup", link = (P1, P2, P3, P4, P5, P6))

self.assertEqual(str(PG), """testgroup:
test1: 0.5 +- 0.3 [none], limits: (-1.0, 1.0)
test2: 2.0 +- 1.0 [none]
test3: [4.0, 5.0] +- [5.0, 3.0] [none], limits: ([0.0, 1.0], None)
test6: [[5.0, 6.0], [7.0, 8.0]] +- [[0.0, 0.0], [0.0, 0.0]] [none], limits: (None, [[10.0, 10.0], [10.0, 10.0]])""", "String representation should return specific string")

ref_string = """testgroup (id-140071931416000, branch node):
test1 (id-140071931414752): 0.5 +- 0.3 [none], limits: (-1.0, 1.0)
test2 (id-140071931415376): 2.0 +- 1.0 [none]
test3 (id-140071931415472): [4.0, 5.0] +- [5.0, 3.0] [none], limits: ([0.0, 1.0], None)
test4 (id-140071931414272) points to: test2 (id-140071931415376): 2.0 +- 1.0 [none]
test5 (id-140071931414992, function node, node_func_sqr):
test1 (id-140071931414752): 0.5 +- 0.3 [none], limits: (-1.0, 1.0)
test6 (id-140071931415616): [[5.0, 6.0], [7.0, 8.0]] +- [[0.0, 0.0], [0.0, 0.0]] [none], limits: (None, [[10.0, 10.0], [10.0, 10.0]])"""
# Remove ids since they change every time
while "(id-" in ref_string:
start = ref_string.find("(id-")
end = ref_string.find(")", start)+1
ref_string = ref_string[:start] + ref_string[end:]

repr_string = repr(PG)
# Remove ids since they change every time
count = 0
while "(id-" in repr_string:
start = repr_string.find("(id-")
end = repr_string.find(")", start)+1
repr_string = repr_string[:start] + repr_string[end:]
count += 1
if count > 100:
raise RuntimeError("infinite loop! Something very wrong with parameter repr")
self.assertEqual(repr_string, ref_string, "Repr should return specific string")



if __name__ == "__main__":
unittest.main()

0 comments on commit 9a72f08

Please sign in to comment.