diff --git a/AUTHORS.rst b/AUTHORS.rst index d2d950ec4..07bbac26c 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -16,6 +16,7 @@ Contributors Active developers are indicated by (*). Authors of the PINT paper are indicated by (#). +* Gabriella Agazie (*) * Akash Anumarlapudi (*) * Anne Archibald (#*) * Matteo Bachetti (#) diff --git a/CHANGELOG-unreleased.md b/CHANGELOG-unreleased.md index ac9bc2495..2c0e0ffb9 100644 --- a/CHANGELOG-unreleased.md +++ b/CHANGELOG-unreleased.md @@ -9,10 +9,13 @@ the released changes. ## Unreleased ### Changed +- WAVE parameters can be added to a Wave model with `add_wave_component()` in wave.py - Moved design matrix normalization code from `pint.fitter` to the new `pint.utils.normalize_designmatrix()` function. - Made `Residuals` independent of `GLSFitter` (GLS chi2 is now computed using the new function `Residuals._calc_gls_chi2()`). ### Added +- Added WaveX model as DelayComponent with wave amplitudes as fitted parameters ### Fixed +- Wave model `validate()` can correctly use PEPOCH to assign WAVEEPOCH parameter - Fixed RTD by specifying theme explicitly. - `.value()` now works for pairParameters - Setting `model.PARAM1 = model.PARAM2` no longer overrides the name of `PARAM1` diff --git a/src/pint/models/__init__.py b/src/pint/models/__init__.py index e22443493..5440cca96 100644 --- a/src/pint/models/__init__.py +++ b/src/pint/models/__init__.py @@ -42,6 +42,7 @@ from pint.models.timing_model import DEFAULT_ORDER, TimingModel from pint.models.troposphere_delay import TroposphereDelay from pint.models.wave import Wave +from pint.models.wavex import WaveX # Define a standard basic model StandardTimingModel = TimingModel( diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py index 57eba15fa..1f02b8c75 100644 --- a/src/pint/models/timing_model.py +++ b/src/pint/models/timing_model.py @@ -117,6 +117,7 @@ "spindown", "phase_jump", "wave", + "wavex", ] diff --git a/src/pint/models/wave.py b/src/pint/models/wave.py index 7abe4ed79..194b2b5b8 100644 --- a/src/pint/models/wave.py +++ b/src/pint/models/wave.py @@ -28,7 +28,13 @@ class Wave(PhaseComponent): def __init__(self): super().__init__() - + self.add_param( + MJDParameter( + name="WAVEEPOCH", + description="Reference epoch for wave solution", + time_scale="tdb", + ) + ) self.add_param( floatParameter( name="WAVE_OM", @@ -46,13 +52,6 @@ def __init__(self): parameter_type="pair", ) ) - self.add_param( - MJDParameter( - name="WAVEEPOCH", - description="Reference epoch for wave solution", - time_scale="tdb", - ) - ) self.phase_funcs_component += [self.wave_phase] def setup(self): @@ -64,14 +63,14 @@ def validate(self): super().validate() self.setup() if self.WAVEEPOCH.quantity is None: - if self.PEPOCH.quantity is None: + if self._parent.PEPOCH.quantity is None: raise MissingParameter( "Wave", "WAVEEPOCH", "WAVEEPOCH or PEPOCH are required if " "WAVE_OM is set.", ) else: - self.WAVEEPOCH = self.PEPOCH + self.WAVEEPOCH.quantity = self._parent.PEPOCH.quantity if (not hasattr(self._parent, "F0")) or (self._parent.F0.quantity is None): raise MissingParameter( @@ -95,6 +94,52 @@ def print_par(self, format="pint"): return result + def add_wave_component(self, amps, index=None): + """Add Wave Component + + Parameters + ---------- + + index : int + Interger label for Wave components. + amps : tuple of float or astropy.quantity.Quantity + Sine and cosine amplitudes + + Returns + ------- + + index : + Index that has been assigned to new Wave component + """ + #### If index is None, increment the current max Wave index by 1. Increment using WAVE + if index is None: + dct = self.get_prefix_mapping_component("WAVE") + index = np.max(list(dct.keys())) + 1 + i = f"{int(index):04d}" + + if int(index) in self.get_prefix_mapping_component("WAVE"): + raise ValueError( + f"Index '{index}' is already in use in this model. Please choose another" + ) + + for amp in amps: + if isinstance(amp, u.quantity.Quantity): + amp = amp.to_value(u.s) + self.add_param( + prefixParameter( + name=f"WAVE{index}", + value=amps, + units="s", + description="Wave components", + type_match="pair", + long_double=True, + parameter_type="pair", + ) + ) + self.setup() + self.validate() + return f"{index}" + def wave_phase(self, toas, delays): times = 0 wave_names = ["WAVE%d" % ii for ii in range(1, self.num_wave_terms + 1)] diff --git a/src/pint/models/wavex.py b/src/pint/models/wavex.py new file mode 100644 index 000000000..2861308c1 --- /dev/null +++ b/src/pint/models/wavex.py @@ -0,0 +1,389 @@ +"""Delays expressed as a sum of sinusoids.""" +import astropy.units as u +import numpy as np +from loguru import logger as log +from warnings import warn + +from pint.models.parameter import MJDParameter, floatParameter, prefixParameter +from pint.models.timing_model import DelayComponent, MissingParameter + + +class WaveX(DelayComponent): + """ + Implementation of the wave model as a delay correction + + Delays are expressed as a sum of sinusoids. + + Used for decomposition of timing noise into a series of sine/cosine components with the amplitudes as fitted parameters. + + Parameters supported: + + .. paramtable:: + :class: pint.models.wavex.WaveX + + This is an extension of the L13 method described in Lentati et al., 2013 doi: 10.1103/PhysRevD.87.104021 + This model is similar to the TEMPO2 WAVE model parameters and users can convert a `TimingModel` with a Wave model + to a WaveX model and produce the same results. The main differences are that the WaveX frequencies are explicitly stated, + they do not necessarily need to be harmonics of some base frequency, the wave amplitudes are fittable parameters, and the + sine and cosine amplutides are reported as separate `prefixParameter`s rather than as a single `pairParameter`. + + Analogous parameters in both models have the same units: + WAVEEPOCH is the same as WXEPOCH + WAVEOM and WXFREQ_000N have units of 1/d + WAVEN and WXSIN_000N/WXCOS_000N have units of seconds + + The `pint.utils` functions `translate_wave_to_wavex()` and `translate_wavex_to_wave()` can be used to go back and forth between + two model. + + WARNING: If the choice of WaveX frequencies in a `TimingModel` doesn't correspond to harmonics of some base + freqeuncy, it will not be possible to convert it to a Wave model. + + To set up a WaveX model, users can use the `pint.utils` function `wavex_setup()` with either a list of frequencies or a choice + of harmonics of a base frequency determined by 2 * pi /Timespan + """ + + register = True + category = "wavex" + + def __init__(self): + super().__init__() + self.add_param( + MJDParameter( + name="WXEPOCH", + description="Reference epoch for Fourier representation of red noise", + time_scale="tdb", + ) + ) + self.add_wavex_component(0.1, index=1, wxsin=0, wxcos=0, frozen=False) + self.set_special_params(["WXFREQ_0001", "WXSIN_0001", "WXCOS_0001"]) + self.delay_funcs_component += [self.wavex_delay] + + def add_wavex_component(self, wxfreq, index=None, wxsin=0, wxcos=0, frozen=True): + """ + Add WaveX component + + Parameters + ---------- + + wxfreq : float or astropy.quantity.Quantity + Base frequency for WaveX component + index : int, None + Interger label for WaveX component. If None, will increment largest used index by 1. + wxsin : float or astropy.quantity.Quantity + Sine amplitude for WaveX component + wxcos : float or astropy.quantity.Quantity + Cosine amplitude for WaveX component + frozen : iterable of bool or bool + Indicates whether wavex will be fit + + Returns + ------- + + index : int + Index that has been assigned to new WaveX component + """ + + #### If index is None, increment the current max WaveX index by 1. Increment using WXFREQ + if index is None: + dct = self.get_prefix_mapping_component("WXFREQ_") + index = np.max(list(dct.keys())) + 1 + i = f"{int(index):04d}" + + if int(index) in self.get_prefix_mapping_component("WXFREQ_"): + raise ValueError( + f"Index '{index}' is already in use in this model. Please choose another" + ) + + if isinstance(wxsin, u.quantity.Quantity): + wxsin = wxsin.to_value(u.s) + if isinstance(wxcos, u.quantity.Quantity): + wxcos = wxcos.to_value(u.s) + if isinstance(wxfreq, u.quantity.Quantity): + wxfreq = wxfreq.to_value(1 / u.d) + self.add_param( + prefixParameter( + name=f"WXFREQ_{i}", + description="Component frequency for Fourier representation of red noise", + units="1/d", + value=wxfreq, + parameter_type="float", + ) + ) + self.add_param( + prefixParameter( + name=f"WXSIN_{i}", + description="Sine amplitudes for Fourier representation of red noise", + units="s", + value=wxsin, + frozen=frozen, + parameter_type="float", + ) + ) + self.add_param( + prefixParameter( + name=f"WXCOS_{i}", + description="Cosine amplitudes for Fourier representation of red noise", + units="s", + value=wxcos, + frozen=frozen, + parameter_type="float", + ) + ) + self.setup() + self.validate() + return index + + def add_wavex_components( + self, wxfreqs, indices=None, wxsins=0, wxcoses=0, frozens=True + ): + """ + Add WaveX components with specified base frequencies + + Parameters + ---------- + + wxfreqs : iterable of float or astropy.quantity.Quantity + Base frequencies for WaveX components + indices : iterable of int, None + Interger labels for WaveX components. If None, will increment largest used index by 1. + wxsins : iterable of float or astropy.quantity.Quantity + Sine amplitudes for WaveX components + wxcoses : iterable of float or astropy.quantity.Quantity + Cosine amplitudes for WaveX components + frozens : iterable of bool or bool + Indicates whether sine adn cosine amplitudes of wavex components will be fit + + Returns + ------- + + indices : list + Indices that have been assigned to new WaveX components + """ + + if indices is None: + indices = [None] * len(wxfreqs) + wxsins = np.atleast_1d(wxsins) + wxcoses = np.atleast_1d(wxcoses) + if len(wxsins) == 1: + wxsins = np.repeat(wxsins, len(wxfreqs)) + if len(wxcoses) == 1: + wxcoses = np.repeat(wxcoses, len(wxfreqs)) + if len(wxsins) != len(wxfreqs): + raise ValueError( + f"Number of base frequencies {len(wxfreqs)} doesn't match number of sine ampltudes {len(wxsins)}" + ) + if len(wxcoses) != len(wxfreqs): + raise ValueError( + f"Number of base frequencies {len(wxfreqs)} doesn't match number of cosine ampltudes {len(wxcoses)}" + ) + frozens = np.atleast_1d(frozens) + if len(frozens) == 1: + frozens = np.repeat(frozens, len(wxfreqs)) + if len(frozens) != len(wxfreqs): + raise ValueError( + f"Number of base frequencies must match number of frozen values" + ) + #### If indices is None, increment the current max WaveX index by 1. Increment using WXFREQ + dct = self.get_prefix_mapping_component("WXFREQ_") + last_index = np.max(list(dct.keys())) + added_indices = [] + for wxfreq, index, wxsin, wxcos, frozen in zip( + wxfreqs, indices, wxsins, wxcoses, frozens + ): + if index is None: + index = last_index + 1 + last_index += 1 + elif index in list(dct.keys()): + raise ValueError( + f"Attempting to insert WXFREQ_{index:04d} but it already exists" + ) + added_indices.append(index) + i = f"{int(index):04d}" + + if int(index) in dct: + raise ValueError( + f"Index '{index}' is already in use in this model. Please choose another" + ) + if isinstance(wxfreq, u.quantity.Quantity): + wxfreq = wxfreq.to_value(u.d**-1) + if isinstance(wxsin, u.quantity.Quantity): + wxsin = wxsin.to_value(u.s) + if isinstance(wxcos, u.quantity.Quantity): + wxcos = wxcos.to_value(u.s) + log.trace(f"Adding WXSIN_{i} and WXCOS_{i} at frequency WXFREQ_{i}") + self.add_param( + prefixParameter( + name=f"WXFREQ_{i}", + description="Component frequency for Fourier representation of red noise", + units="1/d", + value=wxfreq, + parameter_type="float", + ) + ) + self.add_param( + prefixParameter( + name=f"WXSIN_{i}", + description="Sine amplitude for Fourier representation of red noise", + units="s", + value=wxsin, + parameter_type="float", + frozen=frozen, + ) + ) + self.add_param( + prefixParameter( + name=f"WXCOS_{i}", + description="Cosine amplitude for Fourier representation of red noise", + units="s", + value=wxcos, + parameter_type="float", + frozen=frozen, + ) + ) + self.setup() + self.validate() + return added_indices + + def remove_wavex_component(self, index): + """ + Remove all WaveX components associated with a given index or list of indices + + Parameters + ---------- + index : float, int, list, np.ndarray + Number or list/array of numbers corresponding to WaveX indices to be removed from model. + """ + + if isinstance(index, (int, float, np.int64)): + indices = [index] + elif isinstance(index, (list, set, np.ndarray)): + indices = index + else: + raise TypeError( + f"index most be a float, int, set, list, or array - not {type(index)}" + ) + for index in indices: + index_rf = f"{int(index):04d}" + for prefix in ["WXFREQ_", "WXSIN_", "WXCOS_"]: + self.remove_param(prefix + index_rf) + self.validate() + + def get_indices(self): + """ + Returns an array of intergers corresponding to WaveX component parameters using WXFREQs + + Returns + ------- + inds : np.ndarray + Array of WaveX indices in model. + """ + inds = [int(p.split("_")[-1]) for p in self.params if "WXFREQ_" in p] + return np.array(inds) + + # Initialize setup + def setup(self): + super().setup() + # Get WaveX mapping and register WXSIN and WXCOS derivatives + for prefix_par in self.get_params_of_type("prefixParameter"): + if prefix_par.startswith("WXSIN_"): + self.register_deriv_funcs(self.d_wavex_delay_d_WXSIN, prefix_par) + if prefix_par.startswith("WXCOS_"): + self.register_deriv_funcs(self.d_wavex_delay_d_WXCOS, prefix_par) + self.wave_freqs = list(self.get_prefix_mapping_component("WXFREQ_").keys()) + self.num_wave_freqs = len(self.wave_freqs) + + def validate(self): + # Validate all the WaveX parameters + super().validate() + self.setup() + WXFREQ_mapping = self.get_prefix_mapping_component("WXFREQ_") + WXSIN_mapping = self.get_prefix_mapping_component("WXSIN_") + WXCOS_mapping = self.get_prefix_mapping_component("WXCOS_") + if WXFREQ_mapping.keys() != WXSIN_mapping.keys(): + raise ValueError( + "WXFREQ_ parameters do not match WXSIN_ parameters." + "Please check your prefixed parameters" + ) + if WXFREQ_mapping.keys() != WXCOS_mapping.keys(): + raise ValueError( + "WXFREQ_ parameters do not match WXCOS_ parameters." + "Please check your prefixed parameters" + ) + # if len(WXFREQ_mapping.keys()) != len(WXSIN_mapping.keys()): + # raise ValueError( + # "The number of WXFREQ_ parameters do not match the number of WXSIN_ parameters." + # "Please check your prefixed parameters" + # ) + # if len(WXFREQ_mapping.keys()) != len(WXCOS_mapping.keys()): + # raise ValueError( + # "The number of WXFREQ_ parameters do not match the number of WXCOS_ parameters." + # "Please check your prefixed parameters" + # ) + if WXSIN_mapping.keys() != WXCOS_mapping.keys(): + raise ValueError( + "WXSIN_ parameters do not match WXCOS_ parameters." + "Please check your prefixed parameters" + ) + if len(WXSIN_mapping.keys()) != len(WXCOS_mapping.keys()): + raise ValueError( + "The number of WXSIN_ and WXCOS_ parameters do not match" + "Please check your prefixed parameters" + ) + wfreqs = np.zeros(len(WXFREQ_mapping)) + for j, index in enumerate(WXFREQ_mapping): + if (getattr(self, f"WXFREQ_{index:04d}").value == 0) or ( + getattr(self, f"WXFREQ_{index:04d}").quantity is None + ): + raise ValueError( + f"WXFREQ_{index:04d} is zero or None. Please check your prefixed parameters" + ) + if getattr(self, f"WXFREQ_{index:04d}").value < 0.0: + warn(f"Frequency WXFREQ_{index:04d} is negative") + wfreqs[j] = getattr(self, f"WXFREQ_{index:04d}").value + wfreqs.sort() + if np.any(np.diff(wfreqs) <= (1.0 / (2.0 * 364.25))): + warn("Frequency resolution is greater than 1/yr") + if self.WXEPOCH.value is None: + if self._parent is not None: + if self._parent.PEPOCH.value is None: + raise MissingParameter( + "WXEPOCH or PEPOCH are required if WaveX is being used" + ) + else: + self.WXEPOCH.quantity = self._parent.PEPOCH.quantity + + def validate_toas(self, toas): + return super().validate_toas(toas) + + def wavex_delay(self, toas, delays): + total_delay = np.zeros(toas.ntoas) * u.s + wave_freqs = self.get_prefix_mapping_component("WXFREQ_") + wave_sins = self.get_prefix_mapping_component("WXSIN_") + wave_cos = self.get_prefix_mapping_component("WXCOS_") + + base_phase = ( + toas.table["tdbld"].data * u.d - self.WXEPOCH.value * u.d - delays.to(u.d) + ) + for idx, param in wave_freqs.items(): + freq = getattr(self, param).quantity + wxsin = getattr(self, wave_sins[idx]).quantity + wxcos = getattr(self, wave_cos[idx]).quantity + arg = 2.0 * np.pi * freq * base_phase + total_delay += wxsin * np.sin(arg.value) + wxcos * np.cos(arg.value) + return total_delay + + def d_wavex_delay_d_WXSIN(self, toas, param, delays, acc_delay=None): + par = getattr(self, param) + freq = getattr(self, f"WXFREQ_{int(par.index):04d}").quantity + base_phase = toas.table["tdbld"].data * u.d - self.WXEPOCH.value * u.d + arg = 2.0 * np.pi * freq * base_phase + deriv = np.sin(arg.value) + return deriv * u.s / par.units + + def d_wavex_delay_d_WXCOS(self, toas, param, delays, acc_delay=None): + par = getattr(self, param) + freq = getattr(self, f"WXFREQ_{int(par.index):04d}").quantity + base_phase = toas.table["tdbld"].data * u.d - self.WXEPOCH.value * u.d + arg = 2.0 * np.pi * freq * base_phase + deriv = np.cos(arg.value) + return deriv * u.s / par.units diff --git a/src/pint/utils.py b/src/pint/utils.py index b3c7006c8..8cfbd95bf 100644 --- a/src/pint/utils.py +++ b/src/pint/utils.py @@ -51,12 +51,18 @@ from astropy.time import Time from loguru import logger as log from scipy.special import fdtrc +from copy import deepcopy +import warnings import pint import pint.pulsar_ecliptic from pint.toa_select import TOASelect + __all__ = [ + "PINTPrecisionError", + "check_longdouble_precision", + "require_longdouble_precision", "PosVel", "numeric_partial", "numeric_partials", @@ -70,26 +76,39 @@ "lines_of", "interesting_lines", "pmtot", - "dmxselections", - "dmxparse", - "dmxstats", + "dmxrange", + "sum_print", "dmx_ranges_old", "dmx_ranges", + "dmxselections", + "dmxstats", + "dmxparse", + "get_prefix_timerange", + "get_prefix_timeranges", + "find_prefix_bytime", + "merge_dmx", + "split_dmx", + "split_swx", + "wavex_setup", + "translate_wave_to_wavex", + "get_wavex_freqs", + "get_wavex_amps", + "translate_wavex_to_wave", "weighted_mean", "ELL1_check", "FTest", "add_dummy_distance", "remove_dummy_distance", "info_string", - "print_color_examples", + "list_parameters", "colorize", + "print_color_examples", "group_iterator", "compute_hash", - "PINTPrecisionError", - "check_longdouble_precision", - "require_longdouble_precision", "get_conjunction", "divide_times", + "convert_dispersion_measure", + "parse_time", "get_unit", ] @@ -1267,6 +1286,342 @@ def split_swx(model, time): return index, newindex +def wavex_setup(model, T_span, freqs=None, n_freqs=None): + """ + Set-up a WaveX model based on either an array of user-provided frequencies or the wave number + frequency calculation. Sine and Cosine amplitudes are initially set to zero + + User specifies T_span and either freqs or n_freqs. This function assumes that the timing model does not already + have any WaveX components. See add_wavex_component() or add_wavex_components() to add WaveX components + to an existing WaveX model. + + Parameters + ---------- + + model : pint.models.timing_model.TimingModel + freqs : iterable of float or astropy.quantity.Quantity, None + User inputed base frequencies + n_freqs : int, None + Number of wave frequencies to calculate using the equation: freq_n = 2 * pi * n / T_span + Where n is the wave number, and T_span is the total time span of the toas in the fitter object + T_span : float, astropy.quantity.Quantity + Time span used to calculate nyquist frequency when using freqs + Time span used to calculate WaveX frequencies when using n_freqs + Usually to be set as the length of the timing baseline the model is being used for + + Returns + ------- + + indices : list + Indices that have been assigned to new WaveX components + """ + from pint.models.wavex import WaveX + + if (freqs is None) and (n_freqs is None): + raise ValueError( + "WaveX component base frequencies are not specified. " + "Please input either freqs or n_freqs" + ) + + if (freqs is not None) and (n_freqs is not None): + raise ValueError( + "Both freqs and n_freqs are specified. Only one or the other should be used" + ) + + if n_freqs <= 0: + raise ValueError("Must use a non-zero number of wave frequencies") + model.add_component(WaveX()) + if isinstance(T_span, u.quantity.Quantity): + T_span.to(u.d) + else: + T_span *= u.d + + nyqist_freq = 1.0 / (2.0 * T_span) + if freqs is not None: + if isinstance(freqs, u.quantity.Quantity): + freqs.to(u.d**-1) + else: + freqs *= u.d**-1 + if len(freqs) == 1: + model.WXFREQ_0001.quantity = freqs + else: + np.array(freqs) + freqs.sort() + if min(np.diff(freqs)) < nyqist_freq: + warnings.warn( + "Wave frequency spacing is finer than frequency resolution of data" + ) + model.WXFREQ_0001.quantity = freqs[0] + model.components["WaveX"].add_wavex_components(freqs[1:]) + + if n_freqs is not None: + if n_freqs == 1: + wave_freq = 2.0 * np.pi / T_span + model.WXFREQ_0001.quantity = wave_freq + else: + wave_numbers = np.arange(1, n_freqs + 1) + wave_freqs = 2.0 * np.pi * wave_numbers / T_span + model.WXFREQ_0001.quantity = wave_freqs[0] + model.components["WaveX"].add_wavex_components(wave_freqs[1:]) + return model.components["WaveX"].get_indices() + + +def _translate_wave_freqs(om, k): + """ + Use Wave model WAVEOM parameter to calculate a WaveX WXFREQ_ frequency parameter for wave number k + + Parameters + ---------- + + om : float or astropy.quantity.Quantity + Base frequency of Wave model solution - parameter WAVEOM + If float is given default units of 1/d assigned + k : int + wave number to use to calculate WaveX WXFREQ_ frequency parameter + + Returns + ------- + + WXFREQ_ quantity in units 1/d that can be used in WaveX model + """ + if isinstance(om, u.quantity.Quantity): + om.to(u.d**-1) + else: + om *= u.d**-1 + return (om * (k + 1)) / (2.0 * np.pi) + + +def _translate_wavex_freqs(wxfreq, k): + """ + Use WaveX model WXFREQ_ parameters and wave number k to calculate the Wave model WAVEOM frequency parameter. + + Parameters + ---------- + + wxfreq : float or astropy.quantity.Quantity + WaveX frequency from which the WAVEOM parameter will be calculated + If float is given default units of 1/d assigned + k : int + wave number to use to calculate Wave WAVEOM parameter + + Returns + ------- + + WAVEOM quantity in units 1/d that can be used in Wave model + """ + if isinstance(wxfreq, u.quantity.Quantity): + wxfreq.to(u.d**-1) + else: + wxfreq *= u.d**-1 + if len(wxfreq) == 1: + return (2.0 * np.pi * wxfreq) / (k + 1.0) + else: + wave_om = [ + ((2.0 * np.pi * wxfreq[i]) / (k[i] + 1.0)) for i in range(len(wxfreq)) + ] + if np.allclose(wave_om, wave_om[0], atol=1e-3): + om = sum(wave_om) / len(wave_om) + return om + else: + return False + + +def translate_wave_to_wavex(model): + """ + Go from a Wave model to a WaveX model + + WaveX frequencies get calculated based on the Wave model WAVEOM parameter and the number of WAVE parameters. + WXFREQ_000k = [WAVEOM * (k+1)] / [2 * pi] + + WaveX amplitudes are taken from the WAVE pair parameters + + Paramters + --------- + model : pint.models.timing_model.TimingModel + TimingModel containing a Wave model to be converted to a WaveX model + + Returns + ------- + New timing model with converted WaveX model included + """ + from pint.models.wavex import WaveX + + new_model = deepcopy(model) + wave_names = [ + f"WAVE{ii}" for ii in range(1, model.components["Wave"].num_wave_terms + 1) + ] + wave_terms = [getattr(model.components["Wave"], name) for name in wave_names] + wave_om = model.components["Wave"].WAVE_OM.quantity + wave_epoch = model.components["Wave"].WAVEEPOCH.quantity + new_model.remove_component("Wave") + new_model.add_component(WaveX()) + new_model.WXEPOCH.value = wave_epoch.value + for k, wave_term in enumerate(wave_terms): + wave_sin_amp, wave_cos_amp = wave_term.quantity + wavex_freq = _translate_wave_freqs(wave_om, k) + if k == 0: + new_model.WXFREQ_0001.value = wavex_freq.value + new_model.WXSIN_0001.value = -wave_sin_amp.value + new_model.WXCOS_0001.value = -wave_cos_amp.value + else: + new_model.components["WaveX"].add_wavex_component( + wavex_freq, wxsin=-wave_sin_amp, wxcos=-wave_cos_amp + ) + return new_model + + +def get_wavex_freqs(model, index=None, quantity=False): + """ + Return the WaveX frequencies for a timing model. + + If index is specified, returns the frequencies corresponding to the user-provided indices. + If index isn't specified, returns all WaveX frequencies in timing model + + Parameters + ---------- + model : pint.models.timing_model.TimingModel + Timing model from which to return WaveX frequencies + index : float, int, list, np.ndarray, None + Number or list/array of numbers corresponding to WaveX frequencies to return + quantity : bool + If set to True, returns a list of astropy.quanitity.Quantity rather than a list of prefixParameters + + Returns + ------- + List of WXFREQ_ parameters + """ + if index is None: + freqs = model.components["WaveX"].get_prefix_mapping_component("WXFREQ_") + if len(freqs) == 1: + values = getattr(model.components["WaveX"], freqs.values()) + else: + values = [ + getattr(model.components["WaveX"], param) for param in freqs.values() + ] + elif isinstance(index, (int, float, np.int64)): + idx_rf = f"{int(index):04d}" + values = getattr(model.components["WaveX"], "WXFREQ_" + idx_rf) + elif isinstance(index, (list, set, np.ndarray)): + idx_rf = [f"{int(idx):04d}" for idx in index] + values = [getattr(model.components["WaveX"], "WXFREQ_" + ind) for ind in idx_rf] + else: + raise TypeError( + f"index most be a float, int, set, list, array, or None - not {type(index)}" + ) + if quantity: + if len(values) == 1: + values = [values[0].quantity] + else: + values = [v.quantity for v in values] + return values + + +def get_wavex_amps(model, index=None, quantity=False): + """ + Return the WaveX amplitudes for a timing model. + + If index is specified, returns the sine/cosine amplitudes corresponding to the user-provided indices. + If index isn't specified, returns all WaveX sine/cosine amplitudes in timing model + + Parameters + ---------- + model : pint.models.timing_model.TimingModel + Timing model from which to return WaveX frequencies + index : float, int, list, np.ndarray, None + Number or list/array of numbers corresponding to WaveX amplitudes to return + quantity : bool + If set to True, returns a list of tuples of astropy.quanitity.Quantity rather than a list of prefixParameters tuples + + Returns + ------- + List of WXSIN_ and WXCOS_ parameters + """ + if index is None: + indices = ( + model.components["WaveX"].get_prefix_mapping_component("WXSIN_").keys() + ) + if len(indices) == 1: + values = ( + getattr(model.components["WaveX"], "WXSIN_" + f"{int(indices):04d}"), + getattr(model.components["WaveX"], "WXCOS_" + f"{int(indices):04d}"), + ) + else: + values = [ + ( + getattr(model.components["WaveX"], "WXSIN_" + f"{int(idx):04d}"), + getattr(model.components["WaveX"], "WXCOS_" + f"{int(idx):04d}"), + ) + for idx in indices + ] + elif isinstance(index, (int, float, np.int64)): + idx_rf = f"{int(index):04d}" + values = ( + getattr(model.components["WaveX"], "WXSIN_" + idx_rf), + getattr(model.components["WaveX"], "WXCOS_" + idx_rf), + ) + elif isinstance(index, (list, set, np.ndarray)): + idx_rf = [f"{int(idx):04d}" for idx in index] + values = [ + ( + getattr(model.components["WaveX"], "WXSIN_" + ind), + getattr(model.components["WaveX"], "WXCOS_" + ind), + ) + for ind in idx_rf + ] + else: + raise TypeError( + f"index most be a float, int, set, list, array, or None - not {type(index)}" + ) + if quantity: + if isinstance(values, tuple): + values = tuple(v.quantity for v in values) + if isinstance(values, list): + values = [tuple((v[0].quantity, v[1].quantity)) for v in values] + return values + + +def translate_wavex_to_wave(model): + """ + Go from a WaveX timing model to a Wave timing model. + WARNING: Not every WaveX model can be appropriately translated into a Wave model. This is dependent on the user's choice of frequencies in the WaveX model. + In order for a WaveX model to be able to be converted into a Wave model, every WaveX frequency must produce the same value of WAVEOM in the calculation: + + WAVEOM = [2 * pi * WXFREQ_000k] / (k + 1) + Paramters + --------- + model : pint.models.timing_model.TimingModel + TimingModel containing a WaveX model to be converted to a Wave model + + Returns + ------- + New timing model with converted Wave model included + """ + from pint.models.wave import Wave + + new_model = deepcopy(model) + indices = model.components["WaveX"].get_indices() + wxfreqs = get_wavex_freqs(model, indices, quantity=True) + wave_om = _translate_wavex_freqs(wxfreqs, (indices - 1)) + if wave_om == False: + raise ValueError( + "This WaveX model cannot be properly translated into a Wave model due to the WaveX frequencies not producing a consistent WAVEOM value" + ) + wave_amps = get_wavex_amps(model, index=indices, quantity=True) + new_model.remove_component("WaveX") + new_model.add_component(Wave()) + new_model.WAVEEPOCH.quantity = model.WXEPOCH.quantity + new_model.WAVE_OM.quantity = wave_om + new_model.WAVE1.quantity = tuple(w * -1.0 for w in wave_amps[0]) + if len(indices) > 1: + for i in range(1, len(indices)): + print(wave_amps[i]) + wave_amps[i] = tuple(w * -1.0 for w in wave_amps[i]) + new_model.components["Wave"].add_wave_component( + wave_amps[i], index=indices[i] + ) + return new_model + + def weighted_mean(arrin, weights_in, inputmean=None, calcerr=False, sdev=False): """Compute weighted mean of input values diff --git a/tests/test_wavex.py b/tests/test_wavex.py new file mode 100644 index 000000000..0c7ec4dca --- /dev/null +++ b/tests/test_wavex.py @@ -0,0 +1,402 @@ +from io import StringIO +import pytest +import numpy as np +from loguru import logger as log + +from astropy import units as u +from pint.models import get_model, get_model_and_toas +from pint.models import model_builder as mb +from pint.models.timing_model import Component, MissingParameter +from pint.fitter import Fitter +from pint.residuals import Residuals +from pint.toa import get_TOAs +from pint.simulation import make_fake_toas_uniform +import pint.utils +from pinttestdata import datadir +from pint.models.wavex import WaveX + +par1 = """ + PSR B1937+21 + LAMBDA 301.9732445337270 + BETA 42.2967523367957 + PMLAMBDA -0.0175 + PMBETA -0.3971 + PX 0.1515 + POSEPOCH 55321.0000 + F0 641.9282333345536244 1 0.0000000000000132 + F1 -4.330899370129D-14 1 2.149749089617D-22 + PEPOCH 55321.000000 + DM 71.016633 + UNITS TDB + """ + +# Introduce a par file with WaveX already present + +par2 = """ + PSR B1937+21 + LAMBDA 301.9732445337270 + BETA 42.2967523367957 + PMLAMBDA -0.0175 + PMBETA -0.3971 + PX 0.1515 + POSEPOCH 55321.0000 + F0 641.9282333345536244 1 0.0000000000000132 + F1 -4.330899370129D-14 1 2.149749089617D-22 + PEPOCH 55321.000000 + DM 71.016633 + WXEPOCH 55321.000000 + WXFREQ_0001 0.1 + WXSIN_0001 1 + WXCOS_0001 1 + UNITS TDB + """ + +wavex_par = """ + WXFREQ_0002 0.2 + WXSIN_0002 2 + WXCOS_0002 2 + WXFREQ_0003 0.3 + WXSIN_0003 3 + WXCOS_0003 3 +""" +wave_par = """ + WAVEEPOCH 55321.000000 + WAVE_OM 0.1 + WAVE1 0.2 0.1 + WAVE2 0.6 0.3""" + + +def wavex_delay(waves, toas): + total_delay = np.zeros(toas.ntoas) * u.s + wave_freqs = waves.get_prefix_mapping_component("WXFREQ_") + wave_sins = waves.get_prefix_mapping_component("WXSIN_") + wave_cos = waves.get_prefix_mapping_component("WXCOS_") + base_phase = toas.table["tdbld"].data * u.d - waves.WXEPOCH.value * u.d + for idx, param in wave_freqs.items(): + freq = getattr(waves, param).quantity + wxsin = getattr(waves, wave_sins[idx]).quantity + wxcos = getattr(waves, wave_cos[idx]).quantity + arg = 2.0 * np.pi * freq * base_phase + total_delay += wxsin * np.sin(arg.value) + wxcos * np.cos(arg.value) + return total_delay + + +def test_derivative(): + # Check that analytical and numerical derivatives are similar + model = mb.get_model(StringIO(par2)) + model.WXFREQ_0001.value = 0.1 + model.WXSIN_0001.value = 0.01 + model.WXCOS_0001.value = 0.05 + toas = make_fake_toas_uniform(55000, 55100, 100, model, obs="gbt") + p = "WXSIN_0001" + log.debug(f"Running derivative for {p}", f"d_delay_d_{p}") + ndf = model.d_delay_d_param_num(toas, p) + adf = model.d_delay_d_param(toas, p) + diff = ndf - adf + print(diff) + if np.all(diff.value) != 0.0: + mean_der = (adf + ndf) / 2.0 + relative_diff = np.abs(diff) / np.abs(mean_der) + msg = f"Derivative test failed at d_delay_d_{p} with max relative difference {np.nanmax(relative_diff).value}" + tol = 0.7 + log.debug( + ( + f"derivative relative diff for d_delay_d_{p}, {np.nanmax(relative_diff).value}" + ) + ) + assert np.nanmax(relative_diff) < tol, msg + + +def test_wxsin_fit(): + # Check that when a par file with a wavex model is used to generate fake toas the wavex parameters don't change much when fitted for + model = get_model(StringIO(par1)) + model.add_component(WaveX()) + model.WXFREQ_0001.value = 0.1 + model.WXSIN_0001.value = 0.01 + model.WXCOS_0001.value = 0.05 + toas = make_fake_toas_uniform(55000, 55100, 100, model, obs="gbt") + for param in model.free_params: + getattr(model, param).frozen = True + model.WXSIN_0001.value = 0.02 + model.WXSIN_0001.frozen = False + f = Fitter.auto(toas, model) + f.fit_toas() + assert np.isclose(f.model.WXSIN_0001.value, 0.01, atol=1e-3) + + +def test_wxcos_fit(): + # Check that when a par file with a wavex model is used to generate fake toas the wavex parameters don't change much when fitted for + model = get_model(StringIO(par1)) + model.add_component(WaveX()) + model.WXFREQ_0001.value = 0.1 + model.WXSIN_0001.value = 0.01 + model.WXCOS_0001.value = 0.05 + toas = make_fake_toas_uniform(55000, 55100, 100, model, obs="gbt") + for param in model.free_params: + getattr(model, param).frozen = True + model.WXCOS_0001.value = 0.09 + model.WXCOS_0001.frozen = False + f = Fitter.auto(toas, model) + f.fit_toas() + assert np.isclose(f.model.WXCOS_0001.value, 0.05, atol=1e-3) + + +def test_wavex_resids_amp(): + # Check that the amplitude of residuals somewhat matches independent calculation of wave delay for a single component + model = get_model(StringIO(par1)) + toas = make_fake_toas_uniform(55000, 55100, 500, model, obs="gbt") + wave_model = get_model(StringIO(par2)) + rs = Residuals(toas, wave_model) + injected_amp = np.sqrt( + wave_model.WXSIN_0001.quantity**2 + wave_model.WXCOS_0001.quantity**2 + ) + assert np.isclose(max(rs.resids), injected_amp, atol=1e-2) + assert np.isclose(min(rs.resids), -injected_amp, atol=1e-2) + + +def test_multiple_wavex_resids_amp(): + # Check that residuals for multiple components match independent calculation + model = get_model(StringIO(par1)) + toas = make_fake_toas_uniform(55000, 55100, 500, model, obs="gbt") + wave_model = get_model(StringIO(par2 + wavex_par)) + rs = Residuals(toas, wave_model) + wave_delays = wavex_delay(wave_model.components["WaveX"], toas) + assert np.allclose(rs.resids, -wave_delays, atol=max(rs.resids.value) / 10.0) + + +def test_wavex_from_par(): + # Check that a par file with wavex components present produces expected indices + model = get_model(StringIO(par2 + wavex_par)) + indices = model.components["WaveX"].get_indices() + assert np.all(np.array(indices) == np.array([1, 2, 3])) + + +def test_add_wavex_to_par(): + # Add a wavex component to par file that has none and check against par file with some WaveX model + model = get_model(StringIO(par1)) + toas = make_fake_toas_uniform(55000, 55100, 100, model, obs="gbt") + model.add_component(WaveX()) + index = model.components["WaveX"].get_indices() + model.WXFREQ_0001.quantity = 0.1 * (1 / u.d) + model.WXSIN_0001.quantity = 1 * u.s + model.WXCOS_0001.quantity = 1 * u.s + wavex_model = get_model(StringIO(par2)) + assert np.all( + np.array(index) == np.array(wavex_model.components["WaveX"].get_indices()) + ) + assert np.all( + model.components["WaveX"].wavex_delay(toas, 0.0 * u.s) + == wavex_model.components["WaveX"].wavex_delay(toas, 0.0 * u.s) + ) + + +def test_add_existing_index(): + # Check that trying to add an existing index fails + model = get_model(StringIO(par2 + wavex_par)) + with pytest.raises(ValueError): + index = model.components["WaveX"].add_wavex_component(0.01, index=2) + + +def test_add_existing_indices(): + # Check that trying to add multiple existing indices fails + model = get_model(StringIO(par2 + wavex_par)) + with pytest.raises(ValueError): + indices = model.components["WaveX"].add_wavex_components( + [0.01, 0.02], indices=[2, 3] + ) + + +def test_multiple_wavex_none_indices(): + model = get_model(StringIO(par2 + wavex_par)) + model.components["WaveX"].add_wavex_components([0.01, 0.02]) + indices = model.components["WaveX"].get_indices() + assert np.all(indices == np.array(range(1, len(indices) + 1))) + + +def test_add_then_remove_wavex(): + # Check that adding and then removing a wavex component actually gets rid of it + model = get_model(StringIO(par2)) + model.components["WaveX"].add_wavex_component(0.2, index=2, wxsin=2, wxcos=2) + indices = model.components["WaveX"].get_indices() + model.components["WaveX"].remove_wavex_component(2) + index = model.components["WaveX"].get_indices() + assert np.all(np.array(len(indices)) != np.array(len(index))) + + +def test_multiple_wavex(): + # Check that when adding multiple wavex component pythonically is consistent with a par file with the same components + model = get_model(StringIO(par2)) + toas = make_fake_toas_uniform(55000, 55100, 100, model, obs="gbt") + wavex_model = get_model(StringIO(par2 + wavex_par)) + indices = model.components["WaveX"].add_wavex_components( + [0.2, 0.3], indices=[2, 3], wxsins=[2, 3], wxcoses=[2, 3] + ) + assert np.all(np.array(indices) == np.array([2, 3])) + assert np.all( + model.components["WaveX"].wavex_delay(toas, 0.0 * u.s) + == wavex_model.components["WaveX"].wavex_delay(toas, 0.0 * u.s) + ) + + +def test_multiple_wavex_unit_conversion(): + # Check that input frequencies and amplitudes in different units convert properly + model = get_model(StringIO(par2)) + freqs = [2e-7 * u.s**-1, 3e-7 * u.s**-1] + indices = model.components["WaveX"].add_wavex_components( + [2e-7 * u.s**-1, 3e-7 * u.s**-1], + indices=[2, 3], + wxsins=[2, 3], + wxcoses=[2, 3], + frozens=False, + ) + assert getattr(model, f"WXFREQ_0002").value == freqs[0].to(u.d**-1).value + assert getattr(model, f"WXFREQ_0003").value == freqs[1].to(u.d**-1).value + + +def test_cos_amp_missing(): + # Check that validate fails when using a model with missing cosine amplitudes for the frequencies present + bad_wavex_par = """ + WXFREQ_0002 0.2 + WXSIN_0002 2 + """ + with pytest.raises(ValueError): + model = get_model(StringIO(par2 + bad_wavex_par)) + + +def test_sin_amp_missing(): + # Check that validate fails when using a model with missing cosine amplitudes for the frequencies present + bad_wavex_par = """ + WXFREQ_0002 0.2 + WXCOS_0002 2 + """ + with pytest.raises(ValueError): + model = get_model(StringIO(par2 + bad_wavex_par)) + + +def test_bad_wxfreq_value(): + # Check that putting a zero, or None value for an added frequency raises ValueErrors + model = get_model(StringIO(par2)) + with pytest.raises(ValueError): + model.components["WaveX"].add_wavex_component(0) + model.components["WaveX"].add_wavex_component(None) + + +def test_missing_epoch_parameters(): + bad_par = """ + PSR B1937+21 + LAMBDA 301.9732445337270 + BETA 42.2967523367957 + PMLAMBDA -0.0175 + PMBETA -0.3971 + PX 0.1515 + F0 641.9282333345536244 1 0.0000000000000132 + F1 -4.330899370129D-14 1 2.149749089617D-22 + DM 71.016633 + WXFREQ_0001 0.1 + WXSIN_0001 1 + WXCOS_0001 1 + UNITS TDB + """ + with pytest.raises(MissingParameter): + model = get_model(StringIO(bad_par)) + + +def test_sin_cos_mismatch(): + # Check that having mismatching sine and cosine amplitudes raises ValueErrors + bad_wavex_par = """ + WXFREQ_0002 0.2 + WXSIN_0002 2 + WXCOS_0003 2 + WXFREQ_0003 0.3 + WXSIN_0003 2 + WXCOS_0004 2 + """ + with pytest.raises(ValueError): + model = get_model(StringIO(par2 + bad_wavex_par)) + + +def test_multiple_wavex_broadcast_frozens(): + # Check that when a single False is given for frozens, it gets broadcast to all the sine and cosine amplitudes + model = get_model(StringIO(par2)) + indices = model.components["WaveX"].add_wavex_components( + [0.2, 0.3], + indices=[2, 3], + wxsins=[2, 3], + wxcoses=[2, 3], + frozens=False, + ) + for index in indices: + assert getattr(model, f"WXSIN_{index:04d}").frozen == False + assert getattr(model, f"WXCOS_{index:04d}").frozen == False + + +def test_multiple_wavex_wrong_cos_amps(): + # Check that code breaks when adding an extra cosine amplitude than there are frequencies, indices, and sine amplitudes for + model = get_model(StringIO(par2)) + with pytest.raises(ValueError): + indices = model.components["WaveX"].add_wavex_components( + [0.2, 0.3], indices=[2, 3], wxsins=[2, 3], wxcoses=[2, 3, 4] + ) + + +def test_multiple_wavex_wrong_sin_amps(): + # Check that code breaks when adding an extra sine amplitude than there are frequencies, indices, and cosine amplitudes for + model = get_model(StringIO(par2)) + with pytest.raises(ValueError): + indices = model.components["WaveX"].add_wavex_components( + [0.2, 0.3], indices=[2, 3], wxsins=[2, 3, 4], wxcoses=[2, 3] + ) + + +def test_multiple_wavex_wrong_freqs(): + # Check that code breaks when not adding enough frequencies for the number of indices, sine amps, and cosine amps given + model = get_model(StringIO(par2)) + with pytest.raises(ValueError): + indices = model.components["WaveX"].add_wavex_components( + [0.2, 0.3], indices=[2, 3, 4], wxsins=[2, 3, 4], wxcoses=[2, 3, 4] + ) + + +def test_multiple_wavex_wrong_frozens(): + # Check that adding to many elements to frozens breaks code + model = get_model(StringIO(par2)) + with pytest.raises(ValueError): + indices = model.components["WaveX"].add_wavex_components( + [0.2, 0.3], + indices=[2, 3], + wxsins=[2, 3], + wxcoses=[2, 3], + frozens=[False, False, False], + ) + + +def test_multiple_wavex_explicit_indices(): + # Check that adding specific indices is done correctly + model = get_model(StringIO(par2)) + indices = model.components["WaveX"].add_wavex_components( + [0.2, 0.3], indices=[3, 4], wxsins=[2, 3], wxcoses=[2, 3] + ) + assert np.all(np.array(indices) == np.array([3, 4])) + + +def test_multiple_wavex_explicit_indices_duplicate(): + # Check that adding a duplicate index fails + model = get_model(StringIO(par2)) + with pytest.raises(ValueError): + indices = model.components["WaveX"].add_wavex_components( + [0.2, 0.3], indices=[1, 3], wxsins=[2, 3], wxcoses=[2, 3] + ) + + +def test_wave_wavex_roundtrip_conversion(): + # Check that when starting with a TimingModel with a Wave model, conversion to a WaveX mode and then back produces consistent results + model = get_model(StringIO(par1)) + toas = make_fake_toas_uniform(55000, 55100, 500, model, obs="gbt") + wave_model = get_model(StringIO(par1 + wave_par)) + wave_to_wavex_model = pint.utils.translate_wave_to_wavex(wave_model) + wavex_to_wave_model = pint.utils.translate_wavex_to_wave(wave_to_wavex_model) + rs_wave = Residuals(toas, wave_model) + rs_wave_to_wavex = Residuals(toas, wave_to_wavex_model) + rs_wavex_to_wave = Residuals(toas, wavex_to_wave_model) + assert np.allclose(rs_wave.resids, rs_wave_to_wavex.resids, atol=1e-3) + assert np.allclose(rs_wave.resids, rs_wavex_to_wave.resids, atol=1e-3)