diff --git a/dill/_dill.py b/dill/_dill.py index d42432ff..53738dfb 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -56,6 +56,14 @@ from pickle import GLOBAL, POP from _thread import LockType from _thread import RLock as RLockType +try: + from _thread import _ExceptHookArgs as ExceptHookArgsType +except ImportError: + ExceptHookArgsType = None +try: + from _thread import _ThreadHandle as ThreadHandleType +except ImportError: + ThreadHandleType = None #from io import IOBase from types import CodeType, FunctionType, MethodType, GeneratorType, \ TracebackType, FrameType, ModuleType, BuiltinMethodType @@ -775,6 +783,14 @@ def _create_typing_tuple(argz, *args): #NOTE: workaround python/cpython#94245 return typing.Tuple[()] return typing.Tuple[argz] +if ThreadHandleType: + def _create_thread_handle(ident, done, *args): #XXX: ignores 'blocking' + from threading import _make_thread_handle + handle = _make_thread_handle(ident) + if done: + handle._set_done() + return handle + def _create_lock(locked, *args): #XXX: ignores 'blocking' from threading import Lock lock = Lock() @@ -1306,7 +1322,15 @@ def save_generic_alias(pickler, obj): logger.trace(pickler, "# Ga2") return -@register(LockType) +if ThreadHandleType: + @register(ThreadHandleType) + def save_thread_handle(pickler, obj): + logger.trace(pickler, "Th: %s", obj) + pickler.save_reduce(_create_thread_handle, (obj.ident, obj.is_done()), obj=obj) + logger.trace(pickler, "# Th") + return + +@register(LockType) #XXX: copied Thread will have new Event (due to new Lock) def save_lock(pickler, obj): logger.trace(pickler, "Lo: %s", obj) pickler.save_reduce(_create_lock, (obj.locked(),), obj=obj) @@ -1773,7 +1797,7 @@ def save_type(pickler, obj, postproc_list=None): logger.trace(pickler, "# T6") return - # special cases: NoneType, NotImplementedType, EllipsisType, EnumMeta + # special caes: NoneType, NotImplementedType, EllipsisType, EnumMeta, etc elif obj is type(None): logger.trace(pickler, "T7: %s", obj) #XXX: pickler.save_reduce(type, (None,), obj=obj) @@ -1791,6 +1815,10 @@ def save_type(pickler, obj, postproc_list=None): logger.trace(pickler, "T7: %s", obj) pickler.write(GLOBAL + b'enum\nEnumMeta\n') logger.trace(pickler, "# T7") + elif obj is ExceptHookArgsType: #NOTE: must be after NoneType for pypy + logger.trace(pickler, "T7: %s", obj) + pickler.write(GLOBAL + b'threading\nExceptHookArgs\n') + logger.trace(pickler, "# T7") else: _byref = getattr(pickler, '_byref', None) diff --git a/dill/tests/test_threads.py b/dill/tests/test_threads.py new file mode 100644 index 00000000..45f1f58c --- /dev/null +++ b/dill/tests/test_threads.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# +# Author: Mike McKerns (mmckerns @caltech and @uqfoundation) +# Copyright (c) 2024 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE + +import dill +dill.settings['recurse'] = True + + +def test_new_thread(): + import threading + t = threading.Thread() + t_ = dill.copy(t) + assert t.is_alive() == t_.is_alive() + for i in ['daemon','name','ident','native_id']: + if hasattr(t, i): + assert getattr(t, i) == getattr(t_, i) + +def test_run_thread(): + import threading + t = threading.Thread() + t.start() + t_ = dill.copy(t) + assert t.is_alive() == t_.is_alive() + for i in ['daemon','name','ident','native_id']: + if hasattr(t, i): + assert getattr(t, i) == getattr(t_, i) + +def test_join_thread(): + import threading + t = threading.Thread() + t.start() + t.join() + t_ = dill.copy(t) + assert t.is_alive() == t_.is_alive() + for i in ['daemon','name','ident','native_id']: + if hasattr(t, i): + assert getattr(t, i) == getattr(t_, i) + + +if __name__ == '__main__': + test_new_thread() + test_run_thread() + test_join_thread()