Skip to content

Commit

Permalink
[cirqflow] Convenience method for loading results (quantumlib#4720)
Browse files Browse the repository at this point in the history
Add cg.ExecutableGroupResultFilesystemRecord.from_json(run_id)
  • Loading branch information
mpharrigan authored and MichaelBroughton committed Jan 22, 2022
1 parent af4016a commit bae11f5
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
17 changes: 17 additions & 0 deletions cirq-google/cirq_google/workflow/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ class ExecutableGroupResultFilesystemRecord:

run_id: str

@classmethod
def from_json(
cls, *, run_id: str, base_data_dir: str = "."
) -> 'ExecutableGroupResultFilesystemRecord':
fn = f'{base_data_dir}/{run_id}/ExecutableGroupResultFilesystemRecord.json.gz'
egr_record = cirq.read_json_gzip(fn)
if not isinstance(egr_record, cls):
raise ValueError(
f"The file located at {fn} is not an `ExecutableGroupFilesystemRecord`."
)
if egr_record.run_id != run_id:
raise ValueError(
f"The loaded run_id {run_id} does not match the provided run_id {run_id}"
)

return egr_record

def load(self, *, base_data_dir: str = ".") -> 'cg.ExecutableGroupResult':
"""Using the filename references in this dataclass, load a `cg.ExecutableGroupResult`
from its constituent parts.
Expand Down
46 changes: 46 additions & 0 deletions cirq-google/cirq_google/workflow/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest

import cirq
import cirq_google as cg
Expand Down Expand Up @@ -47,6 +50,46 @@ def test_egr_filesystem_record_repr():
cg_assert_equivalent_repr(egr_fs_record)


def test_egr_filesystem_record_from_json(tmpdir):
run_id = 'my-run-id'
egr_fs_record = cg.ExecutableGroupResultFilesystemRecord(
runtime_configuration_path='RuntimeConfiguration.json.gz',
shared_runtime_info_path='SharedRuntimeInfo.jzon.gz',
executable_result_paths=[
'ExecutableResult.1.json.gz',
'ExecutableResult.2.json.gz',
],
run_id=run_id,
)

# Test 1: normal
os.makedirs(f'{tmpdir}/{run_id}')
cirq.to_json_gzip(
egr_fs_record, f'{tmpdir}/{run_id}/ExecutableGroupResultFilesystemRecord.json.gz'
)
egr_fs_record2 = cg.ExecutableGroupResultFilesystemRecord.from_json(
run_id=run_id, base_data_dir=tmpdir
)
assert egr_fs_record == egr_fs_record2

# Test 2: bad object type
cirq.to_json_gzip(
cirq.Circuit(), f'{tmpdir}/{run_id}/ExecutableGroupResultFilesystemRecord.json.gz'
)
with pytest.raises(ValueError, match=r'.*not an `ExecutableGroupFilesystemRecord`.'):
cg.ExecutableGroupResultFilesystemRecord.from_json(run_id=run_id, base_data_dir=tmpdir)

# Test 3: Mismatched run id
os.makedirs(f'{tmpdir}/questionable_run_id')
cirq.to_json_gzip(
egr_fs_record, f'{tmpdir}/questionable_run_id/ExecutableGroupResultFilesystemRecord.json.gz'
)
with pytest.raises(ValueError, match=r'.*does not match the provided run_id'):
cg.ExecutableGroupResultFilesystemRecord.from_json(
run_id='questionable_run_id', base_data_dir=tmpdir
)


def test_filesystem_saver(tmpdir, patch_cirq_default_resolvers):
assert patch_cirq_default_resolvers
run_id = 'asdf'
Expand All @@ -56,11 +99,13 @@ def test_filesystem_saver(tmpdir, patch_cirq_default_resolvers):
shared_rt_info = cg.SharedRuntimeInfo(run_id=run_id)
fs_saver.initialize(rt_config, shared_rt_info=shared_rt_info)

# Test 1: assert fs_saver.initialize() has worked.
rt_config2 = cirq.read_json_gzip(f'{tmpdir}/{run_id}/QuantumRuntimeConfiguration.json.gz')
shared_rt_info2 = cirq.read_json_gzip(f'{tmpdir}/{run_id}/SharedRuntimeInfo.json.gz')
assert rt_config == rt_config2
assert shared_rt_info == shared_rt_info2

# Test 2: assert `consume_result()` works.
# you shouldn't actually mutate run_id in the shared runtime info, but we want to test
# updating the shared rt info object:
shared_rt_info.run_id = 'updated_run_id'
Expand All @@ -76,6 +121,7 @@ def test_filesystem_saver(tmpdir, patch_cirq_default_resolvers):
assert shared_rt_info == shared_rt_info3
assert exe_result == exe_result3

# Test 3: assert loading egr_record works.
egr_record: cg.ExecutableGroupResultFilesystemRecord = cirq.read_json_gzip(
f'{fs_saver.data_dir}/ExecutableGroupResultFilesystemRecord.json.gz'
)
Expand Down
4 changes: 4 additions & 0 deletions cirq-google/cirq_google/workflow/quantum_runtime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,14 @@ def test_execute(tmpdir, run_id_in, patch_cirq_default_resolvers):
f'{tmpdir}/{run_id}/ExecutableGroupResultFilesystemRecord.json.gz'
)
exegroup_result: cg.ExecutableGroupResult = egr_record.load(base_data_dir=tmpdir)
helper_loaded_result = cg.ExecutableGroupResultFilesystemRecord.from_json(
run_id=run_id, base_data_dir=tmpdir
).load(base_data_dir=tmpdir)

# TODO(gh-4699): Don't null-out device once it's serializable.
assert isinstance(returned_exegroup_result.shared_runtime_info.device, cg.SerializableDevice)
returned_exegroup_result.shared_runtime_info.device = None

assert returned_exegroup_result == exegroup_result
assert manual_exegroup_result == exegroup_result
assert helper_loaded_result == exegroup_result

0 comments on commit bae11f5

Please sign in to comment.