Skip to content

Commit

Permalink
added temporary test files for review purposes
Browse files Browse the repository at this point in the history
  • Loading branch information
ysqyang committed Dec 27, 2021
1 parent f293980 commit a904142
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
52 changes: 52 additions & 0 deletions maro/rl_v3/workflows/test/main_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import asyncio
import time

import torch

from maro.rl_v3.distributed.remote import RemoteOps
from maro.rl_v3.workflows.test.ops_creator import ops_creator

DISPATCHER_ADDRESS = ("127.0.0.1", 10000)


class SingleTrainer:
def __init__(self, name, ops_creator, dispatcher_address=None) -> None:
self._name = name
ops_name = [name for name in ops_creator if name.startswith(f"{self._name}.")].pop()
if dispatcher_address:
self.remote_ops = RemoteOps(ops_name, DISPATCHER_ADDRESS)
else:
self._ops = ops_creator["ops"]("ops")

async def train(self, X, y):
await asyncio.gather(self.remote_ops.step(X, y))


class MultiTrainer:
def __init__(self, name, ops_creator, dispatcher_address=None) -> None:
self._name = name
ops_names = [name for name in ops_creator if name.startswith(f"{self._name}.")]
if dispatcher_address:
self._ops_list = [RemoteOps(ops_name, DISPATCHER_ADDRESS) for ops_name in ops_names]
else:
self._ops_list = [ops_creator[ops_name](ops_name) for ops_name in ops_names]

async def train(self, X, y):
for _ in range(3):
await asyncio.gather(*[ops.step(X, y) for ops in self._ops_list])


single_trainer = SingleTrainer("single", ops_creator, dispatcher_address=DISPATCHER_ADDRESS)
multi_trainer = MultiTrainer("multi", ops_creator, dispatcher_address=DISPATCHER_ADDRESS)

X, y = torch.rand(10, 5), torch.randint(0, 2, (10,))
X2, y2 = torch.rand(15, 7), torch.randint(0, 3, (15,))

t0 = time.time()
async def train():
return await asyncio.gather(
single_trainer.train(X, y),
multi_trainer.train(X2, y2)
)
asyncio.run(train())
print(f"total_time: {time.time() - t0}")
60 changes: 60 additions & 0 deletions maro/rl_v3/workflows/test/ops_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import time

from torch import nn
from torch.optim import Adam
from torch.nn import NLLLoss


SLEEP = 5


class Model(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self._fc1 = nn.Linear(input_dim, 3)
self._relu = nn.ReLU()
self._fc2 = nn.Linear(3, output_dim)
self._net = nn.Sequential(self._fc1, self._relu, self._fc2)

def forward(self, inputs):
return self._net(inputs)


class TrainOps:
def __init__(self, name, model):
self.name = name
self._loss_fn = NLLLoss()
self._model = model
self._optim = Adam(self._model.parameters(), lr=0.001)

def step(self, X, y):
y_pred = self._model(X)
loss = self._loss_fn(y_pred, y)
self._optim.zero_grad()
loss.backward()
self._optim.step()
time.sleep(SLEEP)


class TrainOps2:
def __init__(self, name, model):
self.name = name
self._loss_fn = NLLLoss()
self._model = model
self._optim = Adam(self._model.parameters(), lr=0.001)

def step(self, X, y):
y_pred = self._model(X)
loss = self._loss_fn(y_pred, y)
self._optim.zero_grad()
loss.backward()
self._optim.step()
time.sleep(SLEEP)


ops_creator = {
"single.ops": lambda name: TrainOps(name, Model(5, 2)),
"multi.ops0": lambda name: TrainOps2(name, Model(7, 3)),
"multi.ops1": lambda name: TrainOps2(name, Model(7, 3)),
"multi.ops2": lambda name: TrainOps2(name, Model(7, 3))
}
6 changes: 6 additions & 0 deletions maro/rl_v3/workflows/test/test_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import os

from maro.rl_v3.distributed.train_ops_worker import TrainOpsWorker
from maro.rl_v3.workflows.test.ops_creator import ops_creator

TrainOpsWorker(os.getenv("ID"), ops_creator, "127.0.0.1").start()

0 comments on commit a904142

Please sign in to comment.