Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ready for VISIUM #29

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,726 changes: 3,728 additions & 998 deletions MAIN.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions MODULES/cropper_uncropper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def crop(bounding_box: BB, big_stuff: torch.Tensor, width_small: int, height_sma
dtype=big_stuff.dtype,
device=big_stuff.device).view([-1] + small_dependent_dim)

grid = F.affine_grid(affine, list(cropped_images.shape))
cropped_images = F.grid_sample(big_stuff, grid, mode='bilinear', padding_mode='reflection')
grid = F.affine_grid(affine, list(cropped_images.shape), align_corners=True)
cropped_images = F.grid_sample(big_stuff, grid, mode='bilinear', padding_mode='reflection', align_corners=True)
return cropped_images.view(independent_dim + small_dependent_dim)

@staticmethod
Expand Down Expand Up @@ -95,9 +95,9 @@ def uncrop(bounding_box: BB, small_stuff: torch.Tensor, width_big: int, height_b
uncropped_stuff: torch.Tensor = torch.zeros(independent_dim + large_dependent_dim,
dtype=small_stuff.dtype,
device=small_stuff.device).view([-1] + large_dependent_dim)
grid = F.affine_grid(affine_matrix.view(-1, 2, 3), list(uncropped_stuff.shape))
grid = F.affine_grid(affine_matrix.view(-1, 2, 3), list(uncropped_stuff.shape), align_corners=True)
uncropped_stuff = F.grid_sample(small_stuff.view([-1] + small_dependent_dim), grid,
mode='bilinear', padding_mode='zeros')
mode='bilinear', padding_mode='zeros', align_corners=True)
return uncropped_stuff.view(independent_dim + large_dependent_dim)

@staticmethod
Expand Down
111 changes: 62 additions & 49 deletions MODULES/namedtuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,30 @@
import skimage.color
import matplotlib.pyplot as plt


# ---------------------------------------------------------------- #
# ------- Stuff defined in terms of native types ----------------- #
# ------- Stuff Related to PreProcessing ------------------------- #
# ---------------------------------------------------------------- #

class ImageBbox(NamedTuple):
""" Follows Scikit Image convention. Pixels belonging to the bounding box are in the half-open interval:
[min_row;max_row) and [min_col;max_col). """
min_row: int
min_col: int
max_row: int
max_col: int


class PreProcess(NamedTuple):
img: torch.Tensor
roi_mask: torch.Tensor
bbox_original: ImageBbox
bbox_crop: ImageBbox

# -------------------------------------------------------------------------------------- #
# ------- Stuff Related to PostProcessing (i.e. Graph Clustering Based on Modularity) -- #
# -------------------------------------------------------------------------------------- #


class Suggestion(NamedTuple):
best_resolution: float
Expand All @@ -26,7 +46,7 @@ def show_best(self, figsize: tuple = (20, 20), fontsize: int = 20):
fontsize=fontsize)


