Skip to content

Commit

Permalink
add kaggle model
Browse files Browse the repository at this point in the history
  • Loading branch information
laitassou committed May 25, 2020
1 parent 60fe0bc commit 540b9c5
Show file tree
Hide file tree
Showing 9 changed files with 522 additions and 0 deletions.
73 changes: 73 additions & 0 deletions models/extract_small_tiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import cv2
import skimage.io
from tqdm.notebook import tqdm
import zipfile
import numpy as np


##
## extract tiles from large image and store in zip
###


TRAIN = './kaggle/train_images/'
MASKS = './kaggle//train_label_masks/'
OUT_TRAIN = 'train.zip'
OUT_MASKS = 'masks.zip'
sz = 128
N = 16


def tile(img, mask):
result = []
shape = img.shape
pad0,pad1 = (sz - shape[0]%sz)%sz, (sz - shape[1]%sz)%sz
img = np.pad(img,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
constant_values=255)
mask = np.pad(mask,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
constant_values=0)
img = img.reshape(img.shape[0]//sz,sz,img.shape[1]//sz,sz,3)
img = img.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
mask = mask.reshape(mask.shape[0]//sz,sz,mask.shape[1]//sz,sz,3)
mask = mask.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
if len(img) < N:
mask = np.pad(mask,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=0)
img = np.pad(img,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=255)
idxs = np.argsort(img.reshape(img.shape[0],-1).sum(-1))[:N]
img = img[idxs]
mask = mask[idxs]
for i in range(len(img)):
result.append({'img':img[i], 'mask':mask[i], 'idx':i})
return result



x_tot,x2_tot = [],[]
names = [name[:-10] for name in os.listdir(MASKS)]
with zipfile.ZipFile(OUT_TRAIN, 'w') as img_out,\
zipfile.ZipFile(OUT_MASKS, 'w') as mask_out:
for name in tqdm(names):
print (name)
if len(skimage.io.MultiImage(os.path.join(TRAIN,name+'.tiff'))) and len(skimage.io.MultiImage(os.path.join(MASKS,name+'_mask.tiff'))):
img = skimage.io.MultiImage(os.path.join(TRAIN,name+'.tiff'))[-1]
mask = skimage.io.MultiImage(os.path.join(MASKS,name+'_mask.tiff'))[-1]

tiles = tile(img,mask)
for t in tiles:
img,mask,idx = t['img'],t['mask'],t['idx']
x_tot.append((img/255.0).reshape(-1,3).mean(0))
x2_tot.append(((img/255.0)**2).reshape(-1,3).mean(0))
#if read with PIL RGB turns into BGR
img = cv2.imencode('.png',cv2.cvtColor(img, cv2.COLOR_RGB2BGR))[1]
img_out.writestr(f'{name}_{idx}.png', img)
mask = cv2.imencode('.png',mask[:,:,0])[1]
mask_out.writestr(f'{name}_{idx}.png', mask)




#image stats
img_avr = np.array(x_tot).mean(0)
img_std = np.sqrt(np.array(x2_tot).mean(0) - img_avr**2)
print('mean:',img_avr, ', std:', np.sqrt(img_std))
Loading

0 comments on commit 540b9c5

Please sign in to comment.