Skip to content

Commit

Permalink
Merge pull request #210 from icecube/seyfert_fix
Browse files Browse the repository at this point in the history
Add `IrregularParameterGrid` class
  • Loading branch information
chiarabellenghi authored Feb 10, 2024
2 parents 6c3e66f + b62e95d commit 24511ca
Showing 1 changed file with 184 additions and 1 deletion.
185 changes: 184 additions & 1 deletion skyllh/core/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,189 @@ def round_to_upper_grid_point(self, value):
return gp


class IrregularParameterGrid(object):
"""This class provides a data holder for a parameter that has a set of
discrete values on a grid. Thus, the parameter has a value grid.
This class represents a one-dimensional irregular grid.
"""
@staticmethod
def from_BinningDefinition(binning):
"""Creates a IrregularParameterGrid instance from a BinningDefinition instance.
Parameters
----------
binning : BinningDefinition instance
The BinningDefinition instance that should be used to create the
ParameterGrid instance from.
Returns
-------
param_grid : instance of ParameterGrid
The created ParameterGrid instance.
"""
return IrregularParameterGrid(
name=binning.name,
grid=binning.binedges,
)

def __init__(self, name, grid):
"""Creates a new parameter grid.
Parameters
----------
name : str
The name of the parameter.
grid : sequence of float
The sequence of float values defining the discrete grid values of
the parameter.
"""
self.name = name
self.grid = grid

def __str__(self):
"""Pretty string representation.
"""
return '{:s} = {:s}'.format(
self._name, str(self._grid))

@property
def name(self):
"""The name of the parameter.
"""
return self._name

@name.setter
def name(self, name):
if not isinstance(name, str):
raise TypeError(
'The name property must be of type str!')
self._name = name

@property
def grid(self):
"""The numpy.ndarray with the grid values of the parameter.
"""
return self._grid

@grid.setter
def grid(self, arr):
if not issequence(arr):
raise TypeError(
'The grid property must be a sequence!')
if not isinstance(arr, np.ndarray):
arr = np.array(arr, dtype=np.float64)
if arr.ndim != 1:
raise ValueError(
'The grid property must be a 1D numpy.ndarray!')
self._grid = arr

@property
def ndim(self):
"""The dimensionality of the parameter grid.
"""
return self._grid.ndim

def add_extra_lower_and_upper_bin(self):
"""Adds an extra lower and upper bin to this parameter grid. This is
usefull when interpolation or gradient methods require an extra bin on
each side of the grid.
"""
newgrid = np.empty((self._grid.size+2,))
newgrid[1:-1] = self._grid
newgrid[0] = newgrid[1] - (newgrid[2] - newgrid[1])
newgrid[-1] = newgrid[-2] + (newgrid[-2] - newgrid[-3])
del self._grid
self.grid = newgrid

def copy(self):
"""Copies this IrregularParameterGrid object and returns the copy.
"""
copy = deepcopy(self)
return copy

def round_to_nearest_grid_point(self, value):
"""Rounds the given value to the nearest grid point.
Note: If the given value is precisely in between two nearest grid
points, the lower nearest grid point will be returned!
Parameters
----------
value : float | ndarray of float
The value(s) to round.
Returns
-------
grid_point : float | ndarray of float
The calculated grid point(s).
"""
scalar_input = np.isscalar(value)

grid_middle = (self.grid[1:] + self.grid[:-1])/2
idx = np.searchsorted(grid_middle, value, side='left')
gp = self.grid[idx]


if scalar_input:
return gp.item()

return gp

def round_to_lower_grid_point(self, value):
"""Rounds the given value to the nearest grid point that is lower than
the given value.
Note: If the given value is a grid point, that grid point will be
returned!
Parameters
----------
value : float | ndarray of float
The value(s) to round.
Returns
-------
grid_point : float | ndarray of float
The calculated grid point(s).
"""
scalar_input = np.isscalar(value)

idx = np.searchsorted(self.grid, value, side='right') - 1
gp = self.grid[idx]

if scalar_input:
return gp.item()

return gp

def round_to_upper_grid_point(self, value):
"""Rounds the given value to the nearest grid point that is larger than
the given value.
Note: If the given value is a grid point, the next grid point will be
returned!
Parameters
----------
value : float | ndarray of float
The value(s) to round.
Returns
-------
grid_point : ndarray of float
The calculated grid point(s).
"""
scalar_input = np.isscalar(value)

idx = np.searchsorted(self.grid, value, side='right')
gp = self.grid[idx]

if scalar_input:
return gp.item()

return gp


class ParameterGridSet(
NamedObjectCollection):
"""Describes a set of parameter grids.
Expand All @@ -1364,7 +1547,7 @@ def __init__(
"""
super().__init__(
objs=param_grids,
obj_type=ParameterGrid,
obj_type=type(param_grids[0]),
**kwargs)

@property
Expand Down

0 comments on commit 24511ca

Please sign in to comment.