Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(nyz): add new middleware distributed demo #321

Merged
merged 81 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
8f7135b
demo(nyz): add naive dp demo
PaParaZz1 May 14, 2022
6733b69
demo(nyz): add naive ddp demo
PaParaZz1 May 15, 2022
b907c63
feature(nyz): add naive tb_logger in new evaluator
PaParaZz1 May 15, 2022
4b833d0
Add singleton log writer
sailxjx May 16, 2022
b0e6238
Use get_instance on writer
sailxjx May 16, 2022
7a344e6
feature(nyz): add general logger middleware
PaParaZz1 May 16, 2022
ae37da2
feature(nyz): add soft update in DQN target network
PaParaZz1 May 17, 2022
dbee60a
fix(nyz): fix termination env_step bug and eval task.finish broadcast…
PaParaZz1 May 17, 2022
7a4789d
Merge branch 'dev-dist' of https://gitlab.bj.sensetime.com/open-XLab/…
PaParaZz1 May 17, 2022
e7c9d96
Support distributed dqn
sailxjx Apr 26, 2022
472abb7
Add more desc (ci skip)
sailxjx Apr 26, 2022
f80f047
Support distributed dqn
sailxjx Apr 26, 2022
a563a91
Merge branch 'dev-dist' of https://github.com/opendilab/DI-engine int…
PaParaZz1 May 25, 2022
43f6f01
feature(nyz): add online logger freq
PaParaZz1 May 26, 2022
1e0c4a1
fix(nyz): fix policy set device bug
PaParaZz1 Jun 1, 2022
c043ebd
add offline rl logger
hiha3456 Jun 6, 2022
df8719a
change a bit
hiha3456 Jun 6, 2022
7ade025
add else in checking ctx type
hiha3456 Jun 6, 2022
fe6a32f
add test_logger.py
hiha3456 Jun 6, 2022
4754c80
add mock of offline_logger
hiha3456 Jun 6, 2022
83869c8
add mock of online writer
hiha3456 Jun 6, 2022
302c824
reformat
hiha3456 Jun 6, 2022
51e3e0e
reformat
hiha3456 Jun 6, 2022
aa29252
feature(nyz): polish atari ddp demo and add dist demo
PaParaZz1 Jun 7, 2022
16d8107
fix(nyz): fix mq listen bug when stop
PaParaZz1 Jun 8, 2022
01c9868
demo(nyz): add atari ppo(sm+ddp) demo
PaParaZz1 Jun 8, 2022
d04ad08
Merge branch 'dev-dist' of https://gitlab.bj.sensetime.com/open-XLab/…
PaParaZz1 Jun 8, 2022
dd4e0db
demo(nyz): add ppo ddp avgsplit demo
PaParaZz1 Jun 8, 2022
aa9b779
Merge branch 'dev-dist' of https://gitlab.bj.sensetime.com/open-XLab/…
PaParaZz1 Jun 8, 2022
b35acec
demo(nyz): add ditask + pytorch ddp demo
PaParaZz1 Jun 8, 2022
7e71cbb
fix(nyz): fix dict-type obs bugs
PaParaZz1 Jun 9, 2022
629a5ac
fix(nyz): fix get_shape0 bug when nested structure
PaParaZz1 Jun 10, 2022
88c7964
Route finish event to all processes in the cluster
sailxjx Jun 9, 2022
6ef568f
demo(nyz): add naive dp demo
PaParaZz1 May 14, 2022
60d0927
demo(nyz): add naive ddp demo
PaParaZz1 May 15, 2022
1011fee
feature(nyz): add naive tb_logger in new evaluator
PaParaZz1 May 15, 2022
b45ec41
feature(nyz): add soft update in DQN target network
PaParaZz1 May 17, 2022
f9240c7
fix(nyz): fix termination env_step bug and eval task.finish broadcast…
PaParaZz1 May 17, 2022
96e376e
Add singleton log writer
sailxjx May 16, 2022
ec413df
Use get_instance on writer
sailxjx May 16, 2022
0bb1d77
feature(nyz): add general logger middleware
PaParaZz1 May 16, 2022
c86fb2e
Support distributed dqn
sailxjx Apr 26, 2022
3da3455
Add more desc (ci skip)
sailxjx Apr 26, 2022
2050759
Support distributed dqn
sailxjx Apr 26, 2022
09c99a9
feature(nyz): add online logger freq
PaParaZz1 May 26, 2022
9c41400
fix(nyz): fix policy set device bug
PaParaZz1 Jun 1, 2022
869e63a
add offline rl logger
hiha3456 Jun 6, 2022
5d0fafd
change a bit
hiha3456 Jun 6, 2022
3bd5ba5
add else in checking ctx type
hiha3456 Jun 6, 2022
fcaf1ce
add test_logger.py
hiha3456 Jun 6, 2022
52e3500
add mock of offline_logger
hiha3456 Jun 6, 2022
48106f1
add mock of online writer
hiha3456 Jun 6, 2022
0de3bed
reformat
hiha3456 Jun 6, 2022
3063e47
reformat
hiha3456 Jun 6, 2022
d601185
feature(nyz): polish atari ddp demo and add dist demo
PaParaZz1 Jun 7, 2022
be27ff7
fix(nyz): fix mq listen bug when stop
PaParaZz1 Jun 8, 2022
e5868df
demo(nyz): add atari ppo(sm+ddp) demo
PaParaZz1 Jun 8, 2022
d637d2b
demo(nyz): add ppo ddp avgsplit demo
PaParaZz1 Jun 8, 2022
f72b6c9
demo(nyz): add ditask + pytorch ddp demo
PaParaZz1 Jun 8, 2022
8444785
fix(nyz): fix dict-type obs bugs
PaParaZz1 Jun 9, 2022
d0b1c00
fix(nyz): fix get_shape0 bug when nested structure
PaParaZz1 Jun 10, 2022
0fb89c8
Route finish event to all processes in the cluster
sailxjx Jun 9, 2022
84ea5cb
refactor(nyz): split dist ddp demo implementation
PaParaZz1 Jun 9, 2022
e706494
Merge branch 'dev-dist' of https://github.com/opendilab/DI-engine int…
PaParaZz1 Jul 18, 2022
6daaf2d
Merge branch 'main' into dev-dist
PaParaZz1 Jul 18, 2022
c350bae
feature(nyz): add rdma test demo(ci skip)
PaParaZz1 Jul 19, 2022
3f34393
Merge branch 'main' into dev-dist
PaParaZz1 Sep 8, 2022
0206137
feature(xjx): new style dist version, add storage loader and model lo…
sailxjx Sep 8, 2022
702240b
Merge branch 'main' into dev-dist
PaParaZz1 Oct 19, 2022
018b197
style(nyz): correct yapf style
PaParaZz1 Oct 19, 2022
7d931f9
fix(nyz): fix ctx and logger compatibility bugs
PaParaZz1 Oct 19, 2022
f096cfd
polish(nyz): update demo from cartpole v0 to v1
PaParaZz1 Oct 19, 2022
8d34671
fix(nyz): fix evaluator condition bug
PaParaZz1 Oct 19, 2022
06c3a8f
Merge branch 'main' into dev-dist
PaParaZz1 Nov 8, 2022
5fe4f7e
Merge branch 'main' into dev-dist
PaParaZz1 Nov 14, 2022
590728f
style(nyz): correct flake8 style
PaParaZz1 Nov 14, 2022
2236e8f
demo(nyz): move back to CartPole-v0
PaParaZz1 Nov 17, 2022
6a6798f
fix(nyz): fix context manager env step merge bug(ci skip)
PaParaZz1 Dec 1, 2022
060d1b8
fix(nyz): fix context manager env step merge bug(ci skip)
PaParaZz1 Dec 1, 2022
c12a76c
Merge branch 'main' into dev-dist
PaParaZz1 Dec 11, 2022
dde6009
fix(nyz): fix flake8 style
PaParaZz1 Dec 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ding/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def save_project_state(exp_name: str) -> None:
def _fn(cmd: str):
return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.strip().decode("utf-8")

