Skip to content

Commit

Permalink
Address code review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
RaoulSchaffranek committed Nov 21, 2024
1 parent 70f1a68 commit f8fa54f
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 165 deletions.
66 changes: 30 additions & 36 deletions pyk/src/pyk/rpc/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from dataclasses import dataclass
from functools import partial
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import TYPE_CHECKING, Any, Final
from typing import TYPE_CHECKING, Any, Final, NamedTuple

from typing_extensions import Protocol

from ..cli.cli import Options

if TYPE_CHECKING:
from collections.abc import Callable, Iterable
from collections.abc import Callable
from pathlib import Path


Expand Down Expand Up @@ -72,41 +72,16 @@ class JsonRpcMethod(Protocol):
def __call__(self, **kwargs: Any) -> Any: ...


@dataclass(frozen=True)
class JsonRpcRequest:

class JsonRpcRequest(NamedTuple):
id: str | int
method: str
params: Any
id: Any

@staticmethod
def validate(request_dict: Any, valid_methods: Iterable[str]) -> JsonRpcRequest | JsonRpcError:
required_fields = ['jsonrpc', 'method', 'id']
for field in required_fields:
if field not in request_dict:
return JsonRpcError(-32600, f'Invalid request: missing field "{field}"', request_dict.get('id', None))

jsonrpc_version = request_dict['jsonrpc']
if jsonrpc_version != JsonRpcServer.JSONRPC_VERSION:
return JsonRpcError(
-32600, f'Invalid request: bad version: "{jsonrpc_version}"', request_dict.get('id', None)
)

method_name = request_dict['method']
if method_name not in valid_methods:
return JsonRpcError(-32601, f'Method "{method_name}" not found.', request_dict.get('id', None))

return JsonRpcRequest(
method=request_dict['method'], params=request_dict.get('params', None), id=request_dict.get('id', None)
)


@dataclass(frozen=True)
class JsonRpcBatchRequest:
class JsonRpcBatchRequest(NamedTuple):
requests: tuple[JsonRpcRequest]


@dataclass(frozen=True)
class JsonRpcResult:

def encode(self) -> bytes:
Expand All @@ -118,7 +93,7 @@ class JsonRpcError(JsonRpcResult):

code: int
message: str
id: Any
id: str | int | None

