Skip to content

Commit

Permalink
Transparent reorganization and optimization of GP models
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaMalavolta committed Dec 20, 2024
1 parent 148b353 commit 0b7be7a
Show file tree
Hide file tree
Showing 34 changed files with 474 additions and 1,984 deletions.
2 changes: 1 addition & 1 deletion pyorbit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
from .subroutines.input_parser import yaml_parser


__version__ = "10.8.4"
__version__ = "10.9.0"
24 changes: 17 additions & 7 deletions pyorbit/models/abstract_gaussian_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
class AbstractGaussianProcesses(object):

def __init__(self, *args, **kwargs):
pass
self.use_stellar_rotation_period = False
self.use_stellar_activity_decay = False

def _prepare_hyperparameters_conditions(self, mc, **kwargs):

def _prepare_hyperparameter_conditions(self, mc, **kwargs):

if kwargs.get('hyperparameters_condition', False):
self.hyper_condition = self._hypercond_01
Expand Down Expand Up @@ -55,7 +57,7 @@ def _prepare_decay_replacement(self, mc, parameter_name ='Pdec' , **kwargs):
self.list_pams_common.update(['activity_decay'])
self.list_pams_common.discard(parameter_name)

def _set_derivative_option(self, mc, dataset, **kwargs):
def _set_derivative_option(self, mc, dataset, return_flag=False, **kwargs):

if 'derivative'in kwargs:
use_derivative = kwargs['derivative'].get(dataset.name_ref, False)
Expand All @@ -70,18 +72,26 @@ def _set_derivative_option(self, mc, dataset, **kwargs):
else:
use_derivative = True

""" instead of taking an action on the parameter, the flag is returned"""
if return_flag:
return use_derivative

if not use_derivative:
self.fix_list[dataset.name_ref] = {'rot_amp': [0., 0.]}

def update_parameter_values(self, parameter_values, prepend=''):
def update_parameter_values(self,
parameter_values,
prepend='',
replace_rotation='Prot',
replace_decay='Pdec'):

if self.use_stellar_rotation_period:
parameter_values['Prot'] = parameter_values['rotation_period']
parameter_values[replace_rotation] = parameter_values['rotation_period']

if self.use_stellar_activity_decay:
parameter_values['Pdec'] = parameter_values['activity_decay']
parameter_values[replace_decay] = parameter_values['activity_decay']

def check_parameter_values(self, parameter_values):
def check_hyperparameter_values(self, parameter_values):

if not self.hyper_condition(parameter_values):
return -np.inf
Expand Down
69 changes: 31 additions & 38 deletions pyorbit/models/celerite2_granulation_oscillation_rotation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pyorbit.subroutines.common import np, OrderedSet
from pyorbit.models.abstract_model import AbstractModel
from pyorbit.models.abstract_gaussian_processes import AbstractGaussianProcesses
from pyorbit.keywords_definitions import *

try:
Expand All @@ -9,7 +10,7 @@
pass


class Celerite2_Granulation_Oscillation_Rotation(AbstractModel):
class Celerite2_Granulation_Oscillation_Rotation(AbstractModel, AbstractGaussianProcesses):

r"""A
Expand Down Expand Up @@ -42,6 +43,7 @@ class Celerite2_Granulation_Oscillation_Rotation(AbstractModel):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
super(AbstractModel, self).__init__(*args, **kwargs)

try:
import celerite2
Expand All @@ -66,13 +68,7 @@ def __init__(self, *args, **kwargs):

def initialize_model(self, mc, **kwargs):

for common_ref in self.common_ref:
if mc.common_models[common_ref].model_class == 'activity':
self.use_stellar_rotation_period = getattr(mc.common_models[common_ref], 'use_stellar_rotation_period', False)
break

for keyword in keywords_stellar_rotation:
self.use_stellar_rotation_period = kwargs.get(keyword, self.use_stellar_rotation_period)
self._prepare_rotation_replacement(mc, **kwargs)

self.rotation_kernels = kwargs.get('rotation_kernels', 1)
self.granulation_kernels = kwargs.get('granulation_kernels', 2)
Expand Down Expand Up @@ -160,8 +156,7 @@ def lnlk_compute(self, parameter_values, dataset):
In celerite2 the old function "set_parameter_vector" has been removed
and the kernel has to be defined every time
"""
if self.use_stellar_rotation_period:
parameter_values['Prot'] = parameter_values['rotation_period']
self.update_parameter_values(parameter_values)