if subprocess.run("git status", shell=True, stderr=subprocess.PIPE).returncode == 0:
if subprocess.run("git status", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).returncode == 0:
short_sha = _fn("git describe --always")
log = _fn("git log --stat -n 5")
diff = _fn("git diff")
Expand Down
4 changes: 4 additions & 0 deletions ding/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from torch.utils.data import Dataset, DataLoader
from ding.utils.data import create_dataset, offline_data_save_type # for compatibility
from .buffer import *
from .storage import *
from .storage_loader import StorageLoader, FileStorageLoader
from .shm_buffer import ShmBufferContainer, ShmBuffer
from .model_loader import ModelLoader, FileModelLoader
155 changes: 155 additions & 0 deletions ding/data/model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from abc import ABC, abstractmethod
import logging
from os import path
import os
from threading import Thread
from time import sleep, time
from typing import Callable, Optional
import uuid
import torch.multiprocessing as mp

import torch
from ding.data.storage.file import FileModelStorage
from ding.data.storage.storage import Storage
from ding.framework import Supervisor
from ding.framework.supervisor import ChildType, SendPayload


class ModelWorker():

def __init__(self, model: torch.nn.Module) -> None:
self._model = model

def save(self, storage: Storage) -> Storage:
storage.save(self._model.state_dict())
return storage


