|
2 | 2 | import scipy.sparse as sparse |
3 | 3 | import h5py |
4 | 4 | from scipy import ndimage |
| 5 | +from scipy.optimize import linear_sum_assignment |
| 6 | +from collections import namedtuple |
| 7 | +from skimage.segmentation import relabel_sequential |
| 8 | + |
| 9 | +matching_criteria = dict() |
5 | 10 |
|
6 | 11 | __all__ = [ |
7 | 12 | 'get_binary_jaccard', |
| 13 | + 'adapted_rand', |
| 14 | + 'instance_matching' |
8 | 15 | ] |
9 | 16 |
|
10 | 17 |
|
@@ -478,3 +485,288 @@ def convert_dtype(data): |
478 | 485 | print("\tdistance to proposal : " + str(false_negative_stats)) |
479 | 486 |
|
480 | 487 | return false_positive_stats['mean'], false_negative_stats['mean'] |
| 488 | + |
| 489 | + |
| 490 | +# Code modified from https://github.com/stardist/stardist |
| 491 | + |
| 492 | +# Copied from https://github.com/CSBDeep/CSBDeep/blob/master/csbdeep/utils/utils.py |
| 493 | +def _raise(e): |
| 494 | + if isinstance(e, BaseException): |
| 495 | + raise e |
| 496 | + else: |
| 497 | + raise ValueError(e) |
| 498 | + |
| 499 | +def label_are_sequential(y): |
| 500 | + """ returns true if y has only sequential labels from 1... """ |
| 501 | + labels = np.unique(y) |
| 502 | + return (set(labels)-{0}) == set(range(1,1+labels.max())) |
| 503 | + |
| 504 | + |
| 505 | +def is_array_of_integers(y): |
| 506 | + return isinstance(y,np.ndarray) and np.issubdtype(y.dtype, np.integer) |
| 507 | + |
| 508 | + |
| 509 | +def _check_label_array(y, name=None, check_sequential=False): |
| 510 | + err = ValueError("{label} must be an array of {integers}.".format( |
| 511 | + label = 'labels' if name is None else name, |
| 512 | + integers = ('sequential ' if check_sequential else '') + 'non-negative integers', |
| 513 | + )) |
| 514 | + is_array_of_integers(y) or _raise(err) |
| 515 | + if len(y) == 0: |
| 516 | + return True |
| 517 | + if check_sequential: |
| 518 | + label_are_sequential(y) or _raise(err) |
| 519 | + else: |
| 520 | + y.min() >= 0 or _raise(err) |
| 521 | + return True |
| 522 | + |
| 523 | + |
| 524 | +def label_overlap(x, y, check=True): |
| 525 | + if check: |
| 526 | + _check_label_array(x,'x',True) |
| 527 | + _check_label_array(y,'y',True) |
| 528 | + x.shape == y.shape or _raise(ValueError("x and y must have the same shape")) |
| 529 | + return _label_overlap(x, y) |
| 530 | + |
| 531 | +def _label_overlap(x, y): |
| 532 | + x = x.ravel() |
| 533 | + y = y.ravel() |
| 534 | + overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint) |
| 535 | + for i in range(len(x)): |
| 536 | + overlap[x[i],y[i]] += 1 |
| 537 | + return overlap |
| 538 | + |
| 539 | +def _safe_divide(x,y, eps=1e-10): |
| 540 | + """computes a safe divide which returns 0 if y is zero""" |
| 541 | + if np.isscalar(x) and np.isscalar(y): |
| 542 | + return x/y if np.abs(y)>eps else 0.0 |
| 543 | + else: |
| 544 | + out = np.zeros(np.broadcast(x,y).shape, np.float32) |
| 545 | + np.divide(x,y, out=out, where=np.abs(y)>eps) |
| 546 | + return out |
| 547 | + |
| 548 | + |
| 549 | +def intersection_over_union(overlap): |
| 550 | + _check_label_array(overlap,'overlap') |
| 551 | + if np.sum(overlap) == 0: |
| 552 | + return overlap |
| 553 | + n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) |
| 554 | + n_pixels_true = np.sum(overlap, axis=1, keepdims=True) |
| 555 | + return _safe_divide(overlap, (n_pixels_pred + n_pixels_true - overlap)) |
| 556 | + |
| 557 | +matching_criteria['iou'] = intersection_over_union |
| 558 | + |
| 559 | + |
| 560 | +def intersection_over_true(overlap): |
| 561 | + _check_label_array(overlap,'overlap') |
| 562 | + if np.sum(overlap) == 0: |
| 563 | + return overlap |
| 564 | + n_pixels_true = np.sum(overlap, axis=1, keepdims=True) |
| 565 | + return _safe_divide(overlap, n_pixels_true) |
| 566 | + |
| 567 | +matching_criteria['iot'] = intersection_over_true |
| 568 | + |
| 569 | + |
| 570 | +def intersection_over_pred(overlap): |
| 571 | + _check_label_array(overlap,'overlap') |
| 572 | + if np.sum(overlap) == 0: |
| 573 | + return overlap |
| 574 | + n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) |
| 575 | + return _safe_divide(overlap, n_pixels_pred) |
| 576 | + |
| 577 | +matching_criteria['iop'] = intersection_over_pred |
| 578 | + |
| 579 | + |
| 580 | +def precision(tp,fp,fn): |
| 581 | + return tp/(tp+fp) if tp > 0 else 0 |
| 582 | +def recall(tp,fp,fn): |
| 583 | + return tp/(tp+fn) if tp > 0 else 0 |
| 584 | +def accuracy(tp,fp,fn): |
| 585 | + # also known as "average precision" (?) |
| 586 | + # -> https://www.kaggle.com/c/data-science-bowl-2018#evaluation |
| 587 | + return tp/(tp+fp+fn) if tp > 0 else 0 |
| 588 | +def f1(tp,fp,fn): |
| 589 | + # also known as "dice coefficient" |
| 590 | + return (2*tp)/(2*tp+fp+fn) if tp > 0 else 0 |
| 591 | + |
| 592 | + |
| 593 | +def instance_matching(y_true, y_pred, thresh=0.5, criterion='iou', report_matches=False): |
| 594 | + """Calculate detection/instance segmentation metrics between ground truth and predicted label images. |
| 595 | +
|
| 596 | + Currently, the following metrics are implemented: |
| 597 | +
|
| 598 | + 'fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality' |
| 599 | +
|
| 600 | + Corresponding objects of y_true and y_pred are counted as true positives (tp), false positives (fp), and false negatives (fn) |
| 601 | + whether their intersection over union (IoU) >= thresh (for criterion='iou', which can be changed) |
| 602 | +
|
| 603 | + * mean_matched_score is the mean IoUs of matched true positives |
| 604 | +
|
| 605 | + * mean_true_score is the mean IoUs of matched true positives but normalized by the total number of GT objects |
| 606 | +
|
| 607 | + * panoptic_quality defined as in Eq. 1 of Kirillov et al. "Panoptic Segmentation", CVPR 2019 |
| 608 | +
|
| 609 | + Parameters |
| 610 | + ---------- |
| 611 | + y_true: ndarray |
| 612 | + ground truth label image (integer valued) |
| 613 | + y_pred: ndarray |
| 614 | + predicted label image (integer valued) |
| 615 | + thresh: float |
| 616 | + threshold for matching criterion (default 0.5) |
| 617 | + criterion: string |
| 618 | + matching criterion (default IoU) |
| 619 | + report_matches: bool |
| 620 | + if True, additionally calculate matched_pairs and matched_scores (note, that this returns even gt-pred pairs whose scores are below 'thresh') |
| 621 | +
|
| 622 | + Returns |
| 623 | + ------- |
| 624 | + Matching object with different metrics as attributes |
| 625 | +
|
| 626 | + Examples |
| 627 | + -------- |
| 628 | + >>> y_true = np.zeros((100,100), np.uint16) |
| 629 | + >>> y_true[10:20,10:20] = 1 |
| 630 | + >>> y_pred = np.roll(y_true,5,axis = 0) |
| 631 | +
|
| 632 | + >>> stats = instance_matching(y_true, y_pred) |
| 633 | + >>> print(stats) |
| 634 | + Matching(criterion='iou', thresh=0.5, fp=1, tp=0, fn=1, precision=0, recall=0, accuracy=0, f1=0, n_true=1, n_pred=1, mean_true_score=0.0, mean_matched_score=0.0, panoptic_quality=0.0) |
| 635 | +
|
| 636 | + """ |
| 637 | + _check_label_array(y_true,'y_true') |
| 638 | + _check_label_array(y_pred,'y_pred') |
| 639 | + y_true.shape == y_pred.shape or _raise(ValueError("y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes".format(y_true=y_true, y_pred=y_pred))) |
| 640 | + criterion in matching_criteria or _raise(ValueError("Matching criterion '%s' not supported." % criterion)) |
| 641 | + if thresh is None: thresh = 0 |
| 642 | + thresh = float(thresh) if np.isscalar(thresh) else map(float,thresh) |
| 643 | + |
| 644 | + y_true, _, map_rev_true = relabel_sequential(y_true) |
| 645 | + y_pred, _, map_rev_pred = relabel_sequential(y_pred) |
| 646 | + |
| 647 | + map_rev_true = np.array(map_rev_true) |
| 648 | + map_rev_pred = np.array(map_rev_pred) |
| 649 | + |
| 650 | + overlap = label_overlap(y_true, y_pred, check=False) |
| 651 | + scores = matching_criteria[criterion](overlap) |
| 652 | + assert 0 <= np.min(scores) <= np.max(scores) <= 1 |
| 653 | + |
| 654 | + # ignoring background |
| 655 | + scores = scores[1:,1:] |
| 656 | + n_true, n_pred = scores.shape |
| 657 | + n_matched = min(n_true, n_pred) |
| 658 | + |
| 659 | + def _single(thr): |
| 660 | + not_trivial = n_matched > 0 and np.any(scores >= thr) |
| 661 | + if not_trivial: |
| 662 | + # compute optimal matching with scores as tie-breaker |
| 663 | + costs = -(scores >= thr).astype(float) - scores / (2*n_matched) |
| 664 | + true_ind, pred_ind = linear_sum_assignment(costs) |
| 665 | + assert n_matched == len(true_ind) == len(pred_ind) |
| 666 | + match_ok = scores[true_ind,pred_ind] >= thr |
| 667 | + tp = np.count_nonzero(match_ok) |
| 668 | + else: |
| 669 | + tp = 0 |
| 670 | + fp = n_pred - tp |
| 671 | + fn = n_true - tp |
| 672 | + # assert tp+fp == n_pred |
| 673 | + # assert tp+fn == n_true |
| 674 | + |
| 675 | + # the score sum over all matched objects (tp) |
| 676 | + sum_matched_score = np.sum(scores[true_ind,pred_ind][match_ok]) if not_trivial else 0.0 |
| 677 | + |
| 678 | + # the score average over all matched objects (tp) |
| 679 | + mean_matched_score = _safe_divide(sum_matched_score, tp) |
| 680 | + # the score average over all gt/true objects |
| 681 | + mean_true_score = _safe_divide(sum_matched_score, n_true) |
| 682 | + panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2) |
| 683 | + |
| 684 | + stats_dict = dict ( |
| 685 | + criterion = criterion, |
| 686 | + thresh = thr, |
| 687 | + fp = fp, |
| 688 | + tp = tp, |
| 689 | + fn = fn, |
| 690 | + precision = precision(tp,fp,fn), |
| 691 | + recall = recall(tp,fp,fn), |
| 692 | + accuracy = accuracy(tp,fp,fn), |
| 693 | + f1 = f1(tp,fp,fn), |
| 694 | + n_true = n_true, |
| 695 | + n_pred = n_pred, |
| 696 | + mean_true_score = mean_true_score, |
| 697 | + mean_matched_score = mean_matched_score, |
| 698 | + panoptic_quality = panoptic_quality, |
| 699 | + ) |
| 700 | + if bool(report_matches): |
| 701 | + if not_trivial: |
| 702 | + stats_dict.update ( |
| 703 | + # int() to be json serializable |
| 704 | + matched_pairs = tuple((int(map_rev_true[i]),int(map_rev_pred[j])) for i,j in zip(1+true_ind,1+pred_ind)), |
| 705 | + matched_scores = tuple(scores[true_ind,pred_ind]), |
| 706 | + matched_tps = tuple(map(int,np.flatnonzero(match_ok))), |
| 707 | + pred_ids = tuple(map_rev_pred), |
| 708 | + gt_ids = tuple(map_rev_true), |
| 709 | + ) |
| 710 | + else: |
| 711 | + stats_dict.update ( |
| 712 | + matched_pairs = (), |
| 713 | + matched_scores = (), |
| 714 | + matched_tps = (), |
| 715 | + pred_ids = (), |
| 716 | + gt_ids = (), |
| 717 | + ) |
| 718 | + return stats_dict |
| 719 | + |
| 720 | + return _single(thresh) if np.isscalar(thresh) else tuple(map(_single,thresh)) |
| 721 | + |
| 722 | + |
| 723 | +def wrapper_matching_dataset_lazy(stats_all, thresh, criterion='iou', by_image=False): |
| 724 | + |
| 725 | + expected_keys = set(('fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality')) |
| 726 | + |
| 727 | + # accumulate results over all images for each threshold separately |
| 728 | + n_images, n_threshs = len(stats_all), len(thresh) |
| 729 | + single_thresh = True if n_threshs == 1 else False |
| 730 | + accumulate = [{} for _ in range(n_threshs)] |
| 731 | + for stats in stats_all: |
| 732 | + for i, s in enumerate(stats): |
| 733 | + acc = accumulate[i] |
| 734 | + for item in s.items(): |
| 735 | + k, v = item |
| 736 | + if k == 'mean_true_score' and not bool(by_image): |
| 737 | + # convert mean_true_score to "sum_matched_score" |
| 738 | + acc[k] = acc.setdefault(k,0) + v * s['n_true'] |
| 739 | + else: |
| 740 | + try: |
| 741 | + acc[k] = acc.setdefault(k,0) + v |
| 742 | + except TypeError: |
| 743 | + pass |
| 744 | + |
| 745 | + # normalize/compute 'precision', 'recall', 'accuracy', 'f1' |
| 746 | + for thr,acc in zip(thresh,accumulate): |
| 747 | + acc['criterion'] = criterion |
| 748 | + acc['thresh'] = thr |
| 749 | + acc['by_image'] = bool(by_image) |
| 750 | + if bool(by_image): |
| 751 | + for k in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'): |
| 752 | + acc[k] /= n_images |
| 753 | + else: |
| 754 | + tp, fp, fn, n_true = acc['tp'], acc['fp'], acc['fn'], acc['n_true'] |
| 755 | + sum_matched_score = acc['mean_true_score'] |
| 756 | + |
| 757 | + mean_matched_score = _safe_divide(sum_matched_score, tp) |
| 758 | + mean_true_score = _safe_divide(sum_matched_score, n_true) |
| 759 | + panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2) |
| 760 | + |
| 761 | + acc.update( |
| 762 | + precision = precision(tp,fp,fn), |
| 763 | + recall = recall(tp,fp,fn), |
| 764 | + accuracy = accuracy(tp,fp,fn), |
| 765 | + f1 = f1(tp,fp,fn), |
| 766 | + mean_true_score = mean_true_score, |
| 767 | + mean_matched_score = mean_matched_score, |
| 768 | + panoptic_quality = panoptic_quality, |
| 769 | + ) |
| 770 | + |
| 771 | + accumulate = tuple(namedtuple('DatasetMatching',acc.keys())(*acc.values()) for acc in accumulate) |
| 772 | + return accumulate[0] if single_thresh else accumulate |
0 commit comments