Source code for wonambi.detect.agreement

"""Module for agreement and consensus analysis between raters"""

from numpy import (arange, argmax, asarray, concatenate, diff, invert, 
                   logical_and, maximum, mean, minimum, newaxis, ones, repeat, 
                   sum, vstack, where, zeros)

from .. import Graphoelement

[docs]class MatchedEvents: """Class for storing matched events and producing statistics. Parameters ---------- tp : ndarray true positives as boolean array of shape len(detection) x len(standard) fp : ndarray indices of false positives in detection fn : ndarray indices of false negatives in standard detection : list of dict list of detected events tested against the standard, with 'start', 'end' and 'chan' standard : list of dict list of ground-truth events, with 'start', 'end' and 'chan' threshold : float minimum intersection-union score for events to be considered overlapping """ def __init__(self, tp, fp, fn, detection, standard, threshold): self.tp = tp self.fp = fp self.fn = fn self.detection = detection self.standard = standard self.threshold = threshold self.n_tp = sum(tp) self.n_fp = len(fp) self.n_fn = len(fn) @property def recall(self): tp = self.n_tp fn = self.n_fn if tp + fn == 0: return 0 return tp / (tp + fn) @property def precision(self): tp = self.n_tp fp = self.n_fp if tp + fp == 0: return 0 return tp / (tp + fp) @property def f1score(self): recall = self.recall precision = self.precision if precision + recall == 0: return 0 return 2 * precision * recall / (precision + recall)
[docs] def to_annot(self, annot, category, name, s_freq=512): """Write matched events to Wonambi XML file for visualization. Parameters ---------- annot : instance of Annotations Annotations file category : str 'tp_cons', 'tp_det', 'tp_std', 'fp' or 'fn' name : str name for the event type s_freq : int sampling frequency, in Hz, only required for 'tp_cons' category """ if 'tp_cons' == category: cons = consensus((self.detection, self.standard), 1, s_freq) events = cons.events elif 'tp_det' == category: events = asarray(self.detection)[self.tp.any(axis=1)] elif 'tp_std' == category: events = asarray(self.standard)[self.tp.any(axis=0)] elif 'fp' == category: events = asarray(self.detection)[self.fp] elif 'fn' == category: events = asarray(self.standard)[self.fn] else: raise ValueError("Invalid category.") annot.add_events(events, name=name, chan=None)
[docs] def all_to_annot(self, annot, names=['TPd', 'TPs', 'FP', 'FN']): """Convenience function to write all events to XML by category, showing overlapping TP detection and TP standard.""" self.to_annot(annot, 'tp_det', names[0]) self.to_annot(annot, 'tp_std', names[1]) self.to_annot(annot, 'fp', names[2]) self.to_annot(annot, 'fn', names[3])
[docs]def consensus(events, threshold, s_freq, min_duration=None, weights=None): """Take two or more event lists and output a merged list based on consensus. Parameters ---------- events: tuple of lists of dict two or more lists of events from different raters, with 'start', 'end' and 'chan' threshold : float value between 0 and 1 to threshold consensus. Consensus is computed on a per-sample basis. For a given rater, if an event is present at a sample, that rater-sample is assigned the value 1; otherwise it is assigned 0. The arithmetic mean is taken per sample across all raters, and if this mean exceeds 'threshold', the sample is counted as belonging to a merged event. s_freq : int sampling frequency, in Hz min_duration : float, optional minimum duration for merged events, in s. weights : list of float a vector of relative weights of each event type Returns ------- instance of wonambi.Graphoelement events merged by consensus """ chan = [one_rater[0]['chan'] for one_rater in events if one_rater][0] beg = min([one_rater[0]['start'] for one_rater in events if one_rater]) end = max([one_rater[-1]['end'] for one_rater in events if one_rater]) n_samples = int((end - beg) * s_freq) times = arange(beg, end + 1/s_freq, 1/s_freq) if weights is None: weights = ones(len(events)) positives = zeros((len(events), n_samples)) for i, (one_rater, wt) in enumerate(zip(events, weights)): for ev in one_rater: n_start = int((ev['start'] - beg) * s_freq) n_end = int((ev['end'] - beg) * s_freq) positives[i, n_start:n_end].fill(wt) consensus = mean(positives, axis=0) consensus[consensus >= threshold] = 1 consensus[consensus < 1] = 0 consensus = concatenate(([0], consensus, [0])) on_off = diff(consensus) onsets = where(on_off == 1) offsets = where(on_off == -1) start_times = times[onsets] end_times = times[offsets] merged = vstack((start_times, end_times)) if min_duration: merged = merged[:, merged[1, :] - merged[0, :] >= min_duration] out = Graphoelement() out.events = [{'start': merged[0, i], 'end': merged[1, i], 'chan': chan} for i in range(merged.shape[1])] return out
[docs]def consensus_exact(events, threshold, s_freq, window=None, min_duration=None, weights=None): """Take two or more event lists and output a merged list based on consensus, where agreement is exactly equal to a threshold. This is useful when combining >2 event types, and creating a consensus event type based on some combination of these events. Parameters ---------- events: tuple of lists of dict two or more lists of events from different raters, with 'start', 'end' and 'chan' threshold : float value between 0 and 1 to threshold consensus. Consensus is computed on a per-sample basis. For a given rater, if an event is present at a sample, that rater-sample is assigned the value 1; otherwise it is assigned 0. The arithmetic sum is taken per sample across all raters, and if this exactly equals 'threshold', the sample is counted as belonging to a merged event. s_freq : int sampling frequency, in Hz min_duration : float, optional minimum duration for merged events, in s. weights : a dict containing event names (str) and their corresponding weighting (int) e.g. {'low' : 1,'med' : 2,'high' : 3} Returns ------- instance of wonambi.Graphoelement events merged by consensus and named by confidence rating Notes ----- This function is a modification of agreement.consensus contributed by Nathan Cross. """ chan = [one_rater[0]['chan'] for one_rater in events if one_rater][0] if window is None: beg = min([one_rater[0]['start'] for one_rater in events if one_rater]) end = max([one_rater[-1]['end'] for one_rater in events if one_rater]) else: beg = window[0] end = window[1] n_samples = int((end - beg) * s_freq) times = arange(beg, end + 1/s_freq, 1/s_freq) if weights is None: weights = {'low':1,'med':2,'high':3} positives = zeros((len(events), n_samples)) for i, one_rater in enumerate(events): for ev in one_rater: n_start = int((ev['start'] - beg) * s_freq) n_end = int((ev['end'] - beg) * s_freq) if ev['name'] == 'low': positives[i, n_start:n_end].fill(weights['low']) elif ev['name'] == 'med': positives[i, n_start:n_end].fill(weights['med']) elif ev['name'] == 'high': positives[i, n_start:n_end].fill(weights['high']) consensus = sum(positives, axis=0) consensus[consensus != threshold] = 0 consensus[consensus == threshold] = 1 consensus = concatenate(([0], consensus, [0])) on_off = diff(consensus) onsets = where(on_off == 1) offsets = where(on_off == -1) start_times = times[onsets] end_times = times[offsets] merged = vstack((start_times, end_times)) if min_duration: merged = merged[:, merged[1, :] - merged[0, :] >= min_duration] out = Graphoelement() out.events = [{'start': merged[0, i], 'end': merged[1, i], 'chan': chan} for i in range(merged.shape[1])] return out
[docs]def match_events(detection, standard, threshold): """Find best matches between detected and standard events, by a thresholded intersection-union rule. Parameters ---------- detection : list of dict list of detected events to be tested against the standard, with 'start', 'end' and 'chan' standard : list of dict list of ground-truth events, with 'start', 'end' and 'chan' threshold : float minimum intersection-union score to match a pair, between 0 and 1 Returns ------- instance of MatchedEvents indices of true positives, false positives and false negatives, with statistics (recall, precision, F1) """ # Vectorize start and end times and set up for broadcasting det_beg = asarray([x['start'] for x in detection])[:, newaxis] det_end = asarray([x['end'] for x in detection])[:, newaxis] std_beg = asarray([x['start'] for x in standard])[newaxis, :] std_end = asarray([x['end'] for x in standard])[newaxis, :] # Get durations and broadcast them det_dur = repeat(det_end - det_beg, len(standard), axis=1) std_dur = repeat(std_end - std_beg, len(detection), axis=0) # Subtract every end by every start and find overlaps det_minus_std = det_end - std_beg # array of shape (len(det), len(std)) std_minus_det = std_end - det_beg overlapping = logical_and(det_minus_std > 0, std_minus_det > 0) # Find intersection and union shorter_diff = minimum(det_minus_std, std_minus_det) longer_diff = maximum(det_minus_std, std_minus_det) shorter_dur = minimum(det_dur, std_dur) longer_dur = maximum(det_dur, std_dur) interx = minimum(shorter_diff, shorter_dur) union = maximum(longer_diff, longer_dur) # Compute intersection-union score and set non-overlapping pairs to 0 iu = interx / union iu[invert(overlapping)] = 0 # Threshold IU score to yield True Positive candidates iu[iu <= threshold] = 0 # If no events, tp and fp are empty, fn is all events if iu.size == 0: tp = fp = asarray([]) fn = arange(len(standard)) else: # Find partial matches, round 1 det_match1 = argmax(iu, axis=1) std_match1 = argmax(iu, axis=0) # Find full matches, round 1, then remove them from IU tp = zeros(iu.shape, dtype=bool) for i, j in enumerate(std_match1): if det_match1[j] == i: tp[j, i] = True iu[j, :].fill(0) iu[:, i].fill(0) # Round 2 det_match2 = argmax(iu, axis=1) std_match2 = argmax(iu, axis=0) for i, j in enumerate(std_match2): if det_match2[j] == i: tp[j, i] = True # Find false positives and false negatives fp = where(logical_and(det_match1 == 0, det_match2 == 0))[0] fn = where(logical_and(std_match1 == 0, std_match2 == 0))[0] # Store in MatchedEvents class, which computes statistics match = MatchedEvents(tp, fp, fn, detection, standard, threshold) return match