#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Feb 7 22:55:20 2018
@author: felix
"""
[docs]def match_auto_man_detections(man_xy, auto_xy, max_dist=None):
from sklearn.metrics import pairwise_distances
from scipy.optimize import linear_sum_assignment
dists = pairwise_distances(man_xy, auto_xy)
ind_i, ind_j = linear_sum_assignment(dists)
if max_dist is not None and len(ind_i)>0:
ind_i_refine = []
ind_j_refine = []
for i in range(len(ind_i)):
dist = dists[ind_i[i], ind_j[i]]
if dist < max_dist:
ind_i_refine.append(ind_i[i])
ind_j_refine.append(ind_j[i])
ind_i = np.hstack(ind_i_refine)
ind_j = np.hstack(ind_j_refine)
return ind_i, ind_j
[docs]def compute_matching_statistics(detections_manual, detections_auto, max_match_dist=10):
"""
Computes missed man, intersection, auto over detections.
"""
overlap_stats = []
for ii in range(len(detections_manual)):
# intersection.
overlap_ij = match_auto_man_detections(detections_manual[ii],
detections_auto[ii][0], max_match_dist)
n_intersect = len(overlap_ij[0])
n_missed = len(np.setdiff1d(np.arange(len(detections_manual)), overlap_ij[0]))
n_over = len(np.setdiff1d(np.arange(len(detections_auto)), overlap_ij[1]))
overlap_stats.append([n_intersect, n_missed, n_over])
return np.vstack(overlap_stats)
if __name__=="__main__":
"""
Compute the detection accuracy of centrioles for the proposed algorithm.
"""
import numpy as np
import pylab as plt
import file_io as fio
import training_fn as training
import image_fn
import visualization as viz
import scipy.io as spio
import os
from tqdm import tqdm
"""
All detection files to test.
"""
inputfiles = ['/media/felix/Elements/Raff Lab/Centriole Distancing GitHub/detections-auto-vs-manual_Early-CV0.30.mat',
'/media/felix/Elements/Raff Lab/Centriole Distancing GitHub/detections-auto-vs-manual_Mid-CV0.30.mat',
'/media/felix/Elements/Raff Lab/Centriole Distancing GitHub/detections-auto-vs-manual_Late-CV0.30.mat']
#==============================================================================
# Load detections
#==============================================================================
early_detections = spio.loadmat(inputfiles[0])
mid_detections = spio.loadmat(inputfiles[1])
late_detections = spio.loadmat(inputfiles[2])
#==============================================================================
# Compute Statistics (TP/FP)
#==============================================================================
# test_ij = match_auto_man_detections(np.squeeze(early_detections['manual'])[0],
# np.squeeze(early_detections['auto'])[0][0], max_dist=10)
early_detect_stats = compute_matching_statistics(np.squeeze(early_detections['manual']),
np.squeeze(early_detections['auto']), max_match_dist=32//2)
mid_detect_stats = compute_matching_statistics(np.squeeze(mid_detections['manual']),
np.squeeze(mid_detections['auto']), max_match_dist=32//2)
late_detect_stats = compute_matching_statistics(np.squeeze(late_detections['manual']),
np.squeeze(late_detections['auto']), max_match_dist=32//2)
#==============================================================================
# Construct the TP/FP table for all (aggregated).
#==============================================================================
# precision = TP/(TP+FP)
# recall = TP/(TP+FN)
all_stats = np.vstack([early_detect_stats, mid_detect_stats, late_detect_stats])
all_files = np.hstack([early_detections['files'], mid_detections['files'], late_detections['files']])
all_precision = [stat[0]/float(stat[0]+stat[2]) for stat in all_stats]
all_recall = [stat[0]/float(stat[0]+stat[1]) for stat in all_stats]
all_F_score = 2*(np.hstack(all_precision)*np.hstack(all_recall)) / (np.hstack(all_precision)+ np.hstack(all_recall))
print(all_files[np.argmax(all_F_score)])
precision_mean = np.mean([stat[0]/float(stat[0]+stat[2]) for stat in all_stats])
recall_mean = np.mean([stat[0]/float(stat[0]+stat[1]) for stat in all_stats])
print('mean precision: %.3f' %precision_mean)
print('mean recall: %.3f' %recall_mean)
#==============================================================================
# Draw the Venn diagram.
#==============================================================================
from matplotlib_venn import venn2, venn2_circles
# Custom text labels: change the label of group A
v=venn2(subsets = (early_detect_stats[:,1].sum(), early_detect_stats[:,2].sum(), early_detect_stats[:,0].sum()), set_labels = ('Group A', 'Group B'), set_colors=('g', 'r'), alpha=0.4)
v.get_label_by_id('A').set_text('Manual')
v.get_label_by_id('B').set_text('Auto')
v=venn2_circles((early_detect_stats[:,1].sum(), early_detect_stats[:,2].sum(), early_detect_stats[:,0].sum()), color='k', alpha=1.0, linestyle='solid', linewidth=2.0)
# plt.savefig('Venn-Centriole_detections-early.svg', dpi=300)
plt.show()
v=venn2(subsets = (mid_detect_stats[:,1].sum(), mid_detect_stats[:,2].sum(), mid_detect_stats[:,0].sum()), set_labels = ('Group A', 'Group B'), set_colors=('g', 'r'), alpha=0.4)
v.get_label_by_id('A').set_text('Manual')
v.get_label_by_id('B').set_text('Auto')
v=venn2_circles((mid_detect_stats[:,1].sum(), mid_detect_stats[:,2].sum(), mid_detect_stats[:,0].sum()), color='k', alpha=1.0, linestyle='solid', linewidth=2.0)
# plt.savefig('Venn-Centriole_detections-mid.svg', dpi=300)
plt.show()
v=venn2(subsets = (late_detect_stats[:,1].sum(), late_detect_stats[:,2].sum(), late_detect_stats[:,0].sum()), set_labels = ('Group A', 'Group B'), set_colors=('g', 'r'), alpha=0.4)
v.get_label_by_id('A').set_text('Manual')
v.get_label_by_id('B').set_text('Auto')
v=venn2_circles((late_detect_stats[:,1].sum(), late_detect_stats[:,2].sum(), late_detect_stats[:,0].sum()), color='k', alpha=1.0, linestyle='solid', linewidth=2.0)
# plt.savefig('Venn-Centriole_detections-late.svg', dpi=300)
plt.show()
# infolder = '/media/felix/Elements/Raff Lab/Centriole Distancing/Sectioned Training/Early S-phase'
## infolder = '/media/felix/Elements/Raff Lab/Centriole Distancing/Sectioned Training/Mid S-phase'
## infolder = '/media/felix/Elements/Raff Lab/Centriole Distancing/Sectioned Training/Late S-phase'
# intifs = fio.locate_centriole_files(infolder)
#
##==============================================================================
## Load the corresponding img stack.
##==============================================================================
#
# imgfolder = '/media/felix/Elements/Raff Lab/Centriole Distancing/Datasets/OM NM Tracking with Felix/cep97-gfp distancing images (for felix)/cycle 12 - cep97-gfp in wt'
# imgtifs = fio.locate_centriole_files(imgfolder)
#
# # retrieve the relevant tifs.
# imgtifs_match = fio.retrieve_matching_imgs(intifs, imgtifs)
#
#
##==============================================================================
## Iterate over the images and try to detect accurately each of the centriole positions and compare with the annotated ver.
## highlight using circles matplotlib.
##==============================================================================
# n_tifs = len(intifs)
# print(n_tifs)
#
#
# """
# Iterate over the image set.
# """
# all_centriole_detections_manual = []
# all_centriole_detections_auto = []
#
#
# for i in range(n_tifs)[:]:
#
# tif = intifs[i]
# im_tif = imgtifs_match[i]
#
# print '%d: processing %s' %(i+1, tif)
# basename = tif.split('/')[-1].split('.tif')[0] # get the basename.
#
# # read the time-slice stack
# raw_stack_img = fio.read_stack_time_img(im_tif, n_timepoints=n_timepoints, n_slices=n_slices)
# annot_stack_img = fio.read_stack_time_img(tif, n_timepoints=n_timepoints, n_slices=n_slices) # use only the first time point.
#
#
##==============================================================================
## 1. From Manual
##==============================================================================
# annot_slice_img_blue = training.extract_dot_annotation_zstack(annot_stack_img[0], c=[0,0,255])
# annot_slice_img_blue = np.max(annot_slice_img_blue, axis=0) # max projection image.
# centriole_centroids_manual = image_fn.locate_centroids_simple(annot_slice_img_blue).astype(np.int)
#
##==============================================================================
## 2. From Automatic Centriole Detection
##==============================================================================
#
# detect_dict = image_fn.detect_centrioles_full_slice(raw_stack_img, size, aniso_params, patch_size, CV_thresh, tslice=0, filter_border=True, separation=5, invert=False, minmass=10) # set tslice=-1 if just a single timepoint.
#
# # parse the detection.
# centriole_centroids_auto = detect_dict['centriole_centroids']
# f = detect_dict['centriole_pos']
# max_slice_im = detect_dict['max_proj_full_img']
# max_slice_im_denoise = detect_dict['max_proj_full_img_denoise']
# slice_bg_mask = detect_dict['background_mask']
# slice_valid_detection_mask = detect_dict['valid_detection_mask']
# centriole_SNR = detect_dict['centriole_SNR']
#
# all_centriole_detections_manual.append(centriole_centroids_manual)
# all_centriole_detections_auto.append([centriole_centroids_auto, centriole_SNR])
#
#
##==============================================================================
## 3. Visualize the Detections (Visual Comparison)
##==============================================================================
#
# """
# Visualise the detections.
# """
# fig, ax = plt.subplots(figsize=(10,10))
# plt.imshow(max_slice_im, cmap='gray')
# viz.draw_circles(centriole_centroids_manual, ax, radii=patch_size/2, col='g', lw=2)
# viz.draw_circles(centriole_centroids_auto, ax, radii=patch_size*1, col='r', lw=2)
# ax.grid('off')
# ax.axis('off')
# fig.savefig(os.path.join(infolder, basename + '_manual-auto-compare.svg'), dpi=300)
# plt.show()
# plt.close()
#
###==============================================================================
### 4. Save the detections per file in order to compute mAP and TP/FP statistics.
###==============================================================================
#
# spio.savemat('detections-auto-vs-manual_Early-CV%.2f.mat' %(CV_thresh), {'files':imgtifs,
# 'manual':all_centriole_detections_manual,
# 'auto':all_centriole_detections_auto})