Source code for test_centriole_detection_accuracy_mAP

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Feb  7 22:55:20 2018

@author: felix
"""

[docs]def match_auto_man_detections(man_xy, auto_xy, max_dist=None): from sklearn.metrics import pairwise_distances from scipy.optimize import linear_sum_assignment dists = pairwise_distances(man_xy, auto_xy) ind_i, ind_j = linear_sum_assignment(dists) if max_dist is not None and len(ind_i)>0: ind_i_refine = [] ind_j_refine = [] for i in range(len(ind_i)): dist = dists[ind_i[i], ind_j[i]] if dist < max_dist: ind_i_refine.append(ind_i[i]) ind_j_refine.append(ind_j[i]) ind_i = np.hstack(ind_i_refine) ind_j = np.hstack(ind_j_refine) return ind_i, ind_j
[docs]def compute_matching_statistics(detections_manual, detections_auto, max_match_dist=10): """ Computes missed man, intersection, auto over detections. """ overlap_stats = [] for ii in range(len(detections_manual)): # intersection. overlap_ij = match_auto_man_detections(detections_manual[ii], detections_auto[ii][0], max_match_dist) n_intersect = len(overlap_ij[0]) n_missed = len(np.setdiff1d(np.arange(len(detections_manual)), overlap_ij[0])) n_over = len(np.setdiff1d(np.arange(len(detections_auto)), overlap_ij[1])) overlap_stats.append([n_intersect, n_missed, n_over]) return np.vstack(overlap_stats)
if __name__=="__main__": """ Compute the detection accuracy of centrioles for the proposed algorithm. """ import numpy as np import pylab as plt import file_io as fio import training_fn as training import image_fn import visualization as viz import scipy.io as spio import os from tqdm import tqdm """ All detection files to test. """ inputfiles = ['/media/felix/Elements/Raff Lab/Centriole Distancing GitHub/detections-auto-vs-manual_Early-CV0.30.mat', '/media/felix/Elements/Raff Lab/Centriole Distancing GitHub/detections-auto-vs-manual_Mid-CV0.30.mat', '/media/felix/Elements/Raff Lab/Centriole Distancing GitHub/detections-auto-vs-manual_Late-CV0.30.mat'] #============================================================================== # Load detections #============================================================================== early_detections = spio.loadmat(inputfiles[0]) mid_detections = spio.loadmat(inputfiles[1]) late_detections = spio.loadmat(inputfiles[2]) #============================================================================== # Compute Statistics (TP/FP) #============================================================================== # test_ij = match_auto_man_detections(np.squeeze(early_detections['manual'])[0], # np.squeeze(early_detections['auto'])[0][0], max_dist=10) early_detect_stats = compute_matching_statistics(np.squeeze(early_detections['manual']), np.squeeze(early_detections['auto']), max_match_dist=32//2) mid_detect_stats = compute_matching_statistics(np.squeeze(mid_detections['manual']), np.squeeze(mid_detections['auto']), max_match_dist=32//2) late_detect_stats = compute_matching_statistics(np.squeeze(late_detections['manual']), np.squeeze(late_detections['auto']), max_match_dist=32//2) #============================================================================== # Construct the TP/FP table for all (aggregated). #============================================================================== # precision = TP/(TP+FP) # recall = TP/(TP+FN) all_stats = np.vstack([early_detect_stats, mid_detect_stats, late_detect_stats]) all_files = np.hstack([early_detections['files'], mid_detections['files'], late_detections['files']]) all_precision = [stat[0]/float(stat[0]+stat[2]) for stat in all_stats] all_recall = [stat[0]/float(stat[0]+stat[1]) for stat in all_stats] all_F_score = 2*(np.hstack(all_precision)*np.hstack(all_recall)) / (np.hstack(all_precision)+ np.hstack(all_recall)) print(all_files[np.argmax(all_F_score)]) precision_mean = np.mean([stat[0]/float(stat[0]+stat[2]) for stat in all_stats]) recall_mean = np.mean([stat[0]/float(stat[0]+stat[1]) for stat in all_stats]) print('mean precision: %.3f' %precision_mean) print('mean recall: %.3f' %recall_mean) #============================================================================== # Draw the Venn diagram. #============================================================================== from matplotlib_venn import venn2, venn2_circles # Custom text labels: change the label of group A v=venn2(subsets = (early_detect_stats[:,1].sum(), early_detect_stats[:,2].sum(), early_detect_stats[:,0].sum()), set_labels = ('Group A', 'Group B'), set_colors=('g', 'r'), alpha=0.4) v.get_label_by_id('A').set_text('Manual') v.get_label_by_id('B').set_text('Auto') v=venn2_circles((early_detect_stats[:,1].sum(), early_detect_stats[:,2].sum(), early_detect_stats[:,0].sum()), color='k', alpha=1.0, linestyle='solid', linewidth=2.0) # plt.savefig('Venn-Centriole_detections-early.svg', dpi=300) plt.show() v=venn2(subsets = (mid_detect_stats[:,1].sum(), mid_detect_stats[:,2].sum(), mid_detect_stats[:,0].sum()), set_labels = ('Group A', 'Group B'), set_colors=('g', 'r'), alpha=0.4) v.get_label_by_id('A').set_text('Manual') v.get_label_by_id('B').set_text('Auto') v=venn2_circles((mid_detect_stats[:,1].sum(), mid_detect_stats[:,2].sum(), mid_detect_stats[:,0].sum()), color='k', alpha=1.0, linestyle='solid', linewidth=2.0) # plt.savefig('Venn-Centriole_detections-mid.svg', dpi=300) plt.show() v=venn2(subsets = (late_detect_stats[:,1].sum(), late_detect_stats[:,2].sum(), late_detect_stats[:,0].sum()), set_labels = ('Group A', 'Group B'), set_colors=('g', 'r'), alpha=0.4) v.get_label_by_id('A').set_text('Manual') v.get_label_by_id('B').set_text('Auto') v=venn2_circles((late_detect_stats[:,1].sum(), late_detect_stats[:,2].sum(), late_detect_stats[:,0].sum()), color='k', alpha=1.0, linestyle='solid', linewidth=2.0) # plt.savefig('Venn-Centriole_detections-late.svg', dpi=300) plt.show() # infolder = '/media/felix/Elements/Raff Lab/Centriole Distancing/Sectioned Training/Early S-phase' ## infolder = '/media/felix/Elements/Raff Lab/Centriole Distancing/Sectioned Training/Mid S-phase' ## infolder = '/media/felix/Elements/Raff Lab/Centriole Distancing/Sectioned Training/Late S-phase' # intifs = fio.locate_centriole_files(infolder) # ##============================================================================== ## Load the corresponding img stack. ##============================================================================== # # imgfolder = '/media/felix/Elements/Raff Lab/Centriole Distancing/Datasets/OM NM Tracking with Felix/cep97-gfp distancing images (for felix)/cycle 12 - cep97-gfp in wt' # imgtifs = fio.locate_centriole_files(imgfolder) # # # retrieve the relevant tifs. # imgtifs_match = fio.retrieve_matching_imgs(intifs, imgtifs) # # ##============================================================================== ## Iterate over the images and try to detect accurately each of the centriole positions and compare with the annotated ver. ## highlight using circles matplotlib. ##============================================================================== # n_tifs = len(intifs) # print(n_tifs) # # # """ # Iterate over the image set. # """ # all_centriole_detections_manual = [] # all_centriole_detections_auto = [] # # # for i in range(n_tifs)[:]: # # tif = intifs[i] # im_tif = imgtifs_match[i] # # print '%d: processing %s' %(i+1, tif) # basename = tif.split('/')[-1].split('.tif')[0] # get the basename. # # # read the time-slice stack # raw_stack_img = fio.read_stack_time_img(im_tif, n_timepoints=n_timepoints, n_slices=n_slices) # annot_stack_img = fio.read_stack_time_img(tif, n_timepoints=n_timepoints, n_slices=n_slices) # use only the first time point. # # ##============================================================================== ## 1. From Manual ##============================================================================== # annot_slice_img_blue = training.extract_dot_annotation_zstack(annot_stack_img[0], c=[0,0,255]) # annot_slice_img_blue = np.max(annot_slice_img_blue, axis=0) # max projection image. # centriole_centroids_manual = image_fn.locate_centroids_simple(annot_slice_img_blue).astype(np.int) # ##============================================================================== ## 2. From Automatic Centriole Detection ##============================================================================== # # detect_dict = image_fn.detect_centrioles_full_slice(raw_stack_img, size, aniso_params, patch_size, CV_thresh, tslice=0, filter_border=True, separation=5, invert=False, minmass=10) # set tslice=-1 if just a single timepoint. # # # parse the detection. # centriole_centroids_auto = detect_dict['centriole_centroids'] # f = detect_dict['centriole_pos'] # max_slice_im = detect_dict['max_proj_full_img'] # max_slice_im_denoise = detect_dict['max_proj_full_img_denoise'] # slice_bg_mask = detect_dict['background_mask'] # slice_valid_detection_mask = detect_dict['valid_detection_mask'] # centriole_SNR = detect_dict['centriole_SNR'] # # all_centriole_detections_manual.append(centriole_centroids_manual) # all_centriole_detections_auto.append([centriole_centroids_auto, centriole_SNR]) # # ##============================================================================== ## 3. Visualize the Detections (Visual Comparison) ##============================================================================== # # """ # Visualise the detections. # """ # fig, ax = plt.subplots(figsize=(10,10)) # plt.imshow(max_slice_im, cmap='gray') # viz.draw_circles(centriole_centroids_manual, ax, radii=patch_size/2, col='g', lw=2) # viz.draw_circles(centriole_centroids_auto, ax, radii=patch_size*1, col='r', lw=2) # ax.grid('off') # ax.axis('off') # fig.savefig(os.path.join(infolder, basename + '_manual-auto-compare.svg'), dpi=300) # plt.show() # plt.close() # ###============================================================================== ### 4. Save the detections per file in order to compute mAP and TP/FP statistics. ###============================================================================== # # spio.savemat('detections-auto-vs-manual_Early-CV%.2f.mat' %(CV_thresh), {'files':imgtifs, # 'manual':all_centriole_detections_manual, # 'auto':all_centriole_detections_auto})