Skip to content

Commit

Permalink
add a new feature to allow user config global_trend_sigma (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
edwinnglabs authored May 23, 2024
1 parent c64053e commit f2096d9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion karpiu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
name = "karpiu"
__version__ = "0.0.2"
__version__ = "0.0.2.1"
10 changes: 6 additions & 4 deletions karpiu/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
fs_orders: Optional[List[int]] = None,
total_market_sigma_prior: float = None,
default_spend_sigma_prior: float = 0.1,
global_trend_sigma_prior: float = 0.001,
logger: Optional[logging.Logger] = None,
fit_args: Optional[Dict[str, float]] = None,
**kwargs
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(

self.spend_cols = deepcopy(spend_cols)
self.default_spend_sigma_prior = default_spend_sigma_prior
self.global_trend_sigma_prior = global_trend_sigma_prior
self.spend_cols.sort()
# self.scalability_df = scalability_df

Expand Down Expand Up @@ -212,7 +214,7 @@ def filter_features(
num_sample=num_sample,
chains=chains,
# use small sigma for global trend as this is a long-term daily model
global_trend_sigma_prior=0.001,
global_trend_sigma_prior=self.global_trend_sigma_prior,
# **self.best_params,
**kwargs,
)
Expand Down Expand Up @@ -296,7 +298,7 @@ def optim_hyper_params(
estimator="stan-map",
verbose=False,
# use small sigma for global trend as this is a long-term daily model
global_trend_sigma_prior=0.001,
global_trend_sigma_prior=self.global_trend_sigma_prior,
**kwargs,
)

Expand Down Expand Up @@ -505,7 +507,7 @@ def fit(
num_sample=num_sample,
chains=chains,
# use small sigma for global trend as this is a long-term daily model
global_trend_sigma_prior=0.001,
global_trend_sigma_prior=self.global_trend_sigma_prior,
**self.best_params,
**kwargs,
)
Expand Down Expand Up @@ -566,7 +568,7 @@ def fit(
num_sample=num_sample,
chains=chains,
# use small sigma for global trend as this is a long-term daily model
global_trend_sigma_prior=0.001,
global_trend_sigma_prior=self.global_trend_sigma_prior,
**self.best_params,
**kwargs,
)
Expand Down

0 comments on commit f2096d9

Please sign in to comment.