Skip to content

Commit

Permalink
Make outputs go to correct cell when generated in threads/asyncio
Browse files Browse the repository at this point in the history
  • Loading branch information
krassowski committed Dec 21, 2023
1 parent 6d97970 commit 9e9c40e
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 36 deletions.
98 changes: 62 additions & 36 deletions ipykernel/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

import asyncio
import atexit
import contextvars
import io
import os
import sys
import threading
import traceback
import warnings
from binascii import b2a_hex
from collections import deque
from collections import defaultdict, deque
from io import StringIO, TextIOBase
from threading import local
from typing import Any, Callable, Deque, Dict, Optional
Expand Down Expand Up @@ -412,7 +413,7 @@ def __init__(
name : str {'stderr', 'stdout'}
the name of the standard stream to replace
pipe : object
the pip object
the pipe object
echo : bool
whether to echo output
watchfd : bool (default, True)
Expand Down Expand Up @@ -446,13 +447,18 @@ def __init__(
self.pub_thread = pub_thread
self.name = name
self.topic = b"stream." + name.encode()
self.parent_header = {}
self._parent_header: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar(
"parent_header"
)
self._parent_header.set({})
self._thread_parents = {}
self._parent_header_global = {}
self._master_pid = os.getpid()
self._flush_pending = False
self._subprocess_flush_pending = False
self._io_loop = pub_thread.io_loop
self._buffer_lock = threading.RLock()
self._buffer = StringIO()
self._buffers = defaultdict(StringIO)
self.echo = None
self._isatty = bool(isatty)
self._should_watch = False
Expand Down Expand Up @@ -495,6 +501,24 @@ def __init__(
msg = "echo argument must be a file-like object"
raise ValueError(msg)

@property
def parent_header(self):
try:
# asyncio-specific
return self._parent_header.get()
except LookupError:
try:
# thread-specific
return self._thread_parents[threading.current_thread().ident]
except KeyError:
# global (fallback)
return self._parent_header_global

@parent_header.setter
def parent_header(self, value):
self._parent_header_global = value
return self._parent_header.set(value)

def isatty(self):
"""Return a bool indicating whether this is an 'interactive' stream.
Expand Down Expand Up @@ -598,28 +622,28 @@ def _flush(self):
if self.echo is not sys.__stderr__:
print(f"Flush failed: {e}", file=sys.__stderr__)

data = self._flush_buffer()
if data:
# FIXME: this disables Session's fork-safe check,
# since pub_thread is itself fork-safe.
# There should be a better way to do this.
self.session.pid = os.getpid()
content = {"name": self.name, "text": data}
msg = self.session.msg("stream", content, parent=self.parent_header)

# Each transform either returns a new
# message or None. If None is returned,
# the message has been 'used' and we return.
for hook in self._hooks:
msg = hook(msg)
if msg is None:
return

self.session.send(
self.pub_thread,
msg,
ident=self.topic,
)
for parent, data in self._flush_buffers():
if data:
# FIXME: this disables Session's fork-safe check,
# since pub_thread is itself fork-safe.
# There should be a better way to do this.
self.session.pid = os.getpid()
content = {"name": self.name, "text": data}
msg = self.session.msg("stream", content, parent=parent)

# Each transform either returns a new
# message or None. If None is returned,
# the message has been 'used' and we return.
for hook in self._hooks:
msg = hook(msg)
if msg is None:
return

self.session.send(
self.pub_thread,
msg,
ident=self.topic,
)

def write(self, string: str) -> Optional[int]: # type:ignore[override]
"""Write to current stream after encoding if necessary
Expand All @@ -630,6 +654,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
number of items from input parameter written to stream.
"""
parent = self.parent_header

if not isinstance(string, str):
msg = f"write() argument must be str, not {type(string)}" # type:ignore[unreachable]
Expand All @@ -649,7 +674,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
is_child = not self._is_master_process()
# only touch the buffer in the IO thread to avoid races
with self._buffer_lock:
self._buffer.write(string)
self._buffers[frozenset(parent.items())].write(string)
if is_child:
# mp.Pool cannot be trusted to flush promptly (or ever),
# and this helps.
Expand All @@ -675,19 +700,20 @@ def writable(self):
"""Test whether the stream is writable."""
return True

def _flush_buffer(self):
def _flush_buffers(self):
"""clear the current buffer and return the current buffer data."""
buf = self._rotate_buffer()
data = buf.getvalue()
buf.close()
return data
buffers = self._rotate_buffers()
for frozen_parent, buffer in buffers.items():
data = buffer.getvalue()
buffer.close()
yield dict(frozen_parent), data

def _rotate_buffer(self):
def _rotate_buffers(self):
"""Returns the current buffer and replaces it with an empty buffer."""
with self._buffer_lock:
old_buffer = self._buffer
self._buffer = StringIO()
return old_buffer
old_buffers = self._buffers
self._buffers = defaultdict(StringIO)
return old_buffers

@property
def _hooks(self):
Expand Down
70 changes: 70 additions & 0 deletions ipykernel/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import builtins
import gc
import getpass
import os
import signal
Expand All @@ -14,6 +15,7 @@
import comm
from IPython.core import release
from IPython.utils.tokenutil import line_at_cursor, token_at_cursor
from jupyter_client.session import extract_header
from traitlets import Any, Bool, HasTraits, Instance, List, Type, observe, observe_compat
from zmq.eventloop.zmqstream import ZMQStream

Expand All @@ -22,6 +24,7 @@
from .compiler import XCachingCompiler
from .debugger import Debugger, _is_debugpy_available
from .eventloops import _use_appnope
from .iostream import OutStream
from .kernelbase import Kernel as KernelBase
from .kernelbase import _accepts_parameters
from .zmqshell import ZMQInteractiveShell
Expand Down Expand Up @@ -66,6 +69,10 @@ def _get_comm_manager(*args, **kwargs):
comm.create_comm = _create_comm
comm.get_comm_manager = _get_comm_manager

import threading

threading_start = threading.Thread.start


class IPythonKernel(KernelBase):
"""The IPython Kernel class."""
Expand Down Expand Up @@ -151,6 +158,11 @@ def __init__(self, **kwargs):

appnope.nope()

if hasattr(gc, "callbacks"):
# while `gc.callbacks` exists since Python 3.3, pypy does not
# implement it even as of 3.9.
gc.callbacks.append(self._clean_thread_parent_frames)

help_links = List(
[
{
Expand Down Expand Up @@ -341,6 +353,12 @@ def set_sigint_result():
# restore the previous sigint handler
signal.signal(signal.SIGINT, save_sigint)

async def execute_request(self, stream, ident, parent):
"""Override for cell output - cell reconciliation."""
parent_header = extract_header(parent)
self._associate_identity_of_new_threads_with(parent_header)
await super().execute_request(stream, ident, parent)

async def do_execute(
self,
code,
Expand Down Expand Up @@ -706,6 +724,58 @@ def do_clear(self):
self.shell.reset(False)
return dict(status="ok")

def _associate_identity_of_new_threads_with(self, parent_header):
"""Intercept the identity of any thread started after this method finished,
and associate the thread's output with the parent header frame, which allows
to direct the outputs to the cell which started the thread.
This is a no-op if the `self._stdout` and `self._stderr` are not
sub-classes of `OutStream`.
"""
stdout = self._stdout
stderr = self._stderr

def start_closure(self: threading.Thread):
"""Wrap the `threading.Thread.start` to intercept thread identity.
This is needed because there is no "start" hook yet, but there
might be one in the future: https://bugs.python.org/issue14073
"""

threading_start(self)
for stream in [stdout, stderr]:
if isinstance(stream, OutStream):
stream._thread_parents[self.ident] = parent_header

threading.Thread.start = start_closure # type:ignore[method-assign]

def _clean_thread_parent_frames(
self, phase: t.Literal["start", "stop"], info: t.Dict[str, t.Any]
):
"""Clean parent frames of threads which are no longer running.
This is meant to be invoked by garbage collector callback hook.
The implementation enumerates the threads because there is no "exit" hook yet,
but there might be one in the future: https://bugs.python.org/issue14073
This is a no-op if the `self._stdout` and `self._stderr` are not
sub-classes of `OutStream`.
"""
# Only run before the garbage collector starts
if phase != "start":
return
active_threads = {thread.ident for thread in threading.enumerate()}
for stream in [self._stdout, self._stderr]:
if isinstance(stream, OutStream):
thread_parents = stream._thread_parents
for identity in list(thread_parents.keys()):
if identity not in active_threads:
try:
del thread_parents[identity]
except KeyError:
pass


# This exists only for backwards compatibility - use IPythonKernel instead

Expand Down
8 changes: 8 additions & 0 deletions ipykernel/kernelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from ipykernel.jsonutil import json_clean

from ._version import kernel_protocol_version
from .iostream import OutStream


def _accepts_parameters(meth, param_names):
Expand Down Expand Up @@ -272,6 +273,13 @@ def _parent_header(self):
def __init__(self, **kwargs):
"""Initialize the kernel."""
super().__init__(**kwargs)

# Kernel application may swap stdout and stderr to OutStream,
# which is the case in `IPKernelApp.init_io`, hence `sys.stdout`
# can already by different from TextIO at initialization time.
self._stdout: OutStream | t.TextIO = sys.stdout
self._stderr: OutStream | t.TextIO = sys.stderr

# Build dict of handlers for message types
self.shell_handlers = {}
for msg_type in self.msg_types:
Expand Down

0 comments on commit 9e9c40e

Please sign in to comment.