Skip to content

Commit

Permalink
commits
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jun 15, 2024
1 parent a0afd94 commit 8fc61ad
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/gluonts/torch/model/mamba/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def __init__(
self.context_length = context_length
self.prediction_length = prediction_length
self.distr_output = distr_output
self.param_proj = distr_output.get_args_proj(hidden_size)
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_feat_static_cat = num_feat_static_cat
self.num_feat_static_real = num_feat_static_real
Expand Down Expand Up @@ -166,6 +165,8 @@ def __init__(
# batch_first=True,
)
self.nonnegative_pred_samples = nonnegative_pred_samples
self.param_proj = distr_output.get_args_proj(self.d_model)


def describe_inputs(self, batch_size=1) -> InputSpec:
return InputSpec(
Expand Down Expand Up @@ -472,7 +473,7 @@ def forward(
)
rnn_input = torch.cat((next_lags, next_features), dim=-1)

output, repeated_state = self.rnn(rnn_input, repeated_state)
output = self.mamba(rnn_input)

repeated_past_target = torch.cat(
(repeated_past_target, scaled_next_sample), dim=1
Expand Down

0 comments on commit 8fc61ad

Please sign in to comment.