Skip to content

Commit

Permalink
Fix Mamba slow path bug with dtype mismatch. (huggingface#32691)
Browse files Browse the repository at this point in the history
* Fix Mamba slow path bug with dtype mismatch.

* Update test_modeling_mamba.py

* Improve style.

* Fix issue with cache position of dtype mismatch test.

* Change test for slow path.

* Revert changes.

* Switch to buggy code and add test to catch it.

* Fix the dtype mismatch bug and add test code to verify it.

* Fix minor bug with test.

* Fix incorrect dtype of model output.

* Fix incorrect dtype of cache.

* Fix incorrect dtype of ssm cache.

* Fix incorrect dtype of conv state.

* Remove assertion for ssm state.

* Add assertion for conv state dtype.

* Fix all issues with dtype mismatch test.
  • Loading branch information
Adibvafa authored and BernardZach committed Dec 6, 2024
1 parent 09b0aba commit 462e507
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,7 +1797,7 @@ def update_conv_state(
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)

conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]
Expand Down
24 changes: 24 additions & 0 deletions tests/models/mamba/test_modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,30 @@ def recursive_check(tuple_object, dict_object):
def test_beam_sample_generate(self):
pass

def test_dtype_mismatch_handled_in_cache(self):
config, input_ids, *args = self.model_tester.prepare_config_and_inputs()
model = MambaModel(config)
model.to(torch_device).to(torch.float16)
model.eval()

# Create cache with float32 dtype
cache_params = MambaCache(config, batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device)

# If code is correct, no error occurs and test passes
outputs = model(
input_ids,
cache_params=cache_params,
use_cache=True,
cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device),
)

self.assertIsNotNone(outputs)
self.assertIsNotNone(outputs.last_hidden_state)
self.assertEqual(
outputs.last_hidden_state.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size),
)


@require_torch
class MambaIntegrationTests(unittest.TestCase):
Expand Down

0 comments on commit 462e507

Please sign in to comment.