-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
85 lines (73 loc) · 5.99 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import cv2
from tqdm import tqdm
from PIL import Image
from torch.utils import data
from torchvision import transforms
from preproc import preproc
from config import Config
from utils.utils import path_to_image
Image.MAX_IMAGE_PIXELS = None # remove DecompressionBombWarning
config = Config()
# _class_labels_TR_sorted = 'Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht'
# class_labels_TR_sorted = _class_labels_TR_sorted.split(', ')
class MyData(data.Dataset):
def __init__(self, datasets, image_size, is_train=True):
self.size_train = image_size
self.size_test = image_size
self.keep_size = not config.train_size
self.data_size = (config.train_size, config.train_size)
self.is_train = is_train
self.load_all = config.load_all
self.device = config.device
self.transform_image = transforms.Compose([
transforms.Resize(self.data_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
][self.load_all or self.keep_size:])
self.transform_label = transforms.Compose([
transforms.Resize(self.data_size),
transforms.ToTensor(),
][self.load_all or self.keep_size:])
dataset_root = config.dataset_root
# datasets can be a list of different datasets for training on combined sets.
self.image_paths = []
for dataset in datasets.split('+'):
image_root = os.path.join(dataset_root, dataset, 'Imgs')
self.image_paths += [os.path.join(image_root, p) for p in os.listdir(image_root)]
self.label_paths = []
for p in self.image_paths:
for ext in ['.png', '.jpg', '.PNG', '.JPG', '.JPEG']:
## 'im' and 'gt' may need modifying
p_gt = p.replace('/Imgs/', '/GT/').replace('.' + p.split('.')[-1], ext)
if os.path.exists(p_gt):
self.label_paths.append(p_gt)
break
if self.load_all:
self.images_loaded, self.labels_loaded = [], []
# self.class_labels_loaded = []
# for image_path, label_path in zip(self.image_paths, self.label_paths):
for image_path, label_path in tqdm(zip(self.image_paths, self.label_paths), total=len(self.image_paths)):
_image = path_to_image(image_path, size=(config.train_size, config.train_size), color_type='rgb')
_label = path_to_image(label_path, size=(config.train_size, config.train_size), color_type='gray')
self.images_loaded.append(_image)
self.labels_loaded.append(_label)
def __getitem__(self, index):
if self.load_all:
image = self.images_loaded[index]
label = self.labels_loaded[index]
else:
image = path_to_image(self.image_paths[index], size=(config.train_size, config.train_size),
color_type='rgb')
label = path_to_image(self.label_paths[index], size=(config.train_size, config.train_size),
color_type='gray')
if self.is_train:
image, label, = preproc(image, label, preproc_methods=config.preproc_methods)
image, label = self.transform_image(image), self.transform_label(label)
if self.is_train:
return image, label # , class_label
else:
return image, label, self.label_paths[index]
def __len__(self):
return len(self.image_paths)