Skip to content

Commit

Permalink
python: added missing job tests
Browse files Browse the repository at this point in the history
Signed-off-by: Soumyendra Shrivastava <shrivastavasoumyendra@gmail.com>
  • Loading branch information
soumyendra98 committed Mar 23, 2024
1 parent 6b64ead commit bcd7dca
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 39 deletions.
33 changes: 32 additions & 1 deletion python/tests/integration/sdk/test_job_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Copyright (c) 2018-2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved.
#
from datetime import datetime
import unittest

from tests.integration.sdk.remote_enabled_test import RemoteEnabledTest
Expand Down Expand Up @@ -51,3 +52,33 @@ def test_job_wait(self):
existing_obj = [entry.name for entry in self.bucket.list_all_objects()]
for name in object_names:
self.assertNotIn(name, existing_obj)

@unittest.skipIf(
not REMOTE_SET,
"Remote bucket is not set",
)
def test_job_wait_single_node(self):
obj_name = "test-obj"
_ = self._create_object_with_content(obj_name=obj_name)

evict_job_id = self.bucket.objects(obj_names=[obj_name]).evict()
self.client.job(evict_job_id).wait(timeout=TEST_TIMEOUT)

job_id = self.bucket.object(obj_name).blob_download()
self.assertNotEqual(job_id, "")
self.client.job(job_id=job_id).wait_single_node(timeout=TEST_TIMEOUT)

objects = self.bucket.list_objects(props="name,cached", prefix=obj_name).entries
self._validate_objects_cached(objects, True)

def test_get_within_timeframe(self):
start_time = datetime.now().time()
job_id = self.client.job(job_kind="lru").start()
self.client.job(job_id=job_id).wait()
end_time = datetime.now().time()
self.assertNotEqual(job_id, "")
jobs_list = self.client.job(job_id=job_id).get_within_timeframe(
start_time=start_time, end_time=end_time
)

self.assertTrue(len(jobs_list) > 0)
1 change: 0 additions & 1 deletion python/tests/integration/sdk/test_object_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def test_delete(self):
def test_blob_download(self):
obj_name = "obj-blob-download"
_ = self._create_object_with_content(obj_name=obj_name)
self._register_for_post_test_cleanup(names=[obj_name], is_bucket=False)

evict_job_id = self.bucket.objects(obj_names=[obj_name]).evict()
self.client.job(evict_job_id).wait(timeout=TEST_TIMEOUT)
Expand Down
111 changes: 74 additions & 37 deletions python/tests/unit/sdk/test_job.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from datetime import datetime, timedelta
from typing import Dict, List
from unittest.mock import Mock, patch, call

Expand All @@ -12,14 +13,14 @@
ACT_START,
WHAT_QUERY_XACT_STATS,
)
from aistore.sdk.errors import Timeout
from aistore.sdk.errors import Timeout, JobInfoNotFound
from aistore.sdk.request_client import RequestClient
from aistore.sdk.types import JobStatus, JobArgs, BucketModel, ActionMsg, JobSnapshot
from aistore.sdk.utils import probing_frequency
from aistore.sdk.job import Job


# pylint: disable=unused-variable
# pylint: disable=unused-variable, too-many-public-methods
class TestJob(unittest.TestCase):
def setUp(self):
self.mock_client = Mock()
Expand Down Expand Up @@ -240,53 +241,89 @@ def job_start_exec_assert(self, job, expected_json, expected_params, **kwargs):
)

@patch("aistore.sdk.job.time.sleep", Mock())
@patch("aistore.sdk.job.Job._query_job_snapshots")
def test_wait_single_node_finishes_successfully(self, mock_query_snapshot):
finished_snapshot = [
JobSnapshot(
id=self.job_id,
is_idle=True,
end_time="2024-01-01T00:00:00Z",
aborted=False,
)
]
mock_query_snapshot.return_value = finished_snapshot
def test_wait_single_node_finishes_successfully(self):
finished_snapshot = {
"key": [
JobSnapshot(
id=self.job_id,
is_idle=True,
end_time="2024-01-01T00:00:00Z",
aborted=False,
)
]
}
self.mock_client.request_deserialize.return_value = finished_snapshot

self.job.wait_single_node()

mock_query_snapshot.assert_called()
self.assertEqual(mock_query_snapshot.call_count, 1)
self.mock_client.request_deserialize.assert_called()
self.assertEqual(self.mock_client.request_deserialize.call_count, 1)

@patch("aistore.sdk.job.time.sleep", Mock())
@patch("aistore.sdk.job.Job._query_job_snapshots")
def test_wait_single_node_is_aborted(self, mock_query_snapshot):
aborted_snapshot = [
JobSnapshot(
id=self.job_id,
is_idle=True,
end_time="2024-01-01T00:00:00Z",
aborted=True,
)
]
mock_query_snapshot.return_value = aborted_snapshot
def test_wait_single_node_is_aborted(self):
aborted_snapshot = {
"key": [
JobSnapshot(
id=self.job_id,
is_idle=True,
end_time="2024-01-01T00:00:00Z",
aborted=True,
)
]
}
self.mock_client.request_deserialize.return_value = aborted_snapshot

self.job.wait_single_node()
mock_query_snapshot.assert_called()
self.mock_client.request_deserialize.assert_called()

@patch("aistore.sdk.job.time.sleep", Mock())
@patch("aistore.sdk.job.Job._query_job_snapshots")
def test_wait_single_node_timeout(self, mock_query_snapshot):
ongoing_snapshot = [
def test_wait_single_node_timeout(self):
ongoing_snapshots = {
"key": [
JobSnapshot(
id=self.job_id,
is_idle=False,
end_time="0001-01-01T00:00:00Z",
aborted=False,
)
]
}
self.mock_client.request_deserialize.return_value = ongoing_snapshots

with self.assertRaises(Timeout):
self.job.wait_single_node()

self.mock_client.request_deserialize.assert_called()

def test_get_within_timeframe_found_jobs(self):
start_time = datetime.now() - timedelta(days=1)
end_time = datetime.now()

mock_snapshots = [
JobSnapshot(
id=self.job_id,
is_idle=False,
end_time="0001-01-01T00:00:00Z",
id="1234",
kind="test job",
start_time=(start_time.isoformat() + "Z"),
end_time=(end_time.isoformat() + "Z"),
aborted=False,
is_idle=True,
)
]
mock_query_snapshot.return_value = ongoing_snapshot

with self.assertRaises(Timeout):
self.job.wait_single_node()
self.mock_client.request_deserialize.return_value = {"key": mock_snapshots}

found_jobs = self.job.get_within_timeframe(start_time.time(), end_time.time())

self.assertEqual(len(found_jobs), len(mock_snapshots))
for found_job, expected_snapshot in zip(found_jobs, mock_snapshots):
self.assertEqual(found_job.id, expected_snapshot.id)
self.assertEqual(found_job.start_time, expected_snapshot.start_time)
self.assertEqual(found_job.end_time, expected_snapshot.end_time)

def test_get_within_timeframe_no_jobs_found(self):
start_time = datetime.now() - timedelta(days=1)
end_time = datetime.now()
self.mock_client.request_deserialize.return_value = {}

mock_query_snapshot.assert_called()
with self.assertRaises(JobInfoNotFound):
self.job.get_within_timeframe(start_time.time(), end_time.time())

0 comments on commit bcd7dca

Please sign in to comment.