self.gp[dataset.name_ref].mean = 0.
i_kernels = 0
Expand All @@ -176,23 +171,23 @@ def lnlk_compute(self, parameter_values, dataset):
for i_k in range(0, self.granulation_kernels):
if i_kernels > 0:
kernel += terms.SHOTerm(sigma=parameter_values['grn_k'+repr(i_k) + '_sigma'],
rho=parameter_values['grn_k'+repr(i_k) + '_period'],
Q=self.Q_granulation)
rho=parameter_values['grn_k'+repr(i_k) + '_period'],
Q=self.Q_granulation)
else:
kernel = terms.SHOTerm(sigma=parameter_values['grn_k'+repr(i_k) + '_sigma'],
rho=parameter_values['grn_k'+repr(i_k) + '_period'],
Q=self.Q_granulation)
rho=parameter_values['grn_k'+repr(i_k) + '_period'],
Q=self.Q_granulation)
i_kernels += 1

for i_k in range(0, self.oscillation_kernels):
if i_kernels > 0:
kernel += terms.SHOTerm(sigma=parameter_values['osc_k'+repr(i_k) + '_sigma'],
rho=parameter_values['osc_k'+repr(i_k) + '_period'],
Q=parameter_values['osc_k'+repr(i_k) + '_Q0'])
rho=parameter_values['osc_k'+repr(i_k) + '_period'],
Q=parameter_values['osc_k'+repr(i_k) + '_Q0'])
else:
kernel = terms.SHOTerm(sigma=parameter_values['osc_k'+repr(i_k) + '_sigma'],
rho=parameter_values['osc_k'+repr(i_k) + '_period'],
Q=parameter_values['osc_k'+repr(i_k) + '_Q0'])
rho=parameter_values['osc_k'+repr(i_k) + '_period'],
Q=parameter_values['osc_k'+repr(i_k) + '_Q0'])
i_kernels += 1


Expand All @@ -204,8 +199,7 @@ def lnlk_compute(self, parameter_values, dataset):

def sample_predict(self, parameter_values, dataset, x0_input=None, return_covariance=False, return_variance=False):

if self.use_stellar_rotation_period:
parameter_values['Prot'] = parameter_values['rotation_period']
self.update_parameter_values(parameter_values)

self.gp[dataset.name_ref].mean = 0.

Expand All @@ -221,23 +215,23 @@ def sample_predict(self, parameter_values, dataset, x0_input=None, return_covari
for i_k in range(0, self.granulation_kernels):
if i_kernels > 0:
kernel += terms.SHOTerm(sigma=parameter_values['grn_k'+repr(i_k) + '_sigma'],
rho=parameter_values['grn_k'+repr(i_k) + '_period'],
Q=self.Q_granulation)
rho=parameter_values['grn_k'+repr(i_k) + '_period'],
Q=self.Q_granulation)
else:
kernel = terms.SHOTerm(sigma=parameter_values['grn_k'+repr(i_k) + '_sigma'],
rho=parameter_values['grn_k'+repr(i_k) + '_period'],
Q=self.Q_granulation)
rho=parameter_values['grn_k'+repr(i_k) + '_period'],
Q=self.Q_granulation)
i_kernels += 1

for i_k in range(0, self.oscillation_kernels):
if i_kernels > 0:
kernel += terms.SHOTerm(sigma=parameter_values['osc_k'+repr(i_k) + '_sigma'],
rho=parameter_values['osc_k'+repr(i_k) + '_period'],
Q=parameter_values['osc_k'+repr(i_k) + '_Q0'])
rho=parameter_values['osc_k'+repr(i_k) + '_period'],
Q=parameter_values['osc_k'+repr(i_k) + '_Q0'])
else:
kernel = terms.SHOTerm(sigma=parameter_values['osc_k'+repr(i_k) + '_sigma'],
rho=parameter_values['osc_k'+repr(i_k) + '_period'],
Q=parameter_values['osc_k'+repr(i_k) + '_Q0'])
rho=parameter_values['osc_k'+repr(i_k) + '_period'],
Q=parameter_values['osc_k'+repr(i_k) + '_Q0'])
i_kernels += 1


Expand All @@ -252,8 +246,7 @@ def sample_predict(self, parameter_values, dataset, x0_input=None, return_covari

def sample_conditional(self, parameter_values, dataset, x0_input=None):

if self.use_stellar_rotation_period:
parameter_values['Prot'] = parameter_values['rotation_period']
self.update_parameter_values(parameter_values)

self.gp[dataset.name_ref].mean = 0.
i_kernels = 0
Expand All @@ -268,23 +261,23 @@ def sample_conditional(self, parameter_values, dataset, x0_input=None):
for i_k in range(0, self.granulation_kernels):
if i_kernels > 0:
kernel += terms.SHOTerm(sigma=parameter_values['grn_k'+repr(i_k) + '_sigma'],
rho=parameter_values['grn_k'+repr(i_k) + '_period'],
Q=self.Q_granulation)
rho=parameter_values['grn_k'+repr(i_k) + '_period'],
Q=self.Q_granulation)
else:
kernel = terms.SHOTerm(sigma=parameter_values['grn_k'+repr(i_k) + '_sigma'],
rho=parameter_values['grn_k'+repr(i_k) + '_period'],
Q=self.Q_granulation)
rho=parameter_values['grn_k'+repr(i_k) + '_period'],
Q=self.Q_granulation)
i_kernels += 1

