Skip to content

Commit

Permalink
mostly use torchvision CelebA
Browse files Browse the repository at this point in the history
  • Loading branch information
djsutherland committed Jun 10, 2019
1 parent 5259eea commit ba6e558
Showing 1 changed file with 14 additions and 58 deletions.
72 changes: 14 additions & 58 deletions igms/datasets.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,21 @@
import os

import numpy as np
import pandas as pd
import torch
from torchvision.datasets.folder import default_loader
from torchvision.datasets import CelebA as tv_CelebA
from torchvision import transforms


class CelebA(torch.utils.data.Dataset):
def __init__(
self,
path,
split="train",
transform=None,
target_transform=None,
attr_query=None,
):
self.path = path
self.split = split
self.transform = transform
self.target_transform = target_transform
super().__init__()
class CelebA(tv_CelebA):
# https://github.com/pytorch/vision/pull/1008 does this
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

s_pth = os.path.join(path, "list_eval_partition.txt")
splits = pd.read_csv(s_pth, delim_whitespace=True, header=None, index_col=1)[0]

a_pth = os.path.join(path, "Anno", "list_attr_celeba.txt")
attr = pd.read_csv(a_pth, delim_whitespace=True, header=1, index_col=0) == 1

if split == "all":
self.filenames = list(splits)
mask = np.full(splits.shape[0], True, dtype=bool)
else:
s = {"train": 0, "valid": 1, "test": 2}[split]
self.filenames = list(splits[s])
mask = splits.index == s

if attr_query is not None:
mask = mask & np.asarray(attr.eval(attr_query))

self.filenames = splits[mask].values
self.attr = torch.as_tensor(attr[mask].values)
self.attr_names = list(attr.columns)

def __getitem__(self, i):
X = default_loader(
os.path.join(self.path, "img_align_celeba", self.filenames[i])
)
if self.transform is not None:
X = self.transform(X)

y = self.attr[i, :]
if self.target_transform is not None:
y = self.target_transform(y)

return X, y

def __len__(self):
return self.filenames.shape[0]
with open(
os.path.join(self.root, self.base_folder, "list_attr_celeba.txt")
) as f:
_ = f.readline()
self.attr_names = f.readline().split()

@staticmethod
def default_transform(out_size=64, max_crop=160, min_crop=140):
Expand All @@ -78,15 +36,13 @@ def default_transform(out_size=64, max_crop=160, min_crop=140):
def get_dataset(spec, out_size, **kwargs):
parts = spec.split(":")
kind = parts.pop(0).lower()
kwargs["root"] = os.path.expanduser(parts.pop(0) if parts else "") or "data"
if kind == "celeba":
kwargs["path"] = (
os.path.expanduser(parts.pop(0) if parts else "") or "data/celebA"
)
if parts:
kwargs["attr_query"] = parts.pop(0) or None
if parts:
kwargs["split"] = parts.pop(0)
assert not parts
return CelebA(**kwargs, transform=CelebA.default_transform(out_size=out_size))
if "transform" not in kwargs:
kwargs["transform"] = CelebA.default_transform(out_size=out_size)
return CelebA(**kwargs)
else:
raise ValueError(f"Unknown dataset {kind}")

0 comments on commit ba6e558

Please sign in to comment.