class Concordance(NamedTuple):
class ConcordancePartition(NamedTuple):
joint_distribution: torch.tensor
mutual_information: float
delta_n: int
Expand Down Expand Up @@ -96,7 +116,7 @@ def filter_by_size(self, min_size: Optional[int] = None, max_size: Optional[int]
new_membership = old_2_new[self.membership]
return self._replace(membership=new_membership, params=new_dict, sizes=torch.bincount(new_membership))

def concordance_with_partition(self, other_partition) -> Concordance:
def concordance_with_partition(self, other_partition) -> ConcordancePartition:
""" Compute measure of concordance between two partitions:
joint_distribution
mutual_information
Expand Down Expand Up @@ -146,38 +166,10 @@ def concordance_with_partition(self, other_partition) -> Concordance:
union = torch.sum(self.sizes[1:]) + torch.sum(other_partition.sizes[1:]) - intersection # exclude background
iou = intersection.float()/union

return Concordance(joint_distribution=pxy,
mutual_information=mutual_information.item(),
delta_n=ny - nx,
iou=iou.item())


class DIST(NamedTuple):
sample: torch.Tensor
kl: torch.Tensor


class ZZ(NamedTuple):
mu: torch.Tensor
std: torch.Tensor


class BB(NamedTuple):
bx: torch.Tensor # shape: n_box, batch_size
by: torch.Tensor
bw: torch.Tensor
bh: torch.Tensor


class NMSoutput(NamedTuple):
nms_mask: torch.Tensor
index_top_k: torch.Tensor


class Checkpoint(NamedTuple):
epoch: int
hyperparams_dict: dict
history_dict: dict
return ConcordancePartition(joint_distribution=pxy,
mutual_information=mutual_information.item(),
delta_n=ny - nx,
iou=iou.item())


class Similarity(NamedTuple):
Expand Down Expand Up @@ -205,10 +197,38 @@ def reduce_similarity_radius(self, new_radius: int):


# ---------------------------------------------------------------- #
# -------Stuff defined in term of other sutff -------------------- #
# ------- Stuff Related to Processing (i.e. CompositionalVAE) ---- #
# ---------------------------------------------------------------- #


class DIST(NamedTuple):
sample: torch.Tensor
kl: torch.Tensor


class ZZ(NamedTuple):
mu: torch.Tensor
std: torch.Tensor


class BB(NamedTuple):
bx: torch.Tensor # shape: n_box, batch_size
by: torch.Tensor
bw: torch.Tensor
bh: torch.Tensor


class NMSoutput(NamedTuple):
nms_mask: torch.Tensor
index_top_k: torch.Tensor


class Checkpoint(NamedTuple):
epoch: int
hyperparams_dict: dict
history_dict: dict


class Segmentation(NamedTuple):
""" Where * is the batch dimension which might be NOT present """
raw_image: torch.Tensor # *,ch,w,h
Expand All @@ -235,17 +255,18 @@ class Inference(NamedTuple):
big_bg: torch.Tensor
big_img: torch.Tensor
big_mask: torch.Tensor
big_mask_NON_interacting: torch.Tensor
big_mask_NON_interacting: torch.Tensor # Use exclusively to compute overlap penalty
prob: torch.Tensor
bounding_box: BB
zinstance_each_obj: torch.Tensor
kl_zinstance_each_obj: torch.Tensor
kl_zwhere_map: torch.Tensor
kl_logit_map: torch.Tensor


class MetricMiniBatch(NamedTuple):
loss: torch.Tensor
nll: torch.Tensor
mse: torch.Tensor
reg: torch.Tensor
kl_tot: torch.Tensor
kl_instance: torch.Tensor
Expand All @@ -262,19 +283,15 @@ class MetricMiniBatch(NamedTuple):


class RegMiniBatch(NamedTuple):
# cost_fg_pixel_fraction: torch.Tensor
cost_overlap: torch.Tensor
cost_vol_absolute: torch.Tensor
# cost_volume_mask_fraction: torch.Tensor
# cost_prob_map_integral: torch.Tensor
# cost_prob_map_fraction: torch.Tensor
# cost_prob_map_TV: torch.Tensor
cost_total: torch.Tensor


class Metric_and_Reg(NamedTuple):
# MetricMiniBatch (in the same order as underlying class)
loss: torch.Tensor
nll: torch.Tensor
mse: torch.Tensor
reg: torch.Tensor
kl_tot: torch.Tensor
kl_instance: torch.Tensor
Expand All @@ -289,13 +306,9 @@ class Metric_and_Reg(NamedTuple):
length_GP: torch.Tensor
n_obj_counts: torch.Tensor
# RegMiniBatch (in the same order as underlying class)
# cost_fg_pixel_fraction: torch.Tensor
cost_overlap: torch.Tensor
cost_vol_absolute: torch.Tensor
# cost_volume_mask_fraction: torch.Tensor
# cost_prob_map_integral: torch.Tensor
# cost_prob_map_fraction: torch.Tensor
# cost_prob_map_TV: torch.Tensor
cost_total: torch.Tensor

@classmethod
def from_merge(cls, metrics, regularizations):
Expand Down
14 changes: 6 additions & 8 deletions MODULES/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import collections
import skimage.exposure
import numpy as np
from .namedtuple import ImageBbox
import PIL.Image
PIL.Image.MAX_IMAGE_PIXELS = None

BBOX = collections.namedtuple("BBOX","min_row, min_col, max_row, max_col")


def pil_to_numpy(pilfile, mode: str = 'L', reduction_factor: int = 1):
""" Open file using pillow, and return numpy array with shape:
w,h,channel if channel > 1
w,h if channel ==1.
Mode can be: 'L', 'RGB', 'I'
Mode can be: 'L', 'RGB', 'I', 'F'
See https://pillow.readthedocs.io/en/3.0.x/handbook/concepts.html
"""
assert (mode == 'L' or mode == 'RGB' or mode == 'I' or mode == 'F')
Expand Down Expand Up @@ -87,10 +85,10 @@ def find_bbox(mask):
min_row = row.shape[0] - max(np.arange(start=row.shape[0], stop=0, step=-1) * row)
max_col = max(np.arange(col.shape[0]) * col) + 1
min_col = col.shape[0] - max(np.arange(start=col.shape[0], stop=0, step=-1) * col)
return BBOX(min_row=min_row,
min_col=min_col,
max_row=max_row,
max_col=max_col)
return ImageBbox(min_row=min_row,
min_col=min_col,
max_row=max_row,
max_col=max_col)


#####def normalize_tensor(image, scale_each_image=False, scale_each_channel=False, in_place=False):
Expand Down
66 changes: 12 additions & 54 deletions MODULES/unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,34 +38,29 @@ def __init__(self, params: dict):
# Up path
self.up_path = torch.nn.ModuleList()
for i in range(0, self.n_max_pool):
j = int(j / 2)
ch = int(ch / 2)
j = int(j // 2)
ch = int(ch // 2)
self.ch_list.append(ch)
self.j_list.append(j)
self.up_path.append(UpBlock(self.ch_list[-2], self.ch_list[-1]))

# Compute s_p_k
self.s_p_k = list()
for module in self.down_path:
self.s_p_k = module.__add_to_spk_list__(self.s_p_k)
for module in self.up_path:
self.s_p_k = module.__add_to_spk_list__(self.s_p_k)

# Prediction maps
self.pred_features = MLP_1by1(ch_in=self.ch_list[-1],
ch_out=self.n_ch_output_features,
ch_hidden=-1) # this means there is NO hidden layer

self.encode_zwhere = Encoder1by1(ch_in=self.ch_list[-self.level_zwhere_and_logit_output - 1],
self.ch_in_zwhere = min(20, self.ch_list[-self.level_zwhere_and_logit_output - 1])
self.encode_zwhere = Encoder1by1(ch_in=self.ch_in_zwhere,
dim_z=self.dim_zwhere,
ch_hidden=None) # this means there is ONE hidden layer of automatic size

self.encode_logit = Encoder1by1(ch_in=self.ch_list[-self.level_zwhere_and_logit_output - 1],
self.ch_in_logit = min(20, self.ch_list[-self.level_zwhere_and_logit_output - 1])
self.encode_logit = Encoder1by1(ch_in=self.ch_in_logit,
dim_z=self.dim_logit,
ch_hidden=None) # this means there is ONE hidden layer of automatic size

# I don't need all the channels to predict the background. Few channels are enough
self.ch_in_bg = min(5, self.ch_list[-self.level_background_output - 1])
self.ch_in_bg = min(10, self.ch_list[-self.level_background_output - 1])
self.pred_background = PredictBackground(ch_in=self.ch_in_bg,
ch_out=self.ch_raw_image,
ch_hidden=-1) # this means there is NO hidden layer
Expand All @@ -92,8 +87,8 @@ def forward(self, x: torch.Tensor, verbose: bool):
for i, up in enumerate(self.up_path):
dist_to_end_of_net = self.n_max_pool - i
if dist_to_end_of_net == self.level_zwhere_and_logit_output:
zwhere = self.encode_zwhere(x)
logit = self.encode_logit(x)
zwhere = self.encode_zwhere(x[..., :self.ch_in_zwhere, :, :])
logit = self.encode_logit(x[..., :self.ch_in_logit, :, :])
if dist_to_end_of_net == self.level_background_output:
zbg = self.pred_background(x[..., :self.ch_in_bg, :, :]) # only few channels needed for predicting bg

Expand All @@ -110,6 +105,7 @@ def forward(self, x: torch.Tensor, verbose: bool):
features=features)

def show_grid(self, ref_image):
""" overimpose a grid the size of the corresponding resolution of each unet layer """

assert len(ref_image.shape) == 4
batch, ch, w_raw, h_raw = ref_image.shape
Expand All @@ -121,9 +117,9 @@ def show_grid(self, ref_image):

for k in range(nj):
j = self.j_list[k]
index_w = 1 + ((counter_w / j) % 2) # either 1 or 2
index_w = 1 + ((counter_w // j) % 2) # either 1 or 2
dx = index_w.float().view(w_raw, 1)
index_h = 1 + ((counter_h / j) % 2) # either 1 or 2
index_h = 1 + ((counter_h // j) % 2) # either 1 or 2
dy = index_h.float().view(1, h_raw)
check_board[k, 0, 0, :, :] = 0.25 * (dy * dx) # dx*dy=1,2,4 multiply by 0.25 to have (0,1)

Expand All @@ -134,41 +130,3 @@ def show_grid(self, ref_image):
# ref_image of shape -----> batch, ch, w_raw, h_raw
return ref_image + check_board

def describe_receptive_field(self, image):
""" Show the value of ch_w_h_j_rf_loc as the tensor moves thorugh the net.
Here:
a. w,h are the width and height
b. j is grid spacing
c. rf is the maximum theoretical receptive field
d. wloc,hloc are the location of the center of the first cell
"""
w, h = image.shape[-2:]
j = 1
rf = 1
w_loc = 0.5
h_loc = 0.5
current_layer = (w, h, j, rf, w_loc, h_loc)
i = -1
for i in range(0, len(self.s_p_k)):
print("At layer l= ", i, " we have w_h_j_rf_wloc_hloc= ", current_layer)
current_layer = self.out_from_in(self.s_p_k[i], current_layer)
print("At layer l= ", i + 1, " we have w_h_j_rf_wloc_hloc= ", current_layer)

@staticmethod
def out_from_in(s_p_k, layer_in):
w_in, h_in, j_in, rf_in, wloc_in, hloc_in = layer_in
s = s_p_k[0]
p = s_p_k[1]
k = s_p_k[2]

w_out = numpy.floor((w_in - k + 2 * p) / s) + 1
h_out = numpy.floor((h_in - k + 2 * p) / s) + 1

pad_w = ((w_out - 1) * s - w_in + k) / 2
pad_h = ((h_out - 1) * s - h_in + k) / 2

j_out = j_in * s
rf_out = rf_in + (k - 1) * j_in
wloc_out = wloc_in + ((k - 1) / 2 - pad_w) * j_in
hloc_out = hloc_in + ((k - 1) / 2 - pad_h) * j_in
return int(w_out), int(h_out), j_out, int(rf_out), wloc_out, hloc_out
Loading