Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Abe404 committed Apr 21, 2023
2 parents ae041a5 + f61a5de commit ab73637
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 10 deletions.
46 changes: 40 additions & 6 deletions painter/src/main/python/plot_seg_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_cache_key(seg_dir, annot_dir, fname):
fname = os.path.splitext(fname)[0] + '.png'
return fname

def compute_seg_metrics(seg_dir, annot_dir, fname):
def compute_seg_metrics(seg_dir, annot_dir, fname, model_dir):

# annot and seg are both PNG
fname = os.path.splitext(fname)[0] + '.png'
Expand Down Expand Up @@ -133,9 +133,31 @@ def compute_seg_metrics(seg_dir, annot_dir, fname):
corrected_segmentation_metrics = compute_metrics_from_masks(
seg, corrected, np.sum(foreground > 0), np.sum(background > 0))

corrected_segmentation_metrics['model_name'] = get_model_name_for_seg(
seg_path, model_dir)

return corrected_segmentation_metrics


def get_model_name_for_seg(seg_path, models_dir):
# get modified time
seg_m_time = os.path.getmtime(seg_path)
model_fnames = sorted(os.listdir(models_dir))
for i, m in enumerate(model_fnames):
model_m_time = int(m.replace('.pkl', '').split('_')[1])
# if it's the last model then it must be this one.
if i == (len(model_fnames) - 1):
return m
next_m = model_fnames[i+1]
next_model_m_time = int(next_m.replace('.pkl', '').split('_')[1])
# if the model is before the seg
if model_m_time < seg_m_time:
# and the next model is after the seg
if next_model_m_time > seg_m_time:
# then return the model
return m


class Thread(QtCore.QThread):
progress_change = QtCore.pyqtSignal(int, int)
done = QtCore.pyqtSignal(str)
Expand All @@ -145,6 +167,7 @@ def __init__(self, proj_dir, csv_path, fnames):
self.proj_dir = proj_dir
self.seg_dir = os.path.join(proj_dir, 'segmentations')
self.annot_dir = os.path.join(proj_dir, 'annotations')
self.model_dir = os.path.join(proj_dir, 'models')
self.csv_path = csv_path
self.fnames = fnames

Expand Down Expand Up @@ -184,7 +207,8 @@ def run(self):
# but segmentation may now be available so we still need to ignore this cached result and
# recompute metrics, just incase the segmentation now exists.
if metrics is None:
metrics = compute_seg_metrics(self.seg_dir, self.annot_dir, fname)
metrics = compute_seg_metrics(self.seg_dir, self.annot_dir,
fname, self.model_dir)
cache_dict[cache_key] = metrics

if metrics:
Expand Down Expand Up @@ -315,7 +339,8 @@ def add_file_metrics(self, fname):
cache_dict = pickle.load(open(cache_dict_path, 'rb'))
# cache_key is just the fname for now.
cache_key = get_cache_key(seg_dir, annot_dir, fname)
metrics = compute_seg_metrics(seg_dir, annot_dir, fname)
metrics = compute_seg_metrics(seg_dir, annot_dir, fname,
os.path.join(self.proj_dir, 'models'))
if metrics:
cache_dict[cache_key] = metrics
self.plot_window.add_point(fname, metrics)
Expand Down Expand Up @@ -348,6 +373,7 @@ def view_plot_from_csv(self, csv_fpath):
"annot_fg": int(annot_fg),
"annot_bg": int(annot_bg)
})

self.plot_window = QtGraphMetricsPlot(
fnames, metrics_list, rolling_n=30)
self.plot_window.setWindowTitle(
Expand Down Expand Up @@ -402,7 +428,6 @@ def __init__(self, fnames, metrics_list, rolling_n, selected_fname=None):
# so image 7 (index of 6) will be the first correctively annotated.
self.first_corrective_idx = 6


self.highlight_point = None
self.show_selected = True
self.graph_plot = None
Expand Down Expand Up @@ -631,7 +656,7 @@ def render_data(self):
self.graph_plot.clear()

def hover_tip(x, y, data):
return f'{int(x)} {data} {self.metric_display_name}: {round(y, 4)}'
return f'{int(x)} {data} {self.metric_display_name}: {round(y, 4)} model: {self.get_model_name(int(x)-1)}'

self.scatter = pg.ScatterPlotItem(size=8, symbol='x', clickable=True, hoverable=True,
hoverBrush=pg.mkBrush('grey'), hoverPen=pg.mkPen('grey'),
Expand All @@ -651,7 +676,16 @@ def hover_tip(x, y, data):
if self.highlight_point_fname is not None:
self.render_highlight_point()
self.add_events()


def get_model_name(self, x):
# we want to know which model was used to generate each segmentation.
# x is the index of the segmentation in the list.
if x < len(self.metrics_list):
if self.metrics_list[x] and 'model_name' in self.metrics_list[x]:
return self.metrics_list[x]['model_name']
return ''


def get_corrected_dice(self):
corrected_dice = []
for i, m in enumerate(self.metrics_list):
Expand Down
13 changes: 9 additions & 4 deletions painter/src/main/python/root_painter.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,10 +592,15 @@ def view_metric_csv():
self.metrics_plot = MetricsPlot()


view_metrics_csv_btn.triggered.connect(view_metric_csv)
extras_menu.addAction(view_metrics_csv_btn)


# view_metrics_csv_btn.triggered.connect(view_metric_csv)
# This has been disabled because the metrics are getting constantly
# expanded (new features) and the code that loads metrics from csv
# needs to be udpated to include the new metrics. I'm going to disable
# this functionality for a while and see if anyone notices. If they
# notice/complain, then I think it would be worthwhile to update this
# functionality and think about how to make it work with both new
# (including more metrics) and old metrics CSV files.
# extras_menu.addAction(view_metrics_csv_btn)

if project_open:
metrics_plot_btn = QtWidgets.QAction(QtGui.QIcon('missing.png'),
Expand Down

0 comments on commit ab73637

Please sign in to comment.