-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added temporary test files for review purposes
- Loading branch information
ysqyang
committed
Dec 27, 2021
1 parent
f293980
commit a904142
Showing
3 changed files
with
118 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import asyncio | ||
import time | ||
|
||
import torch | ||
|
||
from maro.rl_v3.distributed.remote import RemoteOps | ||
from maro.rl_v3.workflows.test.ops_creator import ops_creator | ||
|
||
DISPATCHER_ADDRESS = ("127.0.0.1", 10000) | ||
|
||
|
||
class SingleTrainer: | ||
def __init__(self, name, ops_creator, dispatcher_address=None) -> None: | ||
self._name = name | ||
ops_name = [name for name in ops_creator if name.startswith(f"{self._name}.")].pop() | ||
if dispatcher_address: | ||
self.remote_ops = RemoteOps(ops_name, DISPATCHER_ADDRESS) | ||
else: | ||
self._ops = ops_creator["ops"]("ops") | ||
|
||
async def train(self, X, y): | ||
await asyncio.gather(self.remote_ops.step(X, y)) | ||
|
||
|
||
class MultiTrainer: | ||
def __init__(self, name, ops_creator, dispatcher_address=None) -> None: | ||
self._name = name | ||
ops_names = [name for name in ops_creator if name.startswith(f"{self._name}.")] | ||
if dispatcher_address: | ||
self._ops_list = [RemoteOps(ops_name, DISPATCHER_ADDRESS) for ops_name in ops_names] | ||
else: | ||
self._ops_list = [ops_creator[ops_name](ops_name) for ops_name in ops_names] | ||
|
||
async def train(self, X, y): | ||
for _ in range(3): | ||
await asyncio.gather(*[ops.step(X, y) for ops in self._ops_list]) | ||
|
||
|
||
single_trainer = SingleTrainer("single", ops_creator, dispatcher_address=DISPATCHER_ADDRESS) | ||
multi_trainer = MultiTrainer("multi", ops_creator, dispatcher_address=DISPATCHER_ADDRESS) | ||
|
||
X, y = torch.rand(10, 5), torch.randint(0, 2, (10,)) | ||
X2, y2 = torch.rand(15, 7), torch.randint(0, 3, (15,)) | ||
|
||
t0 = time.time() | ||
async def train(): | ||
return await asyncio.gather( | ||
single_trainer.train(X, y), | ||
multi_trainer.train(X2, y2) | ||
) | ||
asyncio.run(train()) | ||
print(f"total_time: {time.time() - t0}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import time | ||
|
||
from torch import nn | ||
from torch.optim import Adam | ||
from torch.nn import NLLLoss | ||
|
||
|
||
SLEEP = 5 | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self, input_dim, output_dim): | ||
super().__init__() | ||
self._fc1 = nn.Linear(input_dim, 3) | ||
self._relu = nn.ReLU() | ||
self._fc2 = nn.Linear(3, output_dim) | ||
self._net = nn.Sequential(self._fc1, self._relu, self._fc2) | ||
|
||
def forward(self, inputs): | ||
return self._net(inputs) | ||
|
||
|
||
class TrainOps: | ||
def __init__(self, name, model): | ||
self.name = name | ||
self._loss_fn = NLLLoss() | ||
self._model = model | ||
self._optim = Adam(self._model.parameters(), lr=0.001) | ||
|
||
def step(self, X, y): | ||
y_pred = self._model(X) | ||
loss = self._loss_fn(y_pred, y) | ||
self._optim.zero_grad() | ||
loss.backward() | ||
self._optim.step() | ||
time.sleep(SLEEP) | ||
|
||
|
||
class TrainOps2: | ||
def __init__(self, name, model): | ||
self.name = name | ||
self._loss_fn = NLLLoss() | ||
self._model = model | ||
self._optim = Adam(self._model.parameters(), lr=0.001) | ||
|
||
def step(self, X, y): | ||
y_pred = self._model(X) | ||
loss = self._loss_fn(y_pred, y) | ||
self._optim.zero_grad() | ||
loss.backward() | ||
self._optim.step() | ||
time.sleep(SLEEP) | ||
|
||
|
||
ops_creator = { | ||
"single.ops": lambda name: TrainOps(name, Model(5, 2)), | ||
"multi.ops0": lambda name: TrainOps2(name, Model(7, 3)), | ||
"multi.ops1": lambda name: TrainOps2(name, Model(7, 3)), | ||
"multi.ops2": lambda name: TrainOps2(name, Model(7, 3)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import os | ||
|
||
from maro.rl_v3.distributed.train_ops_worker import TrainOpsWorker | ||
from maro.rl_v3.workflows.test.ops_creator import ops_creator | ||
|
||
TrainOpsWorker(os.getenv("ID"), ops_creator, "127.0.0.1").start() |