Skip to content

Commit

Permalink
Merge pull request #54277 from dwoz/win_runas_plus
Browse files Browse the repository at this point in the history
Win runas plus
  • Loading branch information
dwoz authored Aug 21, 2019
2 parents 93cf40a + f719591 commit de77762
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 25 deletions.
29 changes: 7 additions & 22 deletions salt/utils/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,40 +680,29 @@ def kill_children(self, *args, **kwargs):

class MultiprocessingProcess(multiprocessing.Process, NewStyleClassMixIn):

def __new__(cls, *args, **kwargs):
instance = super(MultiprocessingProcess, cls).__new__(cls)
# Patch the run method at runtime because decorating the run method
# with a function with a similar behavior would be ignored once this
# class'es run method is overridden.
instance._original_run = instance.run
instance.run = instance._run
return instance

def __init__(self, *args, **kwargs):
log_queue = kwargs.pop('log_queue', None)
log_queue_level = kwargs.pop('log_queue_level', None)
super(MultiprocessingProcess, self).__init__(*args, **kwargs)
if salt.utils.platform.is_windows():
# On Windows, subclasses should call super if they define
# __setstate__ and/or __getstate__
self._args_for_getstate = copy.copy(args)
self._kwargs_for_getstate = copy.copy(kwargs)

self.log_queue = kwargs.pop('log_queue', None)
self.log_queue = log_queue
if self.log_queue is None:
self.log_queue = salt.log.setup.get_multiprocessing_logging_queue()
else:
# Set the logging queue so that it can be retrieved later with
# salt.log.setup.get_multiprocessing_logging_queue().
salt.log.setup.set_multiprocessing_logging_queue(self.log_queue)

self.log_queue_level = kwargs.pop('log_queue_level', None)
self.log_queue_level = log_queue_level
if self.log_queue_level is None:
self.log_queue_level = salt.log.setup.get_multiprocessing_logging_level()
else:
salt.log.setup.set_multiprocessing_logging_level(self.log_queue_level)

# Call __init__ from 'multiprocessing.Process' only after removing
# 'log_queue' and 'log_queue_level' from kwargs.
super(MultiprocessingProcess, self).__init__(*args, **kwargs)

self._after_fork_methods = [
(MultiprocessingProcess._setup_process_logging, [self], {}),
]
Expand All @@ -737,10 +726,6 @@ def __getstate__(self):
kwargs['log_queue'] = self.log_queue
if 'log_queue_level' not in kwargs:
kwargs['log_queue_level'] = self.log_queue_level
# Remove the version of these in the parent process since
# they are no longer needed.
del self._args_for_getstate
del self._kwargs_for_getstate
return {'args': args,
'kwargs': kwargs,
'_after_fork_methods': self._after_fork_methods,
Expand All @@ -750,11 +735,11 @@ def __getstate__(self):
def _setup_process_logging(self):
salt.log.setup.setup_multiprocessing_logging(self.log_queue)

def _run(self):
def run(self):
for method, args, kwargs in self._after_fork_methods:
method(*args, **kwargs)
try:
return self._original_run()
return super(MultiprocessingProcess, self).run()
except SystemExit:
# These are handled by multiprocessing.Process._bootstrap()
raise
Expand Down
4 changes: 3 additions & 1 deletion salt/utils/win_runas.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def runas(cmdLine, username, password=None, cwd=None):
# Create the environment for the user
env = win32profile.CreateEnvironmentBlock(user_token, False)

hProcess = None
try:
# Start the process in a suspended state.
process_info = salt.platform.win.CreateProcessWithTokenW(
Expand Down Expand Up @@ -216,7 +217,8 @@ def runas(cmdLine, username, password=None, cwd=None):
stderr = f_err.read()
ret['stderr'] = stderr
finally:
salt.platform.win.kernel32.CloseHandle(hProcess)
if hProcess is not None:
salt.platform.win.kernel32.CloseHandle(hProcess)
win32api.CloseHandle(th)
win32api.CloseHandle(user_token)
if impersonation_token:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/utils/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_signal_processing_test_after_fork_called(self):
log_to_mock = 'salt.utils.process.MultiprocessingProcess._setup_process_logging'
with patch(sig_to_mock) as ma, patch(log_to_mock) as mb:
self.sh_proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=self.no_op_target)
self.sh_proc._run()
self.sh_proc.run()
ma.assert_called()
mb.assert_called()

Expand All @@ -342,7 +342,7 @@ def test_signal_processing_test_final_methods_called(self):
with patch(sig_to_mock):
with patch(teardown_to_mock) as ma, patch(log_to_mock) as mb:
self.sh_proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=self.no_op_target)
self.sh_proc._run()
self.sh_proc.run()
ma.assert_called()
mb.assert_called()

Expand Down

0 comments on commit de77762

Please sign in to comment.