-
Notifications
You must be signed in to change notification settings - Fork 4
/
dataset.py
65 lines (49 loc) · 1.71 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch as t
from water_dataset import WaterDataset
from skimage import transform as sktsf
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
from torch.utils import data as data_
from torch.utils.data import Dataset
import torch.utils.data.distributed
from torchvision import transforms as tvtsf
import torchvision.datasets as datasets
import torchvision.models as models
# from data import util
import numpy as np
from config import opt
import os
import logging
class TrainDataset(Dataset):
def __init__(self, config, split='train'):
self.config = config
self.db = WaterDataset(config.data_dir, split=split)
def __getitem__(self, idx):
label, datas = self.db.get_example(idx)
label = t.from_numpy(np.array(label))
datas = np.array(datas)
datas = t.from_numpy(datas)
datas = datas.contiguous().view(1, -1)
# TODO: check whose stride is negative to fix this instead copy all
return label, datas
def __len__(self):
return len(self.db)
class TestDataset(Dataset):
def __init__(self, config, split='test'):
self.config = config
self.db = WaterDataset(config.data_dir, split=split)
def __getitem__(self, idx):
label, datas = self.db.get_example(idx)
label = t.from_numpy(np.array(label))
datas = np.array(datas)
datas = t.from_numpy(datas)
datas = datas.contiguous().view(1, -1)
# TODO: check whose stride is negative to fix this instead copy all
return label, datas
def __len__(self):
return len(self.db)