Skip to content

Commit

Permalink
Semplification of GP models management
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaMalavolta committed Dec 20, 2024
1 parent 88fc0f7 commit 148b353
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 89 deletions.
114 changes: 114 additions & 0 deletions pyorbit/models/abstract_gaussian_processes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from pyorbit.subroutines.common import np
from pyorbit.keywords_definitions import *

class AbstractGaussianProcesses(object):

def __init__(self, *args, **kwargs):
pass

def _prepare_hyperparameters_conditions(self, mc, **kwargs):

if kwargs.get('hyperparameters_condition', False):
self.hyper_condition = self._hypercond_01
else:
self.hyper_condition = self._hypercond_00

if kwargs.get('rotation_decay_condition', False):
self.rotdec_condition = self._hypercond_02
else:
self.rotdec_condition = self._hypercond_00

if kwargs.get('halfrotation_decay_condition', False):
self.halfrotdec_condition = self._hypercond_03
else:
self.halfrotdec_condition = self._hypercond_00


def _prepare_rotation_replacement(self, mc, parameter_name ='Prot', common_pam=True, **kwargs):

for common_ref in self.common_ref:
if mc.common_models[common_ref].model_class == 'activity':
self.use_stellar_rotation_period = getattr(mc.common_models[common_ref], 'use_stellar_rotation_period', False)
break

for keyword in keywords_stellar_rotation:
self.use_stellar_rotation_period = kwargs.get(keyword, self.use_stellar_rotation_period)

if self.use_stellar_rotation_period:
self.list_pams_common.update(['rotation_period'])
if common_pam:
self.list_pams_common.discard(parameter_name)
else:
self.list_pams_dataset.discard(parameter_name)

def _prepare_decay_replacement(self, mc, parameter_name ='Pdec' , **kwargs):

for common_ref in self.common_ref:
if mc.common_models[common_ref].model_class == 'activity':
self.use_stellar_activity_decay = getattr(mc.common_models[common_ref], 'use_stellar_activity_decay', False)
break

for keyword in keywords_stellar_activity_decay:
self.use_stellar_activity_decay = kwargs.get(keyword, self.use_stellar_activity_decay)

if self.use_stellar_activity_decay:
self.list_pams_common.update(['activity_decay'])
self.list_pams_common.discard(parameter_name)

def _set_derivative_option(self, mc, dataset, **kwargs):

if 'derivative'in kwargs:
use_derivative = kwargs['derivative'].get(dataset.name_ref, False)
elif dataset.name_ref in kwargs:
use_derivative = kwargs[dataset.name_ref].get('derivative', False)
else:
if dataset.kind == 'H-alpha' or \
dataset.kind == 'S_index' or \
dataset.kind == 'Ca_HK' or \
dataset.kind == 'FWHM':
use_derivative = False
else:
use_derivative = True

if not use_derivative:
self.fix_list[dataset.name_ref] = {'rot_amp': [0., 0.]}

def update_parameter_values(self, parameter_values, prepend=''):

if self.use_stellar_rotation_period:
parameter_values['Prot'] = parameter_values['rotation_period']

if self.use_stellar_activity_decay:
parameter_values['Pdec'] = parameter_values['activity_decay']

def check_parameter_values(self, parameter_values):

if not self.hyper_condition(parameter_values):
return -np.inf
if not self.rotdec_condition(parameter_values):
return -np.inf
if not self.halfrotdec_condition(parameter_values):
return -np.inf

return True

@staticmethod
def _hypercond_00(parameter_values):
#Condition from Rajpaul 2017, Rajpaul+2021
return True

@staticmethod
def _hypercond_01(parameter_values):
# Condition from Rajpaul 2017, Rajpaul+2021
# Taking into account that Pdec^2 = 2*lambda_2^2
return parameter_values['Pdec']**2 > (3. / 2. / np.pi) * parameter_values['Oamp']**2 * parameter_values['Prot']**2

@staticmethod
def _hypercond_02(parameter_values):
#Condition on Rotation period and decay timescale
return parameter_values['Pdec'] > 2. * parameter_values['Prot']

@staticmethod
def _hypercond_03(parameter_values):
#Condition on Rotation period and decay timescale
return parameter_values['Pdec'] > 0.5 * parameter_values['Prot']
102 changes: 13 additions & 89 deletions pyorbit/models/spleaf_multidimensional_esp_activity_devel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pyorbit.subroutines.common import *
from pyorbit.models.abstract_model import *
from pyorbit.models.abstract_gaussian_processes import *
from pyorbit.keywords_definitions import *

from scipy.linalg import cho_factor, cho_solve, lapack, LinAlgError
Expand All @@ -16,7 +17,7 @@
pass


