Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clearer error message when wrong load function is used #102

Merged
merged 6 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cogsworth/pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,9 @@ def load(file_name):

BSE_settings = {}
with h5.File(file_name, "r") as file:
if "numeric_params" not in file.keys():
raise ValueError((f"{file_name} is not a Population file, "
"perhaps you meant to use `cogsworth.sfh.load`?"))
numeric_params = file["numeric_params"][...]

store_entire_orbits = file["numeric_params"].attrs["store_entire_orbits"]
Expand Down
2 changes: 2 additions & 0 deletions cogsworth/sfh.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,8 @@ def load(file_name, key="sfh"):

# load the parameters back in using yaml
with h5.File(file_name, "r") as file:
if key not in file.keys():
raise ValueError((f"Can't find a saved SFH in {file_name} under the key {key}."))
params = yaml.load(file[key].attrs["params"], Loader=yaml.Loader)

# get the current module, get a class using the name, delete it from parameters that will be passed
Expand Down
39 changes: 32 additions & 7 deletions cogsworth/tests/test_pop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import unittest
import cogsworth.pop as pop
import cogsworth.sfh as sfh
import cogsworth.observables as obs
import os
import pytest
Expand All @@ -25,7 +26,7 @@ def test_bad_inputs(self):

def test_io(self):
"""Check that a population can be saved and re-loaded"""
p = pop.Population(2, bcm_timestep_conditions=[['dtp=100000.0']])
p = pop.Population(2, processes=1, bcm_timestep_conditions=[['dtp=100000.0']])
p.create_population()

p.save("testing-pop-io", overwrite=True)
Expand Down Expand Up @@ -55,6 +56,30 @@ def test_io(self):

os.remove("testing-pop-io.h5")

def test_wrong_load_function(self):
"""Check that errors are properly raised when the wrong load function is used"""
g = sfh.Wagg2022(10000)
g.save("test-sfh-for-load")

it_broke = False
try:
pop.load("test-sfh-for-load")
except ValueError:
it_broke = True
os.remove("test-sfh-for-load.h5")
self.assertTrue(it_broke)

p = pop.Population(2, processes=1)
p.create_population()
p.save("test-pop-for-load", overwrite=True)
it_broke = False
try:
sfh.load("test-pop-for-load")
except ValueError:
it_broke = True
os.remove("test-pop-for-load.h5")
self.assertTrue(it_broke)

def test_orbit_storage(self):
"""Test that we can control how orbits are stored"""
p = pop.Population(20, final_kstar1=[13, 14], processes=1, store_entire_orbits=True)
Expand All @@ -72,7 +97,7 @@ def test_orbit_storage(self):

def test_overly_stringent_cutoff(self):
"""Make sure that it crashes if the m1_cutoff is too large to create anything"""
p = pop.Population(10, m1_cutoff=10000)
p = pop.Population(10, processes=1, m1_cutoff=10000)

it_broke = False
try:
Expand All @@ -84,14 +109,14 @@ def test_overly_stringent_cutoff(self):

def test_interface(self):
"""Test the interface of this class with the other modules"""
p = pop.Population(10, final_kstar1=[13, 14], store_entire_orbits=False)
p = pop.Population(10, processes=1, final_kstar1=[13, 14], store_entire_orbits=False)
p.create_population()

# ensure we get something that disrupts to ensure coverage
MAX_REPS = 5
i = 0
while not p.disrupted.any() and i < MAX_REPS:
p = pop.Population(10, final_kstar1=[13, 14])
p = pop.Population(10, processes=1, final_kstar1=[13, 14])
p.create_population()
i += 1
if i == MAX_REPS:
Expand Down Expand Up @@ -123,7 +148,7 @@ def test_interface(self):

def test_getters(self):
"""Test the property getters"""
p = pop.Population(2, store_entire_orbits=False)
p = pop.Population(2, processes=1, store_entire_orbits=False)

# test getters from sampling
p.mass_singles
Expand Down Expand Up @@ -159,7 +184,7 @@ def test_getters(self):

def test_singles_evolution(self):
"""Check everything works well when evolving singles"""
p = pop.Population(2, BSE_settings={"binfrac": 0.0},
p = pop.Population(2, processes=1, BSE_settings={"binfrac": 0.0},
sampling_params={'keep_singles': True, 'total_mass': 100,
'sampling_target': 'total_mass'})
p.create_population(with_timing=False)
Expand All @@ -169,7 +194,7 @@ def test_singles_evolution(self):
def test_singles_bad_input(self):
"""Test what happens when you mess up single stars"""
it_failed = True
p = pop.Population(1, BSE_settings={"binfrac": 0.0},
p = pop.Population(1, processes=1, BSE_settings={"binfrac": 0.0},
sampling_params={'total_mass': 1000, 'sampling_target': 'total_mass'})
try:
p.sample_initial_binaries()
Expand Down
Loading