diff --git a/ding/config/config.py b/ding/config/config.py index 0403d50304..51a347de7c 100644 --- a/ding/config/config.py +++ b/ding/config/config.py @@ -315,7 +315,7 @@ def save_project_state(exp_name: str) -> None: def _fn(cmd: str): return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.strip().decode("utf-8") - if subprocess.run("git status", shell=True, stderr=subprocess.PIPE).returncode == 0: + if subprocess.run("git status", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).returncode == 0: short_sha = _fn("git describe --always") log = _fn("git log --stat -n 5") diff = _fn("git diff") diff --git a/ding/data/__init__.py b/ding/data/__init__.py index 79ac868c86..b72987cac9 100644 --- a/ding/data/__init__.py +++ b/ding/data/__init__.py @@ -1,3 +1,7 @@ from torch.utils.data import Dataset, DataLoader from ding.utils.data import create_dataset, offline_data_save_type # for compatibility from .buffer import * +from .storage import * +from .storage_loader import StorageLoader, FileStorageLoader +from .shm_buffer import ShmBufferContainer, ShmBuffer +from .model_loader import ModelLoader, FileModelLoader diff --git a/ding/data/buffer/tests/test_benchmark.py b/ding/data/buffer/tests/test_buffer_benchmark.py similarity index 100% rename from ding/data/buffer/tests/test_benchmark.py rename to ding/data/buffer/tests/test_buffer_benchmark.py diff --git a/ding/data/model_loader.py b/ding/data/model_loader.py new file mode 100644 index 0000000000..cd3182897b --- /dev/null +++ b/ding/data/model_loader.py @@ -0,0 +1,155 @@ +from abc import ABC, abstractmethod +import logging +from os import path +import os +from threading import Thread +from time import sleep, time +from typing import Callable, Optional +import uuid +import torch.multiprocessing as mp + +import torch +from ding.data.storage.file import FileModelStorage +from ding.data.storage.storage import Storage +from ding.framework import Supervisor +from ding.framework.supervisor import ChildType, SendPayload + + +class ModelWorker(): + + def __init__(self, model: torch.nn.Module) -> None: + self._model = model + + def save(self, storage: Storage) -> Storage: + storage.save(self._model.state_dict()) + return storage + + +class ModelLoader(Supervisor, ABC): + + def __init__(self, model: torch.nn.Module) -> None: + """ + Overview: + Save and send models asynchronously and load them synchronously. + Arguments: + - model (:obj:`torch.nn.Module`): Torch module. + """ + if next(model.parameters()).is_cuda: + super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn")) + else: + super().__init__(type_=ChildType.PROCESS) + self._model = model + self._send_callback_loop = None + self._send_callbacks = {} + self._model_worker = ModelWorker(self._model) + + def start(self): + if not self._running: + self._model.share_memory() + self.register(self._model_worker) + self.start_link() + self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True) + self._send_callback_loop.start() + + def shutdown(self, timeout: Optional[float] = None) -> None: + super().shutdown(timeout) + self._send_callback_loop = None + self._send_callbacks = {} + + def _loop_send_callback(self): + while True: + payload = self.recv(ignore_err=True) + if payload.err: + logging.warning("Got error when loading data: {}".format(payload.err)) + if payload.req_id in self._send_callbacks: + del self._send_callbacks[payload.req_id] + else: + if payload.req_id in self._send_callbacks: + callback = self._send_callbacks.pop(payload.req_id) + callback(payload.data) + + def load(self, storage: Storage) -> object: + """ + Overview: + Load model synchronously. + Arguments: + - storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage. + Returns: + - object (:obj:): The loaded model. + """ + return storage.load() + + @abstractmethod + def save(self, callback: Callable) -> Storage: + """ + Overview: + Save model asynchronously. + Arguments: + - callback (:obj:`Callable`): The callback function after saving model. + Returns: + - storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned. + """ + raise NotImplementedError + + +class FileModelLoader(ModelLoader): + + def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None: + """ + Overview: + Model loader using files as storage media. + Arguments: + - model (:obj:`torch.nn.Module`): Torch module. + - dirname (:obj:`str`): The directory for saving files. + - ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \ + files that do not time out when the process is stopped are not cleaned up \ + (to avoid errors when other processes read the file), so you may need to \ + clean up the remaining files manually + """ + super().__init__(model) + self._dirname = dirname + self._ttl = ttl + self._files = [] + self._cleanup_thread = None + + def _start_cleanup(self): + """ + Overview: + Start a cleanup thread to clean up files that are taking up too much time on the disk. + """ + if self._cleanup_thread is None: + self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True) + self._cleanup_thread.start() + + def shutdown(self, timeout: Optional[float] = None) -> None: + super().shutdown(timeout) + self._cleanup_thread = None + + def _loop_cleanup(self): + while True: + if len(self._files) == 0 or time() - self._files[0][0] < self._ttl: + sleep(1) + continue + _, file_path = self._files.pop(0) + if path.exists(file_path): + os.remove(file_path) + + def save(self, callback: Callable) -> FileModelStorage: + if not self._running: + logging.warning("Please start model loader before saving model.") + return + if not path.exists(self._dirname): + os.mkdir(self._dirname) + file_path = "model_{}.pth.tar".format(uuid.uuid1()) + file_path = path.join(self._dirname, file_path) + model_storage = FileModelStorage(file_path) + payload = SendPayload(proc_id=0, method="save", args=[model_storage]) + self.send(payload) + + def clean_callback(storage: Storage): + self._files.append([time(), file_path]) + callback(storage) + + self._send_callbacks[payload.req_id] = clean_callback + self._start_cleanup() + return model_storage diff --git a/ding/data/shm_buffer.py b/ding/data/shm_buffer.py new file mode 100644 index 0000000000..b76f5d56e9 --- /dev/null +++ b/ding/data/shm_buffer.py @@ -0,0 +1,133 @@ +from typing import Any, Optional, Union, Tuple, Dict +from multiprocessing import Array +import ctypes +import numpy as np +import torch + +_NTYPE_TO_CTYPE = { + np.bool_: ctypes.c_bool, + np.uint8: ctypes.c_uint8, + np.uint16: ctypes.c_uint16, + np.uint32: ctypes.c_uint32, + np.uint64: ctypes.c_uint64, + np.int8: ctypes.c_int8, + np.int16: ctypes.c_int16, + np.int32: ctypes.c_int32, + np.int64: ctypes.c_int64, + np.float32: ctypes.c_float, + np.float64: ctypes.c_double, +} + + +class ShmBuffer(): + """ + Overview: + Shared memory buffer to store numpy array. + """ + + def __init__( + self, + dtype: Union[type, np.dtype], + shape: Tuple[int], + copy_on_get: bool = True, + ctype: Optional[type] = None + ) -> None: + """ + Overview: + Initialize the buffer. + Arguments: + - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. + - shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer. + - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. + - ctype (:obj:`Optional[type]`): Origin class type, e.g. np.ndarray, torch.Tensor. + """ + if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype + dtype = dtype.type + self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape))) + self.dtype = dtype + self.shape = shape + self.copy_on_get = copy_on_get + self.ctype = ctype + + def fill(self, src_arr: np.ndarray) -> None: + """ + Overview: + Fill the shared memory buffer with a numpy array. (Replace the original one.) + Arguments: + - src_arr (:obj:`np.ndarray`): array to fill the buffer. + """ + assert isinstance(src_arr, np.ndarray), type(src_arr) + # for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten + # for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten + # so we reshape dst_arr rather than flatten src_arr + dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) + np.copyto(dst_arr, src_arr) + + def get(self) -> np.ndarray: + """ + Overview: + Get the array stored in the buffer. + Return: + - data (:obj:`np.ndarray`): A copy of the data stored in the buffer. + """ + data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) + if self.copy_on_get: + data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory + if self.ctype is torch.Tensor: + data = torch.from_numpy(data) + return data + + +class ShmBufferContainer(object): + """ + Overview: + Support multiple shared memory buffers. Each key-value is name-buffer. + """ + + def __init__( + self, + dtype: Union[Dict[Any, type], type, np.dtype], + shape: Union[Dict[Any, tuple], tuple], + copy_on_get: bool = True + ) -> None: + """ + Overview: + Initialize the buffer container. + Arguments: + - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. + - shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \ + multiple buffers; If `tuple`, use single buffer. + - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. + """ + if isinstance(shape, dict): + self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()} + elif isinstance(shape, (tuple, list)): + self._data = ShmBuffer(dtype, shape, copy_on_get) + else: + raise RuntimeError("not support shape: {}".format(shape)) + self._shape = shape + + def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None: + """ + Overview: + Fill the one or many shared memory buffer. + Arguments: + - src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer. + """ + if isinstance(self._shape, dict): + for k in self._shape.keys(): + self._data[k].fill(src_arr[k]) + elif isinstance(self._shape, (tuple, list)): + self._data.fill(src_arr) + + def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]: + """ + Overview: + Get the one or many arrays stored in the buffer. + Return: + - data (:obj:`np.ndarray`): The array(s) stored in the buffer. + """ + if isinstance(self._shape, dict): + return {k: self._data[k].get() for k in self._shape.keys()} + elif isinstance(self._shape, (tuple, list)): + return self._data.get() diff --git a/ding/data/storage/__init__.py b/ding/data/storage/__init__.py new file mode 100644 index 0000000000..962fbbbf18 --- /dev/null +++ b/ding/data/storage/__init__.py @@ -0,0 +1,2 @@ +from .storage import Storage +from .file import FileStorage, FileModelStorage diff --git a/ding/data/storage/file.py b/ding/data/storage/file.py new file mode 100644 index 0000000000..e6a89910b8 --- /dev/null +++ b/ding/data/storage/file.py @@ -0,0 +1,25 @@ +from typing import Any +from ding.data.storage import Storage +import pickle + +from ding.utils.file_helper import read_file, save_file + + +class FileStorage(Storage): + + def save(self, data: Any) -> None: + with open(self.path, "wb") as f: + pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) + + def load(self) -> Any: + with open(self.path, "rb") as f: + return pickle.load(f) + + +class FileModelStorage(Storage): + + def save(self, state_dict: object) -> None: + save_file(self.path, state_dict) + + def load(self) -> object: + return read_file(self.path) diff --git a/ding/data/storage/storage.py b/ding/data/storage/storage.py new file mode 100644 index 0000000000..e6a0dae679 --- /dev/null +++ b/ding/data/storage/storage.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class Storage(ABC): + + def __init__(self, path: str) -> None: + self.path = path + + @abstractmethod + def save(self, data: Any) -> None: + raise NotImplementedError + + @abstractmethod + def load(self) -> Any: + raise NotImplementedError diff --git a/ding/data/storage/tests/test_storage.py b/ding/data/storage/tests/test_storage.py new file mode 100644 index 0000000000..8f6f1d2c47 --- /dev/null +++ b/ding/data/storage/tests/test_storage.py @@ -0,0 +1,18 @@ +import tempfile +import pytest +import os +from os import path +from ding.data.storage import FileStorage + + +@pytest.mark.unittest +def test_file_storage(): + path_ = path.join(tempfile.gettempdir(), "test_storage.txt") + try: + storage = FileStorage(path=path_) + storage.save("test") + content = storage.load() + assert content == "test" + finally: + if path.exists(path_): + os.remove(path_) diff --git a/ding/data/storage_loader.py b/ding/data/storage_loader.py new file mode 100644 index 0000000000..daf18e2d82 --- /dev/null +++ b/ding/data/storage_loader.py @@ -0,0 +1,305 @@ +from dataclasses import dataclass +import os +import torch +import numpy as np +import uuid +import treetensor.torch as ttorch +from abc import ABC, abstractmethod +from ditk import logging +from time import sleep, time +from threading import Lock, Thread +from typing import Any, Callable, Dict, List, Optional, Union +from ding.data import FileStorage, Storage +from os import path +from ding.data.shm_buffer import ShmBuffer +from ding.framework.supervisor import RecvPayload, Supervisor, ChildType, SendPayload + + +@dataclass +class ShmObject: + id_: ShmBuffer + buf: Any + + +class StorageWorker: + + def load(self, storage: Storage) -> Any: + return storage.load() + + +class StorageLoader(Supervisor, ABC): + + def __init__(self, worker_num: int = 3) -> None: + """ + Overview: + Save and send data synchronously and load them asynchronously. + Arguments: + - worker_num (:obj:`int`): Subprocess worker number. + """ + super().__init__(type_=ChildType.PROCESS) + self._load_lock = Lock() # Load (first meet) should be called one by one. + self._callback_map: Dict[str, Callable] = {} + self._shm_obj_map: Dict[int, ShmObject] = {} + self._worker_num = worker_num + self._req_count = 0 + + def shutdown(self, timeout: Optional[float] = None) -> None: + super().shutdown(timeout) + self._recv_loop = None + self._callback_map = {} + self._shm_obj_map = {} + self._req_count = 0 + + def start_link(self) -> None: + if not self._running: + super().start_link() + self._recv_loop = Thread(target=self._loop_recv, daemon=True) + self._recv_loop.start() + + @property + def _next_proc_id(self): + return self._req_count % self._worker_num + + @abstractmethod + def save(self, obj: Union[Dict, List]) -> Storage: + """ + Overview: + Save data with a storage object synchronously. + Arguments: + - obj (:obj:`Union[Dict, List]`): The data (traj or episodes), can be numpy, tensor or treetensor. + Returns: + - storage (:obj:`Storage`): The storage object. + """ + raise NotImplementedError + + def load(self, storage: Storage, callback: Callable): + """ + Overview: + Load data from a storage object asynchronously. \ + This function will analysis the data structure when first meet a new data, \ + then alloc a shared memory buffer for each subprocess, these shared memory buffer \ + will be responsible for asynchronously loading data into memory. + Arguments: + - storage (:obj:`Storage`): The storage object. + - callback (:obj:`Callable`): Callback function after data loaded. + """ + with self._load_lock: + if not self._running: + self._first_meet(storage, callback) + return + + payload = SendPayload(proc_id=self._next_proc_id, method="load", args=[storage]) + self._callback_map[payload.req_id] = callback + self.send(payload) + self._req_count += 1 + + def _first_meet(self, storage: Storage, callback: Callable): + """ + Overview: + When first meet an object type, we'll load this object directly and analysis the structure, + to allocate the shared memory object and create subprocess workers. + Arguments: + - storage (:obj:`Storage`): The storage object. + - callback (:obj:`Callable`): Callback function after data loaded. + """ + obj = storage.load() + # Create three workers for each usage type. + for i in range(self._worker_num): + shm_obj = self._create_shm_buffer(obj) + self._shm_obj_map[i] = shm_obj + self.register(StorageWorker, shm_buffer=shm_obj, shm_callback=self._shm_callback) + self.start_link() + callback(obj) + + def _loop_recv(self): + while True: + payload = self.recv(ignore_err=True) + if payload.err: + logging.warning("Got error when loading data: {}".format(payload.err)) + if payload.req_id in self._callback_map: + del self._callback_map[payload.req_id] + else: + self._shm_putback(payload, self._shm_obj_map[payload.proc_id]) + if payload.req_id in self._callback_map: + callback = self._callback_map.pop(payload.req_id) + callback(payload.data) + + def _create_shm_buffer(self, obj: Union[Dict, List]) -> Optional[ShmObject]: + """ + Overview: + Create shared object (buf and callback) by walk through the data structure. + Arguments: + - obj (:obj:`Union[Dict, List]`): The data (traj or episodes), can be numpy, tensor or treetensor. + Returns: + - shm_buf (:obj:`Optional[ShmObject]`): The shared memory buffer. + """ + max_level = 2 + + def to_shm(obj: Dict, level: int): + if level > max_level: + return + shm_buf = None + if isinstance(obj, Dict) or isinstance(obj, ttorch.Tensor): + shm_buf = {} + for key, val in obj.items(): + # Only numpy array can fill into shm buffer + if isinstance(val, np.ndarray): + shm_buf[key] = ShmBuffer(val.dtype, val.shape, copy_on_get=False) + elif isinstance(val, torch.Tensor): + shm_buf[key] = ShmBuffer( + val.numpy().dtype, val.numpy().shape, copy_on_get=False, ctype=torch.Tensor + ) + # Recursive parsing structure + elif isinstance(val, Dict) or isinstance(val, ttorch.Tensor) or isinstance(val, List): + buf = to_shm(val, level=level + 1) + if buf: + shm_buf[key] = buf + elif isinstance(obj, List): + # Double the size of buffer + shm_buf = [to_shm(o, level=level) for o in obj] * 2 + if all(s is None for s in shm_buf): + shm_buf = [] + return shm_buf + + shm_buf = to_shm(obj, level=0) + if shm_buf is not None: + random_id = self._random_id() + shm_buf = ShmObject(id_=ShmBuffer(random_id.dtype, random_id.shape, copy_on_get=False), buf=shm_buf) + return shm_buf + + def _random_id(self) -> np.ndarray: + return np.random.randint(1, 9e6, size=(1)) + + def _shm_callback(self, payload: RecvPayload, shm_obj: ShmObject): + """ + Overview: + Called in subprocess, put payload.data into buf. + Arguments: + - payload (:obj:`RecvPayload`): The recv payload with meta info of the data. + - shm_obj (:obj:`ShmObject`): The shm buffer. + """ + assert isinstance(payload.data, type( + shm_obj.buf + )), "Data type ({}) and buf type ({}) are not match!".format(type(payload.data), type(shm_obj.buf)) + + # Sleep while shm object is not ready. + while shm_obj.id_.get()[0] != 0: + sleep(0.001) + + max_level = 2 + + def shm_callback(data: Union[Dict, List, ttorch.Tensor], buf: Union[Dict, List], level: int): + if level > max_level: + return + + if isinstance(buf, List): + assert isinstance(data, List), "Data ({}) and buf ({}) type not match".format(type(data), type(buf)) + elif isinstance(buf, Dict): + assert isinstance(data, ttorch.Tensor) or isinstance( + data, Dict + ), "Data ({}) and buf ({}) type not match".format(type(data), type(buf)) + + if isinstance(data, Dict) or isinstance(data, ttorch.Tensor): + for key, val in data.items(): + if isinstance(val, torch.Tensor): + val = val.numpy() + buf_val = buf.get(key) + if buf_val is None: + continue + if isinstance(buf_val, ShmBuffer) and isinstance(val, np.ndarray): + buf_val.fill(val) + data[key] = None + else: + shm_callback(val, buf_val, level=level + 1) + elif isinstance(data, List): + for i, data_ in enumerate(data): + shm_callback(data_, buf[i], level=level) + + shm_callback(payload.data, buf=shm_obj.buf, level=0) + id_ = self._random_id() + shm_obj.id_.fill(id_) + payload.extra = id_ + + def _shm_putback(self, payload: RecvPayload, shm_obj: ShmObject): + """ + Overview: + Called in main process, put buf back into payload.data. + Arguments: + - payload (:obj:`RecvPayload`): The recv payload with meta info of the data. + - shm_obj (:obj:`ShmObject`): The shm buffer. + """ + assert isinstance(payload.data, type( + shm_obj.buf + )), "Data type ({}) and buf type ({}) are not match!".format(type(payload.data), type(shm_obj.buf)) + + assert shm_obj.id_.get()[0] == payload.extra[0], "Shm object and payload do not match ({} - {}).".format( + shm_obj.id_.get()[0], payload.extra[0] + ) + + def shm_putback(data: Union[Dict, List], buf: Union[Dict, List]): + if isinstance(data, Dict) or isinstance(data, ttorch.Tensor): + for key, val in data.items(): + buf_val = buf.get(key) + if buf_val is None: + continue + if val is None and isinstance(buf_val, ShmBuffer): + data[key] = buf[key].get() + else: + shm_putback(val, buf_val) + elif isinstance(data, List): + for i, data_ in enumerate(data): + shm_putback(data_, buf[i]) + + shm_putback(payload.data, buf=shm_obj.buf) + shm_obj.id_.fill(np.array([0])) + + +class FileStorageLoader(StorageLoader): + + def __init__(self, dirname: str, ttl: int = 20, worker_num: int = 3) -> None: + """ + Overview: + Dump and load object with file storage. + Arguments: + - dirname (:obj:`str`): The directory to save files. + - ttl (:obj:`str`): Maximum time to keep a file, after which it will be deleted. + - worker_num (:obj:`int`): Number of subprocess worker loaders. + """ + super().__init__(worker_num) + self._dirname = dirname + self._files = [] + self._cleanup_thread = None + self._ttl = ttl # # Delete files created 10 minutes ago. + + def save(self, obj: Union[Dict, List]) -> FileStorage: + if not path.exists(self._dirname): + os.mkdir(self._dirname) + filename = "{}.pkl".format(uuid.uuid1()) + full_path = path.join(self._dirname, filename) + f = FileStorage(full_path) + f.save(obj) + self._files.append([time(), f.path]) + self._start_cleanup() + return f + + def _start_cleanup(self): + """ + Overview: + Start a cleanup thread to clean up files that are taking up too much time on the disk. + """ + if self._cleanup_thread is None: + self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True) + self._cleanup_thread.start() + + def shutdown(self, timeout: Optional[float] = None) -> None: + super().shutdown(timeout) + self._cleanup_thread = None + + def _loop_cleanup(self): + while True: + if len(self._files) == 0 or time() - self._files[0][0] < self._ttl: + sleep(1) + continue + _, file_path = self._files.pop(0) + if path.exists(file_path): + os.remove(file_path) diff --git a/ding/data/tests/test_model_loader.py b/ding/data/tests/test_model_loader.py new file mode 100644 index 0000000000..caf8c07186 --- /dev/null +++ b/ding/data/tests/test_model_loader.py @@ -0,0 +1,74 @@ +import shutil +import tempfile +from time import sleep, time +import pytest +from ding.data.model_loader import FileModelLoader +from ding.data.storage.file import FileModelStorage +from ding.model import DQN +from ding.config import compile_config +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config +from os import path +import torch + + +@pytest.mark.tmp # gitlab ci and local test pass, github always fail +def test_model_loader(): + tempdir = path.join(tempfile.gettempdir(), "test_model_loader") + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + model = DQN(**cfg.policy.model) + loader = FileModelLoader(model=model, dirname=tempdir, ttl=1) + try: + loader.start() + model_storage = None + + def save_model(storage): + nonlocal model_storage + model_storage = storage + + start = time() + loader.save(save_model) + save_time = time() - start + print("Save time: {:.4f}s".format(save_time)) + assert save_time < 0.1 + sleep(0.5) + assert isinstance(model_storage, FileModelStorage) + assert len(loader._files) > 0 + + state_dict = loader.load(model_storage) + model.load_state_dict(state_dict) + + sleep(2) + assert not path.exists(model_storage.path) + assert len(loader._files) == 0 + finally: + if path.exists(tempdir): + shutil.rmtree(tempdir) + + +@pytest.mark.benchmark +def test_model_loader_benchmark(): + model = torch.nn.Sequential(torch.nn.Linear(1024, 1024), torch.nn.Linear(1024, 100)) # 40MB + tempdir = path.join(tempfile.gettempdir(), "test_model_loader") + loader = FileModelLoader(model=model, dirname=tempdir) + + try: + loader.start() + count = 0 + + def send_callback(_): + nonlocal count + count += 1 + + start = time() + for _ in range(5): + loader.save(send_callback) + sleep(0.2) + + while count < 5: + sleep(0.001) + + assert time() - start < 1.2 + finally: + if path.exists(tempdir): + shutil.rmtree(tempdir) + loader.shutdown() diff --git a/ding/data/tests/test_shm_buffer.py b/ding/data/tests/test_shm_buffer.py new file mode 100644 index 0000000000..04334b4799 --- /dev/null +++ b/ding/data/tests/test_shm_buffer.py @@ -0,0 +1,20 @@ +import pytest +import numpy as np +import timeit +from ding.data.shm_buffer import ShmBuffer +import multiprocessing as mp + + +def subprocess(shm_buf): + data = np.random.rand(1024, 1024).astype(np.float32) + res = timeit.repeat(lambda: shm_buf.fill(data), repeat=5, number=1000) + print("Mean: {:.4f}s, STD: {:.4f}s, Mean each call: {:.4f}ms".format(np.mean(res), np.std(res), np.mean(res))) + + +@pytest.mark.benchmark +def test_shm_buffer(): + data = np.random.rand(1024, 1024).astype(np.float32) + shm_buf = ShmBuffer(data.dtype, data.shape, copy_on_get=False) + proc = mp.Process(target=subprocess, args=[shm_buf]) + proc.start() + proc.join() diff --git a/ding/data/tests/test_storage_loader.py b/ding/data/tests/test_storage_loader.py new file mode 100644 index 0000000000..5ab07acd73 --- /dev/null +++ b/ding/data/tests/test_storage_loader.py @@ -0,0 +1,176 @@ +import os +import timeit +import pytest +import tempfile +import shutil +import numpy as np +import torch +import treetensor.torch as ttorch +from ding.data.shm_buffer import ShmBuffer +from ding.data.storage_loader import FileStorageLoader +from time import sleep, time +from os import path +from ding.framework.supervisor import RecvPayload + + +@pytest.mark.tmp # gitlab ci and local test pass, github always fail +def test_file_storage_loader(): + tempdir = path.join(tempfile.gettempdir(), "test_storage_loader") + loader = FileStorageLoader(dirname=tempdir) + try: + total_num = 200 + storages = [] + for i in range(10): + # 21MB + data = [ + { + "s": "abc", + "obs": np.random.rand(4, 84, 84).astype(np.float32), + # "next_obs": np.random.rand(4, 84, 84).astype(np.float32), + # "obs": torch.rand(4, 84, 84, dtype=torch.float32), + "next_obs": torch.rand(4, 84, 84, dtype=torch.float32) + } for _ in range(96) + ] + storage = loader.save(data) + storages.append(storage) + + start = time() + for i in range(total_num): + storage = storages[i % 10] + data = storage.load() + origin_time_cost = time() - start + print("Load time cost: {:.4f}s".format(origin_time_cost)) + + call_times = 0 + + def callback(data): + assert data[0]['obs'] is not None + nonlocal call_times + call_times += 1 + + # First initialize shared memory is very slow, discard this time cost. + start = time() + loader._first_meet(storage=storages[0], callback=callback) + print("Initialize shared memory time: {:.4f}s".format(time() - start)) + + start = time() + for i in range(1, total_num): + storage = storages[i % 10] + loader.load(storage, callback) + + while True: + if call_times == total_num: + break + sleep(0.01) + new_time_cost = time() - start + print("Loader time cost: {:.4f}s".format(new_time_cost)) + + assert new_time_cost < origin_time_cost + finally: + if path.exists(tempdir): + shutil.rmtree(tempdir) + loader.shutdown() + + +@pytest.mark.unittest +def test_file_storage_loader_cleanup(): + tempdir = path.join(tempfile.gettempdir(), "test_storage_loader") + loader = FileStorageLoader(dirname=tempdir, ttl=1) + try: + storages = [] + for _ in range(4): + data = np.random.rand(4, 84, 84).astype(np.float32) + storage = loader.save(data) + storages.append(storage) + sleep(0.5) + assert len(os.listdir(tempdir)) < 4 + finally: + if path.exists(tempdir): + shutil.rmtree(tempdir) + loader.shutdown() + + +@pytest.mark.unittest +def test_shared_object(): + loader = FileStorageLoader(dirname="") + + # ========== Test array ========== + obj = [{"obs": np.random.rand(100, 100)} for _ in range(10)] + shm_obj = loader._create_shm_buffer(obj) + assert len(shm_obj.buf) == len(obj) * 2 + assert isinstance(shm_obj.buf[0]["obs"], ShmBuffer) + + # Callback + payload = RecvPayload(proc_id=0, data=obj) + loader._shm_callback(payload=payload, shm_obj=shm_obj) + assert len(payload.data) == 10 + assert [d["obs"] is None for d in payload.data] + + # ========== Putback ========== + loader._shm_putback(payload=payload, shm_obj=shm_obj) + obj = payload.data + assert len(obj) == 10 + for o in obj: + assert isinstance(o["obs"], np.ndarray) + assert o["obs"].shape == (100, 100) + + # ========== Test dict ========== + obj = {"obs": torch.rand(100, 100, dtype=torch.float32)} + shm_obj = loader._create_shm_buffer(obj) + assert isinstance(shm_obj.buf["obs"], ShmBuffer) + + payload = RecvPayload(proc_id=0, data=obj) + loader._shm_callback(payload=payload, shm_obj=shm_obj) + assert payload.data["obs"] is None + + loader._shm_putback(payload=payload, shm_obj=shm_obj) + assert isinstance(payload.data["obs"], torch.Tensor) + assert payload.data["obs"].shape == (100, 100) + + # ========== Test treetensor ========== + obj = {"trajectories": [ttorch.as_tensor({"obs": torch.rand(10, 10, dtype=torch.float32)}) for _ in range(10)]} + shm_obj = loader._create_shm_buffer(obj) + + payload = RecvPayload(proc_id=0, data=obj) + loader._shm_callback(payload=payload, shm_obj=shm_obj) + assert len(payload.data["trajectories"]) == 10 + for traj in payload.data["trajectories"]: + assert traj["obs"] is None + + loader._shm_putback(payload=payload, shm_obj=shm_obj) + for traj in payload.data["trajectories"]: + assert isinstance(traj["obs"], torch.Tensor) + assert traj["obs"].shape == (10, 10) + + +@pytest.mark.benchmark +def test_shared_object_benchmark(): + loader = FileStorageLoader(dirname="") + # ========== Test treetensor ========== + obj = { + "env_step": 0, + "trajectories": [ + ttorch.as_tensor( + { + "done": False, + "reward": torch.tensor([1, 0], dtype=torch.int32), + "obs": torch.rand(4, 84, 84, dtype=torch.float32), + "next_obs": torch.rand(4, 84, 84, dtype=torch.float32), + "action": torch.tensor([1], dtype=torch.int32), + "collect_train_iter": torch.tensor([1], dtype=torch.int32), + "env_data_id": torch.tensor([1], dtype=torch.int32), + } + ) for _ in range(10) + ] + } + buf = loader._create_shm_buffer(obj) + payload = RecvPayload(proc_id=0, data=obj) + loader._shm_callback(payload=payload, shm_obj=buf) + + def stmt(): + payload.extra = buf.id_.get() + loader._shm_putback(payload=payload, shm_obj=buf) + + res = timeit.repeat(stmt, repeat=5, number=1000) + print("Mean: {:.4f}s, STD: {:.4f}s, Mean each call: {:.4f}ms".format(np.mean(res), np.std(res), np.mean(res))) + assert np.mean(res) < 1 diff --git a/ding/entry/cli_ditask.py b/ding/entry/cli_ditask.py index 68ec836fe6..443fe1a6b6 100644 --- a/ding/entry/cli_ditask.py +++ b/ding/entry/cli_ditask.py @@ -43,8 +43,7 @@ def print_version(ctx: Context, param: Option, value: bool) -> None: @click.option( "--ports", type=str, - default="50515", - help="The port addresses that the tasks listen to, e.g. 50515,50516, default: 50515" + help="The port addresses that the tasks listen to, e.g. 50515,50516, default: k8s, local: 50515, slurm: 15151" ) @click.option("--attach-to", type=str, help="The addresses to connect to.") @click.option("--address", type=str, help="The address to listen to (without port).") @@ -62,6 +61,8 @@ def print_version(ctx: Context, param: Option, value: bool) -> None: @click.option("--redis-host", type=str, help="Redis host.") @click.option("--redis-port", type=int, help="Redis port.") @click.option("-m", "--main", type=str, help="Main function of entry module.") +@click.option("--startup-interval", type=int, default=1, help="Start up interval between each task.") +@click.option("--local_rank", type=int, default=0, help="Compatibility with PyTorch DDP") def cli_ditask(*args, **kwargs): return _cli_ditask(*args, **kwargs) @@ -86,7 +87,7 @@ def _parse_platform_args(platform: str, platform_spec: str, all_args: dict): parsed_args = PLATFORM_PARSERS[platform](platform_spec, **all_args) except Exception as e: click.echo("error when parse platform spec configure: {}".format(e)) - exit(1) + raise e return parsed_args @@ -105,6 +106,8 @@ def _cli_ditask( mq_type: str, redis_host: str, redis_port: int, + startup_interval: int, + local_rank: int = 0, platform: str = None, platform_spec: str = None, ): @@ -128,6 +131,7 @@ def _cli_ditask( mod = importlib.import_module(mod_name) main_func = getattr(mod, func_name) # Parse arguments + ports = ports or 50515 if not isinstance(ports, int): ports = ports.split(",") ports = list(map(lambda i: int(i), ports)) @@ -152,5 +156,6 @@ def _cli_ditask( node_ids=node_ids, mq_type=mq_type, redis_host=redis_host, - redis_port=redis_port + redis_port=redis_port, + startup_interval=startup_interval )(main_func) diff --git a/ding/entry/cli_parsers/k8s_parser.py b/ding/entry/cli_parsers/k8s_parser.py index 3767ef2879..6f2b0aebe7 100644 --- a/ding/entry/cli_parsers/k8s_parser.py +++ b/ding/entry/cli_parsers/k8s_parser.py @@ -1,11 +1,12 @@ import os import numpy as np -from typing import List, Optional +from time import sleep +from typing import Dict, List, Optional class K8SParser(): - def __init__(self, platform_spec: Optional[str] = None, **kwargs) -> None: + def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None: """ Overview: Should only set global cluster properties @@ -14,9 +15,9 @@ def __init__(self, platform_spec: Optional[str] = None, **kwargs) -> None: self.nodelist = self._parse_node_list() self.ntasks = len(self.nodelist) self.platform_spec = platform_spec - self.parallel_workers = kwargs.get("parallel_workers", 1) - self.topology = kwargs.get("topology", "alone") - self.ports = kwargs.get("ports", 50515) + self.parallel_workers = kwargs.get("parallel_workers") or 1 + self.topology = kwargs.get("topology") or "alone" + self.ports = int(kwargs.get("ports") or 50515) self.tasks = {} def parse(self) -> dict: @@ -49,13 +50,13 @@ def _get_task(self, procid: int) -> dict: else: task = {} if "ports" not in task: - task["ports"] = self._get_ports() + task["ports"] = self.kwargs.get("ports") or self._get_ports() if "address" not in task: - task["address"] = self._get_address(procid) + task["address"] = self.kwargs.get("address") or self._get_address(procid) if "node_ids" not in task: - task["node_ids"] = self._get_node_id(procid) + task["node_ids"] = self.kwargs.get("node_ids") or self._get_node_id(procid) - task["attach_to"] = self._get_attach_to(procid, task.get("attach_to")) + task["attach_to"] = self.kwargs.get("attach_to") or self._get_attach_to(procid, task.get("attach_to")) task["topology"] = self.topology task["parallel_workers"] = self.parallel_workers diff --git a/ding/entry/cli_parsers/slurm_parser.py b/ding/entry/cli_parsers/slurm_parser.py index 3a335eb758..c46716438b 100644 --- a/ding/entry/cli_parsers/slurm_parser.py +++ b/ding/entry/cli_parsers/slurm_parser.py @@ -1,30 +1,55 @@ import os import re -from typing import List +from time import sleep +import numpy as np +from typing import Any, Dict, List, Optional class SlurmParser(): - def __init__(self, platform_spec: str, **kwargs) -> None: + def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None: """ Overview: Should only set global cluster properties """ self.kwargs = kwargs self.ntasks = int(os.environ["SLURM_NTASKS"]) - self.tasks = platform_spec["tasks"] + self.platform_spec = platform_spec + self.tasks = {} self.ntasks_per_node = int(os.environ["SLURM_NTASKS_PER_NODE"]) self.nodelist = self._parse_node_list() + self.ports = int(kwargs.get("ports") or 15151) + self.parallel_workers = kwargs.get("parallel_workers") or 1 + self.topology = kwargs.get("topology") or "alone" def parse(self) -> dict: - assert len(self.tasks) == self.ntasks procid = int(os.environ["SLURM_PROCID"]) - nodename = os.environ["SLURMD_NODENAME"] - task = self._get_node_args(procid) + task = self._get_task(procid) # Validation - assert task["address"] == nodename + assert task["address"] == os.environ["SLURMD_NODENAME"] return {**self.kwargs, **task} + def _get_task(self, procid: int) -> Dict[str, Any]: + if procid in self.tasks: + return self.tasks.get(procid) + if self.platform_spec: + task = self.platform_spec["tasks"][procid] + else: + task = {} + if "ports" not in task: + task["ports"] = self._get_ports(procid) + if "address" not in task: + task["address"] = self._get_address(procid) + if "node_ids" not in task: + task["node_ids"] = self._get_node_id(procid) + + task["attach_to"] = self._get_attach_to(procid, task.get("attach_to")) + task["topology"] = self.topology + task["parallel_workers"] = self.parallel_workers + + self.tasks[procid] = task + return task + def _parse_node_list(self) -> List[str]: nodelist = os.environ["SLURM_NODELIST"] result = re.match(r"(.*)?\[(.*)\]$", nodelist) @@ -40,58 +65,86 @@ def _parse_node_list(self) -> List[str]: nodelist.append(prefix + tail) elif isinstance(nodelist, str): nodelist = [nodelist] + if self.ntasks_per_node > 1: + expand_nodelist = [] # Expand node for each task + for node in nodelist: + for _ in range(self.ntasks_per_node): + expand_nodelist.append(node) + nodelist = expand_nodelist return nodelist - def _get_node_args(self, procid: int) -> dict: - """ - Overview: - Complete node properties, use environment vars in list instead of on current node. - For example, if you want to set nodename in this function, please derive it from SLURM_NODELIST, - the variable from SLURMD_NODENAME should only be used in validation. - """ - task = self.tasks[procid] - if "attach_to" in task: - task["attach_to"] = self._get_attach_to(task["attach_to"]) - if "address" not in task: - task["address"] = self._get_address(procid) - if "ports" not in task: - task["ports"] = self._get_ports(procid) - if "node_ids" not in task: - task["node_ids"] = procid - return task + def _get_attach_to(self, procid: int, attach_to: Optional[str] = None) -> str: + if attach_to: + attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")] + elif procid == 0: + attach_to = [] + else: + if self.topology == "mesh": + prev_tasks = [self._get_task(i) for i in range(procid)] + attach_to = [self._get_attach_to_from_task(task) for task in prev_tasks] + attach_to = list(np.concatenate(attach_to)) + elif self.topology == "star": + head_task = self._get_task(0) + attach_to = self._get_attach_to_from_task(head_task) + else: + attach_to = [] - def _get_attach_to(self, attach_to: str) -> str: - attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")] return ",".join(attach_to) def _get_attach_to_part(self, attach_part: str) -> str: + """ + Overview: + Parse each part of attach_to. + Arguments: + - attach_part (:obj:`str`): The attach_to field with specific pattern, e.g. $node:0 + Returns + - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000 + """ if not attach_part.startswith("$node."): return attach_part attach_node_id = int(attach_part[6:]) - attach_node = self._get_node_args(self._get_procid_from_nodeid(attach_node_id)) - return "tcp://{}:{}".format(attach_node["address"], attach_node["ports"]) + attach_task = self._get_task(self._get_procid_from_nodeid(attach_node_id)) + return self._get_tcp_link(attach_task["address"], attach_task["ports"]) + + def _get_attach_to_from_task(self, task: dict) -> List[str]: + """ + Overview: + Get attach_to list from task, note that parallel_workers will affact the connected processes. + Arguments: + - task (:obj:`dict`): The task object. + Returns + - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000 + """ + port = task.get("ports") + address = task.get("address") + ports = [int(port) + i for i in range(self.parallel_workers)] + attach_to = [self._get_tcp_link(address, port) for port in ports] + return attach_to def _get_procid_from_nodeid(self, nodeid: int) -> int: procid = None - for i, task in enumerate(self.tasks): - if task.get("node_ids") == nodeid: - procid = i - break - elif nodeid == i: + for i in range(self.ntasks): + task = self._get_task(i) + if task["node_ids"] == nodeid: procid = i break if procid is None: raise Exception("Can not find procid from nodeid: {}".format(nodeid)) return procid - def _get_ports(self, procid: int) -> List[int]: - ports = 15151 + procid % self.ntasks_per_node - return ports + def _get_ports(self, procid) -> int: + return self.ports + (procid % self.ntasks_per_node) * self.parallel_workers def _get_address(self, procid: int) -> str: - address = self.nodelist[procid // self.ntasks_per_node] + address = self.nodelist[procid] return address + def _get_node_id(self, procid: int) -> int: + return procid * self.parallel_workers + + def _get_tcp_link(self, address: str, port: int) -> str: + return "tcp://{}:{}".format(address, port) + def slurm_parser(platform_spec: str, **kwargs) -> dict: return SlurmParser(platform_spec, **kwargs).parse() diff --git a/ding/entry/cli_parsers/tests/test_slurm_parser.py b/ding/entry/cli_parsers/tests/test_slurm_parser.py index f56efdd663..9b817ba48a 100644 --- a/ding/entry/cli_parsers/tests/test_slurm_parser.py +++ b/ding/entry/cli_parsers/tests/test_slurm_parser.py @@ -10,12 +10,7 @@ def set_slurm_env(): os.environ["SLURM_NTASKS"] = '6' # Parameter n,Process count / Task count os.environ["SLURM_NTASKS_PER_NODE"] = '3' # Parameter ntasks-per-node,process count of each node os.environ["SLURM_NODELIST"] = 'SH-IDC1-10-5-38-[190,215]' # All the nodes - os.environ["SLURM_SRUN_COMM_PORT"] = '42932' # Available ports - os.environ["SLURM_TOPOLOGY_ADDR"] = 'SH-IDC1-10-5-38-215' # Name of current node - os.environ["SLURM_NODEID"] = '1' # Node order,start from 0 os.environ["SLURM_PROCID"] = '3' # Proc order,start from 0,the read proc order may be different from nominal order - os.environ["SLURM_LOCALID"] = '0' # Proc order on current node, smaller or equal than ntasks-per-node - 1 - os.environ["SLURM_GTIDS"] = '2,3' # All the proc ids on current node os.environ["SLURMD_NODENAME"] = 'SH-IDC1-10-5-38-215' # Name of current node yield @@ -23,12 +18,7 @@ def set_slurm_env(): del os.environ["SLURM_NTASKS"] del os.environ["SLURM_NTASKS_PER_NODE"] del os.environ["SLURM_NODELIST"] - del os.environ["SLURM_SRUN_COMM_PORT"] - del os.environ["SLURM_TOPOLOGY_ADDR"] - del os.environ["SLURM_NODEID"] del os.environ["SLURM_PROCID"] - del os.environ["SLURM_LOCALID"] - del os.environ["SLURM_GTIDS"] del os.environ["SLURMD_NODENAME"] @@ -73,8 +63,22 @@ def test_slurm_parser(): "tcp://SH-IDC1-10-5-38-190:15152," +\ "tcp://SH-IDC1-10-5-38-190:15153" + # Test without platform_spec + all_args = slurm_parser(None, topology="mesh", mq_type="nng") + assert all_args["address"] == "SH-IDC1-10-5-38-215" + assert all_args["node_ids"] == 3 + assert all_args["parallel_workers"] == 1 + assert all_args[ + "attach_to" + ] == "tcp://SH-IDC1-10-5-38-190:15151," +\ + "tcp://SH-IDC1-10-5-38-190:15152," +\ + "tcp://SH-IDC1-10-5-38-190:15153" + # Test _parse_node_list sp = SlurmParser(platform_spec) os.environ["SLURM_NODELIST"] = 'SH-IDC1-10-5-[38-40]' - nodelist = sp._parse_node_list() - assert nodelist == ['SH-IDC1-10-5-38', 'SH-IDC1-10-5-39', 'SH-IDC1-10-5-40'] + nodelist = sp._parse_node_list() # Nodes * parallel_workers + assert nodelist == [ + 'SH-IDC1-10-5-38', 'SH-IDC1-10-5-38', 'SH-IDC1-10-5-38', 'SH-IDC1-10-5-39', 'SH-IDC1-10-5-39', + 'SH-IDC1-10-5-39', 'SH-IDC1-10-5-40', 'SH-IDC1-10-5-40', 'SH-IDC1-10-5-40' + ] diff --git a/ding/entry/tests/test_cli_ditask.py b/ding/entry/tests/test_cli_ditask.py index 66cd37e906..6bb64e5e6e 100644 --- a/ding/entry/tests/test_cli_ditask.py +++ b/ding/entry/tests/test_cli_ditask.py @@ -25,7 +25,8 @@ def test_cli_ditask(): "node_ids": 0, "mq_type": "nng", "redis_host": "", - "redis_port": "" + "redis_port": "", + "startup_interval": 1 } os.environ["DI_NODES"] = '127.0.0.1' os.environ["DI_RANK"] = '0' diff --git a/ding/envs/env_manager/base_env_manager.py b/ding/envs/env_manager/base_env_manager.py index 47dfdd5732..6dec958194 100644 --- a/ding/envs/env_manager/base_env_manager.py +++ b/ding/envs/env_manager/base_env_manager.py @@ -122,16 +122,6 @@ def __init__( self._action_space = self._env_ref.action_space self._reward_space = self._env_ref.reward_space self._env_ref.close() - try: - global space_log_flag - if space_log_flag: - logging.info("Env Space Information:") - logging.info("\tObservation Space: {}".format(self._observation_space)) - logging.info("\tAction Space: {}".format(self._action_space)) - logging.info("\tReward Space: {}".format(self._reward_space)) - space_log_flag = False - except: - pass self._env_states = {i: EnvState.VOID for i in range(self._env_num)} self._env_seed = {i: None for i in range(self._env_num)} self._episode_num = self._cfg.episode_num @@ -240,6 +230,16 @@ def launch(self, reset_param: Optional[Dict] = None) -> None: value is the cooresponding reset parameters. """ assert self._closed, "Please first close the env manager" + try: + global space_log_flag + if space_log_flag: + logging.info("Env Space Information:") + logging.info("\tObservation Space: {}".format(self._observation_space)) + logging.info("\tAction Space: {}".format(self._action_space)) + logging.info("\tReward Space: {}".format(self._reward_space)) + space_log_flag = False + except: + pass if reset_param is not None: assert len(reset_param) == len(self._env_fn) self._create_state() @@ -475,6 +475,8 @@ def ready_obs(self) -> tnp.array: """ active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN] obs = [self._ready_obs[i] for i in active_env] + if isinstance(obs[0], dict): + obs = [tnp.array(o) for o in obs] return tnp.stack(obs) def step(self, actions: List[tnp.ndarray]) -> List[tnp.ndarray]: diff --git a/ding/envs/env_manager/env_supervisor.py b/ding/envs/env_manager/env_supervisor.py index b3e1c17fcf..ec5e29beab 100644 --- a/ding/envs/env_manager/env_supervisor.py +++ b/ding/envs/env_manager/env_supervisor.py @@ -5,10 +5,10 @@ import gym from ding.framework import Supervisor from typing import TYPE_CHECKING, Any, List, Union, Dict, Optional, Callable -from ding.framework.supervisor import ChildType, RecvPayload, SendPayload, SharedObject +from ding.framework.supervisor import ChildType, RecvPayload, SendPayload from ding.utils import make_key_as_identifier from ditk import logging -from ding.envs.env_manager.subprocess_env_manager import ShmBufferContainer +from ding.data import ShmBufferContainer import enum import treetensor.numpy as tnp import numbers @@ -106,9 +106,7 @@ def __init__( for env_id in range(len(self._env_fn)) } for env_init in env_fn: - self.register( - env_init, shared_object=SharedObject(buf=self._obs_buffers, callback=self._shm_callback) - ) + self.register(env_init, shm_buffer=self._obs_buffers, shm_callback=self._shm_callback) else: for env_init in env_fn: self.register(env_init) @@ -136,6 +134,11 @@ def _init_states(self): self._last_called = defaultdict(lambda: {"step": math.inf, "reset": math.inf}) def _shm_callback(self, payload: RecvPayload, obs_buffers: Any): + """ + Overview: + This method will be called in child worker, so we can put large data into shared memory + and replace the original payload data to none, then reduce the serialization/deserialization cost. + """ if payload.method == "reset" and payload.data is not None: obs_buffers[payload.proc_id].fill(payload.data) payload.data = None diff --git a/ding/envs/env_manager/subprocess_env_manager.py b/ding/envs/env_manager/subprocess_env_manager.py index 627b8be51b..e3bf00d4b1 100644 --- a/ding/envs/env_manager/subprocess_env_manager.py +++ b/ding/envs/env_manager/subprocess_env_manager.py @@ -1,5 +1,5 @@ from typing import Any, Union, List, Tuple, Dict, Callable, Optional -from multiprocessing import Pipe, connection, get_context, Array +from multiprocessing import connection, get_context from collections import namedtuple from ditk import logging import platform @@ -8,33 +8,19 @@ import gym import traceback import torch -import ctypes import pickle import cloudpickle import numpy as np import treetensor.numpy as tnp from easydict import EasyDict from types import MethodType +from ding.data import ShmBufferContainer, ShmBuffer from ding.envs.env import BaseEnvTimestep from ding.utils import PropagatingThread, LockContextType, LockContext, ENV_MANAGER_REGISTRY, make_key_as_identifier, \ remove_illegal_item from .base_env_manager import BaseEnvManager, EnvState, timeout_wrapper -_NTYPE_TO_CTYPE = { - np.bool_: ctypes.c_bool, - np.uint8: ctypes.c_uint8, - np.uint16: ctypes.c_uint16, - np.uint32: ctypes.c_uint32, - np.uint64: ctypes.c_uint64, - np.int8: ctypes.c_int8, - np.int16: ctypes.c_int16, - np.int32: ctypes.c_int32, - np.int64: ctypes.c_int64, - np.float32: ctypes.c_float, - np.float64: ctypes.c_double, -} - def is_abnormal_timestep(timestep: namedtuple) -> bool: if isinstance(timestep.info, dict): @@ -45,110 +31,6 @@ def is_abnormal_timestep(timestep: namedtuple) -> bool: raise TypeError("invalid env timestep type: {}".format(type(timestep.info))) -class ShmBuffer(): - """ - Overview: - Shared memory buffer to store numpy array. - """ - - def __init__(self, dtype: Union[type, np.dtype], shape: Tuple[int], copy_on_get: bool = True) -> None: - """ - Overview: - Initialize the buffer. - Arguments: - - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. - - shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer. - - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. - """ - if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype - dtype = dtype.type - self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape))) - self.dtype = dtype - self.shape = shape - self.copy_on_get = copy_on_get - - def fill(self, src_arr: np.ndarray) -> None: - """ - Overview: - Fill the shared memory buffer with a numpy array. (Replace the original one.) - Arguments: - - src_arr (:obj:`np.ndarray`): array to fill the buffer. - """ - assert isinstance(src_arr, np.ndarray), type(src_arr) - # for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten - # for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten - # so we reshape dst_arr rather than flatten src_arr - dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) - np.copyto(dst_arr, src_arr) - - def get(self) -> np.ndarray: - """ - Overview: - Get the array stored in the buffer. - Return: - - data (:obj:`np.ndarray`): A copy of the data stored in the buffer. - """ - data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) - if self.copy_on_get: - data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory - return data - - -class ShmBufferContainer(object): - """ - Overview: - Support multiple shared memory buffers. Each key-value is name-buffer. - """ - - def __init__( - self, - dtype: Union[Dict[Any, type], type, np.dtype], - shape: Union[Dict[Any, tuple], tuple], - copy_on_get: bool = True - ) -> None: - """ - Overview: - Initialize the buffer container. - Arguments: - - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. - - shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \ - multiple buffers; If `tuple`, use single buffer. - - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. - """ - if isinstance(shape, dict): - self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()} - elif isinstance(shape, (tuple, list)): - self._data = ShmBuffer(dtype, shape, copy_on_get) - else: - raise RuntimeError("not support shape: {}".format(shape)) - self._shape = shape - - def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None: - """ - Overview: - Fill the one or many shared memory buffer. - Arguments: - - src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer. - """ - if isinstance(self._shape, dict): - for k in self._shape.keys(): - self._data[k].fill(src_arr[k]) - elif isinstance(self._shape, (tuple, list)): - self._data.fill(src_arr) - - def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]: - """ - Overview: - Get the one or many arrays stored in the buffer. - Return: - - data (:obj:`np.ndarray`): The array(s) stored in the buffer. - """ - if isinstance(self._shape, dict): - return {k: self._data[k].get() for k in self._shape.keys()} - elif isinstance(self._shape, (tuple, list)): - return self._data.get() - - class CloudPickleWrapper: """ Overview: diff --git a/ding/example/__init__.py b/ding/example/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ding/example/cql.py b/ding/example/cql.py index 5651121d8d..1e1c678dd0 100644 --- a/ding/example/cql.py +++ b/ding/example/cql.py @@ -5,9 +5,9 @@ from ding.envs import DingEnvWrapper, BaseEnvManagerV2 from ding.data import create_dataset from ding.config import compile_config -from ding.framework import task +from ding.framework import task, ding_init from ding.framework.context import OfflineRLContext -from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher +from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger from ding.utils import set_pkg_seed from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv from dizoo.classic_control.pendulum.config.pendulum_cql_config import main_config, create_config @@ -18,6 +18,7 @@ def main(): # For demostration, we also can train a RL policy (e.g. SAC) and collect some data logging.getLogger().setLevel(logging.INFO) cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) with task.start(async_mode=False, ctx=OfflineRLContext()): evaluator_env = BaseEnvManagerV2( env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager @@ -33,6 +34,7 @@ def main(): task.use(offline_data_fetcher(cfg, dataset)) task.use(trainer(cfg, policy.learn_mode)) task.use(CkptSaver(cfg, policy, train_freq=100)) + task.use(offline_logger()) task.run() diff --git a/ding/example/dqn.py b/ding/example/dqn.py index 27947914d5..c3670def0a 100644 --- a/ding/example/dqn.py +++ b/ding/example/dqn.py @@ -1,5 +1,43 @@ +""" +# Example of DQN pipeline + +Use the pipeline on a single process: + +> python3 -u ding/example/dqn.py + +Use the pipeline on multiple processes: + +We surpose there are N processes (workers) = 1 learner + 1 evaluator + (N-2) collectors + +## First Example —— Execute on one machine with multi processes. + +Execute 4 processes with 1 learner + 1 evaluator + 2 collectors +Remember to keep them connected by mesh to ensure that they can exchange information with each other. + +> ditask --package . --main ding.example.dqn.main --parallel-workers 4 --topology mesh + +## Second Example —— Execute on multiple machines. + +1. Execute 1 learner + 1 evaluator on one machine. + +> ditask --package . --main ding.example.dqn.main --parallel-workers 2 --topology mesh --node-ids 0 --ports 50515 + +2. Execute 2 collectors on another machine. (Suppose the ip of the first machine is 127.0.0.1). + Here we use `alone` topology instead of `mesh` because the collectors do not need communicate with each other. + Remember the `node_ids` cannot be duplicated with the learner, evaluator processes. + And remember to set the `ports` (should not conflict with others) and `attach_to` parameters. + The value of the `attach_to` parameter should be obtained from the log of the + process started earlier (e.g. 'NNG listen on tcp://10.0.0.4:50515'). + +> ditask --package . --main ding.example.dqn.main --parallel-workers 2 --topology alone --node-ids 2 \ + --ports 50517 --attach-to tcp://10.0.0.4:50515,tcp://127.0.0.1:50516 + +3. You can repeat step 2 to start more collectors on other machines. +""" import gym from ditk import logging +from ding.data.model_loader import FileModelLoader +from ding.data.storage_loader import FileStorageLoader from ding.model import DQN from ding.policy import DQNPolicy from ding.envs import DingEnvWrapper, BaseEnvManagerV2 @@ -8,14 +46,14 @@ from ding.framework import task, ding_init from ding.framework.context import OnlineRLContext from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ - eps_greedy_handler, CkptSaver, online_logger + eps_greedy_handler, CkptSaver, ContextExchanger, ModelExchanger, online_logger from ding.utils import set_pkg_seed from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config def main(): logging.getLogger().setLevel(logging.INFO) - cfg = compile_config(main_config, create_cfg=create_config, auto=True) + cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0) ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( @@ -33,6 +71,22 @@ def main(): buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) policy = DQNPolicy(cfg.policy, model=model) + # Consider the case with multiple processes + if task.router.is_active: + # You can use labels to distinguish between workers with different roles, + # here we use node_id to distinguish. + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + elif task.router.node_id == 1: + task.add_role(task.role.EVALUATOR) + else: + task.add_role(task.role.COLLECTOR) + + # Sync their context and model between each worker. + task.use(ContextExchanger(skip_n_iter=1)) + task.use(ModelExchanger(model)) + + # Here is the part of single process pipeline. task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) task.use(eps_greedy_handler(cfg)) task.use(StepCollector(cfg, policy.collect_mode, collector_env)) @@ -40,6 +94,7 @@ def main(): task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(online_logger(train_show_freq=10)) task.use(CkptSaver(cfg, policy, train_freq=100)) + task.run() diff --git a/ding/example/sac.py b/ding/example/sac.py index 3b45129029..8abb4ce1a5 100644 --- a/ding/example/sac.py +++ b/ding/example/sac.py @@ -4,10 +4,10 @@ from ding.envs import BaseEnvManagerV2 from ding.data import DequeBuffer from ding.config import compile_config -from ding.framework import task +from ding.framework import task, ding_init from ding.framework.context import OnlineRLContext from ding.framework.middleware import data_pusher, StepCollector, interaction_evaluator, \ - CkptSaver, OffPolicyLearner, termination_checker + CkptSaver, OffPolicyLearner, termination_checker, online_logger from ding.utils import set_pkg_seed from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv from dizoo.classic_control.pendulum.config.pendulum_sac_config import main_config, create_config @@ -16,6 +16,7 @@ def main(): logging.getLogger().setLevel(logging.INFO) cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = BaseEnvManagerV2( env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager @@ -38,6 +39,7 @@ def main(): task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(CkptSaver(cfg, policy, train_freq=100)) task.use(termination_checker(max_train_iter=10000)) + task.use(online_logger()) task.run() diff --git a/ding/framework/__init__.py b/ding/framework/__init__.py index 274d6f2364..72c23d0475 100644 --- a/ding/framework/__init__.py +++ b/ding/framework/__init__.py @@ -1,5 +1,5 @@ from .context import Context, OnlineRLContext, OfflineRLContext -from .task import Task, task +from .task import Task, task, VoidMiddleware from .parallel import Parallel from .event_loop import EventLoop from .supervisor import Supervisor diff --git a/ding/framework/context.py b/ding/framework/context.py index a22c1616d4..1dbe998b49 100644 --- a/ding/framework/context.py +++ b/ding/framework/context.py @@ -11,6 +11,7 @@ class Context: Context is an object that pass contextual data between middlewares, whose life cycle is only one training iteration. It is a dict that reflect itself, so you can set any properties as you wish. + Note that the initial value of the property must be equal to False. """ _kept_keys: set = dataclasses.field(default_factory=set) total_step: int = 0 @@ -59,6 +60,8 @@ class OnlineRLContext(Context): trajectories: List = None episodes: List = None trajectory_end_idx: List = dataclasses.field(default_factory=list) + action: Dict = None + inference_output: Dict = None # eval eval_value: float = -np.inf last_eval_iter: int = -1 diff --git a/ding/framework/event_loop.py b/ding/framework/event_loop.py index b5f58720a1..6641d07adb 100644 --- a/ding/framework/event_loop.py +++ b/ding/framework/event_loop.py @@ -1,6 +1,7 @@ from collections import defaultdict from typing import Callable, Optional from concurrent.futures import ThreadPoolExecutor +from copy import copy import fnmatch from ditk import logging @@ -35,7 +36,10 @@ def off(self, event: str, fn: Optional[Callable] = None) -> None: """ for e in fnmatch.filter(self._listeners.keys(), event): if fn: - self._listeners[e].remove(fn) + try: + self._listeners[e].remove(fn) + except: + pass else: self._listeners[e] = [] @@ -79,7 +83,7 @@ def _trigger(self, event: str, *args, **kwargs) -> None: if event not in self._listeners: logging.debug("Event {} is not registered in the callbacks of {}!".format(event, self._name)) return - for fn in self._listeners[event]: + for fn in copy(self._listeners[event]): try: fn(*args, **kwargs) except Exception as e: diff --git a/ding/framework/message_queue/nng.py b/ding/framework/message_queue/nng.py index c0b52f7090..379601b0ed 100644 --- a/ding/framework/message_queue/nng.py +++ b/ding/framework/message_queue/nng.py @@ -30,6 +30,7 @@ def listen(self) -> None: sleep(0.1) # Wait for peers to bind for contact in self.attach_to: sock.dial(contact) + logging.info("NNG listen on {}, attach to {}".format(self.listen_to, self.attach_to)) self._running = True def publish(self, topic: str, data: bytes) -> None: diff --git a/ding/framework/middleware/__init__.py b/ding/framework/middleware/__init__.py index a2c428932c..558d9affa3 100644 --- a/ding/framework/middleware/__init__.py +++ b/ding/framework/middleware/__init__.py @@ -2,3 +2,4 @@ from .collector import StepCollector, EpisodeCollector from .learner import OffPolicyLearner, HERLearner from .ckpt_handler import CkptSaver +from .distributer import ContextExchanger, ModelExchanger diff --git a/ding/framework/middleware/ckpt_handler.py b/ding/framework/middleware/ckpt_handler.py index eaec7243d6..c23cf58b6d 100644 --- a/ding/framework/middleware/ckpt_handler.py +++ b/ding/framework/middleware/ckpt_handler.py @@ -17,7 +17,12 @@ class CkptSaver: The class used to save checkpoint data. """ - def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = None): + def __new__(cls, *args, **kwargs): + if task.router.is_active and not (task.has_role(task.role.LEARNER) or task.has_role(task.role.EVALUATOR)): + return task.void() + return super(CkptSaver, cls).__new__(cls) + + def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = None, save_finish: bool = True): """ Overview: Initialize the `CkptSaver`. @@ -25,6 +30,7 @@ def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = No - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.exp_name`. - policy (:obj:`Policy`): Policy used to save the checkpoint. - train_freq (:obj:`int`): Number of training iterations between each saving checkpoint data. + - save_finish (:obj:`bool`): Whether save final ckpt when ``task.finish = True``. """ self.policy = policy self.train_freq = train_freq @@ -33,6 +39,7 @@ def __init__(self, cfg: EasyDict, policy: Policy, train_freq: Optional[int] = No os.makedirs(self.prefix) self.last_save_iter = 0 self.max_eval_value = -np.inf + self.save_finish = save_finish def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: """ @@ -54,10 +61,10 @@ def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: self.last_save_iter = ctx.train_iter # best episode return so far - if ctx.eval_value > self.max_eval_value: - save_file("{}/eval.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) + if ctx.eval_value is not None and ctx.eval_value > self.max_eval_value: + save_file("{}/eval.pth.tar".format(self.prefix), self.policy.eval_mode.state_dict()) self.max_eval_value = ctx.eval_value # finish - if task.finish: + if task.finish and self.save_finish: save_file("{}/final.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index fa70a00766..6660025b33 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable, List +from typing import TYPE_CHECKING from easydict import EasyDict from ding.policy import get_random_policy @@ -17,6 +17,11 @@ class StepCollector: process. Use the `__call__` method to execute the whole collection process. """ + def __new__(cls, *args, **kwargs): + if task.router.is_active and not task.has_role(task.role.COLLECTOR): + return task.void() + return super(StepCollector, cls).__new__(cls) + def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None: """ Arguments: diff --git a/ding/framework/middleware/distributer.py b/ding/framework/middleware/distributer.py new file mode 100644 index 0000000000..c68a4b808f --- /dev/null +++ b/ding/framework/middleware/distributer.py @@ -0,0 +1,289 @@ +from time import sleep, time +from dataclasses import fields +from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union +from ditk import logging +from ding.framework import task +from ding.data import StorageLoader, Storage, ModelLoader +if TYPE_CHECKING: + from ding.framework.context import Context + from torch.nn import Module + + +class ContextExchanger: + + def __init__(self, skip_n_iter: int = 1, storage_loader: Optional[StorageLoader] = None) -> None: + """ + Overview: + Exchange context between processes, + support properties: trajectories, episodes, env_step, env_episode, train_iter + Arguments: + - skip_n_iter (:obj:`int`): For collectors, it may be necessary to skip waiting \ + for the first n iterations to collect data for the learner to learn. This parameter \ + will not work on learner. + - storage_loader (:obj:`Optional[StorageLoader]`): Turn data into storage class to reduce \ + the network overhead. + """ + if not task.router.is_active: + raise RuntimeError("ContextHandler should be used in parallel mode!") + self._state = {} + self._local_state = {} # just save local state, not send to remote node + if task.has_role(task.role.COLLECTOR): + self._local_state['env_step'] = 0 + self._local_state['env_episode'] = 0 + self._event_name = "context_exchanger_{role}" + self._skip_n_iter = skip_n_iter + self._storage_loader = storage_loader + for role in task.role: # Only subscribe to other roles + if not task.has_role(role): + task.on(self._event_name.format(role=role), self.put) + if storage_loader: + task.once("finish", lambda _: storage_loader.shutdown()) + + def __new__(cls, *args, **kwargs): + if not task.router.is_active: + return task.void() + + if len(task.roles) == 0: + logging.warning("The task does not have any roles defined, the ContextExchanger will not work.") + return task.void() + + if len(task.roles) > 1: + logging.warning( + "Use multiple roles in one exchanger may lead to unexpected result, please check your code." + ) + + return super(ContextExchanger, cls).__new__(cls) + + def __call__(self, ctx: "Context"): + self.merge(ctx) + yield + payload = self.fetch(ctx) + if payload: + if self._storage_loader and task.has_role(task.role.COLLECTOR): + payload = self._storage_loader.save(payload) + for role in task.roles: + task.emit(self._event_name.format(role=role), payload, only_remote=True) + + def __del__(self): + if self._storage_loader: + self._storage_loader.shutdown() + + def put(self, payload: Union[Dict, Storage]): + """ + Overview: + Get attributes from ctx on the callback of event. + Each attribute should have a standalone put handler, which named `_put_{key}` + """ + + def callback(payload: Dict): + for key, item in payload.items(): + fn_name = "_put_{}".format(key) + if hasattr(self, fn_name): + getattr(self, fn_name)(item) + else: + logging.warning("Receive unexpected key ({}) in context exchanger".format(key)) + + if isinstance(payload, Storage): + assert self._storage_loader is not None, "Storage loader is not defined when data is a storage object." + self._storage_loader.load(payload, callback) + else: + callback(payload) + + def fetch(self, ctx: "Context") -> Dict[str, Any]: + """ + Overview: + Fetch attributes from ctx before emit them to the event bus. + Each attribute should have a standalone fetch handler, which named `_fetch_{key}` + """ + payload = {} + for field in fields(ctx): + key, item = field.name, getattr(ctx, field.name) + fn_name = "_fetch_{}".format(key) + if hasattr(self, fn_name): + value = getattr(self, fn_name)(item) + if value is not None: + payload[key] = value + return payload + + def merge(self, ctx: "Context"): + if task.has_role(task.role.LEARNER): + # Learner should always wait for trajs. + # TODO: Automaticlly wait based on properties, not roles. + while len(self._state) == 0: + sleep(0.01) + elif ctx.total_step >= self._skip_n_iter: + start = time() + while len(self._state) == 0: + if time() - start > 60: + logging.warning("Timeout when waiting for new context! Node id: {}".format(task.router.node_id)) + break + sleep(0.01) + + for k, v in self._state.items(): + if not task.has_role(task.role.COLLECTOR) and k.startswith('increment_'): + pure_k = k.split('increment_')[-1] + setattr(ctx, pure_k, getattr(ctx, pure_k) + v) + else: + setattr(ctx, k, v) + self._state = {} + + # Handle each attibute of context + def _put_trajectories(self, traj: List[Any]): + if not task.has_role(task.role.LEARNER): + return + if "trajectories" not in self._state: + self._state["trajectories"] = [] + self._state["trajectories"].extend(traj) + + def _fetch_trajectories(self, traj: List[Any]): + if task.has_role(task.role.COLLECTOR): + return traj + + def _put_episodes(self, episodes: List[Any]): + if not task.has_role(task.role.LEARNER): + return + if "episodes" not in self._state: + self._state["episodes"] = [] + self._state["episodes"].extend(episodes) + + def _fetch_episodes(self, episodes: List[Any]): + if task.has_role(task.role.COLLECTOR): + return episodes + + def _put_trajectory_end_idx(self, trajectory_end_idx: List[str]): + if not task.has_role(task.role.LEARNER): + return + if "trajectory_end_idx" not in self._state: + self._state["trajectory_end_idx"] = [] + self._state["trajectory_end_idx"].extend(trajectory_end_idx) + + def _fetch_trajectory_end_idx(self, trajectory_end_idx: List[str]): + if task.has_role(task.role.COLLECTOR): + return trajectory_end_idx + + def _put_env_step(self, increment_env_step: int): + if not task.has_role(task.role.COLLECTOR): + if 'increment_env_step' not in self._state: + self._state['increment_env_step'] = 0 + self._state["increment_env_step"] += increment_env_step + + def _fetch_env_step(self, env_step: int): + if task.has_role(task.role.COLLECTOR): + increment_env_step = env_step - self._local_state['env_step'] + self._local_state['env_step'] = env_step + return increment_env_step + + def _put_env_episode(self, increment_env_episode: int): + if not task.has_role(task.role.COLLECTOR): + if 'increment_env_episode' not in self._state: + self._state['increment_env_episode'] = 0 + self._state["increment_env_episode"] += increment_env_episode + + def _fetch_env_episode(self, env_episode: int): + if task.has_role(task.role.COLLECTOR): + increment_env_episode = env_episode - self._local_state['env_episode'] + self._local_state['env_episode'] = env_episode + return increment_env_episode + + def _put_train_iter(self, train_iter: int): + if not task.has_role(task.role.LEARNER): + self._state["train_iter"] = train_iter + + def _fetch_train_iter(self, train_iter: int): + if task.has_role(task.role.LEARNER): + return train_iter + + +class ModelExchanger: + + def __init__(self, model: "Module", model_loader: Optional[ModelLoader] = None) -> None: + """ + Overview: + Exchange model between processes, only the learner will send the model, + otherwise the model will only be received. + If you are using a shared model on a single host, there is no need to use this middleware. + Arguments: + - model (:obj:`torch.nn.Module`): Pytorch module. + - model_loader (:obj:`ModelLoader`): Encode model in subprocess. + """ + self._model = model + self._model_loader = model_loader + self._event_name = "model_exchanger" + self._state_dict_cache: Optional[Union[object, Storage]] = None + self._is_learner = task.has_role(task.role.LEARNER) + if not self._is_learner: + task.on(self._event_name, self._cache_state_dict) + if model_loader: + task.once("finish", lambda _: model_loader.shutdown()) + + def _cache_state_dict(self, state_dict: Union[object, Storage]): + self._state_dict_cache = state_dict + + def __new__(cls, *args, **kwargs): + if not task.router.is_active: + return task.void() + + if len(task.roles) == 0: + logging.warning("The task does not have any roles defined, the ModelExchanger will not work.") + return task.void() + + if len(task.roles) > 1: + logging.warning( + "Use multiple roles in one exchanger may lead to unexpected result, please check your code." + ) + + return super(ModelExchanger, cls).__new__(cls) + + def __call__(self, ctx: "Context") -> Any: + if self._model_loader: + self._model_loader.start() + + if not self._is_learner: + if ctx.total_step != 0: # Skip first iteration + self._update_model() + else: + yield + self._send_model() + + def _update_model(self): + start = time() + while True: + if task.finish: + return + if time() - start > 60: + logging.warning("Timeout when waiting for new model! Node id: {}".format(task.router.node_id)) + break + if self._state_dict_cache is None: + sleep(0.01) + else: + if isinstance(self._state_dict_cache, Storage) and self._model_loader is not None: + try: + self._model.load_state_dict(self._model_loader.load(self._state_dict_cache)) + self._state_dict_cache = None + break + except FileNotFoundError as e: + logging.warning( + "Model file has been deleted on node {}, maybe you can increase the ttl.".format( + task.router.node_id + ) + ) + self._state_dict_cache = None + continue + else: + self._model.load_state_dict(self._state_dict_cache) + self._state_dict_cache = None + break + + def _send_model(self): + if self._model_loader: + self._model_loader.save(self._send_callback) + else: + task.emit(self._event_name, self._model.state_dict(), only_remote=True) + + def _send_callback(self, storage: Storage): + if task.running: + task.emit(self._event_name, storage, only_remote=True) + + def __del__(self): + if self._model_loader: + self._model_loader.shutdown() diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index 26bebea62f..66bdc6dd0e 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -3,7 +3,7 @@ sqil_data_pusher, buffer_saver from .collector import inferencer, rolloutor, TransitionList from .evaluator import interaction_evaluator -from .termination_checker import termination_checker +from .termination_checker import termination_checker, ddp_termination_checker from .logger import online_logger, offline_logger, wandb_online_logger, wandb_offline_logger from .ctx_helper import final_ctx_saver @@ -11,3 +11,5 @@ from .explorer import eps_greedy_handler, eps_greedy_masker from .advantage_estimator import gae_estimator from .enhancer import reward_estimator, her_data_enhancer, nstep_reward_enhancer + +from .timer import epoch_timer diff --git a/ding/framework/middleware/functional/advantage_estimator.py b/ding/framework/middleware/functional/advantage_estimator.py index 56d9a9b01d..fa3e7fad39 100644 --- a/ding/framework/middleware/functional/advantage_estimator.py +++ b/ding/framework/middleware/functional/advantage_estimator.py @@ -43,6 +43,8 @@ def _gae(ctx: "OnlineRLContext"): data = ctx.trajectories # list data = ttorch_collate(data) with torch.no_grad(): + if cfg.policy.get("cuda", False): + data = data.cuda() value = model.forward(data.obs, mode='compute_critic')['value'] next_value = model.forward(data.next_obs, mode='compute_critic')['value'] data.value = value @@ -54,6 +56,8 @@ def _gae(ctx: "OnlineRLContext"): # done is bool type when acquired from env.step data_ = gae_data(data.value, next_value, data.reward, data.done.float(), traj_flag.float()) data.adv = gae(data_, cfg.policy.collect.discount_factor, cfg.policy.collect.gae_lambda) + if cfg.policy.get("cuda", False): + data = data.cpu() if buffer_ is None: ctx.train_data = data else: diff --git a/ding/framework/middleware/functional/collector.py b/ding/framework/middleware/functional/collector.py index 227017bc0d..f5b3ff0d89 100644 --- a/ding/framework/middleware/functional/collector.py +++ b/ding/framework/middleware/functional/collector.py @@ -74,8 +74,8 @@ def _inference(ctx: "OnlineRLContext"): obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32) ctx.obs = obs # TODO mask necessary rollout - num_envs = get_shape0(obs) - obs = {i: obs[i] for i in range(num_envs)} # TBD + + obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD inference_output = policy.forward(obs, **ctx.collect_kwargs) ctx.action = [to_ndarray(v['action']) for v in inference_output.values()] # TBD ctx.inference_output = inference_output diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index d4afdc0292..e2b3aafdc8 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -5,6 +5,7 @@ import torch from ding.data import Buffer, Dataset, DataLoader, offline_data_save_type from ding.data.buffer.middleware import PriorityExperienceReplay +from ding.framework import task if TYPE_CHECKING: from ding.framework import OnlineRLContext, OfflineRLContext @@ -18,6 +19,8 @@ def data_pusher(cfg: EasyDict, buffer_: Buffer, group_by_env: Optional[bool] = N - cfg (:obj:`EasyDict`): Config. - buffer (:obj:`Buffer`): Buffer to push the data in. """ + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() def _push(ctx: "OnlineRLContext"): """ diff --git a/ding/framework/middleware/functional/enhancer.py b/ding/framework/middleware/functional/enhancer.py index 4bd9ad45e4..b983945791 100644 --- a/ding/framework/middleware/functional/enhancer.py +++ b/ding/framework/middleware/functional/enhancer.py @@ -2,7 +2,7 @@ from easydict import EasyDict from ditk import logging import torch -from ding.policy import Policy +from ding.framework import task if TYPE_CHECKING: from ding.framework import OnlineRLContext from ding.reward_model import BaseRewardModel, HerRewardModel @@ -17,6 +17,8 @@ def reward_estimator(cfg: EasyDict, reward_model: "BaseRewardModel") -> Callable - cfg (:obj:`EasyDict`): Config. - reward_model (:obj:`BaseRewardModel`): Reward model. """ + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() def _enhance(ctx: "OnlineRLContext"): """ @@ -40,6 +42,8 @@ def her_data_enhancer(cfg: EasyDict, buffer_: "Buffer", her_reward_model: "HerRe - her_reward_model (:obj:`HerRewardModel`): Hindsight Experience Replay (HER) model \ which is used to process episodes. """ + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() def _fetch_and_enhance(ctx: "OnlineRLContext"): """ @@ -69,6 +73,9 @@ def _fetch_and_enhance(ctx: "OnlineRLContext"): def nstep_reward_enhancer(cfg: EasyDict) -> Callable: + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() + def _enhance(ctx: "OnlineRLContext"): nstep = cfg.policy.nstep gamma = cfg.policy.discount_factor diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 262363a646..f5553c4679 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable, Any, List, Union +from typing import Callable, Any, List, Union from abc import ABC, abstractmethod from collections import deque from ditk import logging @@ -8,16 +8,13 @@ import treetensor.torch as ttorch from easydict import EasyDict from ding.envs import BaseEnvManager -from ding.framework.context import OfflineRLContext +from ding.framework.context import Context, OfflineRLContext, OnlineRLContext from ding.policy import Policy from ding.data import Dataset, DataLoader from ding.framework import task -from ding.torch_utils import to_list, to_ndarray, get_shape0 +from ding.torch_utils import tensor_to_list, to_list, to_ndarray, get_shape0 from ding.utils import lists_to_dicts -if TYPE_CHECKING: - from ding.framework import Context, OnlineRLContext - class IMetric(ABC): @@ -223,10 +220,12 @@ def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager, re - env (:obj:`BaseEnvManager`): The env for the evaluation. - render (:obj:`bool`): Whether to render env images and policy logits. """ + if task.router.is_active and not task.has_role(task.role.EVALUATOR): + return task.void() env.seed(cfg.seed, dynamic_seed=False) - def _evaluate(ctx: "OnlineRLContext"): + def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): """ Overview: - The evaluation will be executed if the task begins and enough train_iter passed \ @@ -238,6 +237,7 @@ def _evaluate(ctx: "OnlineRLContext"): - eval_value (:obj:`float`): The average reward in the current evaluation. """ + # evaluation will be executed if the task begins or enough train_iter after last evaluation if ctx.last_eval_iter != -1 and \ (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq): return @@ -251,8 +251,7 @@ def _evaluate(ctx: "OnlineRLContext"): while not eval_monitor.is_finished(): obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32) - num_envs = get_shape0(obs) - obs = {i: obs[i] for i in range(num_envs)} # TBD + obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD inference_output = policy.forward(obs) if render: eval_monitor.update_video(env.ready_imgs) @@ -271,14 +270,16 @@ def _evaluate(ctx: "OnlineRLContext"): episode_return = eval_monitor.get_episode_return() episode_return = np.mean(episode_return) stop_flag = episode_return >= cfg.env.stop_value and ctx.train_iter > 0 - if isinstance(ctx, OfflineRLContext): - logging.info('Evaluation: Train Iter({})\tEpisode Return({:.3f})'.format(ctx.train_iter, episode_return)) - else: + if isinstance(ctx, OnlineRLContext): logging.info( 'Evaluation: Train Iter({})\tEnv Step({})\tEpisode Return({:.3f})'.format( ctx.train_iter, ctx.env_step, episode_return ) ) + elif isinstance(ctx, OfflineRLContext): + logging.info('Evaluation: Train Iter({})\tEval Reward({:.3f})'.format(ctx.train_iter, episode_return)) + else: + raise TypeError("not supported ctx type: {}".format(type(ctx))) ctx.last_eval_iter = ctx.train_iter ctx.eval_value = episode_return ctx.eval_output = {'reward': episode_return} diff --git a/ding/framework/middleware/functional/explorer.py b/ding/framework/middleware/functional/explorer.py index 4c7364004d..45aa9bd24a 100644 --- a/ding/framework/middleware/functional/explorer.py +++ b/ding/framework/middleware/functional/explorer.py @@ -1,6 +1,7 @@ -from typing import TYPE_CHECKING, Callable, List +from typing import TYPE_CHECKING, Callable from easydict import EasyDict from ding.rl_utils import get_epsilon_greedy_fn +from ding.framework import task if TYPE_CHECKING: from ding.framework import OnlineRLContext @@ -13,6 +14,8 @@ def eps_greedy_handler(cfg: EasyDict) -> Callable: Arguments: - cfg (:obj:`EasyDict`): Config. """ + if task.router.is_active and not task.has_role(task.role.COLLECTOR): + return task.void() eps_cfg = cfg.policy.other.eps handle = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index a762ad6e52..8a595ccf26 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -11,6 +11,7 @@ import wandb import h5py import pickle +from ding.framework import task from ding.envs import BaseEnvManagerV2 from ding.utils import DistributedWriter from ding.torch_utils import to_ndarray @@ -43,11 +44,16 @@ def return_distribution(reward): def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable: + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() writer = DistributedWriter.get_instance() last_train_show_iter = -1 def _logger(ctx: "OnlineRLContext"): + if task.finish: + writer.close() nonlocal last_train_show_iter + if not np.isinf(ctx.eval_value): if record_train_iter: writer.add_scalar('basic/eval_episode_return_mean-env_step', ctx.eval_value, ctx.env_step) @@ -82,9 +88,13 @@ def _logger(ctx: "OnlineRLContext"): def offline_logger() -> Callable: + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() writer = DistributedWriter.get_instance() def _logger(ctx: "OfflineRLContext"): + if task.finish: + writer.close() if not np.isinf(ctx.eval_value): writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter) if ctx.train_output is not None: @@ -121,7 +131,8 @@ def wandb_online_logger( - anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization of data without wandb count. ''' - + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"] metric_list = ["q_value", "target q_value", "loss", "lr", "entropy"] # Initialize wandb with default settings diff --git a/ding/framework/middleware/functional/termination_checker.py b/ding/framework/middleware/functional/termination_checker.py index 58c371d57b..3f7cdc0cc4 100644 --- a/ding/framework/middleware/functional/termination_checker.py +++ b/ding/framework/middleware/functional/termination_checker.py @@ -1,5 +1,8 @@ from typing import TYPE_CHECKING, Union, Callable, Optional +from ditk import logging import numpy as np +import torch +from ding.utils import broadcast from ding.framework import task if TYPE_CHECKING: @@ -16,7 +19,35 @@ def _check(ctx: Union["OnlineRLContext", "OfflineRLContext"]): # ">" is better than ">=" when taking logger result into consideration if ctx.env_step > max_env_step: task.finish = True + logging.info('Exceeded maximum number of env_step({}), program is terminated'.format(ctx.env_step)) if ctx.train_iter > max_train_iter: task.finish = True + logging.info('Exceeded maximum number of train_iter({}), program is terminated'.format(ctx.train_iter)) + + return _check + + +def ddp_termination_checker(max_env_step=None, max_train_iter=None, rank=0): + if rank == 0: + if max_env_step is None: + max_env_step = np.inf + if max_train_iter is None: + max_train_iter = np.inf + + def _check(ctx): + if rank == 0: + if ctx.env_step > max_env_step: + finish = torch.ones(1).long().cuda() + logging.info('Exceeded maximum number of env_step({}), program is terminated'.format(ctx.env_step)) + elif ctx.train_iter > max_train_iter: + finish = torch.ones(1).long().cuda() + logging.info('Exceeded maximum number of train_iter({}), program is terminated'.format(ctx.train_iter)) + else: + finish = torch.LongTensor([task.finish]).cuda() + else: + finish = torch.zeros(1).long().cuda() + # broadcast finish result to other DDP workers + broadcast(finish, 0) + task.finish = finish.cpu().bool().item() return _check diff --git a/ding/framework/middleware/functional/timer.py b/ding/framework/middleware/functional/timer.py new file mode 100644 index 0000000000..db8a2c0056 --- /dev/null +++ b/ding/framework/middleware/functional/timer.py @@ -0,0 +1,35 @@ +import numpy as np +from collections import deque +from ditk import logging +from time import time + +from ding.framework import task +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ding.framework.context import Context + + +def epoch_timer(print_per: int = 1, smooth_window: int = 10): + """ + Overview: + Print time cost of each epoch. + Arguments: + - print_per (:obj:`int`): Print each N epoch. + - smooth_window (:obj:`int`): The window size to smooth the mean. + """ + records = deque(maxlen=print_per * smooth_window) + + def _epoch_timer(ctx: "Context"): + start = time() + yield + time_cost = time() - start + records.append(time_cost) + if ctx.total_step % print_per == 0: + logging.info( + "[Epoch Timer][Node:{:>2}]: Cost: {:.2f}ms, Mean: {:.2f}ms".format( + task.router.node_id or 0, time_cost * 1000, + np.mean(records) * 1000 + ) + ) + + return _epoch_timer diff --git a/ding/framework/middleware/functional/trainer.py b/ding/framework/middleware/functional/trainer.py index 29d58d664a..7cd1855553 100644 --- a/ding/framework/middleware/functional/trainer.py +++ b/ding/framework/middleware/functional/trainer.py @@ -3,10 +3,7 @@ from ditk import logging import numpy as np from ding.policy import Policy -from ding.framework import task, OfflineRLContext - -if TYPE_CHECKING: - from ding.framework import OnlineRLContext +from ding.framework import task, OfflineRLContext, OnlineRLContext def trainer(cfg: EasyDict, policy: Policy) -> Callable: @@ -33,17 +30,18 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): return train_output = policy.forward(ctx.train_data) if ctx.train_iter % cfg.policy.learn.learner.hook.log_show_after_iter == 0: - if isinstance(ctx, OfflineRLContext): - logging.info( - 'Training: Train Iter({})\tLoss({:.3f})'.format(ctx.train_iter, train_output['total_loss']) - ) - else: + if isinstance(ctx, OnlineRLContext): logging.info( 'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format( ctx.train_iter, ctx.env_step, train_output['total_loss'] ) ) - + elif isinstance(ctx, OfflineRLContext): + logging.info( + 'Training: Train Iter({})\tLoss({:.3f})'.format(ctx.train_iter, train_output['total_loss']) + ) + else: + raise TypeError("not supported ctx type: {}".format(type(ctx))) ctx.train_iter += 1 ctx.train_output = train_output diff --git a/ding/framework/middleware/learner.py b/ding/framework/middleware/learner.py index 4d60f117bd..91184a7b9b 100644 --- a/ding/framework/middleware/learner.py +++ b/ding/framework/middleware/learner.py @@ -17,6 +17,11 @@ class OffPolicyLearner: the `__call__` method to execute the whole learning process. """ + def __new__(cls, *args, **kwargs): + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() + return super(OffPolicyLearner, cls).__new__(cls) + def __init__( self, cfg: EasyDict, diff --git a/ding/framework/middleware/tests/test_ckpt_handler.py b/ding/framework/middleware/tests/test_ckpt_handler.py index 9a22ffece0..56a3dbf0d4 100644 --- a/ding/framework/middleware/tests/test_ckpt_handler.py +++ b/ding/framework/middleware/tests/test_ckpt_handler.py @@ -11,7 +11,7 @@ from unittest.mock import Mock, patch from ding.framework import task -from ding.utils import save_file +from ding.policy.base_policy import Policy class TheModelClass(nn.Module): @@ -22,10 +22,14 @@ def state_dict(self): class MockPolicy(Mock): - def __init__(self, model) -> None: - super(MockPolicy, self).__init__() + def __init__(self, model, **kwargs) -> None: + super(MockPolicy, self).__init__(model) self.learn_mode = model + @property + def eval_mode(self): + return EasyDict({"state_dict": lambda: {}}) + @pytest.mark.unittest def test_ckpt_saver(): diff --git a/ding/framework/middleware/tests/test_distributer.py b/ding/framework/middleware/tests/test_distributer.py new file mode 100644 index 0000000000..c7c323bac9 --- /dev/null +++ b/ding/framework/middleware/tests/test_distributer.py @@ -0,0 +1,223 @@ +import shutil +from time import sleep +import pytest +import numpy as np +import tempfile + +import torch +from ding.data.model_loader import FileModelLoader +from ding.data.storage_loader import FileStorageLoader +from ding.framework import task +from ding.framework.context import OnlineRLContext +from ding.framework.middleware.distributer import ContextExchanger, ModelExchanger +from ding.framework.parallel import Parallel +from ding.utils.default_helper import set_pkg_seed +from os import path + + +def context_exchanger_main(): + with task.start(ctx=OnlineRLContext()): + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + elif task.router.node_id == 1: + task.add_role(task.role.COLLECTOR) + + task.use(ContextExchanger(skip_n_iter=1)) + + if task.has_role(task.role.LEARNER): + + def learner_context(ctx: OnlineRLContext): + assert len(ctx.trajectories) == 2 + assert len(ctx.trajectory_end_idx) == 4 + assert len(ctx.episodes) == 8 + assert ctx.env_step > 0 + assert ctx.env_episode > 0 + yield + ctx.train_iter += 1 + + task.use(learner_context) + elif task.has_role(task.role.COLLECTOR): + + def collector_context(ctx: OnlineRLContext): + if ctx.total_step > 0: + assert ctx.train_iter > 0 + yield + ctx.trajectories = [np.random.rand(10, 10) for _ in range(2)] + ctx.trajectory_end_idx = [1 for _ in range(4)] + ctx.episodes = [np.random.rand(10, 10) for _ in range(8)] + ctx.env_step += 1 + ctx.env_episode += 1 + + task.use(collector_context) + + task.run(max_step=3) + + +@pytest.mark.unittest +def test_context_exchanger(): + Parallel.runner(n_parallel_workers=2)(context_exchanger_main) + + +def context_exchanger_with_storage_loader_main(): + with task.start(ctx=OnlineRLContext()): + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + elif task.router.node_id == 1: + task.add_role(task.role.COLLECTOR) + + tempdir = path.join(tempfile.gettempdir(), "test_storage_loader") + storage_loader = FileStorageLoader(dirname=tempdir) + try: + task.use(ContextExchanger(skip_n_iter=1, storage_loader=storage_loader)) + + if task.has_role(task.role.LEARNER): + + def learner_context(ctx: OnlineRLContext): + assert len(ctx.trajectories) == 2 + assert len(ctx.trajectory_end_idx) == 4 + assert len(ctx.episodes) == 8 + assert ctx.env_step > 0 + assert ctx.env_episode > 0 + yield + ctx.train_iter += 1 + + task.use(learner_context) + elif task.has_role(task.role.COLLECTOR): + + def collector_context(ctx: OnlineRLContext): + if ctx.total_step > 0: + assert ctx.train_iter > 0 + yield + ctx.trajectories = [np.random.rand(10, 10) for _ in range(2)] + ctx.trajectory_end_idx = [1 for _ in range(4)] + ctx.episodes = [np.random.rand(10, 10) for _ in range(8)] + ctx.env_step += 1 + ctx.env_episode += 1 + + task.use(collector_context) + + task.run(max_step=3) + finally: + storage_loader.shutdown() + sleep(1) + if path.exists(tempdir): + shutil.rmtree(tempdir) + + +@pytest.mark.unittest +def test_context_exchanger_with_storage_loader(): + Parallel.runner(n_parallel_workers=2)(context_exchanger_with_storage_loader_main) + + +class MockPolicy: + + def __init__(self) -> None: + self._model = self._get_model(10, 10) + + def _get_model(self, X_shape, y_shape) -> torch.nn.Module: + return torch.nn.Sequential( + torch.nn.Linear(X_shape, 24), torch.nn.ReLU(), torch.nn.Linear(24, 24), torch.nn.ReLU(), + torch.nn.Linear(24, y_shape) + ) + + def train(self, X, y): + loss_fn = torch.nn.MSELoss(reduction="mean") + optimizer = torch.optim.Adam(self._model.parameters(), lr=0.01) + y_pred = self._model(X) + loss = loss_fn(y_pred, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + def predict(self, X): + with torch.no_grad(): + return self._model(X) + + +def model_exchanger_main(): + with task.start(ctx=OnlineRLContext()): + set_pkg_seed(0, use_cuda=False) + policy = MockPolicy() + X = torch.rand(10) + y = torch.rand(10) + + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + else: + task.add_role(task.role.COLLECTOR) + + task.use(ModelExchanger(policy._model)) + + if task.has_role(task.role.LEARNER): + + def train(ctx): + policy.train(X, y) + sleep(0.3) + + task.use(train) + else: + y_pred1 = policy.predict(X) + + def pred(ctx): + if ctx.total_step > 0: + y_pred2 = policy.predict(X) + # Ensure model is upgraded + assert any(y_pred1 != y_pred2) + sleep(0.3) + + task.use(pred) + + task.run(2) + + +@pytest.mark.unittest +def test_model_exchanger(): + Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main) + + +def model_exchanger_main_with_model_loader(): + with task.start(ctx=OnlineRLContext()): + set_pkg_seed(0, use_cuda=False) + policy = MockPolicy() + X = torch.rand(10) + y = torch.rand(10) + + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + else: + task.add_role(task.role.COLLECTOR) + + tempdir = path.join(tempfile.gettempdir(), "test_model_loader") + model_loader = FileModelLoader(policy._model, dirname=tempdir) + task.use(ModelExchanger(policy._model, model_loader=model_loader)) + + try: + if task.has_role(task.role.LEARNER): + + def train(ctx): + policy.train(X, y) + sleep(0.3) + + task.use(train) + else: + y_pred1 = policy.predict(X) + + def pred(ctx): + if ctx.total_step > 0: + y_pred2 = policy.predict(X) + # Ensure model is upgraded + assert any(y_pred1 != y_pred2) + sleep(0.3) + + task.use(pred) + task.run(2) + finally: + model_loader.shutdown() + sleep(0.3) + if path.exists(tempdir): + shutil.rmtree(tempdir) + + +@pytest.mark.unittest +def test_model_exchanger_with_model_loader(): + Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main_with_model_loader) diff --git a/ding/framework/middleware/tests/test_enhancer.py b/ding/framework/middleware/tests/test_enhancer.py index a1765c031d..10d34b264f 100644 --- a/ding/framework/middleware/tests/test_enhancer.py +++ b/ding/framework/middleware/tests/test_enhancer.py @@ -2,8 +2,7 @@ import torch from ding.framework import OnlineRLContext from ding.data.buffer import DequeBuffer -from easydict import EasyDict -from typing import Any, List, Dict, Optional +from typing import Any import numpy as np import copy from ding.framework.middleware.functional.enhancer import reward_estimator, her_data_enhancer diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py index 70134f6584..38e343e495 100644 --- a/ding/framework/parallel.py +++ b/ding/framework/parallel.py @@ -3,8 +3,8 @@ import random import time import traceback -from mpire.pool import WorkerPool import pickle +from mpire.pool import WorkerPool from ditk import logging import tempfile import socket @@ -27,6 +27,7 @@ def __init__(self) -> None: self._listener = None self.is_active = False self.node_id = None + self.local_id = None self.labels = set() self._event_loop = EventLoop("parallel_{}".format(id(self))) self._retries = 0 # Retries in auto recovery @@ -34,17 +35,24 @@ def __init__(self) -> None: def _run( self, node_id: int, + local_id: int, + n_parallel_workers: int, labels: Optional[Set[str]] = None, auto_recover: bool = False, max_retries: int = float("inf"), mq_type: str = "nng", + startup_interval: int = 1, **kwargs ) -> None: self.node_id = node_id + self.local_id = local_id + self.startup_interval = startup_interval + self.n_parallel_workers = n_parallel_workers self.labels = labels or set() self.auto_recover = auto_recover self.max_retries = max_retries self._mq = MQ_REGISTRY.get(mq_type)(**kwargs) + time.sleep(self.local_id * self.startup_interval) self._listener = Thread(target=self.listen, name="mq_listener", daemon=True) self._listener.start() @@ -63,7 +71,8 @@ def runner( auto_recover: bool = False, max_retries: int = float("inf"), redis_host: Optional[str] = None, - redis_port: Optional[int] = None + redis_port: Optional[int] = None, + startup_interval: int = 1 ) -> Callable: """ Overview: @@ -85,6 +94,7 @@ def runner( - max_retries (:obj:`int`): Max retries for auto recover. - redis_host (:obj:`str`): Redis server host. - redis_port (:obj:`int`): Redis server port. + - startup_interval (:obj:`int`): Start up interval between each task. Returns: - _runner (:obj:`Callable`): The wrapper function for main. """ @@ -102,7 +112,10 @@ def _runner(main_process: Callable, *args, **kwargs) -> None: - main_process (:obj:`Callable`): The main function, your program start from here. """ runner_params = args_parsers[mq_type](**all_args) - params_group = [[runner_kwargs, (main_process, args, kwargs)] for runner_kwargs in runner_params] + params_group = [] + for i, runner_kwargs in enumerate(runner_params): + runner_kwargs["local_id"] = i + params_group.append([runner_kwargs, (main_process, args, kwargs)]) if n_parallel_workers == 1: cls._subprocess_runner(*params_group[0]) @@ -128,7 +141,6 @@ def _nng_args_parser( ) -> Dict[str, dict]: attach_to = attach_to or [] nodes = cls.get_node_addrs(n_parallel_workers, protocol=protocol, address=address, ports=ports) - logging.info("Bind subprocesses on these addresses: {}".format(nodes)) def cleanup_nodes(): for node in nodes: @@ -156,6 +168,7 @@ def topology_network(i: int) -> List[str]: "node_id": candidate_node_ids[i], "listen_to": nodes[i], "attach_to": topology_network(i), + "n_parallel_workers": n_parallel_workers, } runner_params.append(runner_kwargs) @@ -166,7 +179,7 @@ def _redis_args_parser(cls, n_parallel_workers: int, node_ids: Optional[Union[Li runner_params = [] candidate_node_ids = cls.padding_param(node_ids, n_parallel_workers, 0) for i in range(n_parallel_workers): - runner_kwargs = {**kwargs, "node_id": candidate_node_ids[i]} + runner_kwargs = {**kwargs, "n_parallel_workers": n_parallel_workers, "node_id": candidate_node_ids[i]} runner_params.append(runner_kwargs) return runner_params @@ -179,6 +192,7 @@ def _subprocess_runner(cls, runner_kwargs: dict, main_params: Tuple[Union[List, - runner_params (:obj:`Tuple[Union[List, Dict]]`): Args and kwargs for runner. - main_params (:obj:`Tuple[Union[List, Dict]]`): Args and kwargs for main function. """ + logging.getLogger().setLevel(logging.INFO) main_process, args, kwargs = main_params with Parallel() as router: @@ -320,7 +334,7 @@ def emit(self, event: str, *args, **kwargs) -> None: if self.is_active: payload = {"a": args, "k": kwargs} try: - data = pickle.dumps(payload, protocol=-1) + data = pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL) except AttributeError as e: logging.error("Arguments are not pickable! Event: {}, Args: {}".format(event, args)) raise e @@ -351,12 +365,12 @@ def get_ip(cls): try: # doesn't even have to be reachable s.connect(('10.255.255.255', 1)) - IP = s.getsockname()[0] + ip = s.getsockname()[0] except Exception: - IP = '127.0.0.1' + ip = '127.0.0.1' finally: s.close() - return IP + return ip def __enter__(self) -> "Parallel": return self diff --git a/ding/framework/supervisor.py b/ding/framework/supervisor.py index 22f67177f0..7d385c12c6 100644 --- a/ding/framework/supervisor.py +++ b/ding/framework/supervisor.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod -import multiprocessing as mp +import functools +import torch.multiprocessing as mp +from multiprocessing.context import BaseContext import threading import queue import platform @@ -12,6 +14,13 @@ from enum import Enum +@functools.lru_cache(maxsize=1) +def get_mp_ctx() -> BaseContext: + context = 'spawn' if platform.system().lower() == 'windows' else 'fork' + mp_ctx = mp.get_context(context) + return mp_ctx + + @dataclass class SendPayload: proc_id: int @@ -29,6 +38,7 @@ class RecvPayload: method: str = None data: Any = None err: Exception = None + extra: Any = None class ReserveMethod(Enum): @@ -41,27 +51,16 @@ class ChildType(Enum): THREAD = "thread" -@dataclass -class SharedObject: - buf: Any - callback: Callable - - class Child(ABC): """ Abstract class of child process/thread. """ - def __init__( - self, proc_id: int, init: Callable, *args, shared_object: Optional[SharedObject] = None, **kwargs - ) -> None: + def __init__(self, proc_id: int, init: Union[Callable, object], **kwargs) -> None: self._proc_id = proc_id self._init = init - self._args = args - self._kwargs = kwargs self._recv_queue = None self._send_queue = None - self._shared_object = shared_object @abstractmethod def start(self, recv_queue: Union[mp.Queue, queue.Queue]): @@ -82,15 +81,17 @@ def send(self, payload: SendPayload): def _target( self, proc_id: int, - init: Callable, - args: List, - kwargs: Dict[str, Any], + init: Union[Callable, object], send_queue: Union[mp.Queue, queue.Queue], recv_queue: Union[mp.Queue, queue.Queue], - shared_object: Optional[SharedObject] = None + shm_buffer: Optional[Any] = None, + shm_callback: Optional[Callable] = None ): send_payload = SendPayload(proc_id=proc_id) - child_ins = init(*args, **kwargs) + if isinstance(init, Callable): + child_ins = init() + else: + child_ins = init while True: try: send_payload: SendPayload = send_queue.get() @@ -103,8 +104,8 @@ def _target( recv_payload = RecvPayload( proc_id=proc_id, req_id=send_payload.req_id, method=send_payload.method, data=data ) - if shared_object: - shared_object.callback(recv_payload, shared_object.buf) + if shm_callback is not None and shm_buffer is not None: + shm_callback(recv_payload, shm_buffer) recv_queue.put(recv_payload) except Exception as e: logging.warning(traceback.format_exc()) @@ -121,27 +122,35 @@ def __del__(self): class ChildProcess(Child): def __init__( - self, proc_id: int, init: Callable, *args, shared_object: Optional[SharedObject] = None, **kwargs + self, + proc_id: int, + init: Union[Callable, object], + shm_buffer: Optional[Any] = None, + shm_callback: Optional[Callable] = None, + mp_ctx: Optional[BaseContext] = None, + **kwargs ) -> None: - super().__init__(proc_id, init, *args, shared_object=shared_object, **kwargs) + super().__init__(proc_id, init, **kwargs) self._proc = None + self._mp_ctx = mp_ctx + self._shm_buffer = shm_buffer + self._shm_callback = shm_callback def start(self, recv_queue: mp.Queue): - self._recv_queue = recv_queue - context = 'spawn' if platform.system().lower() == 'windows' else 'fork' - ctx = mp.get_context(context) - self._send_queue = ctx.Queue() - proc = ctx.Process( - target=self._target, - args=( - self._proc_id, self._init, self._args, self._kwargs, self._send_queue, self._recv_queue, - self._shared_object - ), - name="supervisor_child_{}_{}".format(self._proc_id, time.time()), - daemon=True - ) - proc.start() - self._proc = proc + if self._proc is None: + self._recv_queue = recv_queue + ctx = self._mp_ctx or get_mp_ctx() + self._send_queue = ctx.Queue() + proc = ctx.Process( + target=self._target, + args=( + self._proc_id, self._init, self._send_queue, self._recv_queue, self._shm_buffer, self._shm_callback + ), + name="supervisor_child_{}_{}".format(self._proc_id, time.time()), + daemon=True + ) + proc.start() + self._proc = proc def shutdown(self, timeout: Optional[float] = None): if self._proc: @@ -156,28 +165,30 @@ def shutdown(self, timeout: Optional[float] = None): self._send_queue = None def send(self, payload: SendPayload): + if self._send_queue is None: + logging.warning("Child worker has been terminated or not started.") + return self._send_queue.put(payload) class ChildThread(Child): - def __init__( - self, proc_id: int, init: Callable, *args, shared_object: Optional[SharedObject] = None, **kwargs - ) -> None: - super().__init__(proc_id, init, *args, shared_object=shared_object, **kwargs) + def __init__(self, proc_id: int, init: Union[Callable, object], *args, **kwargs) -> None: + super().__init__(proc_id, init, *args, **kwargs) self._thread = None def start(self, recv_queue: queue.Queue): - self._recv_queue = recv_queue - self._send_queue = queue.Queue() - thread = threading.Thread( - target=self._target, - args=(self._proc_id, self._init, self._args, self._kwargs, self._send_queue, self._recv_queue), - name="supervisor_child_{}_{}".format(self._proc_id, time.time()), - daemon=True - ) - thread.start() - self._thread = thread + if self._thread is None: + self._recv_queue = recv_queue + self._send_queue = queue.Queue() + thread = threading.Thread( + target=self._target, + args=(self._proc_id, self._init, self._send_queue, self._recv_queue), + name="supervisor_child_{}_{}".format(self._proc_id, time.time()), + daemon=True + ) + thread.start() + self._thread = thread def shutdown(self, timeout: Optional[float] = None): if self._thread: @@ -187,6 +198,9 @@ def shutdown(self, timeout: Optional[float] = None): self._send_queue = None def send(self, payload: SendPayload): + if self._send_queue is None: + logging.warning("Child worker has been terminated or not started.") + return self._send_queue.put(payload) @@ -194,26 +208,32 @@ class Supervisor: TYPE_MAPPING = {ChildType.PROCESS: ChildProcess, ChildType.THREAD: ChildThread} - QUEUE_MAPPING = { - ChildType.PROCESS: mp.get_context('spawn' if platform.system().lower() == 'windows' else 'fork').Queue, - ChildType.THREAD: queue.Queue - } - - def __init__(self, type_: ChildType) -> None: + def __init__(self, type_: ChildType, mp_ctx: Optional[BaseContext] = None) -> None: self._children: List[Child] = [] self._type = type_ self._child_class = self.TYPE_MAPPING[self._type] self._running = False self.__queue = None + self._mp_ctx = mp_ctx or get_mp_ctx() - def register(self, init: Callable, *args, shared_object: Optional[SharedObject] = None, **kwargs) -> None: + def register( + self, + init: Union[Callable, object], + shm_buffer: Optional[Any] = None, + shm_callback: Optional[Callable] = None + ) -> None: proc_id = len(self._children) - self._children.append(self._child_class(proc_id, init, *args, shared_object=shared_object, **kwargs)) + self._children.append( + self._child_class(proc_id, init, shm_buffer=shm_buffer, shm_callback=shm_callback, mp_ctx=self._mp_ctx) + ) @property def _recv_queue(self) -> Union[queue.Queue, mp.Queue]: if not self.__queue: - self.__queue = self.QUEUE_MAPPING[self._type]() + if self._type is ChildType.PROCESS: + self.__queue = self._mp_ctx.Queue() + elif self._type is ChildType.THREAD: + self.__queue = queue.Queue() return self.__queue @_recv_queue.setter @@ -233,6 +253,9 @@ def send(self, payload: SendPayload) -> None: Arguments: - payload (:obj:`SendPayload`): Send payload. """ + if not self._running: + logging.warning("Please call start_link before sending any payload to child process.") + return self._children[payload.proc_id].send(payload) def recv(self, ignore_err: bool = False, timeout: float = None) -> RecvPayload: diff --git a/ding/framework/task.py b/ding/framework/task.py index 53e95716b0..ed3e14eb93 100644 --- a/ding/framework/task.py +++ b/ding/framework/task.py @@ -7,8 +7,11 @@ import concurrent.futures import fnmatch import math +import enum from types import GeneratorType from typing import Any, Awaitable, Callable, Dict, Generator, Iterable, List, Optional, Set, Union +import inspect + from ding.framework.context import Context from ding.framework.parallel import Parallel from ding.framework.event_loop import EventLoop @@ -50,11 +53,27 @@ def runtime_handler(task: "Task", *args, async_mode: Optional[bool] = None, **kw return runtime_handler +class Role(str, enum.Enum): + LEARNER = "learner" + COLLECTOR = "collector" + EVALUATOR = "evaluator" + + +class VoidMiddleware: + + def __call__(self, _): + return + + class Task: """ Tash will manage the execution order of the entire pipeline, register new middleware, and generate new context objects. """ + role = Role + + def __init__(self) -> None: + self.router = Parallel() def start( self, @@ -71,6 +90,7 @@ def start( self._wrappers = [] self.ctx = ctx or Context() self._backward_stack = OrderedDict() + self._roles = set() # Bind event loop functions self._event_loop = EventLoop("task_{}".format(id(self))) @@ -85,7 +105,6 @@ def start( self.labels = labels or set() # Parallel segment - self.router = Parallel() if async_mode or self.router.is_active: self._activate_async() @@ -99,6 +118,21 @@ def sync_finish(value): self.init_labels() return self + def add_role(self, role: Role): + self._roles.add(role) + + def has_role(self, role: Role) -> bool: + if len(self._roles) == 0: + return True + return role in self._roles + + @property + def roles(self) -> Set[Role]: + return self._roles + + def void(self): + return VoidMiddleware() + def init_labels(self): if self.async_mode: self.labels.add("async") @@ -120,6 +154,9 @@ def use(self, fn: Callable, lock: Union[bool, Lock] = False) -> 'Task': Returns: - task (:obj:`Task`): The task. """ + assert isinstance(fn, Callable), "Middleware function should be a callable object, current fn {}".format(fn) + if isinstance(fn, VoidMiddleware): # Skip void function + return self for wrapper in self._wrappers: fn = wrapper(fn) self._middleware.append(self.wrap(fn, lock=lock)) @@ -192,7 +229,6 @@ def wrap(self, fn: Callable, lock: Union[bool, Lock] = False) -> Callable: if lock is True: lock = self._thread_lock - @wraps(fn) def forward(ctx: Context): if lock: with lock: @@ -212,6 +248,11 @@ def backward(): return backward + if hasattr(fn, "__name__"): + forward = wraps(fn)(forward) + else: + forward = wraps(fn.__class__)(forward) + return forward @enable_async @@ -258,6 +299,10 @@ def backward(self, backward_stack: Optional[Dict[str, Generator]] = None) -> Non except StopIteration: continue + @property + def running(self): + return self._running + def serial(self, *fns: List[Callable]) -> Callable: """ Overview: @@ -330,6 +375,8 @@ def stop(self) -> None: Overview: Stop and cleanup every thing in the runtime of task. """ + if self.router.is_active: + self.emit("finish", True) if self._thread_pool: self._thread_pool.shutdown() self._event_loop.stop() @@ -472,8 +519,6 @@ def finish(self): @finish.setter def finish(self, value: bool): self._finish = value - if self.router.is_active and value is True: - self.emit("finish", value) def _wrap_event_name(self, event: str) -> str: """ diff --git a/ding/framework/tests/test_event_loop.py b/ding/framework/tests/test_event_loop.py index 1dddae6164..2f3545f3f5 100644 --- a/ding/framework/tests/test_event_loop.py +++ b/ding/framework/tests/test_event_loop.py @@ -31,6 +31,8 @@ def callback(n, lock): assert counter == 10 # Test once + counter = 0 + loop.once("count", callback) loop.once("count", callback) loop.emit("count", 10, lock) sleep(0.1) diff --git a/ding/framework/tests/test_parallel.py b/ding/framework/tests/test_parallel.py index b042cb3a57..429072a3fc 100644 --- a/ding/framework/tests/test_parallel.py +++ b/ding/framework/tests/test_parallel.py @@ -1,9 +1,7 @@ from collections import defaultdict import pytest import time -import os from ding.framework import Parallel -from ding.utils.design_helper import SingletonMetaclass def parallel_main(): @@ -28,8 +26,8 @@ def test_callback(key): @pytest.mark.unittest def test_parallel_run(): - Parallel.runner(n_parallel_workers=2)(parallel_main) - Parallel.runner(n_parallel_workers=2, protocol="tcp")(parallel_main) + Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(parallel_main) + Parallel.runner(n_parallel_workers=2, protocol="tcp", startup_interval=0.1)(parallel_main) def uncaught_exception_main(): @@ -45,7 +43,7 @@ def uncaught_exception_main(): def test_uncaught_exception(): # Make one process crash, then the parent process will also crash and output the stack of the wrong process. with pytest.raises(Exception) as exc_info: - Parallel.runner(n_parallel_workers=2, topology="mesh")(uncaught_exception_main) + Parallel.runner(n_parallel_workers=2, topology="mesh", startup_interval=0.1)(uncaught_exception_main) e = exc_info._excinfo[1] assert "uncaught exception" in str(e) @@ -54,6 +52,7 @@ def disconnected_main(): router = Parallel() if router.node_id == 0: + time.sleep(0.1) # Receive two messages then exit greets = [] router.on("greeting", lambda: greets.append(".")) @@ -75,7 +74,7 @@ def disconnected_main(): def test_disconnected(): # Make one process exit normally and the rest will still run, even if the network request # is not received by other processes. - Parallel.runner(n_parallel_workers=2, topology="mesh")(disconnected_main) + Parallel.runner(n_parallel_workers=2, topology="mesh", startup_interval=0.1)(disconnected_main) class AutoRecover: @@ -145,9 +144,13 @@ def main(cls): @pytest.mark.unittest def test_auto_recover(): # With max_retries=1 - Parallel.runner(n_parallel_workers=3, topology="mesh", auto_recover=True, max_retries=1)(AutoRecover.main) + Parallel.runner( + n_parallel_workers=3, topology="mesh", auto_recover=True, max_retries=1, startup_interval=0.1 + )(AutoRecover.main) # With max_retries=0 with pytest.raises(Exception) as exc_info: - Parallel.runner(n_parallel_workers=3, topology="mesh", auto_recover=True, max_retries=0)(AutoRecover.main) + Parallel.runner( + n_parallel_workers=3, topology="mesh", auto_recover=True, max_retries=0, startup_interval=0.1 + )(AutoRecover.main) e = exc_info._excinfo[1] assert "P1 Error" in str(e) diff --git a/ding/framework/tests/test_supervisor.py b/ding/framework/tests/test_supervisor.py index 57a0f0d49f..b4fdb95dc0 100644 --- a/ding/framework/tests/test_supervisor.py +++ b/ding/framework/tests/test_supervisor.py @@ -1,9 +1,9 @@ import multiprocessing as mp import ctypes -from time import sleep +from time import sleep, time from typing import Any, Dict, List import pytest -from ding.framework.supervisor import RecvPayload, SendPayload, Supervisor, ChildType, SharedObject +from ding.framework.supervisor import RecvPayload, SendPayload, Supervisor, ChildType class MockEnv(): @@ -25,13 +25,16 @@ def block(self): def block_reset(self): sleep(10) + def sleep1(self): + sleep(1) + @pytest.mark.unittest @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD]) def test_supervisor(type_): sv = Supervisor(type_=type_) for _ in range(3): - sv.register(MockEnv, "AnyArgs") + sv.register(lambda: MockEnv("AnyArgs")) sv.start_link() for env_id in range(len(sv._children)): @@ -71,6 +74,25 @@ def test_supervisor(type_): sv.shutdown() +@pytest.mark.unittest +def test_supervisor_spawn(): + sv = Supervisor(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn")) + for _ in range(3): + sv.register(MockEnv("AnyArgs")) + sv.start_link() + + for env_id in range(len(sv._children)): + sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"])) + + recv_states: List[RecvPayload] = [] + for _ in range(3): + recv_states.append(sv.recv()) + + assert sum([payload.proc_id for payload in recv_states]) == 3 + assert all([payload.data == 1 for payload in recv_states]) + sv.shutdown() + + class MockCrashEnv(MockEnv): def step(self, _): @@ -86,8 +108,8 @@ def step(self, _): def test_crash_supervisor(type_): sv = Supervisor(type_=type_) for _ in range(2): - sv.register(MockEnv, "AnyArgs") - sv.register(MockCrashEnv, "AnyArgs") + sv.register(lambda: MockEnv("AnyArgs")) + sv.register(lambda: MockCrashEnv("AnyArgs")) sv.start_link() # Send 6 messages, will cause the third subprocess crash @@ -126,7 +148,7 @@ def test_crash_supervisor(type_): def test_recv_all(type_): sv = Supervisor(type_=type_) for _ in range(3): - sv.register(MockEnv, "AnyArgs") + sv.register(lambda: MockEnv("AnyArgs")) sv.start_link() # Test recv_all @@ -162,7 +184,7 @@ def recv_callback(recv_payload: RecvPayload, remain_payloads: Dict[str, SendPayl def test_timeout(type_): sv = Supervisor(type_=type_) for _ in range(3): - sv.register(MockEnv, "AnyArgs") + sv.register(lambda: MockEnv("AnyArgs")) sv.start_link() send_payloads = [] @@ -202,7 +224,7 @@ def test_timeout(type_): def test_timeout_with_callback(type_): sv = Supervisor(type_=type_) for _ in range(3): - sv.register(MockEnv, "AnyArgs") + sv.register(lambda: MockEnv("AnyArgs")) sv.start_link() send_payloads = [] @@ -239,25 +261,50 @@ def recv_callback(recv_payload: RecvPayload, remain_payloads: Dict[str, SendPayl sv.shutdown(timeout=1) -@pytest.mark.unittest +@pytest.mark.tmp # gitlab ci and local test pass, github always fail def test_shared_memory(): sv = Supervisor(type_=ChildType.PROCESS) def shm_callback(payload: RecvPayload, shm: Any): - shm[payload.proc_id] = payload.data + shm[payload.proc_id] = payload.req_id payload.data = 0 shm = mp.Array(ctypes.c_uint8, 3) for i in range(3): - sv.register(MockEnv, "AnyArgs", shared_object=SharedObject(buf=shm, callback=shm_callback)) + sv.register(lambda: MockEnv("AnyArgs"), shm_buffer=shm, shm_callback=shm_callback) sv.start_link() + # Send init request for env_id in range(len(sv._children)): - sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"])) + sv.send(SendPayload(proc_id=env_id, req_id=env_id, method="sleep1", args=[])) - for i in range(3): + start = time() + for i in range(6): payload = sv.recv() assert payload.data == 0 - assert shm[payload.proc_id] == 1 + assert shm[payload.proc_id] == payload.req_id + sv.send(SendPayload(proc_id=payload.proc_id, req_id=i, method="sleep1", args=[])) + + # Non blocking + assert time() - start < 3 sv.shutdown() + + +@pytest.mark.benchmark +@pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD]) +def test_supervisor_benchmark(type_): + sv = Supervisor(type_=type_) + for _ in range(3): + sv.register(lambda: MockEnv("AnyArgs")) + sv.start_link() + + for env_id in range(len(sv._children)): + sv.send(SendPayload(proc_id=env_id, method="step", args=[""])) + + start = time() + for _ in range(1000): + payload = sv.recv() + sv.send(SendPayload(proc_id=payload.proc_id, method="step", args=[""])) + + assert time() - start < 1 diff --git a/ding/framework/tests/test_task.py b/ding/framework/tests/test_task.py index 7ad13977bc..8b6f9ee1de 100644 --- a/ding/framework/tests/test_task.py +++ b/ding/framework/tests/test_task.py @@ -1,6 +1,7 @@ +import multiprocessing as mp import pytest from threading import Lock -from time import sleep +from time import sleep, time import random import dataclasses from ding.framework import task, Context, Parallel @@ -125,7 +126,7 @@ def _counter(ctx): @pytest.mark.unittest def test_parallel_pipeline(): - Parallel.runner(n_parallel_workers=2)(parallel_main) + Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(parallel_main) @pytest.mark.unittest @@ -162,7 +163,7 @@ def emit_remote_main(): @pytest.mark.unittest def test_emit_remote(): - Parallel.runner(n_parallel_workers=2)(emit_remote_main) + Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(emit_remote_main) @pytest.mark.unittest @@ -228,7 +229,7 @@ def early_stop_main(): @pytest.mark.unittest def test_early_stop(): - Parallel.runner(n_parallel_workers=2)(early_stop_main) + Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(early_stop_main) @pytest.mark.unittest @@ -333,3 +334,49 @@ def slowest(ctx): task.use(fast, lock=lock) task.run(1) assert task.ctx.result == "slowest" + + +def broadcast_finish_main(): + with task.start(): + + def tick(ctx: Context): + if task.router.node_id == 1 and ctx.total_step == 1: + task.finish = True + sleep(1) + + task.use(tick) + task.run(20) + + +def broadcast_main_target(): + Parallel.runner( + n_parallel_workers=1, protocol="tcp", address="127.0.0.1", topology="mesh", ports=50555, startup_interval=0.1 + )(broadcast_finish_main) + + +def broadcast_secondary_target(): + "Start two standalone processes and connect to the main process." + Parallel.runner( + n_parallel_workers=2, + protocol="tcp", + address="127.0.0.1", + topology="alone", + ports=50556, + attach_to=["tcp://127.0.0.1:50555"], + node_ids=[1, 2], + startup_interval=0.1 + )(broadcast_finish_main) + + +@pytest.mark.tmp # gitlab ci and local test pass, github always fail +@pytest.mark.timeout(10) +def test_broadcast_finish(): + start = time() + ctx = mp.get_context("spawn") + main_process = ctx.Process(target=broadcast_main_target) + secondary_process = ctx.Process(target=broadcast_secondary_target) + main_process.start() + secondary_process.start() + main_process.join() + secondary_process.join() + assert (time() - start) < 10 diff --git a/ding/framework/wrapper/step_timer.py b/ding/framework/wrapper/step_timer.py index f7d123bc62..dfabdd1476 100644 --- a/ding/framework/wrapper/step_timer.py +++ b/ding/framework/wrapper/step_timer.py @@ -5,11 +5,20 @@ import numpy as np import time from ditk import logging +from ding.framework import task class StepTimer: def __init__(self, print_per_step: int = 1, smooth_window: int = 10) -> None: + """ + Overview: + Print time cost of each step (execute one middleware). + Arguments: + - print_per_step (:obj:`int`): Print each N step. + - smooth_window (:obj:`int`): The window size to smooth the mean. + """ + self.print_per_step = print_per_step self.records = defaultdict(lambda: deque(maxlen=print_per_step * smooth_window)) @@ -36,11 +45,12 @@ def executor(ctx): time_cost += time.time() - start_time else: time_cost = time.time() - start_time - self.records[step_name].append(time_cost * 1000) + self.records[step_name].append(time_cost) if ctx.total_step % self.print_per_step == 0: logging.info( - "[Step Timer] {}: Cost: {:.2f}ms, Mean: {:.2f}ms".format( - step_name, time_cost * 1000, np.mean(self.records[step_name]) + "[Step Timer][Node:{:>2}] {}: Cost: {:.2f}ms, Mean: {:.2f}ms".format( + task.router.node_id or 0, step_name, time_cost * 1000, + np.mean(self.records[step_name]) * 1000 ) ) diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index 6221b48809..53181ef9ee 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -75,7 +75,6 @@ def __init__( if len(set(self._enable_field).intersection(set(['learn']))) > 0: self._rank = get_rank() if self._cfg.learn.multi_gpu else 0 if self._cuda: - torch.cuda.set_device(self._rank % torch.cuda.device_count()) model.cuda() if self._cfg.learn.multi_gpu: bp_update_sync = self._cfg.learn.get('bp_update_sync', True) @@ -84,7 +83,6 @@ def __init__( else: self._rank = 0 if self._cuda: - torch.cuda.set_device(self._rank % torch.cuda.device_count()) model.cuda() self._model = model self._device = 'cuda:{}'.format(self._rank % torch.cuda.device_count()) if self._cuda else 'cpu' diff --git a/ding/policy/dqn.py b/ding/policy/dqn.py index 776010e245..9efc68350f 100644 --- a/ding/policy/dqn.py +++ b/ding/policy/dqn.py @@ -157,12 +157,22 @@ def _init_learn(self) -> None: # use model_wrapper for specialized demands of different modes self._target_model = copy.deepcopy(self._model) - self._target_model = model_wrap( - self._target_model, - wrapper_name='target', - update_type='assign', - update_kwargs={'freq': self._cfg.learn.target_update_freq} - ) + if 'target_update_freq' in self._cfg.learn: + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.learn.target_update_freq} + ) + elif 'target_theta' in self._cfg.learn: + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.learn.target_theta} + ) + else: + raise RuntimeError("DQN needs target network, please either indicate target_update_freq or target_theta") self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') self._learn_model.reset() self._target_model.reset() @@ -203,7 +213,7 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: # Target q value with torch.no_grad(): target_q_value = self._target_model.forward(data['next_obs'])['logit'] - # Max q value action (main model) + # Max q value action (main model), i.e. Double DQN target_q_action = self._learn_model.forward(data['next_obs'])['action'] data_n = q_nstep_td_data( diff --git a/ding/policy/sac.py b/ding/policy/sac.py index d2f289d045..75f42d5223 100644 --- a/ding/policy/sac.py +++ b/ding/policy/sac.py @@ -366,8 +366,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: torch.zeros_like(self._alpha)).requires_grad_() loss_dict['total_loss'] = sum(loss_dict.values()) - info_dict = {} - # ============= # after update # ============= @@ -384,8 +382,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: 'q_value_2': target_q_value[1].detach().mean().item(), 'target_value': target_value.detach().mean().item(), 'entropy': entropy.item(), - **info_dict, - **loss_dict + #**loss_dict } def _state_dict_learn(self) -> Dict[str, Any]: diff --git a/ding/torch_utils/tests/test_data_helper.py b/ding/torch_utils/tests/test_data_helper.py index 218ce59ba7..629d081b0a 100644 --- a/ding/torch_utils/tests/test_data_helper.py +++ b/ding/torch_utils/tests/test_data_helper.py @@ -4,9 +4,10 @@ import torch import torch.nn as nn from torch.utils.data import DataLoader +import treetensor.torch as ttorch from ding.torch_utils import CudaFetcher, to_device, to_dtype, to_tensor, to_ndarray, to_list, \ - tensor_to_list, same_shape, build_log_buffer, get_tensor_data + tensor_to_list, same_shape, build_log_buffer, get_tensor_data, get_shape0 from ding.utils import EasyTimer @@ -132,6 +133,18 @@ def test_get_tensor_data(self): with pytest.raises(TypeError): get_tensor_data(EasyTimer()) + def test_get_shape0(self): + a = { + 'a': { + 'b': torch.randn(4, 3) + }, + 'c': { + 'd': torch.randn(4) + }, + } + a = ttorch.as_tensor(a) + assert get_shape0(a) == 4 + @pytest.mark.unittest def test_log_dict(): diff --git a/ding/utils/data/collate_fn.py b/ding/utils/data/collate_fn.py index c97f995cab..0485938d82 100644 --- a/ding/utils/data/collate_fn.py +++ b/ding/utils/data/collate_fn.py @@ -65,11 +65,12 @@ def default_collate(batch: Sequence, - ret (:obj:`Union[torch.Tensor, Mapping, Sequence]`): the collated data, with batch size into each data field.\ the return dtype depends on the original element dtype, can be [torch.Tensor, Mapping, Sequence]. """ - elem = batch[0] - elem_type = type(elem) if isinstance(batch, ttorch.Tensor): return batch.json() + + elem = batch[0] + elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch_ge_131() and torch.utils.data.get_worker_info() is not None: diff --git a/ding/utils/data/structure/__init__.py b/ding/utils/data/structure/__init__.py index 9e8011f9d4..3cc58828a6 100644 --- a/ding/utils/data/structure/__init__.py +++ b/ding/utils/data/structure/__init__.py @@ -1 +1,2 @@ from .cache import Cache +from .lifo_deque import LifoDeque diff --git a/ding/utils/data/structure/lifo_deque.py b/ding/utils/data/structure/lifo_deque.py new file mode 100644 index 0000000000..00d9221e5c --- /dev/null +++ b/ding/utils/data/structure/lifo_deque.py @@ -0,0 +1,12 @@ +from queue import LifoQueue +from collections import deque + + +class LifoDeque(LifoQueue): + """ + Like LifoQueue, but automatically replaces the oldest data when the queue is full. + """ + + def _init(self, maxsize): + self.maxsize = maxsize + 1 + self.queue = deque(maxlen=maxsize) diff --git a/ding/utils/pytorch_ddp_dist_helper.py b/ding/utils/pytorch_ddp_dist_helper.py index 96847be357..3c9d5881fb 100644 --- a/ding/utils/pytorch_ddp_dist_helper.py +++ b/ding/utils/pytorch_ddp_dist_helper.py @@ -114,7 +114,9 @@ def dist_finalize() -> None: Overview: Finalize distributed training resources """ - dist.destroy_process_group() + # This operation usually hangs out so we ignore it temporally. + # dist.destroy_process_group() + pass class DistContext: diff --git a/dizoo/atari/example/atari_dqn.py b/dizoo/atari/example/atari_dqn.py index 4f4be49ecb..1ac9fdc0c8 100644 --- a/dizoo/atari/example/atari_dqn.py +++ b/dizoo/atari/example/atari_dqn.py @@ -42,7 +42,7 @@ def main(): task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(CkptSaver(cfg, policy, train_freq=1000)) - task.use(termination_checker(max_train_iter=int(1e7))) + task.use(termination_checker(max_env_step=int(1e7))) task.run() diff --git a/dizoo/atari/example/atari_dqn_ddp.py b/dizoo/atari/example/atari_dqn_ddp.py new file mode 100644 index 0000000000..22e30eba89 --- /dev/null +++ b/dizoo/atari/example/atari_dqn_ddp.py @@ -0,0 +1,59 @@ +from copy import deepcopy +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.utils import DistContext, get_rank +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, nstep_reward_enhancer, online_logger, ddp_termination_checker +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.exp_name = 'pong_dqn_seed0_ddp' + main_config.policy.learn.multi_gpu = True + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with DistContext(): + rank = get_rank() + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + policy = DQNPolicy(cfg.policy, model=model) + + if rank == 0: + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + if rank == 0: + task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(online_logger(record_train_iter=True)) + task.use(ddp_termination_checker(max_env_step=int(1e7), rank=rank)) + task.run() + + +if __name__ == "__main__": + main() diff --git a/dizoo/atari/example/atari_dqn_dist.py b/dizoo/atari/example/atari_dqn_dist.py new file mode 100644 index 0000000000..d692a9f3e3 --- /dev/null +++ b/dizoo/atari/example/atari_dqn_dist.py @@ -0,0 +1,85 @@ +from copy import deepcopy +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger, termination_checker, nstep_reward_enhancer, \ + online_logger +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.exp_name = 'pong_dqn_seed0_ditask_dist' + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model) + + if 'learner' in task.router.labels: + logging.info("Learner running on node {}".format(task.router.node_id)) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + task.use( + context_exchanger( + send_keys=["train_iter"], + recv_keys=["trajectories", "episodes", "env_step", "env_episode"], + skip_n_iter=0 + ) + ) + task.use(model_exchanger(model, is_learner=True)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + task.use(CkptSaver(cfg, policy, train_freq=1000)) + + elif 'evaluator' in task.router.labels: + logging.info("Evaluator running on node {}".format(task.router.node_id)) + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + task.use(context_exchanger(recv_keys=["train_iter", "env_step"], skip_n_iter=1)) + task.use(model_exchanger(model, is_learner=False)) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(CkptSaver(cfg, policy, save_finish=False)) + task.use(online_logger(record_train_iter=True)) + + elif 'collector' in task.router.labels: + logging.info("Collector running on node {}".format(task.router.node_id)) + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + task.use( + context_exchanger( + send_keys=["trajectories", "episodes", "env_step", "env_episode"], + recv_keys=["train_iter"], + skip_n_iter=1 + ) + ) + task.use(model_exchanger(model, is_learner=False)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(termination_checker(max_env_step=int(1e7))) + else: + raise KeyError("invalid router labels: {}".format(task.router.labels)) + + task.run() + + +if __name__ == "__main__": + main() diff --git a/dizoo/atari/example/atari_dqn_dist_ddp.py b/dizoo/atari/example/atari_dqn_dist_ddp.py new file mode 100644 index 0000000000..6b615abb21 --- /dev/null +++ b/dizoo/atari/example/atari_dqn_dist_ddp.py @@ -0,0 +1,107 @@ +from copy import deepcopy +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger, termination_checker, nstep_reward_enhancer, \ + online_logger +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config + + +logging.getLogger().setLevel(logging.INFO) +main_config.exp_name = 'pong_dqn_seed0_ditask_dist_ddp' + + +def learner(): + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model, enable_field=['learn']) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + logging.info("Learner running on node {}".format(task.router.node_id)) + + from ding.utils import DistContext, get_rank + with DistContext(): + rank = get_rank() + task.use( + context_exchanger( + send_keys=["train_iter"], + recv_keys=["trajectories", "episodes", "env_step", "env_episode"], + skip_n_iter=0 + ) + ) + task.use(model_exchanger(model, is_learner=True)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + if rank == 0: + task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.run() + + +def collector(): + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model, enable_field=['collect']) + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + logging.info("Collector running on node {}".format(task.router.node_id)) + + task.use( + context_exchanger( + send_keys=["trajectories", "episodes", "env_step", "env_episode"], + recv_keys=["train_iter"], + skip_n_iter=1 + ) + ) + task.use(model_exchanger(model, is_learner=False)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(termination_checker(max_env_step=int(1e7))) + task.run() + + +def evaluator(): + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model, enable_field=['eval']) + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + logging.info("Evaluator running on node {}".format(task.router.node_id)) + + task.use(context_exchanger(recv_keys=["train_iter", "env_step"], skip_n_iter=1)) + task.use(model_exchanger(model, is_learner=False)) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(CkptSaver(cfg, policy, save_finish=False)) + task.use(online_logger(record_train_iter=True)) + task.run() diff --git a/dizoo/atari/example/atari_dqn_dist_rdma.py b/dizoo/atari/example/atari_dqn_dist_rdma.py new file mode 100644 index 0000000000..71fb1d64a1 --- /dev/null +++ b/dizoo/atari/example/atari_dqn_dist_rdma.py @@ -0,0 +1,72 @@ +from copy import deepcopy +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger, termination_checker, nstep_reward_enhancer, \ + online_logger +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.exp_name = 'pong_dqn_seed0_dist_rdma' + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with task.start(async_mode=False, ctx=OnlineRLContext()): + assert task.router.is_active, "Please execute this script with ditask! See note in the header." + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + policy = DQNPolicy(cfg.policy, model=model) + + if 'learner' in task.router.labels: + logging.info("Learner running on node {}".format(task.router.node_id)) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + task.use( + context_exchanger( + send_keys=["train_iter"], + recv_keys=["trajectories", "episodes", "env_step", "env_episode"], + skip_n_iter=0 + ) + ) + task.use(model_exchanger(model, is_learner=True)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + task.use(CkptSaver(cfg, policy, train_freq=1000)) + + elif 'collector' in task.router.labels: + logging.info("Collector running on node {}".format(task.router.node_id)) + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + task.use( + context_exchanger( + send_keys=["trajectories", "episodes", "env_step", "env_episode"], + recv_keys=["train_iter"], + skip_n_iter=1 + ) + ) + task.use(model_exchanger(model, is_learner=False)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(termination_checker(max_env_step=int(1e7))) + else: + raise KeyError("invalid router labels: {}".format(task.router.labels)) + + task.run() + + +if __name__ == "__main__": + main() diff --git a/dizoo/atari/example/atari_dqn_dp.py b/dizoo/atari/example/atari_dqn_dp.py new file mode 100644 index 0000000000..cea9618061 --- /dev/null +++ b/dizoo/atari/example/atari_dqn_dp.py @@ -0,0 +1,53 @@ +from copy import deepcopy +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.torch_utils import DataParallel +from ding.framework import task +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, nstep_reward_enhancer, termination_checker +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.exp_name = 'pong_dqn_seed0_dp' + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + model = DataParallel(model) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + policy = DQNPolicy(cfg.policy, model=model) + + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(termination_checker(max_env_step=int(1e7))) + task.run() + + +if __name__ == "__main__": + main() diff --git a/dizoo/atari/example/atari_ppo.py b/dizoo/atari/example/atari_ppo.py new file mode 100644 index 0000000000..94b99ca8c2 --- /dev/null +++ b/dizoo/atari/example/atari_ppo.py @@ -0,0 +1,47 @@ +from copy import deepcopy +from ditk import logging +from ding.model import VAC +from ding.policy import PPOPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \ + gae_estimator, termination_checker +from ding.utils import set_pkg_seed +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = VAC(**cfg.policy.model) + policy = PPOPolicy(cfg.policy, model=model) + + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(gae_estimator(cfg, policy.collect_mode)) + task.use(multistep_trainer(cfg, policy.learn_mode)) + task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(termination_checker(max_env_step=int(1e7))) + task.run() + + +if __name__ == "__main__": + main() diff --git a/dizoo/atari/example/atari_ppo_ddp.py b/dizoo/atari/example/atari_ppo_ddp.py new file mode 100644 index 0000000000..e498e03394 --- /dev/null +++ b/dizoo/atari/example/atari_ppo_ddp.py @@ -0,0 +1,56 @@ +from copy import deepcopy +from ditk import logging +from ding.model import VAC +from ding.policy import PPOPolicy +from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \ + gae_estimator, ddp_termination_checker, online_logger +from ding.utils import set_pkg_seed, DistContext, get_rank, get_world_size +from dizoo.atari.envs.atari_env import AtariEnv +from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + with DistContext(): + rank, world_size = get_rank(), get_world_size() + main_config.example = 'pong_ppo_seed0_ddp_avgsplit' + main_config.policy.learn.multi_gpu = True + main_config.policy.learn.batch_size = main_config.policy.learn.batch_size // world_size + main_config.policy.collect.n_sample = main_config.policy.collect.n_sample // world_size + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_cfg = deepcopy(cfg.env) + collector_cfg.is_train = True + evaluator_cfg = deepcopy(cfg.env) + evaluator_cfg.is_train = False + collector_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + evaluator_env = SubprocessEnvManagerV2( + env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = VAC(**cfg.policy.model) + policy = PPOPolicy(cfg.policy, model=model) + + if rank == 0: + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(gae_estimator(cfg, policy.collect_mode)) + task.use(multistep_trainer(cfg, policy.learn_mode)) + if rank == 0: + task.use(CkptSaver(cfg, policy, train_freq=1000)) + task.use(ddp_termination_checker(max_env_step=int(1e7), rank=rank)) + task.run() + + +if __name__ == "__main__": + main() diff --git a/dizoo/mujoco/example/mujoco_sac.py b/dizoo/mujoco/example/mujoco_sac.py index 0349d8d161..471e4c8f29 100644 --- a/dizoo/mujoco/example/mujoco_sac.py +++ b/dizoo/mujoco/example/mujoco_sac.py @@ -37,7 +37,7 @@ def main(): task.use(data_pusher(cfg, buffer_)) task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) task.use(CkptSaver(cfg, policy, train_freq=500)) - task.use(termination_checker(max_train_iter=int(3e6))) + task.use(termination_checker(max_env_step=int(3e6))) task.run()