-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
117 lines (97 loc) · 3.52 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import cv2 as cv
import numpy as np
import pandas as pd
import torch
from albumentations import Normalize, Compose
from albumentations.pytorch import ToTensor
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from tqdm import tqdm
def post_process(probability, threshold, min_size):
'''Post processing of each predicted mask, components with lesser number of pixels
than `min_size` are ignored'''
mask = cv.threshold(probability, threshold, 1, cv.THRESH_BINARY)[1]
num_component, component = cv.connectedComponents(mask.astype(np.uint8))
predictions = np.zeros((256, 1600), np.float32)
num = 0
for c in range(1, num_component):
p = (component == c)
if p.sum() > min_size:
predictions[p] = 1
num += 1
return predictions, num
class TestDataset(Dataset):
'''Dataset for test prediction'''
def __init__(self, root, df, mean, std):
self.root = root
df['ImageId'] = df['ImageId_ClassId'].apply(lambda x: x.split('_')[0])
self.fnames = df['ImageId'].unique().tolist()
self.num_samples = len(self.fnames)
self.transform = Compose(
[
Normalize(mean=mean, std=std, p=1),
ToTensor(),
]
)
def __getitem__(self, idx):
fname = self.fnames[idx]
path = os.path.join(self.root, fname)
image = cv.imread(path)
images = self.transform(image=image)["image"]
return fname, images
def __len__(self):
return self.num_samples
def mask2rle(img):
#https://www.kaggle.com/paulorzp/rle-functions-run-lenght-encode-decode
'''
img: numpy array, 1 - mask, 0 - background
Returns run length as string formated
'''
pixels= img.T.flatten()
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return ' '.join(str(x) for x in runs)
sample_submission_path = '../input/severstal-steel-defect-detection/sample_submission.csv'
test_data_folder = "../input/severstal-steel-defect-detection/test_images"
# initialize test dataloader
best_threshold = 0.3
num_workers = 2
batch_size = 4
print('best_threshold', best_threshold)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
df = pd.read_csv(sample_submission_path)
testset = DataLoader(
TestDataset(test_data_folder, df, mean, std),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
# Initialize mode and load trained weights
ckpt_path = "../input/1.1.resnet50_severstal/best.pth"
device = torch.device("cuda")
model = smp.Unet("resnet50", encoder_weights=None, classes=4, activation=None)
model.to(device)
model.eval()
state = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(state["model_state_dict"])
del state
# start prediction
predictions = []
for i, batch in enumerate(tqdm(testset)):
fnames, images = batch
batch_preds = torch.sigmoid(model(images.to(device)))
batch_preds = batch_preds.detach().cpu().numpy()
for fname, preds in zip(fnames, batch_preds):
for cls, pred in enumerate(preds):
pred = post_process(pred, best_threshold, 3500)
rle = mask2rle(pred)
name = fname + f"_{cls+1}"
predictions.append([name, rle])
# save predictions to submission.csv
df = pd.DataFrame(predictions, columns=['ImageId_ClassId', 'EncodedPixels'])
df.to_csv("submission.csv", index=False)
print(df.head())