Skip to content

Commit

Permalink
Merge pull request #128 from TomWagg/gala
Browse files Browse the repository at this point in the history
Use latest gala, add a catch for slicing half-loaded files
  • Loading branch information
TomWagg authored Aug 27, 2024
2 parents 96595f8 + 6573ad9 commit e3dd5a2
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 54 deletions.
17 changes: 16 additions & 1 deletion cogsworth/pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,22 @@ def __getitem__(self, ind):
# convert any Pandas Series to numpy arrays
ind = ind.values if isinstance(ind, pd.Series) else ind

# if the population is associated with a file, make sure it's entirely loaded before slicing
if self._file is not None:
parts = ["initial_binaries", "bpp", "initial_galaxy", "orbits"]
vars = [self._initial_binaries, self._bpp, self._initial_galaxy, self._orbits]
masks = {f"has_{p}": False for p in parts}
with h5.File(self._file, "r") as f:
for p in parts:
masks[f"has_{p}"] = p in f
missing_parts = [p for i, p in enumerate(parts) if (masks[f"has_{p}"] and vars[i] is None)]

if len(missing_parts) > 0:
raise ValueError(("This population was loaded from a file but you haven't loaded all parts "
"yet. You need to do this before indexing it. The missing parts are: "
f"{missing_parts}. You either need to access each of these variables or "
"reload the entire population using all parts."))

# ensure indexing with the right type
ALLOWED_TYPES = (int, slice, list, np.ndarray, tuple)
if not isinstance(ind, ALLOWED_TYPES):
Expand Down Expand Up @@ -262,7 +278,6 @@ def __getitem__(self, ind):
sampling_params=self.sampling_params,
store_entire_orbits=self.store_entire_orbits)
new_pop.n_binaries_match = new_pop.n_binaries
new_pop._file = self._file

# proxy for checking whether sampling has been done
if self._mass_binaries is not None:
Expand Down
19 changes: 19 additions & 0 deletions cogsworth/tests/test_pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import cogsworth.pop as pop
import cogsworth.sfh as sfh
import cogsworth.observables as obs
import h5py as h5
import os
import pytest

Expand Down Expand Up @@ -421,6 +422,24 @@ def test_indexing_mixed_types(self):
it_worked = False
self.assertFalse(it_worked)

def test_indexing_loaded_pop(self):
"""Test indexing fails when trying to slice a half-loaded population"""
p = pop.Population(10)
p.perform_stellar_evolution()

with h5.File("DUMMY.h5", "w") as f:
f.create_dataset("orbits", data=[])
p._file = "DUMMY.h5"
p._orbits = None

it_worked = True
try:
p[:5]
except ValueError:
it_worked = False
self.assertFalse(it_worked)
os.remove("DUMMY.h5")

def test_evolved_pop(self):
"""Check that the EvolvedPopulation class works as it should"""
p = pop.Population(10)
Expand Down
97 changes: 45 additions & 52 deletions docs/case_studies/binaries_and_potentials.ipynb

Large diffs are not rendered by default.

Binary file modified docs/case_studies/plots/bin_pot_effects.pdf
Binary file not shown.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ install_requires =
astropy >= 5.0
scipy >= 1.8
pandas >= 2.1
gala @ git+https://github.com/adrn/gala
gala >= 1.9.1
cosmic-popsynth >= 3.4.16

[options.package_data]
Expand Down

0 comments on commit e3dd5a2

Please sign in to comment.