diff --git a/mypy/dmypy/client.py b/mypy/dmypy/client.py index c3a2308d1b44..229740e44db0 100644 --- a/mypy/dmypy/client.py +++ b/mypy/dmypy/client.py @@ -17,7 +17,7 @@ from typing import Any, Callable, Mapping, NoReturn from mypy.dmypy_os import alive, kill -from mypy.dmypy_util import DEFAULT_STATUS_FILE, receive +from mypy.dmypy_util import DEFAULT_STATUS_FILE, receive, send from mypy.ipc import IPCClient, IPCException from mypy.util import check_python_version, get_terminal_width, should_force_color from mypy.version import __version__ @@ -659,28 +659,29 @@ def request( # so that it can format the type checking output accordingly. args["is_tty"] = sys.stdout.isatty() or should_force_color() args["terminal_width"] = get_terminal_width() - bdata = json.dumps(args).encode("utf8") _, name = get_status(status_file) try: with IPCClient(name, timeout) as client: - client.write(bdata) - response = receive(client) + send(client, args) + + final = False + while not final: + response = receive(client) + final = bool(response.pop("final", False)) + # Display debugging output written to stdout/stderr in the server process for convenience. + # This should not be confused with "out" and "err" fields in the response. + # Those fields hold the output of the "check" command, and are handled in check_output(). + stdout = response.pop("stdout", None) + if stdout: + sys.stdout.write(stdout) + stderr = response.pop("stderr", None) + if stderr: + sys.stderr.write(stderr) except (OSError, IPCException) as err: return {"error": str(err)} # TODO: Other errors, e.g. ValueError, UnicodeError - else: - # Display debugging output written to stdout/stderr in the server process for convenience. - # This should not be confused with "out" and "err" fields in the response. - # Those fields hold the output of the "check" command, and are handled in check_output(). - stdout = response.get("stdout") - if stdout: - sys.stdout.write(stdout) - stderr = response.get("stderr") - if stderr: - print("-" * 79) - print("stderr:") - sys.stdout.write(stderr) - return response + + return response def get_status(status_file: str) -> tuple[int, str]: diff --git a/mypy/dmypy_server.py b/mypy/dmypy_server.py index faa9a23fadfb..9cc0888fc208 100644 --- a/mypy/dmypy_server.py +++ b/mypy/dmypy_server.py @@ -23,7 +23,7 @@ import mypy.build import mypy.errors import mypy.main -from mypy.dmypy_util import receive +from mypy.dmypy_util import WriteToConn, receive, send from mypy.find_sources import InvalidSourceList, create_source_list from mypy.fscache import FileSystemCache from mypy.fswatcher import FileData, FileSystemWatcher @@ -208,10 +208,12 @@ def _response_metadata(self) -> dict[str, str]: def serve(self) -> None: """Serve requests, synchronously (no thread or fork).""" + command = None server = IPCServer(CONNECTION_NAME, self.timeout) orig_stdout = sys.stdout orig_stderr = sys.stderr + try: with open(self.status_file, "w") as f: json.dump({"pid": os.getpid(), "connection_name": server.connection_name}, f) @@ -219,10 +221,8 @@ def serve(self) -> None: while True: with server: data = receive(server) - debug_stdout = io.StringIO() - debug_stderr = io.StringIO() - sys.stdout = debug_stdout - sys.stderr = debug_stderr + sys.stdout = WriteToConn(server, "stdout") # type: ignore[assignment] + sys.stderr = WriteToConn(server, "stderr") # type: ignore[assignment] resp: dict[str, Any] = {} if "command" not in data: resp = {"error": "No command found in request"} @@ -239,15 +239,13 @@ def serve(self) -> None: tb = traceback.format_exception(*sys.exc_info()) resp = {"error": "Daemon crashed!\n" + "".join(tb)} resp.update(self._response_metadata()) - resp["stdout"] = debug_stdout.getvalue() - resp["stderr"] = debug_stderr.getvalue() - server.write(json.dumps(resp).encode("utf8")) + resp["final"] = True + send(server, resp) raise - resp["stdout"] = debug_stdout.getvalue() - resp["stderr"] = debug_stderr.getvalue() + resp["final"] = True try: resp.update(self._response_metadata()) - server.write(json.dumps(resp).encode("utf8")) + send(server, resp) except OSError: pass # Maybe the client hung up if command == "stop": diff --git a/mypy/dmypy_util.py b/mypy/dmypy_util.py index 2aae41d998da..d95cba9f40b5 100644 --- a/mypy/dmypy_util.py +++ b/mypy/dmypy_util.py @@ -6,7 +6,7 @@ from __future__ import annotations import json -from typing import Any, Final +from typing import Any, Final, Iterable from mypy.ipc import IPCBase @@ -14,7 +14,7 @@ def receive(connection: IPCBase) -> Any: - """Receive JSON data from a connection until EOF. + """Receive single JSON data frame from a connection. Raise OSError if the data received is not valid JSON or if it is not a dict. @@ -23,9 +23,36 @@ def receive(connection: IPCBase) -> Any: if not bdata: raise OSError("No data received") try: - data = json.loads(bdata.decode("utf8")) + data = json.loads(bdata) except Exception as e: raise OSError("Data received is not valid JSON") from e if not isinstance(data, dict): raise OSError(f"Data received is not a dict ({type(data)})") return data + + +def send(connection: IPCBase, data: Any) -> None: + """Send data to a connection encoded and framed. + + The data must be JSON-serializable. We assume that a single send call is a + single frame to be sent on the connect. + """ + connection.write(json.dumps(data)) + + +class WriteToConn: + """Helper class to write to a connection instead of standard output.""" + + def __init__(self, server: IPCBase, output_key: str = "stdout"): + self.server = server + self.output_key = output_key + + def write(self, output: str) -> int: + resp: dict[str, Any] = {} + resp[self.output_key] = output + send(self.server, resp) + return len(output) + + def writelines(self, lines: Iterable[str]) -> None: + for s in lines: + self.write(s) diff --git a/mypy/ipc.py b/mypy/ipc.py index d026f2429a0f..ab01f1b79e7d 100644 --- a/mypy/ipc.py +++ b/mypy/ipc.py @@ -7,6 +7,7 @@ from __future__ import annotations import base64 +import codecs import os import shutil import sys @@ -40,6 +41,10 @@ class IPCBase: This contains logic shared between the client and server, such as reading and writing. + We want to be able to send multiple "messages" over a single connection and + to be able to separate the messages. We do this by encoding the messages + in an alphabet that does not contain spaces, then adding a space for + separation. The last framed message is also followed by a space. """ connection: _IPCHandle @@ -47,12 +52,30 @@ class IPCBase: def __init__(self, name: str, timeout: float | None) -> None: self.name = name self.timeout = timeout + self.buffer = bytearray() - def read(self, size: int = 100000) -> bytes: - """Read bytes from an IPC connection until its empty.""" - bdata = bytearray() + def frame_from_buffer(self) -> bytearray | None: + """Return a full frame from the bytes we have in the buffer.""" + space_pos = self.buffer.find(b" ") + if space_pos == -1: + return None + # We have a full frame + bdata = self.buffer[:space_pos] + self.buffer = self.buffer[space_pos + 1 :] + return bdata + + def read(self, size: int = 100000) -> str: + """Read bytes from an IPC connection until we have a full frame.""" + bdata: bytearray | None = bytearray() if sys.platform == "win32": while True: + # Check if we already have a message in the buffer before + # receiving any more data from the socket. + bdata = self.frame_from_buffer() + if bdata is not None: + break + + # Receive more data into the buffer. ov, err = _winapi.ReadFile(self.connection, size, overlapped=True) try: if err == _winapi.ERROR_IO_PENDING: @@ -66,7 +89,10 @@ def read(self, size: int = 100000) -> bytes: _, err = ov.GetOverlappedResult(True) more = ov.getbuffer() if more: - bdata.extend(more) + self.buffer.extend(more) + bdata = self.frame_from_buffer() + if bdata is not None: + break if err == 0: # we are done! break @@ -77,17 +103,34 @@ def read(self, size: int = 100000) -> bytes: raise IPCException("ReadFile operation aborted.") else: while True: + # Check if we already have a message in the buffer before + # receiving any more data from the socket. + bdata = self.frame_from_buffer() + if bdata is not None: + break + + # Receive more data into the buffer. more = self.connection.recv(size) if not more: + # Connection closed break - bdata.extend(more) - return bytes(bdata) + self.buffer.extend(more) + + if not bdata: + # Socket was empty and we didn't get any frame. + # This should only happen if the socket was closed. + return "" + return codecs.decode(bdata, "base64").decode("utf8") + + def write(self, data: str) -> None: + """Write to an IPC connection.""" + + # Frame the data by urlencoding it and separating by space. + encoded_data = codecs.encode(data.encode("utf8"), "base64") + b" " - def write(self, data: bytes) -> None: - """Write bytes to an IPC connection.""" if sys.platform == "win32": try: - ov, err = _winapi.WriteFile(self.connection, data, overlapped=True) + ov, err = _winapi.WriteFile(self.connection, encoded_data, overlapped=True) try: if err == _winapi.ERROR_IO_PENDING: timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE @@ -101,12 +144,11 @@ def write(self, data: bytes) -> None: raise bytes_written, err = ov.GetOverlappedResult(True) assert err == 0, err - assert bytes_written == len(data) + assert bytes_written == len(encoded_data) except OSError as e: raise IPCException(f"Failed to write with error: {e.winerror}") from e else: - self.connection.sendall(data) - self.connection.shutdown(socket.SHUT_WR) + self.connection.sendall(encoded_data) def close(self) -> None: if sys.platform == "win32": diff --git a/mypy/test/testipc.py b/mypy/test/testipc.py index 9034f514bb45..8ef656dc4579 100644 --- a/mypy/test/testipc.py +++ b/mypy/test/testipc.py @@ -15,14 +15,25 @@ def server(msg: str, q: Queue[str]) -> None: server = IPCServer(CONNECTION_NAME) q.put(server.connection_name) - data = b"" + data = "" while not data: with server: - server.write(msg.encode()) + server.write(msg) data = server.read() server.cleanup() +def server_multi_message_echo(q: Queue[str]) -> None: + server = IPCServer(CONNECTION_NAME) + q.put(server.connection_name) + data = "" + with server: + while data != "quit": + data = server.read() + server.write(data) + server.cleanup() + + class IPCTests(TestCase): def test_transaction_large(self) -> None: queue: Queue[str] = Queue() @@ -31,8 +42,8 @@ def test_transaction_large(self) -> None: p.start() connection_name = queue.get() with IPCClient(connection_name, timeout=1) as client: - assert client.read() == msg.encode() - client.write(b"test") + assert client.read() == msg + client.write("test") queue.close() queue.join_thread() p.join() @@ -44,12 +55,37 @@ def test_connect_twice(self) -> None: p.start() connection_name = queue.get() with IPCClient(connection_name, timeout=1) as client: - assert client.read() == msg.encode() - client.write(b"") # don't let the server hang up yet, we want to connect again. + assert client.read() == msg + client.write("") # don't let the server hang up yet, we want to connect again. with IPCClient(connection_name, timeout=1) as client: - assert client.read() == msg.encode() - client.write(b"test") + assert client.read() == msg + client.write("test") + queue.close() + queue.join_thread() + p.join() + assert p.exitcode == 0 + + def test_multiple_messages(self) -> None: + queue: Queue[str] = Queue() + p = Process(target=server_multi_message_echo, args=(queue,), daemon=True) + p.start() + connection_name = queue.get() + with IPCClient(connection_name, timeout=1) as client: + # "foo bar" with extra accents on letters. + # In UTF-8 encoding so we don't confuse editors opening this file. + fancy_text = b"f\xcc\xb6o\xcc\xb2\xf0\x9d\x91\x9c \xd0\xb2\xe2\xb7\xa1a\xcc\xb6r\xcc\x93\xcd\x98\xcd\x8c" + client.write(fancy_text.decode("utf-8")) + assert client.read() == fancy_text.decode("utf-8") + + client.write("Test with spaces") + client.write("Test write before reading previous") + time.sleep(0) # yield to the server to force reading of all messages by server. + assert client.read() == "Test with spaces" + assert client.read() == "Test write before reading previous" + + client.write("quit") + assert client.read() == "quit" queue.close() queue.join_thread() p.join()