Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refined semantics for Settings thread-safety #1947

Merged
merged 2 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 40 additions & 33 deletions dspy/dsp/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,47 @@
async_max_workers=8,
)

# Global base configuration
# Global base configuration and owner tracking
main_thread_config = copy.deepcopy(DEFAULT_CONFIG)
config_owner_thread_id = None

# Global lock for settings configuration
global_lock = threading.Lock()

class ThreadLocalOverrides(threading.local):
def __init__(self):
self.overrides = dotdict() # Initialize thread-local overrides
self.overrides = dotdict()


# Create the thread-local storage
thread_local_overrides = ThreadLocalOverrides()


class Settings:
"""
A singleton class for DSPy configuration settings.

This is thread-safe. User threads are supported both through ParallelExecutor and native threading.
- If native threading is used, the thread inherits the initial config from the main thread.
- If ParallelExecutor is used, the thread inherits the initial config from its parent thread.
Thread-safe global configuration.
- 'configure' can be called by only one 'owner' thread (the first thread that calls it).
- Other threads see the configured global values from 'main_thread_config'.
- 'context' sets thread-local overrides. These overrides propagate to threads spawned
inside that context block, when (and only when!) using a ParallelExecutor that copies overrides.

1. Only one unique thread (which can be any thread!) can call dspy.configure.
2. It affects a global state, visible to all. As a result, user threads work, but they shouldn't be
mixed with concurrent changes to dspy.configure from the "main" thread.
(TODO: In the future, add warnings: if there are near-in-time user-thread reads followed by .configure calls.)
3. Any thread can use dspy.context. It propagates to child threads created with DSPy primitives: Parallel, asyncify, etc.
"""

_instance = None

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.lock = threading.Lock() # maintained here for DSPy assertions.py
return cls._instance

@property
def lock(self):
return global_lock

def __getattr__(self, name):
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
if name in overrides:
Expand All @@ -64,8 +75,6 @@ def __setattr__(self, name, value):
else:
self.configure(**{name: value})

# Dictionary-like access

def __getitem__(self, key):
return self.__getattr__(key)

Expand All @@ -88,42 +97,40 @@ def copy(self):

@property
def config(self):
config = self.copy()
if 'lock' in config:
del config['lock']
return config

# Configuration methods
return self.copy()

def configure(self, **kwargs):
global main_thread_config
global main_thread_config, config_owner_thread_id
current_thread_id = threading.get_ident()

# Get or initialize thread-local overrides
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
thread_local_overrides.overrides = dotdict(
{**copy.deepcopy(DEFAULT_CONFIG), **main_thread_config, **overrides, **kwargs}
)
with self.lock:
# First configuration: establish ownership. If ownership established, only that thread can configure.
if config_owner_thread_id in [None, current_thread_id]:
config_owner_thread_id = current_thread_id
else:
raise RuntimeError("dspy.settings can only be changed by the thread that initially configured it.")

# Update main_thread_config, in the main thread only
if threading.current_thread() is threading.main_thread():
main_thread_config = thread_local_overrides.overrides
# Update global config
for k, v in kwargs.items():
main_thread_config[k] = v

@contextmanager
def context(self, **kwargs):
"""Context manager for temporary configuration changes."""
global main_thread_config
"""
Context manager for temporary configuration changes at the thread level.
Does not affect global configuration. Changes only apply to the current thread.
If threads are spawned inside this block using ParallelExecutor, they will inherit these overrides.
"""

original_overrides = getattr(thread_local_overrides, 'overrides', dotdict()).copy()
original_main_thread_config = main_thread_config.copy()
new_overrides = dotdict({**main_thread_config, **original_overrides, **kwargs})
thread_local_overrides.overrides = new_overrides

self.configure(**kwargs)
try:
yield
finally:
thread_local_overrides.overrides = original_overrides

if threading.current_thread() is threading.main_thread():
main_thread_config = original_main_thread_config

def __repr__(self):
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
combined_config = {**main_thread_config, **overrides}
Expand Down
44 changes: 23 additions & 21 deletions dspy/utils/asyncify.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

def get_async_max_workers():
import dspy

return dspy.settings.async_max_workers


