diff --git a/cirq-google/cirq_google/workflow/io.py b/cirq-google/cirq_google/workflow/io.py index 7cab45988ea..1165748de42 100644 --- a/cirq-google/cirq_google/workflow/io.py +++ b/cirq-google/cirq_google/workflow/io.py @@ -44,6 +44,23 @@ class ExecutableGroupResultFilesystemRecord: run_id: str + @classmethod + def from_json( + cls, *, run_id: str, base_data_dir: str = "." + ) -> 'ExecutableGroupResultFilesystemRecord': + fn = f'{base_data_dir}/{run_id}/ExecutableGroupResultFilesystemRecord.json.gz' + egr_record = cirq.read_json_gzip(fn) + if not isinstance(egr_record, cls): + raise ValueError( + f"The file located at {fn} is not an `ExecutableGroupFilesystemRecord`." + ) + if egr_record.run_id != run_id: + raise ValueError( + f"The loaded run_id {run_id} does not match the provided run_id {run_id}" + ) + + return egr_record + def load(self, *, base_data_dir: str = ".") -> 'cg.ExecutableGroupResult': """Using the filename references in this dataclass, load a `cg.ExecutableGroupResult` from its constituent parts. diff --git a/cirq-google/cirq_google/workflow/io_test.py b/cirq-google/cirq_google/workflow/io_test.py index 63b53ce01c5..91fd9dec204 100644 --- a/cirq-google/cirq_google/workflow/io_test.py +++ b/cirq-google/cirq_google/workflow/io_test.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + +import pytest import cirq import cirq_google as cg @@ -47,6 +50,46 @@ def test_egr_filesystem_record_repr(): cg_assert_equivalent_repr(egr_fs_record) +def test_egr_filesystem_record_from_json(tmpdir): + run_id = 'my-run-id' + egr_fs_record = cg.ExecutableGroupResultFilesystemRecord( + runtime_configuration_path='RuntimeConfiguration.json.gz', + shared_runtime_info_path='SharedRuntimeInfo.jzon.gz', + executable_result_paths=[ + 'ExecutableResult.1.json.gz', + 'ExecutableResult.2.json.gz', + ], + run_id=run_id, + ) + + # Test 1: normal + os.makedirs(f'{tmpdir}/{run_id}') + cirq.to_json_gzip( + egr_fs_record, f'{tmpdir}/{run_id}/ExecutableGroupResultFilesystemRecord.json.gz' + ) + egr_fs_record2 = cg.ExecutableGroupResultFilesystemRecord.from_json( + run_id=run_id, base_data_dir=tmpdir + ) + assert egr_fs_record == egr_fs_record2 + + # Test 2: bad object type + cirq.to_json_gzip( + cirq.Circuit(), f'{tmpdir}/{run_id}/ExecutableGroupResultFilesystemRecord.json.gz' + ) + with pytest.raises(ValueError, match=r'.*not an `ExecutableGroupFilesystemRecord`.'): + cg.ExecutableGroupResultFilesystemRecord.from_json(run_id=run_id, base_data_dir=tmpdir) + + # Test 3: Mismatched run id + os.makedirs(f'{tmpdir}/questionable_run_id') + cirq.to_json_gzip( + egr_fs_record, f'{tmpdir}/questionable_run_id/ExecutableGroupResultFilesystemRecord.json.gz' + ) + with pytest.raises(ValueError, match=r'.*does not match the provided run_id'): + cg.ExecutableGroupResultFilesystemRecord.from_json( + run_id='questionable_run_id', base_data_dir=tmpdir + ) + + def test_filesystem_saver(tmpdir, patch_cirq_default_resolvers): assert patch_cirq_default_resolvers run_id = 'asdf' @@ -56,11 +99,13 @@ def test_filesystem_saver(tmpdir, patch_cirq_default_resolvers): shared_rt_info = cg.SharedRuntimeInfo(run_id=run_id) fs_saver.initialize(rt_config, shared_rt_info=shared_rt_info) + # Test 1: assert fs_saver.initialize() has worked. rt_config2 = cirq.read_json_gzip(f'{tmpdir}/{run_id}/QuantumRuntimeConfiguration.json.gz') shared_rt_info2 = cirq.read_json_gzip(f'{tmpdir}/{run_id}/SharedRuntimeInfo.json.gz') assert rt_config == rt_config2 assert shared_rt_info == shared_rt_info2 + # Test 2: assert `consume_result()` works. # you shouldn't actually mutate run_id in the shared runtime info, but we want to test # updating the shared rt info object: shared_rt_info.run_id = 'updated_run_id' @@ -76,6 +121,7 @@ def test_filesystem_saver(tmpdir, patch_cirq_default_resolvers): assert shared_rt_info == shared_rt_info3 assert exe_result == exe_result3 + # Test 3: assert loading egr_record works. egr_record: cg.ExecutableGroupResultFilesystemRecord = cirq.read_json_gzip( f'{fs_saver.data_dir}/ExecutableGroupResultFilesystemRecord.json.gz' ) diff --git a/cirq-google/cirq_google/workflow/quantum_runtime_test.py b/cirq-google/cirq_google/workflow/quantum_runtime_test.py index c3dcb610efc..a11da8210c8 100644 --- a/cirq-google/cirq_google/workflow/quantum_runtime_test.py +++ b/cirq-google/cirq_google/workflow/quantum_runtime_test.py @@ -182,6 +182,9 @@ def test_execute(tmpdir, run_id_in, patch_cirq_default_resolvers): f'{tmpdir}/{run_id}/ExecutableGroupResultFilesystemRecord.json.gz' ) exegroup_result: cg.ExecutableGroupResult = egr_record.load(base_data_dir=tmpdir) + helper_loaded_result = cg.ExecutableGroupResultFilesystemRecord.from_json( + run_id=run_id, base_data_dir=tmpdir + ).load(base_data_dir=tmpdir) # TODO(gh-4699): Don't null-out device once it's serializable. assert isinstance(returned_exegroup_result.shared_runtime_info.device, cg.SerializableDevice) @@ -189,3 +192,4 @@ def test_execute(tmpdir, run_id_in, patch_cirq_default_resolvers): assert returned_exegroup_result == exegroup_result assert manual_exegroup_result == exegroup_result + assert helper_loaded_result == exegroup_result