From a904142f3df3a65ffaaded3d3e37dfbc208aa5d7 Mon Sep 17 00:00:00 2001 From: ysqyang Date: Mon, 27 Dec 2021 15:34:38 +0800 Subject: [PATCH] added temporary test files for review purposes --- maro/rl_v3/workflows/test/main_test.py | 52 ++++++++++++++++++++ maro/rl_v3/workflows/test/ops_creator.py | 60 ++++++++++++++++++++++++ maro/rl_v3/workflows/test/test_worker.py | 6 +++ 3 files changed, 118 insertions(+) create mode 100644 maro/rl_v3/workflows/test/main_test.py create mode 100644 maro/rl_v3/workflows/test/ops_creator.py create mode 100644 maro/rl_v3/workflows/test/test_worker.py diff --git a/maro/rl_v3/workflows/test/main_test.py b/maro/rl_v3/workflows/test/main_test.py new file mode 100644 index 000000000..8c4b48f4b --- /dev/null +++ b/maro/rl_v3/workflows/test/main_test.py @@ -0,0 +1,52 @@ +import asyncio +import time + +import torch + +from maro.rl_v3.distributed.remote import RemoteOps +from maro.rl_v3.workflows.test.ops_creator import ops_creator + +DISPATCHER_ADDRESS = ("127.0.0.1", 10000) + + +class SingleTrainer: + def __init__(self, name, ops_creator, dispatcher_address=None) -> None: + self._name = name + ops_name = [name for name in ops_creator if name.startswith(f"{self._name}.")].pop() + if dispatcher_address: + self.remote_ops = RemoteOps(ops_name, DISPATCHER_ADDRESS) + else: + self._ops = ops_creator["ops"]("ops") + + async def train(self, X, y): + await asyncio.gather(self.remote_ops.step(X, y)) + + +class MultiTrainer: + def __init__(self, name, ops_creator, dispatcher_address=None) -> None: + self._name = name + ops_names = [name for name in ops_creator if name.startswith(f"{self._name}.")] + if dispatcher_address: + self._ops_list = [RemoteOps(ops_name, DISPATCHER_ADDRESS) for ops_name in ops_names] + else: + self._ops_list = [ops_creator[ops_name](ops_name) for ops_name in ops_names] + + async def train(self, X, y): + for _ in range(3): + await asyncio.gather(*[ops.step(X, y) for ops in self._ops_list]) + + +single_trainer = SingleTrainer("single", ops_creator, dispatcher_address=DISPATCHER_ADDRESS) +multi_trainer = MultiTrainer("multi", ops_creator, dispatcher_address=DISPATCHER_ADDRESS) + +X, y = torch.rand(10, 5), torch.randint(0, 2, (10,)) +X2, y2 = torch.rand(15, 7), torch.randint(0, 3, (15,)) + +t0 = time.time() +async def train(): + return await asyncio.gather( + single_trainer.train(X, y), + multi_trainer.train(X2, y2) + ) +asyncio.run(train()) +print(f"total_time: {time.time() - t0}") diff --git a/maro/rl_v3/workflows/test/ops_creator.py b/maro/rl_v3/workflows/test/ops_creator.py new file mode 100644 index 000000000..0b0a48ed2 --- /dev/null +++ b/maro/rl_v3/workflows/test/ops_creator.py @@ -0,0 +1,60 @@ +import time + +from torch import nn +from torch.optim import Adam +from torch.nn import NLLLoss + + +SLEEP = 5 + + +class Model(nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self._fc1 = nn.Linear(input_dim, 3) + self._relu = nn.ReLU() + self._fc2 = nn.Linear(3, output_dim) + self._net = nn.Sequential(self._fc1, self._relu, self._fc2) + + def forward(self, inputs): + return self._net(inputs) + + +class TrainOps: + def __init__(self, name, model): + self.name = name + self._loss_fn = NLLLoss() + self._model = model + self._optim = Adam(self._model.parameters(), lr=0.001) + + def step(self, X, y): + y_pred = self._model(X) + loss = self._loss_fn(y_pred, y) + self._optim.zero_grad() + loss.backward() + self._optim.step() + time.sleep(SLEEP) + + +class TrainOps2: + def __init__(self, name, model): + self.name = name + self._loss_fn = NLLLoss() + self._model = model + self._optim = Adam(self._model.parameters(), lr=0.001) + + def step(self, X, y): + y_pred = self._model(X) + loss = self._loss_fn(y_pred, y) + self._optim.zero_grad() + loss.backward() + self._optim.step() + time.sleep(SLEEP) + + +ops_creator = { + "single.ops": lambda name: TrainOps(name, Model(5, 2)), + "multi.ops0": lambda name: TrainOps2(name, Model(7, 3)), + "multi.ops1": lambda name: TrainOps2(name, Model(7, 3)), + "multi.ops2": lambda name: TrainOps2(name, Model(7, 3)) +} diff --git a/maro/rl_v3/workflows/test/test_worker.py b/maro/rl_v3/workflows/test/test_worker.py new file mode 100644 index 000000000..a08438029 --- /dev/null +++ b/maro/rl_v3/workflows/test/test_worker.py @@ -0,0 +1,6 @@ +import os + +from maro.rl_v3.distributed.train_ops_worker import TrainOpsWorker +from maro.rl_v3.workflows.test.ops_creator import ops_creator + +TrainOpsWorker(os.getenv("ID"), ops_creator, "127.0.0.1").start()