Skip to content

Commit

Permalink
Refine session backend
Browse files Browse the repository at this point in the history
  • Loading branch information
刘宝 committed Apr 21, 2022
1 parent 41b464a commit 0411cd9
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 58 deletions.
30 changes: 12 additions & 18 deletions mars/deploy/oscar/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,22 @@ async def new_cluster_in_isolation(
cuda_devices: Union[List[int], str] = "auto",
subprocess_start_method: str = None,
backend: str = None,
config: Union[str, Dict] = None,
config: Union[Dict] = None,
web: bool = True,
timeout: float = None,
n_supervisor_process: int = 0,
) -> ClientType:
if subprocess_start_method is None:
subprocess_start_method = "spawn" if sys.platform == "win32" else "forkserver"
# load config file to dict.
if not config or isinstance(config, str):
config = load_config(config)
if backend is None:
backend = (
config.get("task", {})
.get("task_executor_config", {})
.get("backend", "mars")
)
cluster = LocalCluster(
address,
n_worker,
Expand All @@ -74,9 +83,7 @@ async def new_cluster_in_isolation(
n_supervisor_process,
)
await cluster.start()
return await LocalClient.create(
cluster, backend, cluster.execution_backend, timeout
)
return await LocalClient.create(cluster, backend, timeout)


async def new_cluster(
Expand Down Expand Up @@ -132,9 +139,6 @@ def __init__(
):
# load third party extensions.
init_extension_entrypoints()
# load config file to dict.
if not config or isinstance(config, str):
config = load_config(config)
self._address = address
self._subprocess_start_method = subprocess_start_method
self._config = config
Expand Down Expand Up @@ -180,14 +184,6 @@ def __init__(

self._exiting_check_task = None

@property
def execution_backend(self):
return (
self._config.get("task", {})
.get("task_executor_config", {})
.get("backend", "mars")
)

@property
def external_address(self):
return self._supervisor_pool.external_address
Expand Down Expand Up @@ -281,14 +277,12 @@ async def create(
cls,
cluster: LocalCluster,
backend: str = None,
execution_backend: str = None,
timeout: float = None,
) -> ClientType:
backend = backend or "oscar"
backend = backend or "mars"
session = await _new_session(
cluster.external_address,
backend=backend,
execution_backend=execution_backend,
default=True,
timeout=timeout,
)
Expand Down
49 changes: 17 additions & 32 deletions mars/deploy/oscar/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def init(
cls,
address: str,
session_id: str,
backend: str = "oscar",
backend: str = "mars",
new: bool = True,
**kwargs,
) -> "AbstractSession":
Expand Down Expand Up @@ -660,14 +660,6 @@ def fetch_log(
return fetch(tileables, self, offsets=offsets, sizes=sizes)


_type_name_to_session_cls: Dict[str, Type[AbstractAsyncSession]] = dict()


def register_session_cls(session_cls: Type[AbstractAsyncSession]):
_type_name_to_session_cls[session_cls.name] = session_cls
return session_cls


@dataclass
class ChunkFetchInfo:
tileable: TileableType
Expand Down Expand Up @@ -757,15 +749,12 @@ def gen_submit_tileable_graph(
return graph, to_execute_tileables


@register_session_cls
class _IsolatedSession(AbstractAsyncSession):
name = "oscar"

def __init__(
self,
address: str,
session_id: str,
execution_backend: str,
backend: str,
session_api: AbstractSessionAPI,
meta_api: AbstractMetaAPI,
lifecycle_api: AbstractLifecycleAPI,
Expand All @@ -778,7 +767,7 @@ def __init__(
request_rewriter: Callable = None,
):
super().__init__(address, session_id)
self._execution_backend = execution_backend
self._backend = backend
self._session_api = session_api
self._task_api = task_api
self._meta_api = meta_api
Expand Down Expand Up @@ -807,7 +796,7 @@ async def _init(
cls,
address: str,
session_id: str,
execution_backend: str,
backend: str,
new: bool = True,
timeout: float = None,
):
Expand All @@ -829,7 +818,7 @@ async def _init(
return cls(
address,
session_id,
execution_backend,
backend,
session_api,
meta_api,
lifecycle_api,
Expand All @@ -846,13 +835,13 @@ async def init(
cls,
address: str,
session_id: str,
backend: str,
new: bool = True,
timeout: float = None,
**kwargs,
) -> "AbstractAsyncSession":
init_local = kwargs.pop("init_local", False)
request_rewriter = kwargs.pop("request_rewriter", None)
execution_backend = kwargs.pop("execution_backend", "mars")
if init_local:
from .local import new_cluster_in_isolation

Expand All @@ -870,18 +859,18 @@ async def init(
return await _IsolatedWebSession._init(
address,
session_id,
backend,
new=new,
timeout=timeout,
request_rewriter=request_rewriter,
execution_backend=execution_backend,
)
else:
return await cls._init(
address,
session_id,
backend,
new=new,
timeout=timeout,
execution_backend=execution_backend,
)

async def _update_progress(self, task_id: str, progress: Progress):
Expand Down Expand Up @@ -1102,9 +1091,7 @@ async def fetch(self, *tileables, **kwargs) -> list:
unexpected_keys = ", ".join(list(kwargs.keys()))
raise TypeError(f"`fetch` got unexpected arguments: {unexpected_keys}")

fetcher = Fetcher.create(
self._execution_backend, get_storage_api=self._get_storage_api
)
fetcher = Fetcher.create(self._backend, get_storage_api=self._get_storage_api)

with enter_mode(build=True):
chunks = []
Expand Down Expand Up @@ -1330,7 +1317,7 @@ async def _init(
cls,
address: str,
session_id: str,
execution_backend: str,
backend: str,
new: bool = True,
timeout: float = None,
request_rewriter: Callable = None,
Expand All @@ -1355,7 +1342,7 @@ async def _init(
return cls(
address,
session_id,
execution_backend,
backend,
session_api,
meta_api,
lifecycle_api,
Expand Down Expand Up @@ -1430,13 +1417,12 @@ async def init(
cls,
address: str,
session_id: str,
backend: str = "oscar",
backend: str = "mars",
new: bool = True,
**kwargs,
) -> "AbstractSession":
session_cls = _type_name_to_session_cls[backend]
isolation = ensure_isolation_created(kwargs)
coro = session_cls.init(address, session_id, new=new, **kwargs)
coro = _IsolatedSession.init(address, session_id, backend, new=new, **kwargs)
fut = asyncio.run_coroutine_threadsafe(coro, isolation.loop)
isolated_session = await asyncio.wrap_future(fut)
return AsyncSession(address, session_id, isolated_session, isolation)
Expand Down Expand Up @@ -1602,13 +1588,12 @@ def init(
cls,
address: str,
session_id: str,
backend: str = "oscar",
backend: str = "mars",
new: bool = True,
**kwargs,
) -> "AbstractSession":
session_cls = _type_name_to_session_cls[backend]
isolation = ensure_isolation_created(kwargs)
coro = session_cls.init(address, session_id, new=new, **kwargs)
coro = _IsolatedSession.init(address, session_id, backend, new=new, **kwargs)
fut = asyncio.run_coroutine_threadsafe(coro, isolation.loop)
isolated_session = fut.result()
return SyncSession(address, session_id, isolated_session, isolation)
Expand Down Expand Up @@ -1978,7 +1963,7 @@ def _new_session_id():
async def _new_session(
address: str,
session_id: str = None,
backend: str = "oscar",
backend: str = "mars",
default: bool = False,
**kwargs,
) -> AbstractSession:
Expand All @@ -1996,7 +1981,7 @@ async def _new_session(
def new_session(
address: str = None,
session_id: str = None,
backend: str = "oscar",
backend: str = "mars",
default: bool = True,
new: bool = True,
**kwargs,
Expand Down
6 changes: 1 addition & 5 deletions mars/deploy/oscar/tests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,14 @@
_IsolatedSession,
AsyncSession,
ensure_isolation_created,
register_session_cls,
_ensure_sync,
)


CONFIG_FILE = os.path.join(os.path.dirname(__file__), "check_enabled_config.yml")


@register_session_cls
class CheckedSession(ObjectCheckMixin, _IsolatedSession):
name = "test"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._tileable_checked = dict()
Expand Down Expand Up @@ -99,7 +95,7 @@ def new_test_session(
address: str = None,
session_id: str = None,
default: bool = False,
backend: str = "test",
backend: str = "mars",
**kwargs,
):
isolation = ensure_isolation_created(kwargs)
Expand Down
6 changes: 3 additions & 3 deletions mars/services/subtask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Iterator
from typing import Any, Dict, List, Iterator, Tuple
from ...core import ChunkGraph
from ...core.operand import (
Fetch,
Expand All @@ -27,7 +27,7 @@ def iter_input_data_keys(
subtask: Subtask,
chunk_graph: ChunkGraph,
chunk_key_to_data_keys: Dict[str, List[str]],
) -> Iterator[str, bool]:
) -> Iterator[Tuple[str, bool]]:
"""An iterator yield (input data key, is shuffle)."""
data_keys = set()
for chunk in chunk_graph.iter_indep():
Expand All @@ -52,7 +52,7 @@ def get_mapper_data_keys(key: str, context: Dict[str, Any]) -> List[str]:

def iter_output_data(
chunk_graph: ChunkGraph, context: Dict[str, Any]
) -> Iterator[str, Any, bool]:
) -> Iterator[Tuple[str, Any, bool]]:
"""An iterator yield (output chunk key, output data, is shuffle)."""
data_keys = set()
for result_chunk in chunk_graph.result_chunks:
Expand Down

0 comments on commit 0411cd9

Please sign in to comment.