Skip to content

Commit

Permalink
Update gemma_backbone_test.py
Browse files Browse the repository at this point in the history
Better test messages
  • Loading branch information
mattdangerw committed Jun 20, 2024
1 parent 12fc70c commit 2001a3d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions keras_nlp/src/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ def test_distribution(self):

def test_distribution_with_lora(self):
if keras.backend.backend() != "jax":
return
self.skipTest("`ModelParallel` testing requires the Jax backend.")
devices = keras.distribution.list_devices("CPU")
if len(devices) == 1:
# Need more than 1 device for distribution testing.
return
self.skipTest("`ModelParallel` testing requires multiple devices.")
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
axis_names=("batch", "model"),
Expand Down

0 comments on commit 2001a3d

Please sign in to comment.