Skip to content

Commit

Permalink
Fix missing extra_params when constructing operands (mars-project#2999)
Browse files Browse the repository at this point in the history
  • Loading branch information
fyrestone authored and wjsi committed May 7, 2022
1 parent 53b4d16 commit fd91274
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mars/core/operand/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def __init__(self: OperandType, *args, **kwargs):
extra_names = (
set(kwargs) - set(self._FIELDS) - set(SchedulingHint.all_hint_names)
)
extras = AttributeDict((k, kwargs.pop(k)) for k in extra_names)
kwargs["extra_params"] = kwargs.pop("extra_params", extras)
extras = dict((k, kwargs.pop(k)) for k in extra_names)
kwargs["extra_params"] = AttributeDict(kwargs.pop("extra_params", extras))
self._extract_scheduling_hint(kwargs)
super().__init__(*args, **kwargs)

Expand Down
2 changes: 2 additions & 0 deletions mars/core/operand/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class MyOperand5(MyOperand4):


def test_execute():
op = MyOperand(extra_params={"my_extra_params": 1})
assert op.extra_params["my_extra_params"] == 1
MyOperand.register_executor(lambda *_: 2)
assert execute(dict(), MyOperand(_key="1")) == 2
assert execute(dict(), MyOperand2(_key="1")) == 2
Expand Down
9 changes: 8 additions & 1 deletion mars/oscar/backends/mars/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,14 @@ async def kill_sub_pool(
await asyncio.to_thread(process.join, 5)

async def is_sub_pool_alive(self, process: multiprocessing.Process):
return await asyncio.to_thread(process.is_alive)
try:
return await asyncio.to_thread(process.is_alive)
except RuntimeError as ex: # pragma: no cover
if "shutdown" not in str(ex):
# when atexit is triggered, the default pool might be shutdown
# and to_thread will fail
raise
return process.is_alive()

async def recover_sub_pool(self, address: str):
process_index = self._config.get_process_index(address)
Expand Down

0 comments on commit fd91274

Please sign in to comment.