for i_k in range(0, self.oscillation_kernels):
if i_kernels > 0:
kernel += terms.SHOTerm(sigma=parameter_values['osc_k'+repr(i_k) + '_sigma'],
rho=parameter_values['osc_k'+repr(i_k) + '_period'],
Q=parameter_values['osc_k'+repr(i_k) + '_Q0'])
rho=parameter_values['osc_k'+repr(i_k) + '_period'],
Q=parameter_values['osc_k'+repr(i_k) + '_Q0'])
else:
kernel = terms.SHOTerm(sigma=parameter_values['osc_k'+repr(i_k) + '_sigma'],
rho=parameter_values['osc_k'+repr(i_k) + '_period'],
Q=parameter_values['osc_k'+repr(i_k) + '_Q0'])
rho=parameter_values['osc_k'+repr(i_k) + '_period'],
Q=parameter_values['osc_k'+repr(i_k) + '_Q0'])
i_kernels += 1


Expand Down
29 changes: 9 additions & 20 deletions pyorbit/models/celerite2_granulation_rotation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pyorbit.subroutines.common import np, OrderedSet
from pyorbit.models.abstract_model import AbstractModel
from pyorbit.models.abstract_gaussian_processes import AbstractGaussianProcesses
from pyorbit.keywords_definitions import *

try:
Expand All @@ -9,7 +10,7 @@
pass


class Celerite2_Granulation_Rotation(AbstractModel):
class Celerite2_Granulation_Rotation(AbstractModel, AbstractGaussianProcesses):

r"""A mixture of two SHO terms that can be used to model stellar rotation
This term has two modes in Fourier space: one at ``period`` and one at
Expand Down Expand Up @@ -43,6 +44,7 @@ class Celerite2_Granulation_Rotation(AbstractModel):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
super(AbstractModel, self).__init__(*args, **kwargs)

try:
import celerite2
Expand Down Expand Up @@ -72,17 +74,7 @@ def __init__(self, *args, **kwargs):

def initialize_model(self, mc, **kwargs):

for common_ref in self.common_ref:
if mc.common_models[common_ref].model_class == 'activity':
self.use_stellar_rotation_period = getattr(mc.common_models[common_ref], 'use_stellar_rotation_period', False)
break

for keyword in keywords_stellar_rotation:
self.use_stellar_rotation_period = kwargs.get(keyword, self.use_stellar_rotation_period)

if self.use_stellar_rotation_period:
self.list_pams_common.update(['rotation_period'])
self.list_pams_common.discard('Prot')
self._prepare_rotation_replacement(mc, **kwargs)

def initialize_model_dataset(self, mc, dataset, **kwargs):
self.define_kernel(dataset)
Expand All @@ -106,17 +98,16 @@ def define_kernel(self, dataset):

def lnlk_compute(self, parameter_values, dataset):

if self.use_stellar_rotation_period:
parameter_values['Prot'] = parameter_values['rotation_period']
self.update_parameter_values(parameter_values)

"""
In celerite2 the old function "set_parameter_vector" has been removed
and the kernel has to be defined every time
"""
self.gp[dataset.name_ref].mean = 0.
self.gp[dataset.name_ref].kernel = terms.SHOTerm(sigma=parameter_values['grn_sigma'],
rho=parameter_values['grn_period'],
Q=self.Q_granulation) \
rho=parameter_values['grn_period'],
Q=self.Q_granulation) \
+ terms.RotationTerm(sigma=parameter_values['rot_sigma'],
period=parameter_values['Prot'],
Q0=parameter_values['rot_Q0'],
Expand All @@ -130,8 +121,7 @@ def lnlk_compute(self, parameter_values, dataset):

def sample_predict(self, parameter_values, dataset, x0_input=None, return_covariance=False, return_variance=False):

if self.use_stellar_rotation_period:
parameter_values['Prot'] = parameter_values['rotation_period']
self.update_parameter_values(parameter_values)

self.gp[dataset.name_ref].mean = 0.
self.gp[dataset.name_ref].kernel = terms.SHOTerm(sigma=parameter_values['grn_sigma'],
Expand All @@ -153,8 +143,7 @@ def sample_predict(self, parameter_values, dataset, x0_input=None, return_covari

def sample_conditional(self, parameter_values, dataset, x0_input=None):

if self.use_stellar_rotation_period:
parameter_values['Prot'] = parameter_values['rotation_period']
self.update_parameter_values(parameter_values)

self.gp[dataset.name_ref].mean = 0.
self.gp[dataset.name_ref].kernel = terms.SHOTerm(sigma=parameter_values['grn_sigma'],
Expand Down
Loading

0 comments on commit 0b7be7a

Please sign in to comment.