def to_json(self) -> dict[str, Any]:
return {
Expand Down Expand Up @@ -204,13 +179,12 @@ def _batch_request(self, requests: list[dict[str, Any]]) -> JsonRpcBatchResult:
return JsonRpcBatchResult(tuple(self._single_request(request) for request in requests))

def _single_request(self, request: dict[str, Any]) -> JsonRpcError | JsonRpcSuccess:
validation_result = JsonRpcRequest.validate(request, self.methods.keys())
validation_result = self._validate_request(request)
if isinstance(validation_result, JsonRpcError):
return validation_result

method_name = request['method']
id, method_name, params = validation_result
method = self.methods[method_name]
params = validation_result.params
_LOGGER.info(f'Executing method {method_name}')
result: Any
if type(params) is dict:
Expand All @@ -220,6 +194,26 @@ def _single_request(self, request: dict[str, Any]) -> JsonRpcError | JsonRpcSucc
elif params is None:
result = method()
else:
return JsonRpcError(-32602, 'Unrecognized method parameter format.', validation_result.id)
return JsonRpcError(-32602, 'Unrecognized method parameter format.', id)
_LOGGER.debug(f'Got response {result}')
return JsonRpcSuccess(result, validation_result.id)
return JsonRpcSuccess(result, id)

def _validate_request(self, request_dict: Any) -> JsonRpcRequest | JsonRpcError:
required_fields = ['jsonrpc', 'method', 'id']
for field in required_fields:
if field not in request_dict:
return JsonRpcError(-32600, f'Invalid request: missing field "{field}"', request_dict.get('id', None))

jsonrpc_version = request_dict['jsonrpc']
if jsonrpc_version != JsonRpcServer.JSONRPC_VERSION:
return JsonRpcError(
-32600, f'Invalid request: bad version: "{jsonrpc_version}"', request_dict.get('id', None)
)

method_name = request_dict['method']
if method_name not in self.methods.keys():
return JsonRpcError(-32601, f'Method "{method_name}" not found.', request_dict.get('id', None))

return JsonRpcRequest(
method=request_dict['method'], params=request_dict.get('params', None), id=request_dict.get('id', None)
)
126 changes: 125 additions & 1 deletion pyk/src/tests/integration/test_json_rpc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import json
from http.client import HTTPConnection
from threading import Thread
from time import sleep
from typing import TYPE_CHECKING

from pyk.cterm import CTerm
from pyk.kast.inner import KApply, KSequence, KSort, KToken
Expand All @@ -11,6 +14,9 @@
from pyk.rpc.rpc import JsonRpcServer, ServeRpcOptions
from pyk.testing import KRunTest

if TYPE_CHECKING:
from typing import Any


class StatefulKJsonRpcServer(JsonRpcServer):
krun: KRun
Expand Down Expand Up @@ -67,7 +73,7 @@ def exec_add(self) -> int:
return int(k_cell.token)


class TestJsonRPCServer(KRunTest):
class TestJsonKRPCServer(KRunTest):
KOMPILE_DEFINITION = """
module JSON-RPC-EXAMPLE-SYNTAX
imports INT-SYNTAX
Expand Down Expand Up @@ -133,3 +139,121 @@ def wait_until_ready() -> None:

server.shutdown()
thread.join()


class StatefulJsonRpcServer(JsonRpcServer):

x: int = 42
y: int = 43

def __init__(self, options: ServeRpcOptions) -> None:
super().__init__(options)

self.register_method('get_x', self.exec_get_x)
self.register_method('get_y', self.exec_get_y)
self.register_method('set_x', self.exec_set_x)
self.register_method('set_y', self.exec_set_y)
self.register_method('add', self.exec_add)

def exec_get_x(self) -> int:
return self.x

def exec_get_y(self) -> int:
return self.y

def exec_set_x(self, n: int) -> None:
self.x = n

def exec_set_y(self, n: int) -> None:
self.y = n

def exec_add(self) -> int:
return self.x + self.y


class TestJsonRPCServer(KRunTest):

def test_json_rpc_server(self) -> None:
server = StatefulJsonRpcServer(ServeRpcOptions({'port': 0}))

def run_server() -> None:
server.serve()

def wait_until_server_is_up() -> None:
while True:
try:
server.port()
return
except ValueError:
sleep(0.1)

thread = Thread(target=run_server)
thread.start()

wait_until_server_is_up()

http_client = HTTPConnection('localhost', server.port())
rpc_client = SimpleClient(http_client)

def wait_until_ready() -> None:
while True:
try:
rpc_client.request('get_x', [])
except ConnectionRefusedError:
sleep(0.1)
continue
break

wait_until_ready()

rpc_client.request('set_x', [123])
res = rpc_client.request('get_x')
assert res == 123

rpc_client.request('set_y', [456])
res = rpc_client.request('get_y')
assert res == 456

res = rpc_client.request('add', [])
assert res == (123 + 456)

res = rpc_client.batch_request(('set_x', [1]), ('set_y', [2]), ('add', []))
assert len(res) == 3
assert res[2]['result'] == 1 + 2

server.shutdown()
thread.join()


class SimpleClient:

client: HTTPConnection
_request_id: int = 0

def __init__(self, client: HTTPConnection) -> None:
self.client = client

def request_id(self) -> int:
self._request_id += 1
return self._request_id

def request(self, method: str, params: Any = None) -> Any:
body = json.dumps({'jsonrpc': '2.0', 'method': method, 'params': params, 'id': self.request_id()})

self.client.request('POST', '/', body)
response = self.client.getresponse()
result = json.loads(response.read())
return result['result']

def batch_request(self, *requests: tuple[str, Any]) -> list[Any]:
body = json.dumps(
[
{'jsonrpc': '2.0', 'method': method, 'params': params, 'id': self.request_id()}
for method, params in requests
]
)

self.client.request('POST', '/', body)
response = self.client.getresponse()
result = json.loads(response.read())
return result
128 changes: 0 additions & 128 deletions pyk/src/tests/unit/test_json_rpc.py

This file was deleted.

0 comments on commit f8fa54f

Please sign in to comment.