-
Notifications
You must be signed in to change notification settings - Fork 0
/
code.py
74 lines (53 loc) · 2.13 KB
/
code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import pandas as pd
import numpy as np
import os
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch import optim
from skimage import io, transform
import glob2
#import cv2
from skimage.util import random_noise
from PIL import Image
image_p=[]
#for f in glob2.iglob(r"C:\Users\abheesht\Desktop\Code\Train001\*"):
# image_p.append(f)
for f in glob2.iglob(r"alien_pred/train/fall/*"):
image_p.append(f)
class ImageDataset(Dataset):
def __init__(self, image_path, transform=None):
"""
Args:
csv_file (string): Path to the csv file with labels.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.image_path = image_path
#self.root_dir = root_dir
self.transform = transform
def __len__(self):
return(len(self.image_path))
def __getitem__(self, idx):
img_name =(self.image_path)[idx]
image = Image.open(img_name).convert('RGB')
if self.transform:
image = self.transform(image)
label = np.array([1 for i in range(len(self.image_path))])
sample = (image,label)
return sample
transform = transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),transforms.Lambda(lambda x : x + 0.1*torch.randn_like(x))])
#transforms.Lambda(lambda x : x + torch.randn_like(x))
trainset = ImageDataset(image_path=image_p, transform=transform)
#trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=0)
sample=trainset[0]
print(sample[0])
new_im = Image.fromarray(np.transpose(np.array(sample[0]), (1, 2, 0)))
#new_im.save(r"C:\Users\abheesht\Desktop\Code\numpy_altered_sample2.png")
new_im.save(r"alien_pred/numpy_altered_sample2.jpeg")
plt.show()