Skip to content

Commit 91de668

Browse files
author
donglaiw
committed
add instance matching
1 parent 06a6a22 commit 91de668

File tree

1 file changed

+292
-0
lines changed

1 file changed

+292
-0
lines changed

connectomics/utils/evaluate.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,16 @@
22
import scipy.sparse as sparse
33
import h5py
44
from scipy import ndimage
5+
from scipy.optimize import linear_sum_assignment
6+
from collections import namedtuple
7+
from skimage.segmentation import relabel_sequential
8+
9+
matching_criteria = dict()
510

611
__all__ = [
712
'get_binary_jaccard',
13+
'adapted_rand',
14+
'instance_matching'
815
]
916

1017

@@ -478,3 +485,288 @@ def convert_dtype(data):
478485
print("\tdistance to proposal : " + str(false_negative_stats))
479486

480487
return false_positive_stats['mean'], false_negative_stats['mean']
488+
489+
490+
# Code modified from https://github.com/stardist/stardist
491+
492+
# Copied from https://github.com/CSBDeep/CSBDeep/blob/master/csbdeep/utils/utils.py
493+
def _raise(e):
494+
if isinstance(e, BaseException):
495+
raise e
496+
else:
497+
raise ValueError(e)
498+
499+
def label_are_sequential(y):
500+
""" returns true if y has only sequential labels from 1... """
501+
labels = np.unique(y)
502+
return (set(labels)-{0}) == set(range(1,1+labels.max()))
503+
504+
505+
def is_array_of_integers(y):
506+
return isinstance(y,np.ndarray) and np.issubdtype(y.dtype, np.integer)
507+
508+
509+
def _check_label_array(y, name=None, check_sequential=False):
510+
err = ValueError("{label} must be an array of {integers}.".format(
511+
label = 'labels' if name is None else name,
512+
integers = ('sequential ' if check_sequential else '') + 'non-negative integers',
513+
))
514+
is_array_of_integers(y) or _raise(err)
515+
if len(y) == 0:
516+
return True
517+
if check_sequential:
518+
label_are_sequential(y) or _raise(err)
519+
else:
520+
y.min() >= 0 or _raise(err)
521+
return True
522+
523+
524+
def label_overlap(x, y, check=True):
525+
if check:
526+
_check_label_array(x,'x',True)
527+
_check_label_array(y,'y',True)
528+
x.shape == y.shape or _raise(ValueError("x and y must have the same shape"))
529+
return _label_overlap(x, y)
530+
531+
def _label_overlap(x, y):
532+
x = x.ravel()
533+
y = y.ravel()
534+
overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)
535+
for i in range(len(x)):
536+
overlap[x[i],y[i]] += 1
537+
return overlap
538+
539+
def _safe_divide(x,y, eps=1e-10):
540+
"""computes a safe divide which returns 0 if y is zero"""
541+
if np.isscalar(x) and np.isscalar(y):
542+
return x/y if np.abs(y)>eps else 0.0
543+
else:
544+
out = np.zeros(np.broadcast(x,y).shape, np.float32)
545+
np.divide(x,y, out=out, where=np.abs(y)>eps)
546+
return out
547+
548+
549+
def intersection_over_union(overlap):
550+
_check_label_array(overlap,'overlap')
551+
if np.sum(overlap) == 0:
552+
return overlap
553+
n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
554+
n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
555+
return _safe_divide(overlap, (n_pixels_pred + n_pixels_true - overlap))
556+
557+
matching_criteria['iou'] = intersection_over_union
558+
559+
560+
def intersection_over_true(overlap):
561+
_check_label_array(overlap,'overlap')
562+
if np.sum(overlap) == 0:
563+
return overlap
564+
n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
565+
return _safe_divide(overlap, n_pixels_true)
566+
567+
matching_criteria['iot'] = intersection_over_true
568+
569+
570+
def intersection_over_pred(overlap):
571+
_check_label_array(overlap,'overlap')
572+
if np.sum(overlap) == 0:
573+
return overlap
574+
n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
575+
return _safe_divide(overlap, n_pixels_pred)
576+
577+
matching_criteria['iop'] = intersection_over_pred
578+
579+
580+
def precision(tp,fp,fn):
581+
return tp/(tp+fp) if tp > 0 else 0
582+
def recall(tp,fp,fn):
583+
return tp/(tp+fn) if tp > 0 else 0
584+
def accuracy(tp,fp,fn):
585+
# also known as "average precision" (?)
586+
# -> https://www.kaggle.com/c/data-science-bowl-2018#evaluation
587+
return tp/(tp+fp+fn) if tp > 0 else 0
588+
def f1(tp,fp,fn):
589+
# also known as "dice coefficient"
590+
return (2*tp)/(2*tp+fp+fn) if tp > 0 else 0
591+
592+
593+
def instance_matching(y_true, y_pred, thresh=0.5, criterion='iou', report_matches=False):
594+
"""Calculate detection/instance segmentation metrics between ground truth and predicted label images.
595+
596+
Currently, the following metrics are implemented:
597+
598+
'fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'
599+
600+
Corresponding objects of y_true and y_pred are counted as true positives (tp), false positives (fp), and false negatives (fn)
601+
whether their intersection over union (IoU) >= thresh (for criterion='iou', which can be changed)
602+
603+
* mean_matched_score is the mean IoUs of matched true positives
604+
605+
* mean_true_score is the mean IoUs of matched true positives but normalized by the total number of GT objects
606+
607+
* panoptic_quality defined as in Eq. 1 of Kirillov et al. "Panoptic Segmentation", CVPR 2019
608+
609+
Parameters
610+
----------
611+
y_true: ndarray
612+
ground truth label image (integer valued)
613+
y_pred: ndarray
614+
predicted label image (integer valued)
615+
thresh: float
616+
threshold for matching criterion (default 0.5)
617+
criterion: string
618+
matching criterion (default IoU)
619+
report_matches: bool
620+
if True, additionally calculate matched_pairs and matched_scores (note, that this returns even gt-pred pairs whose scores are below 'thresh')
621+
622+
Returns
623+
-------
624+
Matching object with different metrics as attributes
625+
626+
Examples
627+
--------
628+
>>> y_true = np.zeros((100,100), np.uint16)
629+
>>> y_true[10:20,10:20] = 1
630+
>>> y_pred = np.roll(y_true,5,axis = 0)
631+
632+
>>> stats = instance_matching(y_true, y_pred)
633+
>>> print(stats)
634+
Matching(criterion='iou', thresh=0.5, fp=1, tp=0, fn=1, precision=0, recall=0, accuracy=0, f1=0, n_true=1, n_pred=1, mean_true_score=0.0, mean_matched_score=0.0, panoptic_quality=0.0)
635+
636+
"""
637+
_check_label_array(y_true,'y_true')
638+
_check_label_array(y_pred,'y_pred')
639+
y_true.shape == y_pred.shape or _raise(ValueError("y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes".format(y_true=y_true, y_pred=y_pred)))
640+
criterion in matching_criteria or _raise(ValueError("Matching criterion '%s' not supported." % criterion))
641+
if thresh is None: thresh = 0
642+
thresh = float(thresh) if np.isscalar(thresh) else map(float,thresh)
643+
644+
y_true, _, map_rev_true = relabel_sequential(y_true)
645+
y_pred, _, map_rev_pred = relabel_sequential(y_pred)
646+
647+
map_rev_true = np.array(map_rev_true)
648+
map_rev_pred = np.array(map_rev_pred)
649+
650+
overlap = label_overlap(y_true, y_pred, check=False)
651+
scores = matching_criteria[criterion](overlap)
652+
assert 0 <= np.min(scores) <= np.max(scores) <= 1
653+
654+
# ignoring background
655+
scores = scores[1:,1:]
656+
n_true, n_pred = scores.shape
657+
n_matched = min(n_true, n_pred)
658+
659+
def _single(thr):
660+
not_trivial = n_matched > 0 and np.any(scores >= thr)
661+
if not_trivial:
662+
# compute optimal matching with scores as tie-breaker
663+
costs = -(scores >= thr).astype(float) - scores / (2*n_matched)
664+
true_ind, pred_ind = linear_sum_assignment(costs)
665+
assert n_matched == len(true_ind) == len(pred_ind)
666+
match_ok = scores[true_ind,pred_ind] >= thr
667+
tp = np.count_nonzero(match_ok)
668+
else:
669+
tp = 0
670+
fp = n_pred - tp
671+
fn = n_true - tp
672+
# assert tp+fp == n_pred
673+
# assert tp+fn == n_true
674+
675+
# the score sum over all matched objects (tp)
676+
sum_matched_score = np.sum(scores[true_ind,pred_ind][match_ok]) if not_trivial else 0.0
677+
678+
# the score average over all matched objects (tp)
679+
mean_matched_score = _safe_divide(sum_matched_score, tp)
680+
# the score average over all gt/true objects
681+
mean_true_score = _safe_divide(sum_matched_score, n_true)
682+
panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)
683+
684+
stats_dict = dict (
685+
criterion = criterion,
686+
thresh = thr,
687+
fp = fp,
688+
tp = tp,
689+
fn = fn,
690+
precision = precision(tp,fp,fn),
691+
recall = recall(tp,fp,fn),
692+
accuracy = accuracy(tp,fp,fn),
693+
f1 = f1(tp,fp,fn),
694+
n_true = n_true,
695+
n_pred = n_pred,
696+
mean_true_score = mean_true_score,
697+
mean_matched_score = mean_matched_score,
698+
panoptic_quality = panoptic_quality,
699+
)
700+
if bool(report_matches):
701+
if not_trivial:
702+
stats_dict.update (
703+
# int() to be json serializable
704+
matched_pairs = tuple((int(map_rev_true[i]),int(map_rev_pred[j])) for i,j in zip(1+true_ind,1+pred_ind)),
705+
matched_scores = tuple(scores[true_ind,pred_ind]),
706+
matched_tps = tuple(map(int,np.flatnonzero(match_ok))),
707+
pred_ids = tuple(map_rev_pred),
708+
gt_ids = tuple(map_rev_true),
709+
)
710+
else:
711+
stats_dict.update (
712+
matched_pairs = (),
713+
matched_scores = (),
714+
matched_tps = (),
715+
pred_ids = (),
716+
gt_ids = (),
717+
)
718+
return stats_dict
719+
720+
return _single(thresh) if np.isscalar(thresh) else tuple(map(_single,thresh))
721+
722+
723+
def wrapper_matching_dataset_lazy(stats_all, thresh, criterion='iou', by_image=False):
724+
725+
expected_keys = set(('fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'))
726+
727+
# accumulate results over all images for each threshold separately
728+
n_images, n_threshs = len(stats_all), len(thresh)
729+
single_thresh = True if n_threshs == 1 else False
730+
accumulate = [{} for _ in range(n_threshs)]
731+
for stats in stats_all:
732+
for i, s in enumerate(stats):
733+
acc = accumulate[i]
734+
for item in s.items():
735+
k, v = item
736+
if k == 'mean_true_score' and not bool(by_image):
737+
# convert mean_true_score to "sum_matched_score"
738+
acc[k] = acc.setdefault(k,0) + v * s['n_true']
739+
else:
740+
try:
741+
acc[k] = acc.setdefault(k,0) + v
742+
except TypeError:
743+
pass
744+
745+
# normalize/compute 'precision', 'recall', 'accuracy', 'f1'
746+
for thr,acc in zip(thresh,accumulate):
747+
acc['criterion'] = criterion
748+
acc['thresh'] = thr
749+
acc['by_image'] = bool(by_image)
750+
if bool(by_image):
751+
for k in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):
752+
acc[k] /= n_images
753+
else:
754+
tp, fp, fn, n_true = acc['tp'], acc['fp'], acc['fn'], acc['n_true']
755+
sum_matched_score = acc['mean_true_score']
756+
757+
mean_matched_score = _safe_divide(sum_matched_score, tp)
758+
mean_true_score = _safe_divide(sum_matched_score, n_true)
759+
panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)
760+
761+
acc.update(
762+
precision = precision(tp,fp,fn),
763+
recall = recall(tp,fp,fn),
764+
accuracy = accuracy(tp,fp,fn),
765+
f1 = f1(tp,fp,fn),
766+
mean_true_score = mean_true_score,
767+
mean_matched_score = mean_matched_score,
768+
panoptic_quality = panoptic_quality,
769+
)
770+
771+
accumulate = tuple(namedtuple('DatasetMatching',acc.keys())(*acc.values()) for acc in accumulate)
772+
return accumulate[0] if single_thresh else accumulate

0 commit comments

Comments
 (0)