diff --git a/cirq-google/cirq_google/engine/stream_manager.py b/cirq-google/cirq_google/engine/stream_manager.py index b5bb5696eda..c45e43d81fc 100644 --- a/cirq-google/cirq_google/engine/stream_manager.py +++ b/cirq-google/cirq_google/engine/stream_manager.py @@ -109,8 +109,6 @@ class StreamManager: def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient): self._grpc_client = grpc_client - # TODO(#5996) Make this local to the asyncio thread. - self._request_queue: Optional[asyncio.Queue] = None # Used to determine whether the stream coroutine is actively running, and provides a way to # cancel it. self._manage_stream_loop_future: Optional[duet.AwaitableFuture[None]] = None @@ -121,6 +119,16 @@ def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient): # interface. self._response_demux = ResponseDemux() self._next_available_message_id = 0 + # Construct queue in AsyncioExecutor to ensure it binds to the correct event loop, since it + # is used by asyncio coroutines. + self._request_queue = self._executor.submit(self._make_request_queue).result() + + async def _make_request_queue(self) -> asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]]: + """Returns a queue used to back the request iterator passed to the stream. + + If `None` is put into the queue, the request iterator will stop. + """ + return asyncio.Queue() def submit( self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob @@ -153,8 +161,12 @@ def submit( raise ValueError('Program name must be set.') if self._manage_stream_loop_future is None or self._manage_stream_loop_future.done(): - self._manage_stream_loop_future = self._executor.submit(self._manage_stream) - return self._executor.submit(self._manage_execution, project_name, program, job) + self._manage_stream_loop_future = self._executor.submit( + self._manage_stream, self._request_queue + ) + return self._executor.submit( + self._manage_execution, self._request_queue, project_name, program, job + ) def stop(self) -> None: """Closes the open stream and resets all management resources.""" @@ -168,9 +180,9 @@ def stop(self) -> None: def _reset(self): """Resets the manager state.""" - self._request_queue = None self._manage_stream_loop_future = None self._response_demux = ResponseDemux() + self._request_queue = self._executor.submit(self._make_request_queue).result() @property def _executor(self) -> AsyncioExecutor: @@ -178,7 +190,9 @@ def _executor(self) -> AsyncioExecutor: # clients: https://github.com/grpc/grpc/issues/25364. return AsyncioExecutor.instance() - async def _manage_stream(self) -> None: + async def _manage_stream( + self, request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]] + ) -> None: """The stream coroutine, an asyncio coroutine to manage QuantumRunStream. This coroutine reads responses from the stream and forwards them to the ResponseDemux, where @@ -187,25 +201,32 @@ async def _manage_stream(self) -> None: When the stream breaks, the stream is reopened, and all execution coroutines are notified. There is at most a single instance of this coroutine running. + + Args: + request_queue: The queue holding requests from the execution coroutine. """ - self._request_queue = asyncio.Queue() while True: try: # The default gRPC client timeout is used. response_iterable = await self._grpc_client.quantum_run_stream( - _request_iterator(self._request_queue) + _request_iterator(request_queue) ) async for response in response_iterable: self._response_demux.publish(response) except asyncio.CancelledError: + await request_queue.put(None) break except BaseException as e: - # TODO(#5996) Close the request iterator to close the existing stream. # Note: the message ID counter is not reset upon a new stream. + await request_queue.put(None) self._response_demux.publish_exception(e) # Raise to all request tasks async def _manage_execution( - self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob + self, + request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]], + project_name: str, + program: quantum.QuantumProgram, + job: quantum.QuantumJob, ) -> Union[quantum.QuantumResult, quantum.QuantumJob]: """The execution coroutine, an asyncio coroutine to manage the lifecycle of a job execution. @@ -216,8 +237,20 @@ async def _manage_execution( error by sending another request. The exact request type depends on the error. There is one execution coroutine per running job submission. + + Args: + request_queue: The queue used to send requests to the stream coroutine. + project_name: The full project ID resource path associated with the job. + program: The Quantum Engine program representing the circuit to be executed. + job: The Quantum Engine job to be executed. + + Raises: + concurrent.futures.CancelledError: if either the request is cancelled or the stream + coroutine is cancelled. + google.api_core.exceptions.GoogleAPICallError: if the stream breaks with a non-retryable + error. + ValueError: if the response is of a type which is not recognized by this client. """ - # Construct requests ahead of time to be reused for retries. create_program_and_job_request = quantum.QuantumRunStreamRequest( parent=project_name, create_quantum_program_and_job=quantum.CreateQuantumProgramAndJobRequest( @@ -225,19 +258,12 @@ async def _manage_execution( ), ) - while self._request_queue is None: - # Wait for the stream coroutine to start. - # Ignoring coverage since this is rarely triggered. - # TODO(#5996) Consider awaiting for the queue to become available, once it is changed - # to be local to the asyncio thread. - await asyncio.sleep(1) # pragma: no cover - current_request = create_program_and_job_request while True: try: current_request.message_id = self._generate_message_id() response_future = self._response_demux.subscribe(current_request.message_id) - await self._request_queue.put(current_request) + await request_queue.put(current_request) response = await response_future # Broken stream @@ -325,16 +351,15 @@ def _is_retryable_error(e: google_exceptions.GoogleAPICallError) -> bool: return any(isinstance(e, exception_type) for exception_type in RETRYABLE_GOOGLE_API_EXCEPTIONS) -# TODO(#5996) Add stop signal to the request iterator. async def _request_iterator( - request_queue: asyncio.Queue, + request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]], ) -> AsyncIterator[quantum.QuantumRunStreamRequest]: """The request iterator for Quantum Engine client RPC quantum_run_stream(). Every call to this method generates a new iterator. """ - while True: - yield await request_queue.get() + while request := await request_queue.get(): + yield request def _to_create_job_request( diff --git a/cirq-google/cirq_google/engine/stream_manager_test.py b/cirq-google/cirq_google/engine/stream_manager_test.py index 3732547cdca..42e6defbcc8 100644 --- a/cirq-google/cirq_google/engine/stream_manager_test.py +++ b/cirq-google/cirq_google/engine/stream_manager_test.py @@ -68,21 +68,26 @@ def setup(client_constructor): class FakeQuantumRunStream: """A fake Quantum Engine client which supports QuantumRunStream and CancelQuantumJob.""" + _REQUEST_STOPPED = 'REQUEST_STOPPED' + def __init__(self): self.all_stream_requests: List[quantum.QuantumRunStreamRequest] = [] self.all_cancel_requests: List[quantum.CancelQuantumJobRequest] = [] self._executor = AsyncioExecutor.instance() self._request_buffer = duet.AsyncCollector[quantum.QuantumRunStreamRequest]() + self._request_iterator_stopped = duet.AwaitableFuture() # asyncio.Queue needs to be initialized inside the asyncio thread because all callers need # to use the same event loop. - self._responses_and_exceptions_future = duet.AwaitableFuture[asyncio.Queue]() + self._responses_and_exceptions_future: duet.AwaitableFuture[ + asyncio.Queue[Union[quantum.QuantumRunStreamResponse, BaseException]] + ] = duet.AwaitableFuture() async def quantum_run_stream( self, requests: AsyncIterator[quantum.QuantumRunStreamRequest], **kwargs ) -> Awaitable[AsyncIterable[quantum.QuantumRunStreamResponse]]: """Fakes the QuantumRunStream RPC. - Once a request is received, it is appended to `stream_requests`, and the test calling + Once a request is received, it is appended to `all_stream_requests`, and the test calling `wait_for_requests()` is notified. The response is sent when a test calls `reply()` with a `QuantumRunStreamResponse`. If a @@ -91,25 +96,29 @@ async def quantum_run_stream( This is called from the asyncio thread. """ - responses_and_exceptions: asyncio.Queue = asyncio.Queue() + responses_and_exceptions: asyncio.Queue[ + Union[quantum.QuantumRunStreamResponse, BaseException] + ] = asyncio.Queue() self._responses_and_exceptions_future.try_set_result(responses_and_exceptions) async def read_requests(): async for request in requests: self.all_stream_requests.append(request) self._request_buffer.add(request) + await responses_and_exceptions.put(FakeQuantumRunStream._REQUEST_STOPPED) + self._request_iterator_stopped.try_set_result(None) async def response_iterator(): asyncio.create_task(read_requests()) - while True: - response_or_exception = await responses_and_exceptions.get() - if isinstance(response_or_exception, quantum.QuantumRunStreamResponse): - yield response_or_exception - else: # isinstance(response_or_exception, BaseException) - self._responses_and_exceptions_future = duet.AwaitableFuture[asyncio.Queue]() - raise response_or_exception + while ( + message := await responses_and_exceptions.get() + ) != FakeQuantumRunStream._REQUEST_STOPPED: + if isinstance(message, quantum.QuantumRunStreamResponse): + yield message + else: # isinstance(message, BaseException) + self._responses_and_exceptions_future = duet.AwaitableFuture() + raise message - await asyncio.sleep(0) return response_iterator() async def cancel_quantum_job(self, request: quantum.CancelQuantumJobRequest) -> None: @@ -158,6 +167,14 @@ async def send(): await self._executor.submit(send) + async def wait_for_request_iterator_stop(self): + """Wait for the request iterator to stop. + + This must be called from a duet thread. + """ + await self._request_iterator_stopped + self._request_iterator_stopped = duet.AwaitableFuture() + class TestResponseDemux: @pytest.fixture @@ -704,3 +721,91 @@ def test_get_retry_request_or_raise_expects_stream_error( create_quantum_program_and_job_request, create_quantum_job_request, ) + + @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) + def test_broken_stream_stops_request_iterator(self, client_constructor): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + + async def test(): + async with duet.timeout_scope(5): + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 + ) + await fake_client.wait_for_requests() + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[0].message_id, + result=expected_result, + ) + ) + await actual_result_future + await fake_client.reply(google_exceptions.ServiceUnavailable('service unavailable')) + await fake_client.wait_for_request_iterator_stop() + manager.stop() + + duet.run(test) + + @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) + def test_stop_stops_request_iterator(self, client_constructor): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + + async def test(): + async with duet.timeout_scope(5): + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 + ) + await fake_client.wait_for_requests() + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[0].message_id, + result=expected_result, + ) + ) + await actual_result_future + manager.stop() + await fake_client.wait_for_request_iterator_stop() + + duet.run(test) + + @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) + def test_submit_after_stream_breakage(self, client_constructor): + expected_result0 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + expected_result1 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job1') + fake_client, manager = setup(client_constructor) + + async def test(): + async with duet.timeout_scope(5): + actual_result0_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 + ) + await fake_client.wait_for_requests() + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[0].message_id, + result=expected_result0, + ) + ) + actual_result0 = await actual_result0_future + await fake_client.reply(google_exceptions.ServiceUnavailable('service unavailable')) + actual_result1_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 + ) + await fake_client.wait_for_requests() + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[1].message_id, + result=expected_result1, + ) + ) + actual_result1 = await actual_result1_future + manager.stop() + + assert len(fake_client.all_stream_requests) == 2 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[1] + assert actual_result0 == expected_result0 + assert actual_result1 == expected_result1 + + duet.run(test)