-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'ml6team:main' into feature/commoncrawl-download-segments
- Loading branch information
Showing
7 changed files
with
82 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import pandas as pd | ||
import requests | ||
from caption_images.src.main import CaptionImagesComponent | ||
from fondant.abstract_component_test import AbstractComponentTest | ||
|
||
|
||
class TestCaptionImagesComponent(AbstractComponentTest): | ||
def create_component(self): | ||
return CaptionImagesComponent( | ||
model_id="Salesforce/blip-image-captioning-base", | ||
batch_size=4, | ||
max_new_tokens=2, | ||
) | ||
|
||
def create_input_data(self): | ||
image_urls = [ | ||
"https://cdn.pixabay.com/photo/2023/06/29/09/52/angkor-thom-8096092_1280.jpg", | ||
"https://cdn.pixabay.com/photo/2023/07/19/18/56/japanese-beetle-8137606_1280.png", | ||
] | ||
return pd.DataFrame( | ||
{"images": {"data": [requests.get(url).content for url in image_urls]}}, | ||
) | ||
|
||
def create_output_data(self): | ||
return pd.DataFrame( | ||
data={("captions", "text"): {0: "a motorcycle", 1: "a beetle"}}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
|
||
class AbstractComponentTest(ABC): | ||
@abstractmethod | ||
def create_component(self): | ||
""" | ||
This method should be implemented by concrete test classes | ||
to create the specific component | ||
that needs to be tested. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def create_input_data(self): | ||
"""This method should be implemented by concrete test classes | ||
to create the specific input data. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def create_output_data(self): | ||
"""This method should be implemented by concrete test classes | ||
to create the specific output data. | ||
""" | ||
raise NotImplementedError | ||
|
||
@pytest.fixture(autouse=True) | ||
def __setUp(self): | ||
""" | ||
This method will be run before each test method. | ||
Add any common setup steps for your components here. | ||
""" | ||
self.component = self.create_component() | ||
self.input_data = self.create_input_data() | ||
self.expected_output_data = self.create_output_data() | ||
|
||
def test_transform(self): | ||
""" | ||
Default test for the transform method. | ||
Tests if the transform method executes without errors. | ||
""" | ||
output = self.component.transform(self.input_data) | ||
pd.testing.assert_frame_equal(output, self.expected_output_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters