-
Notifications
You must be signed in to change notification settings - Fork 6
/
dataset.py
129 lines (105 loc) · 4.29 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch
import torchvision
from torchvision.transforms import Compose
import numpy as np
import cv2 as cv
import os
from random import sample
from torchvision.transforms import functional as F
def img_to_tensor(img):
tensor = torch.from_numpy(img.transpose((2, 0, 1)))
return tensor
def to_monochrome(x):
# x_ = x.convert('L')
x_ = np.array(x).astype(np.float32) # convert image to monochrome
return x_
def to_tensor(x):
x_ = np.expand_dims(x, axis=0)
x_ = torch.from_numpy(x_)
return x_
ImageToTensor = torchvision.transforms.ToTensor
def custom_blur_demo(image):
kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], np.float32) #锐化
dst = cv.filter2D(image, -1, kernel=kernel)
return dst
class SasDataset(Dataset):
def __init__(self, root, mode='train', is_ndvi=False):
self.root = root
self.mode = mode
self.mean_bgr = [104.00699, 116.66877, 122.67892]
self.is_ndvi = is_ndvi
self.imgList = sorted(img for img in os.listdir(self.root))
self.imgTransforms = Compose([img_to_tensor])
self.maskTransforms = Compose([
torchvision.transforms.Lambda(to_monochrome),
torchvision.transforms.Lambda(to_tensor),
])
def __getitem__(self, idx):
imgPath = os.path.join(self.root, self.imgList[idx])
img = cv.imread(imgPath, cv.IMREAD_COLOR)
img = np.array(img, dtype=np.float32)
# if self.rgb:
# img = img[:, :, ::-1] # RGB->BGR
img /= 255.
img = img.transpose((2, 0, 1))
img = torch.from_numpy(img.copy()).float()
imgName = os.path.split(imgPath)[-1].split('.')[0]
if self.mode == 'test':
batch_data = {'img': img, 'file_name': imgName}
return batch_data
def __len__(self):
return len(self.imgList)
def build_loader(cfg):
# Get correct indices
num_train = len(sorted(img for img in os.listdir(cfg.trainData)))
indices = list(range(num_train))
indices = sample(indices, len(indices))
split = int(np.floor(0.15 * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
# set sup datasets
train_dataset = SasDataset(cfg.trainData, mode='train')
val_dataset = SasDataset(cfg.trainData, mode='valid')
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, sampler=train_sampler,
num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
valid_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, sampler=valid_sampler,
num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
return train_loader, valid_loader
if __name__=='__main__':
import matplotlib.pyplot as plt
from config_eval import Config
cfg = Config()
train_loader, valid_loader = build_loader(cfg)
for x, y in train_loader:
x = x.numpy() * 255
y = y.numpy()
plt.subplot(121)
plt.imshow(x)
plt.subplot(122)
plt.imshow(y)
plt.show()
# DECAY_POWER = 3
# SHAPE = 512
# LAMBDA = 0.5
# NUM_IMAGES = 1
# dataIter = iter(valid_loader)
# batch, target = next(dataIter)
# batch1 = batch[:NUM_IMAGES]
# batch2 = batch[NUM_IMAGES:]
#
# soft_masks_np = [make_low_freq_image(DECAY_POWER, [SHAPE, SHAPE]) for _ in range(NUM_IMAGES)]
# soft_masks = torch.from_numpy(np.stack(soft_masks_np, axis=0)).float().repeat(1, 3, 1, 1)
#
# masks_np = [binarise_mask(mask, LAMBDA, [SHAPE, SHAPE]) for mask in soft_masks_np]
# masks = torch.from_numpy(np.stack(masks_np, axis=0)).float().repeat(1, 3, 1, 1)
#
# mix = batch1 * masks + batch2 * (1 - masks)
# image = torch.cat((soft_masks, masks, batch1, batch2, mix), 0)
# save_image(image, 'fmix_example.png', nrow=NUM_IMAGES, pad_value=1)
#
# plt.figure(figsize=(NUM_IMAGES, 5))
# plt.imshow(make_grid(image, nrow=NUM_IMAGES, pad_value=5).permute(1, 2, 0).numpy())
# plt.show()