Skip to content

Commit

Permalink
Merge pull request #192 from pynbody/patch-trees
Browse files Browse the repository at this point in the history
Add facility to patch trees where halos (esp subhalos) get 'lost' by halo finder
  • Loading branch information
apontzen authored Jul 26, 2022
2 parents 9b23b06 + 31e0c06 commit 4e3a789
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 4 deletions.
3 changes: 1 addition & 2 deletions tangos/live_calculation/builtin_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@

@BuiltinFunction.register
def match(source_halos, target):
timestep = consistent_collection.ConsistentCollection(source_halos).timestep
if target is None:
results = [None]*len(source_halos)
else:
from ... import relation_finding
if not isinstance(target, core.Base):
target = tangos.get_item(target, core.Session.object_session(timestep))
target = tangos.get_item(target, core.Session.object_session(source_halos[0]))
results = relation_finding.MultiSourceMultiHopStrategy(source_halos, target).all()
# if following assert fails, it might be duplicate links in the database which the
# current MultiSourceMultiHop implementation cannot de-duplicate:
Expand Down
4 changes: 2 additions & 2 deletions tangos/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,6 @@ def add_tools(cls, subparse):
c.add_tools(subparse)

from . import (add_simulation, ahf_merger_tree_importer,
consistent_trees_importer, crosslink, property_deleter,
property_importer, property_writer,
consistent_trees_importer, crosslink, merger_tree_patcher,
property_deleter, property_importer, property_writer,
subfind_merger_tree_importer)
218 changes: 218 additions & 0 deletions tangos/tools/merger_tree_patcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
import argparse

import numpy as np
from sqlalchemy import and_, or_, orm

from .. import core, query
from ..core import get_or_create_dictionary_item
from ..log import logger
from . import GenericTangosTool


class MergerTreePruner(GenericTangosTool):
tool_name = 'prune-trees'
tool_description = '[Experimental] Remove questionable links from merger trees'
parallel = False

@classmethod
def add_parser_arguments(self, parser):
parser.add_argument('--sims', '--for', action='store', nargs='*',
metavar='simulation_name',
help='Specify a simulation (or multiple simulations) to run on')

parser.add_argument('--min-weight-progenitor', action='store', type=float, default=0,
help='The minimum weight that should be associated with a progenitor for retention. '
'To retain minor progenitors, this should normally be set to zero.')

parser.add_argument('--min-weight-descendant', action='store', type=float, default=0.05,
help='The minimum weight that should be associated with a descendant for retention. '
'Weights close to zero are normally suspect, since they indicate the halo'
'transfers only a small fraction of its mass into the descendant.')

parser.add_argument('--max-NDM-progenitor', action='store', type=float, default=2.0,
help='The maximum relative increase in number of dark matter particles in the progenitor'
'for retention. For example, if this is set to 2.0 (default), the link is dropped'
'if there are more than twice as many DM particles in the progenitor as in the'
'descendant, since this suggests a mis-identification.')

parser.add_argument('--min-NDM-progenitor', action='store', type=float, default=0.33,
help='The maximum relative number of dark matter particles in the progenitor'
'for retention. For example, if this is set to 0.33 (default), the link is dropped'
'if there are less than a third as many DM particles in the progenitor as in the'
'descendant, since this suggests a mis-identification.')


def process_options(self, options: argparse.ArgumentParser):
self.options = options

def process_timestep(self, early_timestep: core.TimeStep):
session = core.get_default_session()

late_timestep = early_timestep.next
if late_timestep is None:
return


descendant_alias = orm.aliased(core.halo.SimulationObjectBase, name="descendant")
progenitor_alias = orm.aliased(core.halo.SimulationObjectBase, name="progenitor")
progenitor_halolink = orm.aliased(core.HaloLink, name="descendant_to_progenitor_link")
descendant_halolink = orm.aliased(core.HaloLink, name="progenitor_to_descendant_link")

progenitor_links = session.query(progenitor_halolink.id, descendant_halolink.id)\
.join(descendant_alias, and_(progenitor_halolink.halo_from_id == descendant_alias.id, descendant_alias.timestep_id == late_timestep.id))\
.join(progenitor_alias, and_(progenitor_halolink.halo_to_id == progenitor_alias.id, progenitor_alias.timestep_id == early_timestep.id))\
.outerjoin(descendant_halolink, and_(progenitor_halolink.halo_to_id == descendant_halolink.halo_from_id,
progenitor_halolink.halo_from_id == descendant_halolink.halo_to_id))\
.filter(or_(progenitor_halolink.weight < self.options.min_weight_progenitor,
descendant_halolink.weight < self.options.min_weight_descendant,
progenitor_alias.NDM > self.options.max_NDM_progenitor * descendant_alias.NDM,
progenitor_alias.NDM < self.options.min_NDM_progenitor * descendant_alias.NDM))


delete_links = []
progenitors_and_descendants = progenitor_links.all()
for p,d in progenitors_and_descendants:
delete_links.append(p)
if d is not None:
delete_links.append(d)

row_count = session.query(core.HaloLink).filter(core.HaloLink.id.in_(delete_links)).delete()
session.commit()
if row_count>0:
logger.info(f"Deleted {row_count} links between timesteps {early_timestep} and {late_timestep}")

