Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make outputs go to correct cell when generated in threads/asyncio #1186

Merged
merged 6 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 62 additions & 36 deletions ipykernel/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

import asyncio
import atexit
import contextvars
import io
import os
import sys
import threading
import traceback
import warnings
from binascii import b2a_hex
from collections import deque
from collections import defaultdict, deque
from io import StringIO, TextIOBase
from threading import local
from typing import Any, Callable, Deque, Dict, Optional
Expand Down Expand Up @@ -412,7 +413,7 @@ def __init__(
name : str {'stderr', 'stdout'}
the name of the standard stream to replace
pipe : object
the pip object
the pipe object
echo : bool
whether to echo output
watchfd : bool (default, True)
Expand Down Expand Up @@ -446,13 +447,18 @@ def __init__(
self.pub_thread = pub_thread
self.name = name
self.topic = b"stream." + name.encode()
self.parent_header = {}
self._parent_header: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar(
"parent_header"
)
self._parent_header.set({})
self._thread_parents = {}
self._parent_header_global = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why the ContextVar and the _thread_parents dict? Is a threading.local not more standard?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because contextvar works well for asyncio edge cases. There is some more explanation in the PEP 567, but the gist is:

Thread-local variables are insufficient for asynchronous tasks that execute concurrently in the same OS thread. Any context manager that saves and restores a context value using threading.local() will have its context values bleed to other code unexpectedly when used in async/await code.

That said, I will take another look at using threading.local instead of _thread_parents.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that we need ContextVar now, because also when we create new tasks we want to output to the right output cell, didn't think of that!

I think if we use ContextVar, we do not need the thread local storage, it's a superset of threading.local. In combination with overriding the Thread ctor, we can drop _thread_parents.

self._master_pid = os.getpid()
self._flush_pending = False
self._subprocess_flush_pending = False
self._io_loop = pub_thread.io_loop
self._buffer_lock = threading.RLock()
self._buffer = StringIO()
self._buffers = defaultdict(StringIO)
self.echo = None
self._isatty = bool(isatty)
self._should_watch = False
Expand Down Expand Up @@ -495,6 +501,24 @@ def __init__(
msg = "echo argument must be a file-like object"
raise ValueError(msg)

@property
def parent_header(self):
try:
# asyncio-specific
return self._parent_header.get()
except LookupError:
try:
# thread-specific
return self._thread_parents[threading.current_thread().ident]
except KeyError:
# global (fallback)
return self._parent_header_global

@parent_header.setter
def parent_header(self, value):
self._parent_header_global = value
return self._parent_header.set(value)

def isatty(self):
"""Return a bool indicating whether this is an 'interactive' stream.

Expand Down Expand Up @@ -598,28 +622,28 @@ def _flush(self):
if self.echo is not sys.__stderr__:
print(f"Flush failed: {e}", file=sys.__stderr__)

data = self._flush_buffer()
if data:
# FIXME: this disables Session's fork-safe check,
# since pub_thread is itself fork-safe.
# There should be a better way to do this.
self.session.pid = os.getpid()
content = {"name": self.name, "text": data}
msg = self.session.msg("stream", content, parent=self.parent_header)

# Each transform either returns a new
# message or None. If None is returned,
# the message has been 'used' and we return.
for hook in self._hooks:
msg = hook(msg)
if msg is None:
return

self.session.send(
self.pub_thread,
msg,
ident=self.topic,
)
for parent, data in self._flush_buffers():
if data:
# FIXME: this disables Session's fork-safe check,
# since pub_thread is itself fork-safe.
# There should be a better way to do this.
self.session.pid = os.getpid()
content = {"name": self.name, "text": data}
msg = self.session.msg("stream", content, parent=parent)

# Each transform either returns a new
# message or None. If None is returned,
# the message has been 'used' and we return.
for hook in self._hooks:
msg = hook(msg)
if msg is None:
return

self.session.send(
self.pub_thread,
msg,
ident=self.topic,
)

def write(self, string: str) -> Optional[int]: # type:ignore[override]
"""Write to current stream after encoding if necessary
Expand All @@ -630,6 +654,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
number of items from input parameter written to stream.

"""
parent = self.parent_header

if not isinstance(string, str):
msg = f"write() argument must be str, not {type(string)}" # type:ignore[unreachable]
Expand All @@ -649,7 +674,7 @@ def write(self, string: str) -> Optional[int]: # type:ignore[override]
is_child = not self._is_master_process()
# only touch the buffer in the IO thread to avoid races
with self._buffer_lock:
self._buffer.write(string)
self._buffers[frozenset(parent.items())].write(string)
if is_child:
# mp.Pool cannot be trusted to flush promptly (or ever),
# and this helps.
Expand All @@ -675,19 +700,20 @@ def writable(self):
"""Test whether the stream is writable."""
return True

def _flush_buffer(self):
def _flush_buffers(self):
"""clear the current buffer and return the current buffer data."""
buf = self._rotate_buffer()
data = buf.getvalue()
buf.close()
return data
buffers = self._rotate_buffers()
for frozen_parent, buffer in buffers.items():
data = buffer.getvalue()
buffer.close()
yield dict(frozen_parent), data

def _rotate_buffer(self):
def _rotate_buffers(self):
"""Returns the current buffer and replaces it with an empty buffer."""
with self._buffer_lock:
old_buffer = self._buffer
self._buffer = StringIO()
return old_buffer
old_buffers = self._buffers
self._buffers = defaultdict(StringIO)
return old_buffers

@property
def _hooks(self):
Expand Down
70 changes: 70 additions & 0 deletions ipykernel/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import builtins
import gc
import getpass
import os
import signal
Expand All @@ -14,6 +15,7 @@
import comm
from IPython.core import release
from IPython.utils.tokenutil import line_at_cursor, token_at_cursor
from jupyter_client.session import extract_header
from traitlets import Any, Bool, HasTraits, Instance, List, Type, observe, observe_compat
from zmq.eventloop.zmqstream import ZMQStream

Expand All @@ -22,6 +24,7 @@
from .compiler import XCachingCompiler
from .debugger import Debugger, _is_debugpy_available
from .eventloops import _use_appnope
from .iostream import OutStream
from .kernelbase import Kernel as KernelBase
from .kernelbase import _accepts_parameters
from .zmqshell import ZMQInteractiveShell
Expand Down Expand Up @@ -66,6 +69,10 @@ def _get_comm_manager(*args, **kwargs):
comm.create_comm = _create_comm
comm.get_comm_manager = _get_comm_manager

import threading

threading_start = threading.Thread.start


class IPythonKernel(KernelBase):
"""The IPython Kernel class."""
Expand Down Expand Up @@ -151,6 +158,11 @@ def __init__(self, **kwargs):

appnope.nope()

if hasattr(gc, "callbacks"):
# while `gc.callbacks` exists since Python 3.3, pypy does not
# implement it even as of 3.9.
gc.callbacks.append(self._clean_thread_parent_frames)

help_links = List(
[
{
Expand Down Expand Up @@ -341,6 +353,12 @@ def set_sigint_result():
# restore the previous sigint handler
signal.signal(signal.SIGINT, save_sigint)

async def execute_request(self, stream, ident, parent):
"""Override for cell output - cell reconciliation."""
parent_header = extract_header(parent)
self._associate_identity_of_new_threads_with(parent_header)
await super().execute_request(stream, ident, parent)

async def do_execute(
self,
code,
Expand Down Expand Up @@ -706,6 +724,58 @@ def do_clear(self):
self.shell.reset(False)
return dict(status="ok")

def _associate_identity_of_new_threads_with(self, parent_header):
"""Intercept the identity of any thread started after this method finished,

and associate the thread's output with the parent header frame, which allows
to direct the outputs to the cell which started the thread.

This is a no-op if the `self._stdout` and `self._stderr` are not
sub-classes of `OutStream`.
"""
stdout = self._stdout
stderr = self._stderr

def start_closure(self: threading.Thread):
"""Wrap the `threading.Thread.start` to intercept thread identity.

This is needed because there is no "start" hook yet, but there
might be one in the future: https://bugs.python.org/issue14073
"""

threading_start(self)
for stream in [stdout, stderr]:
if isinstance(stream, OutStream):
stream._thread_parents[self.ident] = parent_header

threading.Thread.start = start_closure # type:ignore[method-assign]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if there could be a way to not monkey-patch threading.Thread.start.
In akernel it's the print function that is monkey-patched, so this PR goes further because it works for lower-level writes to stdout/stderr, but with monkey-patching there is always a risk that another library changes the same object, and we cannot know that our change will be the last applied.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding as per https://bugs.python.org/issue14073 is that there is no other way, but we could mention our use case as another situation motivating introduction of start/exit hooks/callbacks to threads (which is something Python committers have considered in the past but I presume it was not a sufficiently high priority until now).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We crossed 'comments', see my other thread on this exact line for (what I think is) a better solution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it thread-safe. If a thread starts a new thread, that new thread will start outputting in the last executed cell.
I've taken a different approach in Solara, where I patch the ctor and run from the start:
https://github.com/widgetti/solara/blob/0ac4767a8c5f8c8b221cafe41fa9ac36270adbe8/solara/server/patch.py#L336

I think this approach might work better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To summarize the strategy, the ctor (__init__) would take a reference to the parent_header, and in run, that parent header is then set in a thread_local storage (in your case _thread_parents).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a thread starts a new thread, that new thread will start outputting in the last executed cell.

Correct (there is no such a problem with asyncio side of things).

I think this approach might work better.

Thank you for the link, I will take a look!

Copy link
Member Author

@krassowski krassowski Dec 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking a closer look at solara, current_context takes the role of _thread_parents, right? This does not seem to differ that much different. Your approach of patching earlier on initialization rather than startup has the advantage that it will work with:

from threading import Thread
from time import sleep

def child_target():
    for i in range(iterations):
        print(i, end='', flush=True)
        sleep(interval)

def parent_target():
    thread = Thread(target=child_target)
    sleep(interval)
    thread.start()

Thread(target=parent_target).start()

but still not with:

def parent_target():
    sleep(interval)
    Thread(target=child_target).start()

Thread(target=parent_target).start()

do I see this right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the end my implementation converged to overriding the same methods as yours after all in e1258de :)


def _clean_thread_parent_frames(
self, phase: t.Literal["start", "stop"], info: t.Dict[str, t.Any]
):
"""Clean parent frames of threads which are no longer running.
This is meant to be invoked by garbage collector callback hook.

The implementation enumerates the threads because there is no "exit" hook yet,
but there might be one in the future: https://bugs.python.org/issue14073

This is a no-op if the `self._stdout` and `self._stderr` are not
sub-classes of `OutStream`.
"""
# Only run before the garbage collector starts
if phase != "start":
return
active_threads = {thread.ident for thread in threading.enumerate()}
for stream in [self._stdout, self._stderr]:
if isinstance(stream, OutStream):
thread_parents = stream._thread_parents
for identity in list(thread_parents.keys()):
if identity not in active_threads:
try:
del thread_parents[identity]
except KeyError:
pass


# This exists only for backwards compatibility - use IPythonKernel instead

Expand Down
8 changes: 8 additions & 0 deletions ipykernel/kernelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from ipykernel.jsonutil import json_clean

from ._version import kernel_protocol_version
from .iostream import OutStream


def _accepts_parameters(meth, param_names):
Expand Down Expand Up @@ -272,6 +273,13 @@ def _parent_header(self):
def __init__(self, **kwargs):
"""Initialize the kernel."""
super().__init__(**kwargs)

# Kernel application may swap stdout and stderr to OutStream,
# which is the case in `IPKernelApp.init_io`, hence `sys.stdout`
# can already by different from TextIO at initialization time.
self._stdout: OutStream | t.TextIO = sys.stdout
self._stderr: OutStream | t.TextIO = sys.stderr

# Build dict of handlers for message types
self.shell_handlers = {}
for msg_type in self.msg_types:
Expand Down
64 changes: 64 additions & 0 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,70 @@ def test_simple_print():
_check_master(kc, expected=True)


def test_print_to_correct_cell_from_thread():
"""should print to the cell that spawned the thread, not a subsequently run cell"""
iterations = 5
interval = 0.25
code = f"""\
from threading import Thread
from time import sleep

def thread_target():
for i in range({iterations}):
print(i, end='', flush=True)
sleep({interval})

Thread(target=thread_target).start()
"""
with kernel() as kc:
thread_msg_id = kc.execute(code)
_ = kc.execute("pass")

received = 0
while received < iterations:
msg = kc.get_iopub_msg(timeout=interval * 2)
if msg["msg_type"] != "stream":
continue
content = msg["content"]
assert content["name"] == "stdout"
assert content["text"] == str(received)
# this is crucial as the parent header decides to which cell the output goes
assert msg["parent_header"]["msg_id"] == thread_msg_id
received += 1


def test_print_to_correct_cell_from_asyncio():
"""should print to the cell that scheduled the task, not a subsequently run cell"""
iterations = 5
interval = 0.25
code = f"""\
import asyncio

async def async_task():
for i in range({iterations}):
print(i, end='', flush=True)
await asyncio.sleep({interval})

loop = asyncio.get_event_loop()
loop.create_task(async_task());
"""
with kernel() as kc:
thread_msg_id = kc.execute(code)
_ = kc.execute("pass")

received = 0
while received < iterations:
msg = kc.get_iopub_msg(timeout=interval * 2)
if msg["msg_type"] != "stream":
continue
content = msg["content"]
assert content["name"] == "stdout"
assert content["text"] == str(received)
# this is crucial as the parent header decides to which cell the output goes
assert msg["parent_header"]["msg_id"] == thread_msg_id
received += 1


@pytest.mark.skip(reason="Currently don't capture during test as pytest does its own capturing")
def test_capture_fd():
"""simple print statement in kernel"""
Expand Down
Loading