Expand All @@ -31,28 +30,31 @@ def asyncify(program: Module) -> Callable[[Any, Any], Awaitable[Any]]:
Wraps a DSPy program so that it can be called asynchronously. This is useful for running a
program in parallel with another task (e.g., another DSPy program).

This implementation propagates the current thread's configuration context to the worker thread.

Args:
program: The DSPy program to be wrapped for asynchronous execution.

Returns:
A function that takes the same arguments as the program, but returns an awaitable that
resolves to the program's output.

Example:
>>> class TestSignature(dspy.Signature):
>>> input_text: str = dspy.InputField()
>>> output_text: str = dspy.OutputField()
>>>
>>> # Create the program and wrap it for asynchronous execution
>>> program = dspy.asyncify(dspy.Predict(TestSignature))
>>>
>>> # Use the program asynchronously
>>> async def get_prediction():
>>> prediction = await program(input_text="Test")
>>> print(prediction) # Handle the result of the asynchronous execution
An async function that, when awaited, runs the program in a worker thread. The current
thread's configuration context is inherited for each call.
"""
import threading

assert threading.current_thread() is threading.main_thread(), "asyncify can only be called from the main thread"
# NOTE: To allow this to be nested, we'd need behavior with contextvars like parallelizer.py
return asyncer.asyncify(program, abandon_on_cancel=True, limiter=get_limiter())
async def async_program(*args, **kwargs) -> Any:
# Capture the current overrides at call-time.
from dspy.dsp.utils.settings import thread_local_overrides
parent_overrides = thread_local_overrides.overrides.copy()

def wrapped_program(*a, **kw):
from dspy.dsp.utils.settings import thread_local_overrides
original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = parent_overrides.copy()
try:
return program(*a, **kw)
finally:
thread_local_overrides.overrides = original_overrides

# Create a fresh asyncified callable each time, ensuring the latest context is used.
call_async = asyncer.asyncify(wrapped_program, abandon_on_cancel=True, limiter=get_limiter())
return await call_async(*args, **kwargs)

return async_program
17 changes: 10 additions & 7 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

logger = logging.getLogger(__name__)


class ParallelExecutor:
def __init__(
self,
Expand All @@ -20,7 +21,6 @@ def __init__(
compare_results=False,
):
"""Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1."""

self.num_threads = num_threads
self.disable_progress_bar = disable_progress_bar
self.max_errors = max_errors
Expand Down Expand Up @@ -72,15 +72,17 @@ def _execute_isolated_single_thread(self, function, data):
file=sys.stdout
)

from dspy.dsp.utils.settings import thread_local_overrides
original_overrides = thread_local_overrides.overrides

for item in data:
with logging_redirect_tqdm():
if self.cancel_jobs.is_set():
break

# Create an isolated context for each task using thread-local overrides
from dspy.dsp.utils.settings import thread_local_overrides
original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = thread_local_overrides.overrides.copy()
# Create an isolated context for each task by copying current overrides
# This way, even if an iteration modifies the overrides, it won't affect subsequent iterations
thread_local_overrides.overrides = original_overrides.copy()

try:
result = function(item)
Expand Down Expand Up @@ -122,6 +124,8 @@ def _execute_multi_thread(self, function, data):
@contextlib.contextmanager
def interrupt_handler_manager():
"""Sets the cancel_jobs event when a SIGINT is received, only in the main thread."""

# TODO: Is this check conducive to nested usage of ParallelExecutor?
if threading.current_thread() is threading.main_thread():
default_handler = signal.getsignal(signal.SIGINT)

Expand All @@ -145,7 +149,7 @@ def cancellable_function(parent_overrides, index_item):
if self.cancel_jobs.is_set():
return index, job_cancelled

# Create an isolated context for each task using thread-local overrides
# Create an isolated context for each task by copying parent's overrides
from dspy.dsp.utils.settings import thread_local_overrides
original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = parent_overrides.copy()
Expand All @@ -156,7 +160,6 @@ def cancellable_function(parent_overrides, index_item):
thread_local_overrides.overrides = original_overrides

with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager():
# Capture the parent thread's overrides
from dspy.dsp.utils.settings import thread_local_overrides
parent_overrides = thread_local_overrides.overrides.copy()

Expand Down
Loading