Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix/time window exclusions #8

Merged
merged 26 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b99f78b
minor formatting
mhuen Sep 26, 2022
ea7c70f
Fix bug: PDF is now renormalized at once at a DOM and not individuall…
mhuen Sep 26, 2022
d809a25
Fix bug: PDF is now renormalized at once at a DOM and not individuall…
mhuen Sep 26, 2022
de97faa
Fix bug: PDF is now renormalized at once at a DOM and not individuall…
mhuen Sep 27, 2022
a5408f6
Update version number to indicate incompatibility with PDF re-normali…
mhuen Sep 28, 2022
345f8b6
Update version number to indicate incompatibility with PDF re-normali…
mhuen Sep 28, 2022
dbb5a14
Properly re-normalize DOM charge PDF
mhuen Sep 28, 2022
d939d8d
Reco viisualization: allow to pass additional kwargs
mhuen Nov 4, 2022
4602aeb
Reco viisualization: allow to pass additional kwargs
mhuen Nov 4, 2022
2258e56
Reco viisualization: allow to pass additional kwargs
mhuen Nov 4, 2022
ee64292
Reco viisualization: allow to pass additional kwargs
mhuen Nov 4, 2022
8b74be3
Reco viisualization: allow to pass additional kwargs
mhuen Nov 4, 2022
2cbfe1c
Reco viisualization: allow to pass additional kwargs
mhuen Nov 4, 2022
d56a0d9
Reco viisualization: allow to pass additional kwargs
mhuen Nov 4, 2022
091a14d
Reco viisualization: allow to pass additional kwargs
mhuen Nov 4, 2022
257a1c2
Reco viisualization: allow to pass additional kwargs
mhuen Nov 4, 2022
3b5ae79
Reco viisualization: allow to pass additional kwargs
mhuen Nov 4, 2022
6578259
Reco viisualization: allow to pass additional kwargs
mhuen Nov 4, 2022
2392f33
Reco viisualization: allow to pass additional kwargs
mhuen Nov 11, 2022
fcc952d
Reco viisualization: allow to pass additional kwargs
mhuen Nov 11, 2022
0a6b773
Merge branch 'master' into BugFix/TimeWindowExclusions
mhuen Nov 16, 2022
b644e78
merge
mhuen May 29, 2024
e59409b
Update test data to accomodate change in exclusion handling
mhuen May 29, 2024
9201ca4
Add version check when loading component
mhuen May 29, 2024
e6d5b16
Merge branch 'VersionControl' into BugFix/TimeWindowExclusions
mhuen May 29, 2024
2d430db
Add incompatibility check
mhuen May 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,18 @@ You can also manually run the pre-commit on single files or on all files via:
If you need to commit something even though there are errors (this should not have to be done!), then you can add the flag `--no-verify` to the `git commit` command. This will bypass the pre-commit hooks.

Additional information is provided here: https://pre-commit.com/


## Non-backward compatible changes

New contributions to this repository should aim to maintain backwards compatibility, such that
models trained with earlier version of the software may still be run in later software versions.
However, this is not always possible. In such cases where breaking changes are required, these
should be documented in the `__version_compatibility__` dictionary in the `egenerator.__about__`
file. When loading saved components from disk, this dictionary is utilized to verify compatibility
of a previously saved model and the current software version. Breaking changes may either be of
type `global` if they affect all components and trained models of the software, or of type
`local` in case only certain components are affected. In the latter case, a list of
`affected_components` must be provided in the corresponding dictionary entry.
This list contains a list of class strings of each affected component in the event-generator
software.
44 changes: 42 additions & 2 deletions egenerator/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
__url__ = "https://github.com/icecube/event-generator"

__version_major__ = 1
__version_minor__ = 0
__version_patch__ = 3
__version_minor__ = 1
__version_patch__ = 0
__version_info__ = "-dev"

__version__ = "{}.{}.{}{}".format(
Expand All @@ -14,3 +14,43 @@
__version_patch__,
__version_info__,
)

