Skip to content

Commit

Permalink
Python RPC support
Browse files Browse the repository at this point in the history
  • Loading branch information
LasseBlaauwbroek committed Nov 9, 2023
1 parent 99b0c22 commit 120bf96
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 148 deletions.
7 changes: 4 additions & 3 deletions graph_api.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ struct Graph {

label :union { # Inlined for efficiency purposes
# Proof state
# Hash a unique id (evar) for the proof state that distinquishes proof states with identical
# Has a unique id (evar) for the proof state that distinquishes proof states with identical
# contents but do not point to the same object nonetheless
proofState @0 :ProofStateIdP;

Expand Down Expand Up @@ -620,8 +620,9 @@ interface ProofObject {
interface PredictionServer {
addGlobalContext @0 GlobalContextAddition -> ();
requestPrediction @1 PredictionRequest -> (predictions :List(Prediction));
checkAlignment @2 () -> (unalignedTactics :List(TacticId), unalignedDefinitions :List(Node));
explore @3 (result :ExecutionResult);
requestTextPrediction @2 PredictionRequest -> (predictions :List(TextPrediction));
checkAlignment @3 () -> (unalignedTactics :List(TacticId), unalignedDefinitions :List(Node));
explore @4 (result :ExecutionResult);
# An interface allowing a proof exploration session to be initiated by Coq. In this case, Coq decides
# what lemma should be proved and immediately presents the agent with the initial execution result.
}
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = [
"setuptools",
"wheel",
"Cython>=3.0.3",
"pycapnp>=2.0.0b1",
"pycapnp @ git+ssh://git@github.com/capnproto/pycapnp.git",
"Jinja2",
"inflection"]
build-backend = "setuptools.build_meta"
Expand Down Expand Up @@ -36,7 +36,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
"pycapnp>=2.0.0b1",
"pycapnp @ git+ssh://git@github.com/capnproto/pycapnp.git",
"immutables",
"graphviz",
"sanic==23.6.0",
Expand Down
232 changes: 178 additions & 54 deletions pytact/data_reader.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ This trace can later be replayed using
"""

from __future__ import annotations
from contextlib import contextmanager, ExitStack
from contextlib import contextmanager, asynccontextmanager, ExitStack
from dataclasses import dataclass
from typing import Any, Callable, TypeVar, Union, cast, BinaryIO
from collections.abc import Iterable, Sequence, Generator, AsyncGenerator
Expand All @@ -111,6 +111,9 @@ import subprocess
import shutil
import time
import sys
from functools import partial
import asyncio
from asyncio import CancelledError

T = TypeVar('T')
class TupleLike():
Expand Down Expand Up @@ -1451,6 +1454,15 @@ class GlobalContextMessage:
extend the global context of the current message. Once the sub-generator
runs out, the parent generator continues."""

redirect_exceptions : Callable[[BaseException,...], Generator[None, None, None]]
"""A contextmanager used to catch and redirect the specified exceptions back to Coq.

In addition to the specified exceptions, this manager also takes care of
`CancelledError`s thrown as a result of Coq cancelling a request. Instead of terminating
the entire program, such a cancellation should only stop the iteration of the prediction
loop. Otherwise, the loop should continue to respond to requests from Coq.
"""

def _convert_predictions(preds, stack_size):
if isinstance(preds, TacticPredictionsText):
preds = [{'tacticText': pred.tactic_text,
Expand All @@ -1475,15 +1487,12 @@ async def capnp_message_generator_lowlevel(stream: capnp.AsyncIoStream) -> (
`pytact.graph_api_capnp_cython.PredictionProtocol_Request_Reader` after which a
`capnp.lib.capnp._DynamicStructBuilder` message needs to be `send` back.
"""
msg = await graph_api_capnp.PredictionProtocol.Request.read_async(
stream, traversal_limit_in_words=2**64-1)
while msg is not None:
while (msg := await graph_api_capnp.PredictionProtocol.Request.read_async(
stream, traversal_limit_in_words=2**64-1)) is not None:
cython_msg = PredictionProtocol_Request_Reader(msg)
response = yield cython_msg
await response.write_async(stream)
yield
msg = await graph_api_capnp.PredictionProtocol.Request.read_async(
stream, traversal_limit_in_words=2**64-1)

async def capnp_message_generator_from_file_lowlevel(
message_file: BinaryIO,
Expand Down Expand Up @@ -1561,55 +1570,165 @@ if sys.version_info.major == 3 and sys.version_info.minor < 10:
class _MutableBox:
contents: Any

async def prediction_generator(lgenerator, OnlineDefinitionsReader defs, mutret):
async def prediction_generator(lgenerator, OnlineDefinitionsReader defs, mutret, redirect_exceptions):
"""Given the current global context stack `defs`, convert a low-level
generator to a high-level `GlobalContextMessage`"""
msg = await anext(lgenerator, None)
while msg is not None:
if msg.is_initialize:
init = msg.initialize
their_version = init.data_version
our_version = graph_api_capnp.currentVersion
if their_version.major != our_version.major or their_version.minor != our_version.minor:
raise ValueError(
f"This library is compiled for a dataset containing data versioned as "
f"{graph_api_capnp.currentVersion} but file Coq sent a message versioned as "
f"{init.data_version.dynamic}.")
if init.stack_size != defs.graph_index.nodes.size():
mutret.contents = msg
return
else:
response = graph_api_capnp.PredictionProtocol.Response.new_message(initialized=None)
try:
if msg.is_initialize:
init = msg.initialize
their_version = init.data_version
our_version = graph_api_capnp.currentVersion
if their_version.major != our_version.major or their_version.minor != our_version.minor:
raise ValueError(
f"This library is compiled for a dataset containing data versioned as "
f"{graph_api_capnp.currentVersion} but file Coq sent a message versioned as "
f"{init.data_version.dynamic}.")
if init.stack_size != defs.graph_index.nodes.size():
mutret.contents = msg
return
else:
response = graph_api_capnp.PredictionProtocol.Response.new_message(initialized=None)
await lgenerator.asend(response)
with online_definitions_initialize(defs, init) as definitions:
msgm = _MutableBox(None)
pg = prediction_generator(lgenerator, definitions, msgm, redirect_exceptions)
yield GlobalContextMessage(definitions, init.tactics, init.log_annotation, pg,
partial(redirect_exceptions, pg))
if await anext(pg, None) is not None:
raise Exception("Not all prediction requests were consumed")
msg = msgm.contents
elif msg.is_predict:
with online_data_predict(defs, msg.predict) as proof_state:
preds = yield proof_state
response = _convert_predictions(preds, defs.graph_index.nodes.size())
await lgenerator.asend(response)
with online_definitions_initialize(defs, init) as definitions:
msgm = _MutableBox(None)
pg = prediction_generator(lgenerator, definitions, msgm)
yield GlobalContextMessage(definitions, init.tactics, init.log_annotation, pg)
if await anext(pg, None) is not None:
raise Exception("Not all prediction requests were consumed")
msg = msgm.contents
elif msg.is_predict:
with online_data_predict(defs, msg.predict) as proof_state:
preds = yield proof_state
response = _convert_predictions(preds, defs.graph_index.nodes.size())
await lgenerator.asend(response)
yield
msg = await anext(lgenerator, None)
elif msg.is_check_alignment:
alignment = yield CheckAlignmentMessage()
alignment = {'unalignedTactics': alignment.unknown_tactics,
'unalignedDefinitions':
[{'depIndex': defs.graph_index.nodes.size() - 1 - d.node.graph, 'nodeIndex': d.node.nodeid}
for d in alignment.unknown_definitions]}
response = graph_api_capnp.PredictionProtocol.Response.new_message(alignment=alignment)
await lgenerator.asend(response)
yield
msg = await anext(lgenerator, None)
elif msg.is_check_alignment:
alignment = yield CheckAlignmentMessage()
alignment = {'unalignedTactics': alignment.unknown_tactics,
'unalignedDefinitions':
[{'depIndex': defs.graph_index.nodes.size() - 1 - d.node.graph, 'nodeIndex': d.node.nodeid}
for d in alignment.unknown_definitions]}
response = graph_api_capnp.PredictionProtocol.Response.new_message(alignment=alignment)
await lgenerator.asend(response)
yield
msg = await anext(lgenerator, None)
else:
raise Exception(f"Capnp protocol error: Received unknown message type {type(msg)}")
except (Exception, CancelledError) as e:
if isinstance(e, CancelledError) and "ClientCancellation" not in str(e):
raise
await lgenerator.athrow(e)
yield
msg = await anext(lgenerator, None)

@asynccontextmanager
async def redirect_exceptions(gen, *excs: BaseException) -> AsyncGenerator[None, None, None]:
try:
yield
except excs as e:
await gen.athrow(e)
except CancelledError as e:
if "ClientCancellation" in str(e):
task = asyncio.current_task()
if hasattr(task, "uncancel"): # Uncancel was introduced in Python 3.11
task.uncancel()
await gen.athrow(e)
else:
raise Exception("Capnp protocol error")
raise

@asynccontextmanager
async def fake_redirect_exceptions(gen, *execs: BaseException) -> AsyncGenerator[None, None, None]:
yield

class _Server2Generator(graph_api_capnp.PredictionServer.Server):

def __init__(self):
self.request_event = asyncio.Event()
self.response_event = asyncio.Event()
self.lock = asyncio.Lock()
self.main_task = asyncio.current_task()

def disconnected(self):
self.message = None
self.request_event.set()

async def get_request(self):
await self.request_event.wait()
request = self.message
self.request_event.clear()
return request

def put_response(self, response):
self.message = response
self.response_event.set()

async def _communicate(self, request):
async with self.lock:
self.message = request
self.request_event.set()
try:
await self.response_event.wait()
except asyncio.CancelledError:
self.request_event.clear()
self.main_task.cancel("ClientCancellation")
await self.response_event.wait()
response = self.message
self.response_event.clear()
if isinstance(response, BaseException):
raise response
raise
else:
response = self.message
self.response_event.clear()
if isinstance(response, BaseException):
raise response
return response

async def addGlobalContext_context(self, _context):
await self._communicate(apic.PredictionProtocol_Request_Reader(
graph_api_capnp.PredictionProtocol.Request.new_message(initialize=_context.params).as_reader()))

async def requestPrediction_context(self, _context):
response = await self._communicate(apic.PredictionProtocol_Request_Reader(
graph_api_capnp.PredictionProtocol.Request.new_message(predict=_context.params).as_reader()))
_context.results.predictions = response.prediction # TODO: Avoid copying operation

async def requestTextPrediction_context(self, _context):
response = await self._communicate(apic.PredictionProtocol_Request_Reader(
graph_api_capnp.PredictionProtocol.Request.new_message(predict=_context.params).as_reader()))
_context.results.predictions = response.textPrediction # TODO: Avoid copying operation

async def checkAlignment_context(self, _context):
response = await self._communicate(apic.PredictionProtocol_Request_Reader(
graph_api_capnp.PredictionProtocol.Request.new_message(checkAlignment=None).as_reader()))
_context.results.unalignedTactics = response.alignment.unalignedTactics # TODO: Avoid copying
_context.results.unalignedDefinitions = response.alignment.unalignedDefinitions # TODO: Avoid copying

async def capnp_rpc_message_generator_lowlevel(stream):
server = _Server2Generator()
async def server_task():
await capnp.TwoPartyServer(stream, bootstrap=server).on_disconnect()
server.disconnected()
task = asyncio.ensure_future(server_task())
while (request := await server.get_request()) is not None:
try:
response = yield request
server.put_response(response)
yield
except (Exception, CancelledError) as e:
if isinstance(e, CancelledError) and "ClientCancellation" not in str(e):
raise
server.put_response(e)
yield
await task

async def capnp_message_generator(socket: capnp.AsyncIoStream,
record: BinaryIO | None = None) -> GlobalContextMessage:
def capnp_message_generator(stream: capnp.AsyncIoStream,
rpc : bool = False,
record: BinaryIO | None = None) -> GlobalContextMessage:
"""A generator that facilitates communication between a prediction server and a Coq process.

Given a `socket`, this function creates a `GlobalContextMessage` `context`. This message contains an
Expand All @@ -1628,16 +1747,21 @@ async def capnp_message_generator(socket: capnp.AsyncIoStream,
When `record` is passed a file descriptor, all received and sent messages will be dumped into that file
descriptor. These messages can then be replayed later using `capnp_message_generator_from_file`.
"""
lgenerator = capnp_message_generator_lowlevel(socket)
if rpc:
lgenerator = capnp_rpc_message_generator_lowlevel(stream)
redirect = redirect_exceptions
else:
lgenerator = capnp_message_generator_lowlevel(stream)
redirect = fake_redirect_exceptions
if record is not None:
lgenerator = record_lowlevel_generator(record, lgenerator)
defs = OnlineDefinitionsReader.init_empty()
pg = prediction_generator(lgenerator, defs, _MutableBox(None))
return GlobalContextMessage(defs, [], None, pg)
pg = prediction_generator(lgenerator, defs, _MutableBox(None), redirect)
return GlobalContextMessage(defs, [], None, pg, partial(redirect, pg))

async def capnp_message_generator_from_file(message_file: BinaryIO,
check : Callable[[Any, Any, Any], None] | None = None,
record: BinaryIO | None = None) -> GlobalContextMessage:
def capnp_message_generator_from_file(message_file: BinaryIO,
check : Callable[[Any, Any, Any], None] | None = None,
record: BinaryIO | None = None) -> GlobalContextMessage:
"""Replay and verify a pre-recorded communication sequence between Coq and a prediction server.

Highlevel variant of `capnp_message_generator_from_file_lowlevel`.
Expand All @@ -1657,8 +1781,8 @@ async def capnp_message_generator_from_file(message_file: BinaryIO,
if record is not None:
lgenerator = record_lowlevel_generator(record, lgenerator)
defs = OnlineDefinitionsReader.init_empty()
pg = prediction_generator(lgenerator, defs, _MutableBox(None))
return GlobalContextMessage(defs, [], None, pg)
pg = prediction_generator(lgenerator, defs, _MutableBox(None), fake_redirect_exceptions)
return GlobalContextMessage(defs, [], None, pg, partial(fake_redirect_exceptions, pg))


@contextmanager
Expand Down
Loading

0 comments on commit 120bf96

Please sign in to comment.