Skip to content

Commit

Permalink
Use native TF checkpoints for the BLIP TF tests (huggingface#22593)
Browse files Browse the repository at this point in the history
* Use native TF checkpoints for the TF tests

* Remove unneeded exceptions
  • Loading branch information
Rocketknight1 authored and novice03 committed Jun 23, 2023
1 parent 5ed982b commit 439b58c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 21 deletions.
25 changes: 8 additions & 17 deletions tests/models/blip/test_modeling_tf_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,7 @@ def test_save_load_fast_init_to_base(self):
@slow
def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try:
model = TFBlipVisionModel.from_pretrained(model_name)
except OSError:
model = TFBlipVisionModel.from_pretrained(model_name, from_pt=True)
model = TFBlipVisionModel.from_pretrained(model_name)
self.assertIsNotNone(model)


Expand Down Expand Up @@ -320,10 +317,7 @@ def test_save_load_fast_init_to_base(self):
@slow
def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try:
model = TFBlipTextModel.from_pretrained(model_name)
except OSError:
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
model = TFBlipTextModel.from_pretrained(model_name)
self.assertIsNotNone(model)

def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
Expand Down Expand Up @@ -432,7 +426,7 @@ def test_load_vision_text_config(self):
@slow
def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
model = TFBlipModel.from_pretrained(model_name)
self.assertIsNotNone(model)

def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
Expand Down Expand Up @@ -635,7 +629,7 @@ def test_load_vision_text_config(self):
@slow
def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
model = TFBlipModel.from_pretrained(model_name)
self.assertIsNotNone(model)

@unittest.skip(reason="Tested in individual model tests")
Expand Down Expand Up @@ -750,10 +744,7 @@ def test_load_vision_text_config(self):
@slow
def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try:
model = TFBlipModel.from_pretrained(model_name)
except OSError:
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
model = TFBlipModel.from_pretrained(model_name)
self.assertIsNotNone(model)


Expand All @@ -769,7 +760,7 @@ def prepare_img():
@slow
class TFBlipModelIntegrationTest(unittest.TestCase):
def test_inference_image_captioning(self):
model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", from_pt=True)
model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
image = prepare_img()

Expand All @@ -796,7 +787,7 @@ def test_inference_image_captioning(self):
)

def test_inference_vqa(self):
model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base", from_pt=True)
model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")

image = prepare_img()
Expand All @@ -808,7 +799,7 @@ def test_inference_vqa(self):
self.assertEqual(out[0].numpy().tolist(), [30522, 1015, 102])

def test_inference_itm(self):
model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco", from_pt=True)
model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco")

image = prepare_img()
Expand Down
5 changes: 1 addition & 4 deletions tests/models/blip/test_modeling_tf_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,7 @@ def test_save_load_fast_init_to_base(self):
@slow
def test_model_from_pretrained(self):
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
try:
model = TFBlipTextModel.from_pretrained(model_name)
except OSError:
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
model = TFBlipTextModel.from_pretrained(model_name)
self.assertIsNotNone(model)

def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
Expand Down

0 comments on commit 439b58c

Please sign in to comment.