# A dictionary of changes that are not backwards compatible
# with previous versions. The keys are the versions that
# contain the breaking changes and the values contain
# information on the changes that were made.
# Mandatory keys are:
# "type": "global" or "local"
# "global" means that the change affects all components
# "local" means that the change affects only specific components
# and the key "affected_components" must be present.
#
# Example:
# __version_compatibility__ = {
# "1.0.0": {
# "Description": "Description of the breaking change",
# "type": "global",
# },
# "1.0.1": {
# "Description": "Description of the breaking change",
# "type": "local",
# "affected_components": ["class_string1", "class_string2"],
# },
# }
__version_compatibility__ = {
"1.1.0": {
"Description": (
"Bugfix: Fixed a bug in the re-normalization for time "
"window exclusions. The bug re-normalized the individual "
"mixture model components instead of the whole mixture. "
"This bug thus lead to changes in the shape of the pulse "
"arrival time PDF when exclusions were used. This is now "
"fixed to instead re-normalize the whole mixture. "
"Older models will have compensated for this effect if "
"already trained as a mixture of multiple sources. Thus, "
"introducing this bugfix will lead to incompatibilities "
"with older models."
),
"type": "global",
},
}
6 changes: 4 additions & 2 deletions egenerator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from .__about__ import (
__version_compatibility__,
__version_info__,
__version_major__,
__version_minor__,
__version_patch__,
__version_info__,
__version__,
__description__,
__url__,
)

__all__ = [
"__version_compatibility__",
"__version_info__",
"__version_major__",
"__version_minor__",
"__version_patch__",
"__version_info__",
"__version__",
"__description__",
"__url__",
Expand Down
34 changes: 31 additions & 3 deletions egenerator/ic3/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,27 @@ def __init__(self, context):
self.AddParameter(
"n_doms_y", "Number of DOMs to plot along y-axis.", 5
)
self.AddParameter(
"dom_pdf_kwargs",
"Additional keyword arguments passed on to `plot_dom_pdf`.",
{},
)
self.AddParameter(
"dom_cdf_kwargs",
"Additional keyword arguments passed on to `plot_dom_cdf`.",
{},
)
self.AddParameter(
"pdf_file_template",
"The file template name to which the PDF will be saved to",
"dom_pdf_{run_id:06d}_{event_id:06d}.png",
)
self.AddParameter(
"cdf_file_template",
"The file template name to which the CDF will be saved to",
"dom_cdf_{run_id:06d}_{event_id:06d}.png",
)
self.AddParameter("add_event_header", "Add event information.", True)

def Configure(self):
"""Configures Module and loads model from file."""
Expand All @@ -92,6 +113,11 @@ def Configure(self):
self.num_threads = self.GetParameter("num_threads")
self.n_doms_x = self.GetParameter("n_doms_x")
self.n_doms_y = self.GetParameter("n_doms_y")
self.pdf_file_template = self.GetParameter("pdf_file_template")
self.cdf_file_template = self.GetParameter("cdf_file_template")
self.dom_pdf_kwargs = self.GetParameter("dom_pdf_kwargs")
self.dom_cdf_kwargs = self.GetParameter("dom_cdf_kwargs")
self.add_event_header = self.GetParameter("add_event_header")

if isinstance(self.model_names, str):
self.model_names = [self.model_names]
Expand Down Expand Up @@ -158,8 +184,10 @@ def Configure(self):
output_dir=self.output_dir,
n_doms_x=self.n_doms_x,
n_doms_y=self.n_doms_y,
pdf_file_template="dom_pdf_{run_id:06d}_{event_id:06d}.png",
cdf_file_template="dom_cdf_{run_id:06d}_{event_id:06d}.png",
pdf_file_template=self.pdf_file_template,
cdf_file_template=self.cdf_file_template,
dom_pdf_kwargs=self.dom_pdf_kwargs,
dom_cdf_kwargs=self.dom_cdf_kwargs,
)

def Physics(self, frame):
Expand All @@ -176,7 +204,7 @@ def Physics(self, frame):
assert n == 1, "Currently only 1-event at a time is supported"

