Skip to content

Commit

Permalink
Merge pull request #117 from TomWagg/combine
Browse files Browse the repository at this point in the history
Concatenation for `Population`s and `StarFormationHistory`s
  • Loading branch information
TomWagg authored Jun 27, 2024
2 parents ea19671 + cc4693b commit 85e0300
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 1 deletion.
2 changes: 1 addition & 1 deletion cogsworth/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.0"
__version__ = "1.1.0"
93 changes: 93 additions & 0 deletions cogsworth/pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ def __repr__(self):
def __len__(self):
return self.n_binaries_match

def __add__(self, other):
return concat(self, other)

def __getitem__(self, ind):
# convert any Pandas Series to numpy arrays
ind = ind.values if isinstance(ind, pd.Series) else ind
Expand Down Expand Up @@ -1366,6 +1369,96 @@ def load(file_name, parts=["initial_binaries", "initial_galaxy", "stellar_evolut
return p


def concat(*pops):
"""Concatenate multiple populations into a single population
NOTE: The final population will have the same settings as the first population in the list (but data
from all populations)
Parameters
----------
pops : `list` of :class:`~cogsworth.Population` or :class:`~cogsworth.EvolvedPopulation`
List of populations to concatenate
Returns
-------
total_pop : :class:`~cogsworth.Population` or :class:`~cogsworth.EvolvedPopulation`
The concatenated population
"""
# ensure the input is a list of populations
pops = list(pops)
assert all([isinstance(pop, Population) for pop in pops])

# if there's only one population then just return it
if len(pops) == 1:
return pops[0]
elif len(pops) == 0:
raise ValueError("No populations provided to concatenate")

# get the offset for the bin numbers
bin_num_offset = max(pops[0].bin_nums) + 1

# create a new population to store the final population (just a copy of the first population)
final_pop = pops[0][:]

# loop over the remaining populations
for pop in pops[1:]:
# sum the total numbers of binaries
final_pop.n_binaries += pop.n_binaries

# combine the star formation history distributions
if final_pop._initial_galaxy is not None:
if pop._initial_galaxy is None:
raise ValueError(f"Population {pop} does not have an initial galaxy, but the first does")

final_pop._initial_galaxy += pop._initial_galaxy

if final_pop._initial_binaries is not None:
if pop._initial_binaries is None:
raise ValueError(f"Population {pop} does not have initial binaries, but the first does")
new_initial_binaries = pop._initial_binaries.copy()
new_initial_binaries.index += bin_num_offset
final_pop._initial_binaries = pd.concat([final_pop._initial_binaries, pop._initial_binaries])

# loop through pandas tables that may need to be copied
for table in ["_initC", "_bpp", "_bcm", "_kick_info"]:
# only copy if the table exists in the main population
if getattr(final_pop, table) is not None:
# if the table doesn't exist in the new population then raise an error
if getattr(pop, table) is None:
raise ValueError(f"Population {pop} does not have a {table} table, but the first does")

# otherwise copy the table and update the bin nums
new_table = getattr(pop, table).copy()
new_table.index += bin_num_offset
new_table["bin_num"] += bin_num_offset
setattr(final_pop, table, pd.concat([getattr(final_pop, table), new_table]))

# sum the sampling numbers
final_pop._n_singles_req += pop._n_singles_req
final_pop._n_bin_req += pop._n_bin_req
final_pop._mass_singles += pop._mass_singles
final_pop._mass_binaries += pop._mass_binaries
final_pop.n_binaries_match += pop.n_binaries_match

if final_pop._orbits is not None or pop._orbits is not None:
raise NotImplementedError("Cannot concatenate populations with orbits for now")

bin_num_offset = max(final_pop._bpp["bin_num"]) + 1

# reset auto-calculated class variables
final_pop._bin_nums = None
final_pop._classes = None
final_pop._final_pos = None
final_pop._final_vel = None
final_pop._final_bpp = None
final_pop._disrupted = None
final_pop._escaped = None
final_pop._observables = None

return final_pop


class EvolvedPopulation(Population):
def __init__(self, n_binaries, mass_singles=None, mass_binaries=None, n_singles_req=None, n_bin_req=None,
bpp=None, bcm=None, initC=None, kick_info=None, **pop_kwargs):
Expand Down
37 changes: 37 additions & 0 deletions cogsworth/sfh.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def __len__(self):
def __repr__(self):
return f"<{self.__class__.__name__}, size={len(self)}>"

def __add__(self, other):
return concat(self, other)

def __getitem__(self, ind):
# ensure indexing with the right type
if not isinstance(ind, (int, slice, list, np.ndarray, tuple)):
Expand Down Expand Up @@ -1005,6 +1008,40 @@ def load(file_name, key="sfh"):
return loaded_sfh


def concat(*sfhs):
"""Concatenate multiple StarFormationHistory objects together.
Parameters
----------
*sfhs : `StarFormationHistory`
Any number of StarFormationHistory objects to concatenate
Returns
-------
`StarFormationHistory`
A new StarFormationHistory object that is the concatenation of all the input objects
"""
# check that all the objects are of the same type
sfhs = list(sfhs)
assert all([isinstance(sfh, StarFormationHistory) for sfh in sfhs])
if len(sfhs) == 1:
return sfhs[0]
elif len(sfhs) == 0:
raise ValueError("No objects to concatenate")

# create a new object with the same parameters as the first
new_sfh = sfhs[0][:]

# concatenate the velocity components if they exist
for attr in ["_tau", "_Z", "_which_comp", "_x", "_y", "_z", "v_R", "v_T", "v_z"]:
if hasattr(sfhs[0], attr):
setattr(new_sfh, attr, np.concatenate([getattr(sfh, attr) for sfh in sfhs]))

new_sfh._size = len(new_sfh._tau)

return new_sfh


def simplify_params(params, dont_save=["_tau", "_Z", "_x", "_y", "_z", "_which_comp", "v_R", "v_T", "v_z",
"_df", "_agama_pot", "__citations__"]):
# delete any keys that we don't want to save
Expand Down
84 changes: 84 additions & 0 deletions cogsworth/tests/test_pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,87 @@ def test_galactic_pool(self):
p.perform_stellar_evolution()
p.perform_galactic_evolution()
self.assertTrue(p.pool is None)

def test_concat(self):
"""Check that we can concatenate populations"""
p = pop.Population(10)
q = pop.Population(10)
p.perform_stellar_evolution()
q.perform_stellar_evolution()

r = p + q
self.assertTrue(len(r) == len(p) + len(q))
self.assertTrue(len(r.initC["bin_num"].unique()) == len(r))
self.assertTrue(len(r.initial_galaxy) == len(r))

self.assertTrue(len(pop.concat(p)) == len(p))
self.assertTrue(len(sfh.concat(p.initial_galaxy)) == len(p.initial_galaxy))

def test_concat_wrong_type(self):
"""Check that we can't concatenate with the wrong type"""
p = pop.Population(10)
it_failed = False
try:
p + 1
except AssertionError:
it_failed = True
self.assertTrue(it_failed)

def test_concat_empty(self):
it_failed = False
try:
sfh.concat()
except ValueError:
it_failed = True
self.assertTrue(it_failed)

it_failed = False
try:
pop.concat()
except ValueError:
it_failed = True
self.assertTrue(it_failed)

def test_concat_mismatch(self):
"""Check that we can't concatenate populations with different stuff"""
p = pop.Population(10)
q = pop.Population(10)
p.perform_stellar_evolution()

it_failed = False
try:
p + q
except ValueError:
it_failed = True
self.assertTrue(it_failed)

q.sample_initial_galaxy()
q._initial_binaries = None
it_failed = False
try:
p + q
except ValueError:
it_failed = True
self.assertTrue(it_failed)

q.sample_initial_binaries()
it_failed = False
try:
p + q
except ValueError:
it_failed = True
self.assertTrue(it_failed)

def test_concat_no_orbits(self):
"""Check that we can't concatenate populations without orbits"""
p = pop.Population(10)
q = pop.Population(10)
p.create_population()
q.create_population()

it_failed = False
try:
r = p + q
except NotImplementedError:
it_failed = True
self.assertTrue(it_failed)
5 changes: 5 additions & 0 deletions docs/modules/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ Full changelog

This page tracks all of the changes that have been made to ``cogsworth``. We follow the standard versioning convention of A.B.C, where C is a patch/bugfix, B is a large bugfix or new feature and A is a major new breaking change. B/C are backwards compatible but A changes may be breaking.

1.1.1
=====

- New feature: Concatenate multiple populations together with ``pop.concat`` or simply `+` (see #116)

1.0.0
=====

Expand Down

0 comments on commit 85e0300

Please sign in to comment.