def run_calculation_loop(self):
simulations = core.sim_query_from_name_list(self.options.sims)
logger.info("This tool is experimental. Use at your own risk!")
for simulation in simulations:
logger.info("Processing %s", simulation)

for ts in simulation.timesteps[::-1]:
self.process_timestep(ts)


class MergerTreePatcher(GenericTangosTool):
tool_name = 'patch-trees'
tool_description = "[Experimental] Attempt to patch up merger trees"
parallel = False

@classmethod
def add_parser_arguments(self, parser):
parser.add_argument('--sims', '--for', action='store', nargs='*',
metavar='simulation_name',
help='Specify a simulation (or multiple simulations) to run on')
parser.add_argument('--include-only', action='store', type=str, nargs='*', default=['NDM()>100'],
help="Specify a filter that describes which objects the calculation should be executed for. "
"Multiple filters may be specified, in which case they must all evaluate to true for "
"the object to be included.")
parser.add_argument('--relative-match', action='store', type=str, default=['NStar()','NDM()'],
help='Variables to minimise when finding a halo match. If multiple expressions'
'are provided, their relative differences are added in quadrature')
parser.add_argument('--max-for-match', action='store', type=float, default=0.1,
help='The maximum fractional offset in values of variables specified in relative-match'
'that qualifies as a match')
parser.add_argument('--position-variable', action='store', type=str, default='shrink_center/a()',
help='The name of the variable which represents the 3d position')
parser.add_argument('--max-position-offset', action='store', type=float, default=100.0,
help='The maximum value of the position offset that is allowed in comoving units (i.e.'
'will be divided by scalefactor)')
parser.add_argument('--max-timesteps', action='store', type=int, default=6,
help='The maximum number of steps back in time to scan for a match')


def process_options(self, options):
self.options = options


def create_link(self, halo, matched):
session = core.get_default_session()
core.creator.get_creator(session) # just to prevent any commits happening later when not expected

d_id = get_or_create_dictionary_item(session, "patch-trees")
logger.info(f"Creating a link from {halo} to {matched}")

next = halo

ts = halo.timestep.previous

while ts!=matched.timestep:
phantom = core.halo.PhantomHalo(ts, halo.halo_number, 0)
# TODO: shouldn't assume this halo number is "free" for use, probably phantom halo IDs should be
# allocated sequentially

session.add(phantom)
session.add(core.HaloLink(next, phantom, d_id, 1.0))
session.add(core.HaloLink(phantom, next, d_id, 1.0))

next = phantom
ts = ts.previous

session.add(core.HaloLink(next, matched, d_id, 1.0))
session.add(core.HaloLink(matched, next, d_id, 1.0))
session.commit()

_candidates_cache = {}

def _get_candidate_id_position_values(self, ts_candidate):
if ts_candidate.id not in self._candidates_cache:
self._candidates_cache[ts_candidate.id] = ts_candidate.calculate_all("dbid()",
self.options.position_variable,
*self.options.relative_match)
return self._candidates_cache[ts_candidate.id]

def fixup(self, halo):
success = False
ts_candidate = halo.timestep.previous
if ts_candidate is None:
return

source_values = [halo.calculate(m) for m in self.options.relative_match]
source_position = halo.calculate(self.options.position_variable)

for i in range(self.options.max_timesteps):

dbid, candidate_positions, *candidate_values = self._get_candidate_id_position_values(ts_candidate)

offsets = [cv-sv for sv, cv in zip(source_values, candidate_values)]
rel_offsets = [(o/sv)**2 for o, sv in zip(offsets, source_values)]
rel_offsets = np.sqrt(sum(rel_offsets))

pos_offset = np.linalg.norm(source_position - candidate_positions, axis=1)
mask = pos_offset < self.options.max_position_offset*(1.+ts_candidate.redshift)

dbid = dbid[mask]
rel_offsets = rel_offsets[mask]

if len(rel_offsets)>0:
if np.min(rel_offsets) < self.options.max_for_match:
match_dbid = dbid[np.argmin(rel_offsets)]
match = query.get_halo(match_dbid)
self.create_link(halo,match)
success = True
break

ts_candidate = ts_candidate.previous
if ts_candidate is None:
break

if not success:
logger.info(f"No luck in finding a match for {halo}")


def run_calculation_loop(self):
simulations = core.sim_query_from_name_list(self.options.sims)
logger.info("This tool is experimental. Use at your own risk!")

for simulation in simulations:
logger.info("Processing %s",simulation)
logger.info(self.options.include_only)
include_criterion = "(!has_link(earlier(1))) & " + (" & ".join(self.options.include_only))

logger.info("Query for missing links is %s",include_criterion)
for ts in simulation.timesteps[::-1]:
dbids, flag = ts.calculate_all("dbid()", include_criterion)
dbids = dbids[flag]
logger.info(f"Timestep {ts} has {len(dbids)} broken links to consider")
for dbid in dbids:
obj = query.get_halo(dbid)
self.fixup(obj)

0 comments on commit 4e3a789

Please sign in to comment.