Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CelebA Dataset to Plato #164

Merged
merged 8 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/CINIC10/fedavg_vgg16.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ data:
data_path: ./data/CINIC-10

#
download_url: https://iqua.ece.toronto.edu/~bli/CINIC-10.tar.gz
download_url: http://iqua.ece.toronto.edu/baochun/CINIC-10.tar.gz
# Number of samples in each partition
partition_size: 20000

Expand Down
75 changes: 75 additions & 0 deletions configs/CelebA/fedavg_resnet18.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
clients:
# Type
type: simple

# The total number of clients
total_clients: 3

# The number of clients selected in each round
per_round: 1

# Should the clients compute test accuracy locally?
do_test: false

server:
address: 127.0.0.1
port: 8000

data:
# The training and testing dataset
datasource: CelebA

# Only add face identity as labels for training
celeba_targets:
# For ResNet, do not set <attr> to True since it does not match the expected output of ResNet
attr: false
identity: true

# Number of identity in CelebA
num_classes: 10178

# Where the dataset is located
data_path: ./data

# Number of samples in each partition
partition_size: 20000

# IID or non-IID?
sampler: noniid

# The concentration parameter for the Dirichlet distribution
concentration: 0.5

# The random seed for sampling data
random_seed: 1

trainer:
# The type of the trainer
type: basic

# The maximum number of training rounds
rounds: 5

# Whether the training should use multiple GPUs if available
parallelized: true

# The maximum number of clients running concurrently
max_concurrency: 3

# The target accuracy
target_accuracy: 0.94

# Number of epoches for local training in each communication round
epochs: 5
batch_size: 32
optimizer: SGD
learning_rate: 0.01
momentum: 0.9
weight_decay: 0.0

# The machine learning model
model_name: resnet_18

algorithm:
# Aggregation algorithm
type: fedavg
109 changes: 109 additions & 0 deletions plato/datasources/celeba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
The CelebA dataset from the torchvision package.
"""
import logging
import os
from typing import Callable, List, Optional, Union

import torch
from torchvision import datasets, transforms

from plato.config import Config
from plato.datasources import base


class CelebA(datasets.CelebA):
"""
A wrapper class of torchvision's CelebA dataset class
to add <targets> and <classes> attributes as celebrity
identity, which is used for non-IID samplers.
"""

def __init__(self,
root: str,
split: str = "train",
target_type: Union[List[str], str] = "attr",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False) -> None:
super().__init__(root, split, target_type, transform, target_transform,
download)
self.targets = self.identity.flatten().tolist()
self.classes = [f'Celebrity #{i}' for i in range(10177 + 1)]


class DataSource(base.DataSource):
"""The CelebA dataset."""

def __init__(self):
super().__init__()
_path = Config().data.data_path

if not os.path.exists(os.path.join(_path, 'celeba')):
celeba_url = 'http://iqua.ece.toronto.edu/baochun/celeba.tar.gz'
DataSource.download(celeba_url, _path)
else:
logging.info("CelebA data already decompressed under %s",
os.path.join(_path, 'celeba'))

target_types = []
if hasattr(Config().data, "celeba_targets"):
targets = Config().data.celeba_targets
if hasattr(targets, "attr") and targets.attr:
target_types.append("attr")
if hasattr(targets, "identity") and targets.identity:
target_types.append("identity")
else:
target_types = ['attr', 'identity']

image_size = 32
_transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

self.trainset = CelebA(root=_path,
split='train',
target_type=target_types,
download=False,
transform=_transform,
target_transform=DataSource._target_transform)
self.testset = CelebA(root=_path,
split='test',
target_type=target_types,
download=False,
transform=_transform,
target_transform=DataSource._target_transform)

@staticmethod
def _target_transform(label):
"""
Output labels are in a tuple of tensors if specified more
than one target types, so we need to convert the tuple to
tensors. Here, we just merge two tensors by adding identity
as the 41st attribute
"""
if isinstance(label, tuple):
if len(label) == 1:
return label[0]
elif len(label) == 2:
attr, identity = label
return torch.cat((attr.reshape([
-1,
]), identity.reshape([
-1,
])))
else:
return label

@staticmethod
def input_shape():
return [162770, 3, 32, 32]

def num_train_examples(self):
return 162770

def num_test_examples(self):
return 19962
2 changes: 1 addition & 1 deletion plato/datasources/cinic10.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self):
"Downloading the CINIC-10 dataset. This may take a while.")
url = Config().data.download_url if hasattr(
Config().data, 'download_url'
) else 'https://iqua.ece.toronto.edu/~bli/CINIC-10.tar.gz'
) else 'http://iqua.ece.toronto.edu/baochun/CINIC-10.tar.gz'
DataSource.download(url, _path)

_transform = transforms.Compose([
Expand Down
4 changes: 3 additions & 1 deletion plato/datasources/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
else:
from plato.datasources import (mnist, fashion_mnist, emnist, cifar10,
cinic10, huggingface, pascal_voc,
tiny_imagenet, femnist, feature, qoenflx)
tiny_imagenet, femnist, feature, qoenflx,
celeba)

registered_datasources = OrderedDict([
('MNIST', mnist),
Expand All @@ -42,6 +43,7 @@
('TinyImageNet', tiny_imagenet),
('Feature', feature),
('QoENFLX', qoenflx),
('CelebA', celeba),
])

registered_partitioned_datasources = OrderedDict([('FEMNIST', femnist)])
Expand Down
16 changes: 11 additions & 5 deletions plato/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch.nn as nn
import torch.nn.functional as F

from plato.config import Config


class BasicBlock(nn.Module):
expansion = 1
Expand Down Expand Up @@ -176,13 +178,17 @@ def get_model(model_type):

resnet_type = int(model_type.split('_')[1])

num_classes = 10
if hasattr(Config().data, 'num_classes'):
num_classes = Config().data.num_classes

if resnet_type == 18:
return Model(BasicBlock, [2, 2, 2, 2])
return Model(BasicBlock, [2, 2, 2, 2], num_classes)
elif resnet_type == 34:
return Model(BasicBlock, [3, 4, 6, 3])
return Model(BasicBlock, [3, 4, 6, 3], num_classes)
elif resnet_type == 50:
return Model(Bottleneck, [3, 4, 6, 3])
return Model(Bottleneck, [3, 4, 6, 3], num_classes)
elif resnet_type == 101:
return Model(Bottleneck, [3, 4, 23, 3])
return Model(Bottleneck, [3, 4, 23, 3], num_classes)
elif resnet_type == 152:
return Model(Bottleneck, [3, 8, 36, 3])
return Model(Bottleneck, [3, 8, 36, 3], num_classes)