Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vision detect_labels() system tests. #2830

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added system_tests/data/car.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
81 changes: 74 additions & 7 deletions system_tests/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@

from system_test_utils import unique_resource_id
from retry import RetryErrors
import six


_SYS_TESTS_DIR = os.path.abspath(os.path.dirname(__file__))
LOGO_FILE = os.path.join(_SYS_TESTS_DIR, 'data', 'logo.png')
FACE_FILE = os.path.join(_SYS_TESTS_DIR, 'data', 'faces.jpg')
LABEL_FILE = os.path.join(_SYS_TESTS_DIR, 'data', 'car.jpg')


class Config(object):
Expand All @@ -54,6 +56,14 @@ def tearDownModule():
bucket_retry(Config.TEST_BUCKET.delete)(force=True)


class BaseVisionTestCase(unittest.TestCase):
def _assert_coordinate(self, coordinate):
if coordinate is None:
return
self.assertIsInstance(coordinate, (int, float))
self.assertNotEqual(coordinate, 0.0)


class TestVisionClientLogo(unittest.TestCase):
def setUp(self):
self.to_delete_by_case = []
Expand Down Expand Up @@ -111,20 +121,14 @@ def test_detect_logos_gcs(self):
self._assert_logo(logo)


class TestVisionClientFace(unittest.TestCase):
class TestVisionClientFace(BaseVisionTestCase):
def setUp(self):
self.to_delete_by_case = []

def tearDown(self):
for value in self.to_delete_by_case:
value.delete()

def _assert_coordinate(self, coordinate):
if coordinate is None:
return
self.assertIsInstance(coordinate, (int, float))
self.assertGreater(abs(coordinate), 0.0)

def _assert_likelihood(self, likelihood):
from google.cloud.vision.likelihood import Likelihood

Expand Down Expand Up @@ -215,3 +219,66 @@ def test_detect_faces_filename(self):
self.assertEqual(len(faces), 5)
for face in faces:
self._assert_face(face)


class TestVisionClientLabel(BaseVisionTestCase):
DESCRIPTIONS = (
'car',
'vehicle',
'land vehicle',
'automotive design',
'wheel',
'automobile make',
'luxury vehicle',
'sports car',
'performance car',
'automotive exterior',
)

def setUp(self):
self.to_delete_by_case = []

def tearDown(self):
for value in self.to_delete_by_case:
value.delete()

def _assert_label(self, label):

self.assertIsInstance(label, EntityAnnotation)
self.assertIn(label.description, self.DESCRIPTIONS)
self.assertIsInstance(label.mid, six.text_type)
self.assertGreater(label.score, 0.0)

def test_detect_labels_content(self):
client = Config.CLIENT
with open(LABEL_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
labels = image.detect_labels()
self.assertEqual(len(labels), 10)
for label in labels:
self._assert_label(label)

def test_detect_labels_gcs(self):
bucket_name = Config.TEST_BUCKET.name
blob_name = 'car.jpg'
blob = Config.TEST_BUCKET.blob(blob_name)
self.to_delete_by_case.append(blob) # Clean-up.
with open(LABEL_FILE, 'rb') as file_obj:
blob.upload_from_file(file_obj)

source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

client = Config.CLIENT
image = client.image(source_uri=source_uri)
labels = image.detect_labels()
self.assertEqual(len(labels), 10)
for label in labels:
self._assert_label(label)

def test_detect_labels_filename(self):
client = Config.CLIENT
image = client.image(filename=LABEL_FILE)
labels = image.detect_labels()
self.assertEqual(len(labels), 10)
for label in labels:
self._assert_label(label)