class SPLEAF_Multidimensional_ESP_devel(AbstractModel):
class SPLEAF_Multidimensional_ESP_devel(AbstractModel, AbstractGaussianProcesses):
''' Three parameters out of four are the same for all the datasets, since they are related to
the properties of the physical process rather than the observed effects on a dataset
From Grunblatt+2015, Affer+2016
Expand All @@ -28,7 +29,9 @@ class SPLEAF_Multidimensional_ESP_devel(AbstractModel):
default_common = 'activity'

def __init__(self, *args, **kwargs):

super().__init__(*args, **kwargs)
super(AbstractModel, self).__init__(*args, **kwargs)

self.model_class = 'multidimensional_gaussian_process'

Expand Down Expand Up @@ -102,45 +105,9 @@ def initialize_model(self, mc, **kwargs):
print(' S+LEAF model, number of harmonics:', self.n_harmonics)
print()

if kwargs.get('hyperparameters_condition', False):
self.hyper_condition = self._hypercond_01
else:
self.hyper_condition = self._hypercond_00

if kwargs.get('rotation_decay_condition', False):
self.rotdec_condition = self._hypercond_02
else:
self.rotdec_condition = self._hypercond_00

if kwargs.get('halfrotation_decay_condition', False):
self.halfrotdec_condition = self._hypercond_03
else:
self.halfrotdec_condition = self._hypercond_00

for common_ref in self.common_ref:
if mc.common_models[common_ref].model_class == 'activity':
self.use_stellar_rotation_period = getattr(mc.common_models[common_ref], 'use_stellar_rotation_period', False)
break

for keyword in keywords_stellar_rotation:
self.use_stellar_rotation_period = kwargs.get(keyword, self.use_stellar_rotation_period)

if self.use_stellar_rotation_period:
self.list_pams_common.update(['rotation_period'])
self.list_pams_common.discard('Prot')


for common_ref in self.common_ref:
if mc.common_models[common_ref].model_class == 'activity':
self.use_stellar_activity_decay = getattr(mc.common_models[common_ref], 'use_stellar_activity_decay', False)
break

for keyword in keywords_stellar_activity_decay:
self.use_stellar_activity_decay = kwargs.get(keyword, self.use_stellar_activity_decay)

if self.use_stellar_activity_decay:
self.list_pams_common.update(['activity_decay'])
self.list_pams_common.discard('Pdec')
self._prepare_hyperparameters_conditions(mc, **kwargs)
self._prepare_rotation_replacement(mc, **kwargs)
self._prepare_decay_replacement(mc, **kwargs)

def initialize_model_dataset(self, mc, dataset, **kwargs):

Expand Down Expand Up @@ -204,31 +171,13 @@ def initialize_model_dataset(self, mc, dataset, **kwargs):

self._reset_kernel()

if 'derivative'in kwargs:
use_derivative = kwargs['derivative'].get(dataset.name_ref, False)
elif dataset.name_ref in kwargs:
use_derivative = kwargs[dataset.name_ref].get('derivative', False)
else:
if dataset.kind == 'H-alpha' or \
dataset.kind == 'S_index' or \
dataset.kind == 'Ca_HK' or \
dataset.kind == 'FWHM':
use_derivative = False
else:
use_derivative = True

if not use_derivative:
self.fix_list[dataset.name_ref] = {'rot_amp': [0., 0.]}
self._set_derivative_option(mc, dataset, **kwargs)

return

def add_internal_dataset(self, parameter_values, dataset):

if self.use_stellar_rotation_period:
parameter_values['Prot'] = parameter_values['rotation_period']

if self.use_stellar_activity_decay:
parameter_values['Pdec'] = parameter_values['activity_decay']
self.update_parameter_values(parameter_values)

self.internal_parameter_values = parameter_values

Expand All @@ -246,13 +195,9 @@ def add_internal_dataset(self, parameter_values, dataset):

def lnlk_compute(self):


if not self.hyper_condition(self.internal_parameter_values):
return -np.inf
if not self.rotdec_condition(self.internal_parameter_values):
return -np.inf
if not self.halfrotdec_condition(self.internal_parameter_values):
return -np.inf
pass_conditions = self.check_parameter_values(self.internal_parameter_values)
if not pass_conditions:
return pass_conditions

"""
Randomly reset the kernel with a probability of 0.1%
Expand Down Expand Up @@ -326,26 +271,5 @@ def _reset_kernel(self):

self.D_spleaf = spleaf_cov.Cov(self.spleaf_time, **kwargs)
self.D_param = self.D_spleaf.param[1:]
#print(self.D_param)

@staticmethod
def _hypercond_00(parameter_values):
#Condition from Rajpaul 2017, Rajpaul+2021
return True

@staticmethod
def _hypercond_01(parameter_values):
# Condition from Rajpaul 2017, Rajpaul+2021
# Taking into account that Pdec^2 = 2*lambda_2^2
return parameter_values['Pdec']**2 > (3. / 2. / np.pi) * parameter_values['Oamp']**2 * parameter_values['Prot']**2

@staticmethod
def _hypercond_02(parameter_values):
#Condition on Rotation period and decay timescale
return parameter_values['Pdec'] > 2. * parameter_values['Prot']

@staticmethod
def _hypercond_03(parameter_values):
#Condition on Rotation period and decay timescale
return parameter_values['Pdec'] > 0.5 * parameter_values['Prot']


0 comments on commit 148b353

Please sign in to comment.