From 3ea485c26cca15eaf0f13507f10f33785cae52b3 Mon Sep 17 00:00:00 2001 From: swijaya Date: Wed, 29 Jan 2025 15:17:24 -0800 Subject: [PATCH] Export default tracing functions in `mcmc` module PiperOrigin-RevId: 721125228 --- .../python/experimental/mcmc/__init__.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/__init__.py b/tensorflow_probability/python/experimental/mcmc/__init__.py index e3a60ea073..aec880509f 100644 --- a/tensorflow_probability/python/experimental/mcmc/__init__.py +++ b/tensorflow_probability/python/experimental/mcmc/__init__.py @@ -70,6 +70,8 @@ from tensorflow_probability.python.experimental.mcmc.weighted_resampling import resample_independent from tensorflow_probability.python.experimental.mcmc.weighted_resampling import resample_stratified from tensorflow_probability.python.experimental.mcmc.weighted_resampling import resample_systematic +from tensorflow_probability.python.experimental.mcmc.windowed_sampling import default_hmc_trace_fn +from tensorflow_probability.python.experimental.mcmc.windowed_sampling import default_nuts_trace_fn from tensorflow_probability.python.experimental.mcmc.windowed_sampling import windowed_adaptive_hmc from tensorflow_probability.python.experimental.mcmc.windowed_sampling import windowed_adaptive_nuts from tensorflow_probability.python.experimental.mcmc.with_reductions import WithReductions @@ -135,6 +137,8 @@ 'simple_heuristic_tuning', 'snaper_criterion', 'step_kernel', + 'default_hmc_trace_fn', + 'default_nuts_trace_fn', 'windowed_adaptive_hmc', - 'windowed_adaptive_nuts' - ] + 'windowed_adaptive_nuts', +]