Skip to content

Commit

Permalink
fix: make caching optional (#260)
Browse files Browse the repository at this point in the history
* ci: always upload docs artifacts
  • Loading branch information
redeboer authored Apr 1, 2021
1 parent 2c68eae commit f76663e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ jobs:
EXECUTE_NB: YES
run: make html
- uses: actions/upload-artifact@v2
if: ${{ failure() }}
if: ${{ always() }}
with:
name: html
path: docs/_build/html
Expand Down
47 changes: 30 additions & 17 deletions src/tensorwaves/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,36 @@ class UnbinnedNLL(Estimator): # pylint: disable=too-many-instance-attributes
"""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
model: Model,
dataset: DataSample,
phsp_dataset: DataSample,
phsp_volume: float = 1.0,
backend: Union[str, tuple, dict] = "numpy",
use_caching: bool = False,
fixed_parameters: Optional[Dict[str, Union[float, complex]]] = None,
) -> None:
fixed_data_inputs = {k: np.array(v) for k, v in dataset.items()}
fixed_phsp_inputs = {k: np.array(v) for k, v in phsp_dataset.items()}
if fixed_parameters:
fixed_data_inputs.update(fixed_parameters)
fixed_phsp_inputs.update(fixed_parameters)
self.__data_function = LambdifiedFunction(
model.performance_optimize(fix_inputs=fixed_data_inputs),
backend,
)
self.__phsp_function = LambdifiedFunction(
model.performance_optimize(fix_inputs=fixed_phsp_inputs),
backend,
)
self.__use_caching = use_caching
self.__dataset = {k: np.array(v) for k, v in dataset.items()}
self.__phsp_dataset = {k: np.array(v) for k, v in phsp_dataset.items()}
if self.__use_caching:
fixed_data_inputs = dict(self.__dataset)
fixed_phsp_inputs = dict(self.__phsp_dataset)
if fixed_parameters:
fixed_data_inputs.update(fixed_parameters)
fixed_phsp_inputs.update(fixed_parameters)
self.__data_function = LambdifiedFunction(
model.performance_optimize(fix_inputs=fixed_data_inputs),
backend,
)
self.__phsp_function = LambdifiedFunction(
model.performance_optimize(fix_inputs=fixed_phsp_inputs),
backend,
)
else:
self.__data_function = LambdifiedFunction(model, backend)
self.__phsp_function = self.__data_function
self.__gradient = gradient_creator(self.__call__, backend)
backend_modules = get_backend_modules(backend)

Expand All @@ -91,10 +99,15 @@ def __call__(
self, parameters: Mapping[str, Union[float, complex]]
) -> float:
self.__data_function.update_parameters(parameters)
self.__phsp_function.update_parameters(parameters)
bare_intensities = self.__data_function({})
if self.__use_caching:
self.__phsp_function.update_parameters(parameters)
bare_intensities = self.__data_function({})
phsp_intensities = self.__phsp_function({})
else:
bare_intensities = self.__data_function(self.__dataset)
phsp_intensities = self.__phsp_function(self.__phsp_dataset)
normalization_factor = 1.0 / (
self.__phsp_volume * self.__mean_function(self.__phsp_function({}))
self.__phsp_volume * self.__mean_function(phsp_intensities)
)
likelihoods = normalization_factor * bare_intensities
return -self.__sum_function(self.__log_function(likelihoods))
Expand Down

0 comments on commit f76663e

Please sign in to comment.