-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
41 lines (32 loc) · 1.26 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from torch.utils.data import Dataset
import numpy as np
import torch
from PIL import Image
import os
from absl import logging
class MedSyn(Dataset):
def __init__(self, img_path, transform):
super().__init__()
self.total_files = _list_image_files_recursively(img_path)
self.class_names = [os.path.basename(path).split("_")[0] for path in self.total_files]
logging.info('Prepare train dataset done')
self.targets = [int(x) for x in self.class_names]
self.transform = transform
def __getitem__(self, idx):
path = self.total_files[idx]
img = Image.open(path)
img = img.convert('RGB')
img = torch.from_numpy(np.array(img)/255).type(torch.FloatTensor)
return img.permute((2, 0, 1)), self.targets[idx]
def __len__(self):
return len(self.total_files)
def _list_image_files_recursively(data_dir):
results = []
for entry in sorted(os.listdir(data_dir)):
full_path = os.path.join(data_dir, entry)
ext = entry.split(".")[-1]
if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "npy"]:
results.append(full_path)
elif os.listdir(full_path):
results.extend(_list_image_files_recursively(full_path))
return results