diff --git a/src/timesfm/patched_decoder.py b/src/timesfm/patched_decoder.py index cd9d5f8..cdd81be 100644 --- a/src/timesfm/patched_decoder.py +++ b/src/timesfm/patched_decoder.py @@ -237,6 +237,7 @@ class PatchedTimeSeriesDecoder(base_layer.BaseLayer): stacked_transformer_params_tpl: LayerTpl = template_field( transformers.StackedTransformer) use_freq: bool = True + use_pos_emb: bool = True def setup(self) -> None: """Construct the model.""" @@ -332,16 +333,17 @@ def _preprocess_input( model_input = self.input_ff_layer(concat_inputs) # A patch should not be padded even if there is at least one zero. patched_padding = jnp.min(patched_pads, axis=-1) - - if pos_emb is None: - position_emb = self.position_emb(seq_length=model_input.shape[1]) - else: - position_emb = pos_emb - if self.do_eval: - if position_emb.shape[0] != model_input.shape[0]: - position_emb = jnp.repeat(position_emb, model_input.shape[0], axis=0) - position_emb = _shift_padded_seq(patched_padding, position_emb) - model_input += position_emb + + if self.use_pos_emb: + if pos_emb is None: + position_emb = self.position_emb(seq_length=model_input.shape[1]) + else: + position_emb = pos_emb + if self.do_eval: + if position_emb.shape[0] != model_input.shape[0]: + position_emb = jnp.repeat(position_emb, model_input.shape[0], axis=0) + position_emb = _shift_padded_seq(patched_padding, position_emb) + model_input += position_emb return model_input, patched_padding, stats, patched_inputs diff --git a/src/timesfm/timesfm_base.py b/src/timesfm/timesfm_base.py index 230fb82..f2a8538 100644 --- a/src/timesfm/timesfm_base.py +++ b/src/timesfm/timesfm_base.py @@ -109,6 +109,7 @@ class TimesFmHparams: per_core_batch_size: int = 32 backend: Literal["cpu", "gpu", "tpu"] = "cpu" quantiles: Sequence[float] | None = DEFAULT_QUANTILES + use_positional_embedding: bool = True # Hparams beyond the model. point_forecast_mode: Literal["mean", "median"] = "median" @@ -172,6 +173,7 @@ def __init__(self, hparams: TimesFmHparams, self.backend = hparams.backend self.quantiles = hparams.quantiles self.num_heads = hparams.num_heads + self.use_pos_emb = hparams.use_positional_embedding # Rewrite these values in __post_init__ for SPMD. self.num_cores = 1 diff --git a/src/timesfm/timesfm_jax.py b/src/timesfm/timesfm_jax.py index 1b4427f..cbd8576 100644 --- a/src/timesfm/timesfm_jax.py +++ b/src/timesfm/timesfm_jax.py @@ -117,6 +117,7 @@ def load_from_checkpoint( residual_block_tpl=pax_fiddle.Config(patched_decoder.ResidualBlock), quantiles=self.quantiles, use_freq=True, + use_pos_emb=self.use_pos_emb, stacked_transformer_params_tpl=pax_fiddle.Config( transformers.StackedTransformer, num_heads=self.num_heads, diff --git a/src/timesfm/timesfm_torch.py b/src/timesfm/timesfm_torch.py index d28a0b7..7bf298e 100644 --- a/src/timesfm/timesfm_torch.py +++ b/src/timesfm/timesfm_torch.py @@ -40,6 +40,7 @@ def __post_init__(self): horizon_len=self.output_patch_len, head_dim=self.model_dims // self.num_heads, quantiles=self.quantiles, + use_positional_embedding=self.use_pos_emb, ) self._model = None self.num_cores = 1