Skip to content

Commit

Permalink
Add mistnet example and fix some bugs when running mistnet
Browse files Browse the repository at this point in the history
Signed-off-by: XinYao1994 <xyao@cs.hku.hk>
  • Loading branch information
XinYao1994 committed Jul 7, 2021
1 parent 706a8f3 commit 84c45a8
Show file tree
Hide file tree
Showing 18 changed files with 439 additions and 31 deletions.
18 changes: 18 additions & 0 deletions examples/federated-learning-mistnet-client.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
FROM tensorflow/tensorflow:1.15.4

RUN apt update \
&& apt install -y libgl1-mesa-glx

COPY ./lib/requirements.txt /home

RUN pip install -r /home/requirements.txt

ENV PYTHONPATH "/home/lib:/home/plato"

COPY ./lib /home/lib
COPY ./plato /home/plato

WORKDIR /home/work
COPY examples/federated_learning/mistnet/ /home/work/

ENTRYPOINT ["python", "train_worker.py"]
18 changes: 18 additions & 0 deletions examples/federated-learning-mistnet.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
FROM tensorflow/tensorflow:1.15.4

RUN apt update \
&& apt install -y libgl1-mesa-glx

COPY ./lib/requirements.txt /home

RUN pip install -r /home/requirements.txt

ENV PYTHONPATH "/home/lib:/home/plato"

COPY ./lib /home/lib
COPY ./plato /home/plato

WORKDIR /home/work
COPY examples/federated_learning/mistnet/ /home/work/

CMD ["/bin/sh", "-c", "ulimit -n 50000; python agg_worker.py"]
8 changes: 8 additions & 0 deletions examples/federated_learning/mistnet/agg_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import sedna.service.server
from sedna.service.server import MistnetServer
import os
if __name__ == '__main__':
# run_server()
server = MistnetServer()
server.run()

68 changes: 68 additions & 0 deletions examples/federated_learning/mistnet/client.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
clients:
# Type
type: mistnet

# The total number of clients
total_clients: 1

# The number of clients selected in each round
per_round: 1

# Should the clients compute test accuracy locally?
do_test: false

server:
type: mistnet

address: 127.0.0.1
port: 7363

data:
# The training and testing dataset
datasource: MNIST

# Where the dataset is located
data_path: ./data

# Number of samples in each partition
partition_size: 20000

# Fixed random seed
random_seed: 1

# IID, biased, or sharded?
sampler: iid

trainer:
# The type of the trainer
type: basic

# The maximum number of training rounds
rounds: 1

# Whether the training should use multiple GPUs if available
parallelized: false

# The maximum number of clients running concurrently
max_concurrency: 3

# The target accuracy
target_accuracy: 0.95

# Number of epoches for local training in each communication round
epochs: 10
batch_size: 32
optimizer: SGD
learning_rate: 0.01
momentum: 0.9
weight_decay: 0.0

# The machine learning model
model_name: lenet5

algorithm:
# Aggregation algorithm
type: mistnet

cut_layer: relu3
epsilon: null
68 changes: 68 additions & 0 deletions examples/federated_learning/mistnet/server.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
clients:
# Type
type: mistnet

# The total number of clients
total_clients: 0

# The number of clients selected in each round
per_round: 1

# Should the clients compute test accuracy locally?
do_test: false

server:
type: mistnet

address: 127.0.0.1
port: 7363

data:
# The training and testing dataset
datasource: MNIST

# Where the dataset is located
data_path: ./data

# Number of samples in each partition
partition_size: 20000

# Fixed random seed
random_seed: 1

# IID, biased, or sharded?
sampler: iid

trainer:
# The type of the trainer
type: basic

# The maximum number of training rounds
rounds: 1

# Whether the training should use multiple GPUs if available
parallelized: false

# The maximum number of clients running concurrently
max_concurrency: 3

# The target accuracy
target_accuracy: 0.95

# Number of epoches for local training in each communication round
epochs: 10
batch_size: 32
optimizer: SGD
learning_rate: 0.01
momentum: 0.9
weight_decay: 0.0

# The machine learning model
model_name: lenet5

algorithm:
# Aggregation algorithm
type: mistnet

cut_layer: relu3
epsilon: null
9 changes: 9 additions & 0 deletions examples/federated_learning/mistnet/train_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import sedna.core.federated_learning
from sedna.core.federated_learning import MistWorker
import asyncio

if __name__ == '__main__':
client = MistWorker()
client.configure()
loop = asyncio.get_event_loop()
loop.run_until_complete(client.start_client())
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ def run_server():
)
server.start()


import sedna.service.server
from sedna.service.server import PlatoServer
from torch import nn

if __name__ == '__main__':
# run_server()

model = nn.Sequential(
nn.Linear(28 * 28, 128),
nn.ReLU(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
clients:
# Type
type: simple

# The total number of clients, server
total_clients: 0

# The number of clients selected in each round
per_round: 1

# Should the clients compute test accuracy locally?
do_test: true

server:
address: 127.0.0.1
port: 7363

data:
# The training and testing dataset
datasource: MNIST

# Where the dataset is located
data_path: ./data

# Number of samples in each partition
partition_size: 20000

# IID or non-IID?
sampler: iid

# The random seed for sampling data
random_seed: 1

trainer:
# The type of the trainer
type: basic

# The maximum number of training rounds
rounds: 5

# Whether the training should use multiple GPUs if available
parallelized: false

# The maximum number of clients running concurrently
max_concurrency: 1

# The target accuracy
target_accuracy: 0.94

# Number of epoches for local training in each communication round
epochs: 5
batch_size: 32
optimizer: SGD
learning_rate: 0.01
momentum: 0.9
weight_decay: 0.0

# The machine learning model
model_name: lenet5

algorithm:
# Aggregation algorithm
type: fedavg
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ clients:
total_clients: 0

# The number of clients selected in each round
per_round: 1
per_round: 2

# Should the clients compute test accuracy locally?
do_test: true
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
clients:
# Type
type: simple

# The total number of clients
total_clients: 3

# The number of clients selected in each round
per_round: 1

# Should the clients compute test accuracy locally?
do_test: true

server:
address: 127.0.0.1
port: 7363

data:
# The training and testing dataset
datasource: MNIST

# Where the dataset is located
data_path: ./data

# Number of samples in each partition
partition_size: 20000

# IID or non-IID?
sampler: iid

# The random seed for sampling data
random_seed: 1

trainer:
# The type of the trainer
type: basic

# The maximum number of training rounds
rounds: 5

# Whether the training should use multiple GPUs if available
parallelized: false

# The maximum number of clients running concurrently
max_concurrency: 1

# The target accuracy
target_accuracy: 0.94

# Number of epoches for local training in each communication round
epochs: 5
batch_size: 32
optimizer: SGD
learning_rate: 0.01
momentum: 0.9
weight_decay: 0.0

# The machine learning model
model_name: lenet5

algorithm:
# Aggregation algorithm
type: fedavg
Loading

0 comments on commit 84c45a8

Please sign in to comment.