class ModelLoader(Supervisor, ABC):

def __init__(self, model: torch.nn.Module) -> None:
"""
Overview:
Save and send models asynchronously and load them synchronously.
Arguments:
- model (:obj:`torch.nn.Module`): Torch module.
"""
if next(model.parameters()).is_cuda:
super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn"))
else:
super().__init__(type_=ChildType.PROCESS)
self._model = model
self._send_callback_loop = None
self._send_callbacks = {}
self._model_worker = ModelWorker(self._model)

def start(self):
if not self._running:
self._model.share_memory()
self.register(self._model_worker)
self.start_link()
self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True)
self._send_callback_loop.start()

def shutdown(self, timeout: Optional[float] = None) -> None:
super().shutdown(timeout)
self._send_callback_loop = None
self._send_callbacks = {}

def _loop_send_callback(self):
while True:
payload = self.recv(ignore_err=True)
if payload.err:
logging.warning("Got error when loading data: {}".format(payload.err))
if payload.req_id in self._send_callbacks:
del self._send_callbacks[payload.req_id]
else:
if payload.req_id in self._send_callbacks:
callback = self._send_callbacks.pop(payload.req_id)
callback(payload.data)

def load(self, storage: Storage) -> object:
"""
Overview:
Load model synchronously.
Arguments:
- storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage.
Returns:
- object (:obj:): The loaded model.
"""
return storage.load()

@abstractmethod
def save(self, callback: Callable) -> Storage:
"""
Overview:
Save model asynchronously.
Arguments:
- callback (:obj:`Callable`): The callback function after saving model.
Returns:
- storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned.
"""
raise NotImplementedError


class FileModelLoader(ModelLoader):

def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None:
"""
Overview:
Model loader using files as storage media.
Arguments:
- model (:obj:`torch.nn.Module`): Torch module.
- dirname (:obj:`str`): The directory for saving files.
- ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \
files that do not time out when the process is stopped are not cleaned up \
(to avoid errors when other processes read the file), so you may need to \
clean up the remaining files manually
"""
super().__init__(model)
self._dirname = dirname
self._ttl = ttl
self._files = []
self._cleanup_thread = None

def _start_cleanup(self):
"""
Overview:
Start a cleanup thread to clean up files that are taking up too much time on the disk.
"""
if self._cleanup_thread is None:
self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True)
self._cleanup_thread.start()

def shutdown(self, timeout: Optional[float] = None) -> None:
super().shutdown(timeout)
self._cleanup_thread = None

def _loop_cleanup(self):
while True:
if len(self._files) == 0 or time() - self._files[0][0] < self._ttl:
sleep(1)
continue
_, file_path = self._files.pop(0)
if path.exists(file_path):
os.remove(file_path)

def save(self, callback: Callable) -> FileModelStorage:
if not self._running:
logging.warning("Please start model loader before saving model.")
return
if not path.exists(self._dirname):
os.mkdir(self._dirname)
file_path = "model_{}.pth.tar".format(uuid.uuid1())
file_path = path.join(self._dirname, file_path)
model_storage = FileModelStorage(file_path)
payload = SendPayload(proc_id=0, method="save", args=[model_storage])
self.send(payload)

def clean_callback(storage: Storage):
self._files.append([time(), file_path])
callback(storage)

self._send_callbacks[payload.req_id] = clean_callback
self._start_cleanup()
return model_storage
133 changes: 133 additions & 0 deletions ding/data/shm_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import Any, Optional, Union, Tuple, Dict
from multiprocessing import Array
import ctypes
import numpy as np
import torch

_NTYPE_TO_CTYPE = {
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
np.uint16: ctypes.c_uint16,
np.uint32: ctypes.c_uint32,
np.uint64: ctypes.c_uint64,
np.int8: ctypes.c_int8,
np.int16: ctypes.c_int16,
np.int32: ctypes.c_int32,
np.int64: ctypes.c_int64,
np.float32: ctypes.c_float,
np.float64: ctypes.c_double,
}