# collect event meta data
if "I3EventHeader" in frame:
if "I3EventHeader" in frame and self.add_event_header:
header = frame["I3EventHeader"]
event_header = {
"run_id": header.run_id,
Expand Down
21 changes: 12 additions & 9 deletions egenerator/loss/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ def normalized_dom_charge_pdf(
"""

# underneath 5e-5 the log_negative_binomial function becomes unstable
eps = 5e-5
eps = 1e-7
dtype = getattr(
tf, self.configuration.config["config"]["float_precision"]
)
Expand All @@ -1223,13 +1223,6 @@ def normalized_dom_charge_pdf(
hits_true = tf.squeeze(data_batch_dict["x_dom_charge"], axis=-1)
hits_pred = tf.squeeze(result_tensors["dom_charges"], axis=-1)

# shape: [n_batch, 1, 1]
event_total = tf.reduce_sum(hits_pred, axis=[1, 2], keepdims=True)

# shape: [n_batch, 86, 60]
dom_pdf = hits_pred / event_total
llh_dom = hits_true * tf.math.log(dom_pdf + eps)

# throw error if this is being used with time window exclusions
# one needs to calculate cumulative pdf from exclusion window and
# scale up the pulse pdf by this factor
Expand All @@ -1242,6 +1235,8 @@ def normalized_dom_charge_pdf(
), "Model must deal with time exclusions!"

# mask out dom exclusions
# Note that this needs to be done prior to computing `event_total`
# such that the PDF is properly normalized over active DOMs
if (
"x_dom_exclusions" in tensors.names
and tensors.list[tensors.get_index("x_dom_exclusions")].exists
Expand All @@ -1250,7 +1245,15 @@ def normalized_dom_charge_pdf(
tf.squeeze(data_batch_dict["x_dom_exclusions"], axis=-1),
dtype=dtype,
)
llh_dom = llh_dom * mask_valid
hits_true = hits_true * mask_valid
hits_pred = hits_pred * mask_valid

# shape: [n_batch, 1, 1]
event_total = tf.reduce_sum(hits_pred, axis=[1, 2], keepdims=True)

# shape: [n_batch, 86, 60]
dom_pdf = hits_pred / event_total
llh_dom = hits_true * tf.math.log(dom_pdf + eps)

if sort_loss_terms:
loss_terms = [
Expand Down
54 changes: 54 additions & 0 deletions egenerator/manager/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,60 @@ def load(self, dir_path, modified_sub_components={}, **kwargs):
)
)

# check if the saved component was made with a newer version
# than the one currently used
if version_control.is_newer_version(
version_base=egenerator.__version__,
version_test=config_dict["event_generator_version"],
):
msg = (
"The saved component was created with a newer version of "
"Event-Generator. Make sure the component is still "
"compatible with this version!"
)
self._logger.error(msg)

# go through compatibility changes since the saved version
for version, info in egenerator.__version_compatibility__.items():
is_newer = version_control.is_newer_version(
version_base=config_dict["event_generator_version"],
version_test=version,
)

# check if this version is compatible
if is_newer:
if info["type"] == "global":
msg = "A global change was made in "
msg += "Event-Generator version {!r} leading to "
msg += "incompatibility with the version of this model {!r}."
msg = msg.format(
version,
config_dict["event_generator_version"],
)
self._logger.fatal(msg)
raise ValueError(msg)
elif info["type"] == "local":
if (
self.configuration.class_string
in info["affected_components"]
):
msg = "A local change was made to the component {!r} in "
msg += "Event-Generator version {!r} leading to "
msg += "incompatibility with the version of this model {!r}."
msg = msg.format(
self.configuration.class_string,
version,
config_dict["event_generator_version"],
)
self._logger.fatal(msg)
raise ValueError(msg)
else:
raise KeyError(
"Unknown type of compatibility change: {}.".format(
info["type"]
)
)

# check if this is the correct class
if (
self.configuration.class_string
Expand Down
Loading
Loading