class ShmBuffer():
"""
Overview:
Shared memory buffer to store numpy array.
"""

def __init__(
self,
dtype: Union[type, np.dtype],
shape: Tuple[int],
copy_on_get: bool = True,
ctype: Optional[type] = None
) -> None:
"""
Overview:
Initialize the buffer.
Arguments:
- dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer.
- shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer.
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
- ctype (:obj:`Optional[type]`): Origin class type, e.g. np.ndarray, torch.Tensor.
"""
if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype
dtype = dtype.type
self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape)))
self.dtype = dtype
self.shape = shape
self.copy_on_get = copy_on_get
self.ctype = ctype

def fill(self, src_arr: np.ndarray) -> None:
"""
Overview:
Fill the shared memory buffer with a numpy array. (Replace the original one.)
Arguments:
- src_arr (:obj:`np.ndarray`): array to fill the buffer.
"""
assert isinstance(src_arr, np.ndarray), type(src_arr)
# for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten
# for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten
# so we reshape dst_arr rather than flatten src_arr
dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
np.copyto(dst_arr, src_arr)

def get(self) -> np.ndarray:
"""
Overview:
Get the array stored in the buffer.
Return:
- data (:obj:`np.ndarray`): A copy of the data stored in the buffer.
"""
data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
if self.copy_on_get:
data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory
if self.ctype is torch.Tensor:
data = torch.from_numpy(data)
return data


class ShmBufferContainer(object):
"""
Overview:
Support multiple shared memory buffers. Each key-value is name-buffer.
"""

def __init__(
self,
dtype: Union[Dict[Any, type], type, np.dtype],
shape: Union[Dict[Any, tuple], tuple],
copy_on_get: bool = True
) -> None:
"""
Overview:
Initialize the buffer container.
Arguments:
- dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer.
- shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \
multiple buffers; If `tuple`, use single buffer.
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
"""
if isinstance(shape, dict):
self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()}
elif isinstance(shape, (tuple, list)):
self._data = ShmBuffer(dtype, shape, copy_on_get)
else:
raise RuntimeError("not support shape: {}".format(shape))
self._shape = shape

def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None:
"""
Overview:
Fill the one or many shared memory buffer.
Arguments:
- src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer.
"""
if isinstance(self._shape, dict):
for k in self._shape.keys():
self._data[k].fill(src_arr[k])
elif isinstance(self._shape, (tuple, list)):
self._data.fill(src_arr)

def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]:
"""
Overview:
Get the one or many arrays stored in the buffer.
Return:
- data (:obj:`np.ndarray`): The array(s) stored in the buffer.
"""
if isinstance(self._shape, dict):
return {k: self._data[k].get() for k in self._shape.keys()}
elif isinstance(self._shape, (tuple, list)):
return self._data.get()
2 changes: 2 additions & 0 deletions ding/data/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .storage import Storage
from .file import FileStorage, FileModelStorage
25 changes: 25 additions & 0 deletions ding/data/storage/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any
from ding.data.storage import Storage
import pickle

from ding.utils.file_helper import read_file, save_file


class FileStorage(Storage):

def save(self, data: Any) -> None:
with open(self.path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

def load(self) -> Any:
with open(self.path, "rb") as f:
return pickle.load(f)


class FileModelStorage(Storage):

def save(self, state_dict: object) -> None:
save_file(self.path, state_dict)

def load(self) -> object:
return read_file(self.path)
16 changes: 16 additions & 0 deletions ding/data/storage/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from typing import Any


class Storage(ABC):

def __init__(self, path: str) -> None:
self.path = path

@abstractmethod
def save(self, data: Any) -> None:
raise NotImplementedError

@abstractmethod
def load(self) -> Any:
raise NotImplementedError
18 changes: 18 additions & 0 deletions ding/data/storage/tests/test_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import tempfile
import pytest
import os
from os import path
from ding.data.storage import FileStorage


@pytest.mark.unittest
def test_file_storage():
path_ = path.join(tempfile.gettempdir(), "test_storage.txt")
try:
storage = FileStorage(path=path_)
storage.save("test")
content = storage.load()
assert content == "test"
finally:
if path.exists(path_):
os.remove(path_)
Loading