"""Module to detect slow waves.
"""
from logging import getLogger
from numpy import (argmin, concatenate, diff, hstack, logical_and, newaxis,
ones, percentile, sign, sum, vstack, where, zeros)
try:
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import QProgressDialog
except ImportError:
pass
from .spindle import (detect_events, transform_signal, within_duration,
remove_straddlers)
from ..graphoelement import SlowWaves
lg = getLogger(__name__)
MAXIMUM_DURATION = 5
[docs]class DetectSlowWave:
"""Design slow wave detection on a single channel.
Parameters
----------
method : str
one of the predefined methods
duration : tuple of float
min and max duration of SWs
Attributes
----------
invert : bool
pass
trough_duration : float
pass
"""
def __init__(self, method='Massimini2004', duration=None):
self.method = method
self.trough_duration = None
self.invert = False
if method == 'Massimini2004':
self.det_filt = {'order': 2,
'freq': (0.1, 4.)}
self.trough_duration = (0.3, 1.)
self.max_trough_amp = -80
self.min_ptp = 140
self.min_dur = 0
self.max_dur = None
elif method == 'AASM/Massimini2004':
self.det_filt = {'order': 2,
'freq': (0.1, 4.)}
self.trough_duration = (0.25, 1.)
self.max_trough_amp = -40
self.min_ptp = 75
self.min_dur = 0
self.max_dur = None
elif method == 'Ngo2015':
self.lowpass = {'order': 2,
'freq': 3.5}
self.min_dur = 0.833
self.max_dur = 2.0
self.peak_thresh = 1.25
self.ptp_thresh = 1.25
self.det_filt = {'freq': (0.5, 1.20)} # for repr
elif method == 'Staresina2015':
self.lowpass = {'order': 3,
'freq': 1.25}
self.min_dur = 0.8
self.max_dur = 2.0
self.ptp_thresh = 75
self.det_filt = {'freq': (0.5, 1.25)} # for repr
else:
raise ValueError('Unknown method')
if duration is None:
self.duration = (self.min_dur, self.max_dur)
else:
self.duration = duration
def __repr__(self):
return ('detsw_{0}_{1:04.2f}-{2:04.2f}Hz'
''.format(self.method, *self.det_filt['freq']))
def __call__(self, data, parent=None):
"""Detect slow waves on the data.
Parameters
----------
data : instance of Data
data used for detection
parent : QWidget
for use with GUI, as parent widget for the progress bar
Returns
-------
instance of graphoelement.SlowWaves
description of the detected SWs
"""
if parent is not None:
progress = QProgressDialog('Finding slow waves', 'Abort',
0, data.number_of('chan')[0], parent)
progress.setWindowModality(Qt.ApplicationModal)
slowwave = SlowWaves()
slowwave.chan_name = data.axis['chan'][0]
all_slowwaves = []
for i, chan in enumerate(data.axis['chan'][0]):
lg.info('Detecting slow waves on chan %s', chan)
time = hstack(data.axis['time'])
dat_orig = hstack(data(chan=chan))
dat_orig = dat_orig - dat_orig.mean() # demean
if 'Massimini2004' in self.method:
sw_in_chan = detect_Massimini2004(dat_orig, data.s_freq, time,
self)
elif 'Ngo2015' == self.method:
sw_in_chan = detect_Ngo2015(dat_orig, data.s_freq, time, self)
elif 'Staresina2015' == self.method:
sw_in_chan = detect_Staresina2015(dat_orig, data.s_freq, time,
self)
else:
raise ValueError('Unknown method')
for sw in sw_in_chan:
sw.update({'chan': chan})
all_slowwaves.extend(sw_in_chan)
if parent is not None:
progress.setValue(i)
if progress.wasCanceled():
return
# end of loop over chan
slowwave.events = sorted(all_slowwaves, key=lambda x: x['start'])
if parent is not None:
progress.setValue(i + 1)
return slowwave
[docs]def detect_Massimini2004(dat_orig, s_freq, time, opts):
"""Slow wave detection based on Massimini et al., 2004.
Parameters
----------
dat_orig : ndarray (dtype='float')
vector with the data for one channel
s_freq : float
sampling frequency
time : ndarray (dtype='float')
vector with the time points for each sample
opts : instance of 'DetectSlowWave'
'det_filt' : dict
parameters for 'butter',
'duration' : tuple of float
min and max duration of SW
'min_ptp' : float
min peak-to-peak amplitude
'trough_duration' : tuple of float
min and max duration of first half-wave (trough)
'max_trough_amp' : float
The trough amplitude has a negative value, so this parameter sets
the minimum depth of the trough
Returns
-------
list of dict
list of detected SWs
float
SW density, per 30-s epoch
References
----------
Massimini, M. et al. J Neurosci 24(31) 6862-70 (2004).
"""
if opts.invert:
dat_orig = -dat_orig
dat_det = transform_signal(dat_orig, s_freq, 'double_butter',
opts.det_filt)
above_zero = detect_events(dat_det, 'above_thresh', value=0.)
sw_in_chan = []
if above_zero is not None:
troughs = within_duration(above_zero, time, opts.trough_duration)
#lg.info('troughs within duration: ' + str(troughs.shape))
if troughs is not None:
troughs = select_peaks(dat_det, troughs, opts.max_trough_amp)
#lg.info('troughs deep enough: ' + str(troughs.shape))
if troughs is not None:
events = _add_halfwave(dat_det, troughs, s_freq, opts)
#lg.info('SWs high enough: ' + str(events.shape))
if len(events):
events = within_duration(events, time, opts.duration)
events = remove_straddlers(events, time, s_freq)
#lg.info('SWs within duration: ' + str(events.shape))
sw_in_chan = make_slow_waves(events, dat_det, time, s_freq)
if len(sw_in_chan) == 0:
lg.info('No slow wave found')
return sw_in_chan
[docs]def detect_Ngo2015(dat_orig, s_freq, time, opts):
"""Slow wave detection based on Ngo et al., 2015.
Parameters
----------
dat_orig : ndarray (dtype='float')
vector with the data for one channel
s_freq : float
sampling frequency
time : ndarray (dtype='float')
vector with the time points for each sample
opts : instance of 'DetectSlowWave'
'lowpass' : dict
parameters for 'low_butter',
'duration' : tuple of float
min and max duration of SW
'peak_thresh' : float
mean trough amplitude is multiplied by this scalar to yield
threshold; SWs above this threshold are kept
'ptp_thresh' : float
percentile of mean ptp values, above which SW is kept
Returns
-------
list of dict
list of detected SWs
References
----------
Ngo, H-V. et al. J Neurosci 35(17) 6630-8 (2015).
"""
if opts.invert:
dat_orig = -dat_orig
sw_in_chan = []
# filter to SO band:
dat_det = transform_signal(dat_orig, s_freq, 'low_butter', opts.lowpass)
# detect positive-to-negative zero crossings:
idx_zx = find_zero_crossings(dat_det, xtype='pos_to_neg')
# find zero-crossing intervals within duration:
events = find_intervals(idx_zx, s_freq, opts.duration)
if events is not None:
# find start, trough, -to+ zero crossing, peak and end:
events = find_peaks_in_slowwwave(dat_det, events)
if events is not None:
# Negative peak threshold
idx_neg_peak = events[:, 1]
# Trough threshhold is set as peak_thresh (float) times the mean trough amplitude over all events:
neg_peak_thresh = dat_det[idx_neg_peak].mean() * opts.peak_thresh
events = events[dat_det[idx_neg_peak] < neg_peak_thresh, :]
if events is not None:
# Peak-to-peak amplitude threshold
ptp = dat_det[events[:, 3]] - dat_det[events[:, 1]]
# Peak-to-peak threshold is set as a percentile of the mean ptp amplitude:
ptp_thresh = ptp.mean() * opts.ptp_thresh
events = events[ptp > ptp_thresh, :]
if events is not None:
events = remove_straddlers(events, time, s_freq)
sw_in_chan = make_slow_waves(events, dat_det, time, s_freq)
if sw_in_chan:
lg.info('No slow waves found')
return sw_in_chan
[docs]def detect_Staresina2015(dat_orig, s_freq, time, opts):
"""Slow wave detection based on Ngo et al., 2015.
Parameters
----------
dat_orig : ndarray (dtype='float')
vector with the data for one channel
s_freq : float
sampling frequency
time : ndarray (dtype='float')
vector with the time points for each sample
opts : instance of 'DetectSlowWave'
'lowpass' : dict
parameters for 'low_butter',
'duration' : tuple of float
min and max duration of SW
'ptp_thresh' : float
percentile of mean ptp values, above which SW is kept
Returns
-------
list of dict
list of detected SWs
References
----------
Staresina, B. et al. 18(11) 1679-86 (2015).
"""
if opts.invert:
dat_orig = -dat_orig
sw_in_chan = []
dat_det = transform_signal(dat_orig, s_freq, 'low_butter', opts.lowpass)
idx_zx = find_zero_crossings(dat_det, xtype='pos_to_neg')
events = find_intervals(idx_zx, s_freq, opts.duration)
if events is not None:
events = find_peaks_in_slowwwave(dat_det, events)
if events is not None:
# Peak-to-peak amplitude threshold
ptp = dat_det[events[:, 3]] - dat_det[events[:, 1]]
ptp_thresh = percentile(ptp, opts.ptp_thresh)
events = events[ptp >= ptp_thresh, :]
if events is not None:
events = remove_straddlers(events, time, s_freq)
sw_in_chan = make_slow_waves(events, dat_det, time, s_freq)
if sw_in_chan:
lg.info('No slow waves found')
return sw_in_chan
[docs]def select_peaks(data, events, limit):
"""Check whether event satisfies amplitude limit.
Parameters
----------
data : ndarray (dtype='float')
vector with data
events : ndarray (dtype='int')
N x 2+ matrix with peak/trough in second position
limit : float
low and high limit for spindle duration
Returns
-------
ndarray (dtype='int')
N x 2+ matrix with peak/trough in second position
"""
selected = abs(data[events[:, 1]]) >= abs(limit)
return events[selected, :]
[docs]def make_slow_waves(events, data, time, s_freq):
"""Create dict for each slow wave, based on events of time points.
Parameters
----------
events : ndarray (dtype='int')
N x 5 matrix with start, trough, zero, peak, end samples
data : ndarray (dtype='float')
vector with the data
time : ndarray (dtype='float')
vector with time points
s_freq : float
sampling frequency
Returns
-------
list of dict
list of all the SWs, with information about start,
trough_time, zero_time, peak_time, end, duration (s), trough_val,
peak_val, peak-to-peak amplitude (signal units), area_under_curve
(signal units * s)
"""
slow_waves = []
for ev in events:
one_sw = {'start': time[ev[0]],
'trough_time': time[ev[1]],
'zero_time': time[ev[2]],
'peak_time': time[ev[3]],
'end': time[ev[4] - 1],
'trough_val': data[ev[1]],
'peak_val': data[ev[3]],
'dur': (ev[4] - ev[0]) / s_freq,
'ptp': abs(ev[3] - ev[1])
}
slow_waves.append(one_sw)
return slow_waves
def _add_halfwave(data, events, s_freq, opts):
"""Find the next zero crossing and the intervening peak and add
them to events. If no zero found before max_dur, event is discarded. If
peak-to-peak is smaller than min_ptp, the event is discarded.
Parameters
----------
data : ndarray (dtype='float')
vector with the data
events : ndarray (dtype='int')
N x 3 matrix with start, trough, end samples
s_freq : float
sampling frequency
opts : instance of 'DetectSlowWave'
'duration' : tuple of float
min and max duration of SW
'min_ptp' : float
min peak-to-peak amplitude
Returns
-------
ndarray (dtype='int')
N x 5 matrix with start, trough, - to + zero crossing, peak,
and end samples
"""
max_dur = opts.duration[1]
if max_dur is None:
max_dur = MAXIMUM_DURATION
window = int(s_freq * max_dur)
peak_and_end = zeros((events.shape[0], 2), dtype='int')
events = concatenate((events, peak_and_end), axis=1)
selected = []
for ev in events:
zero_crossings = where(diff(sign(data[ev[2]:ev[0] + window])))[0]
if zero_crossings.any():
ev[4] = ev[2] + zero_crossings[0] + 1
#lg.info('0cross is at ' + str(ev[4]))
else:
selected.append(False)
#lg.info('no 0cross, rejected')
continue
ev[3] = ev[2] + argmin(data[ev[2]:ev[4]])
if abs(data[ev[1]] - data[ev[3]]) < opts.min_ptp:
selected.append(False)
#lg.info('ptp too low, rejected: ' + str(abs(data[ev[1]] - data[ev[3]])))
continue
selected.append(True)
#lg.info('SW checks out, accepted! ptp is ' + str(abs(data[ev[1]] - data[ev[3]])))
return events[selected, :]
[docs]def find_zero_crossings(data, xtype='all'):
"""Find indices of zero-crossings in data.
Parameters
----------
data : ndarray (dtype='float')
vector with the data
xtype : str
if 'all', returns all zero crossings
if 'neg_to_pos', returns only negative-to-positive zero-crossings
if 'pos_to_neg', returns only positive-to-negative zero-crossings
Returns
-------
nadarray of int
indices of zero-crossings in the data
Note
----
A value of exactly 0 in data will always create a zero-crossing with
nonzero values preceding of following it.
"""
if xtype == 'all':
zx = where(diff(sign(data)))[0]
elif xtype == 'neg_to_pos':
zx = where(diff(sign(data)) > 0)[0]
elif xtype == 'pos_to_neg':
zx = where(diff(sign(data)) < 0)[0]
else:
raise ValueError(
"Invalid xtype. Choose 'all', 'neg_to_pos' or 'pos_to_neg'.")
return zx
[docs]def find_intervals(indices, s_freq, duration):
"""From sample indices, find intervals within a certain duration.
Parameters
----------
indices : ndarray (dtype='int')
vector with the indices
s_freq : float
sampling frequency of indices/data
duration: tuple of float
min and max duration (s) of intervals
Returns
-------
ndarray (dtype='int')
N x 2 matrix with start and end samples
"""
intervals = diff(indices) / s_freq
idx_event_starts = where(logical_and(
intervals >= duration[0],
intervals < duration[1]
))[0]
idx_event_ends = idx_event_starts + 1
if len(idx_event_starts):
events = vstack((indices[idx_event_starts],
indices[idx_event_ends]
)).T
else:
events = None
return events
[docs]def find_peaks_in_slowwwave(data, events):
"""Find trough, - to + zero-crossing and peak from start/end times.
Parameters
----------
data : ndarray (dtype='float')
vector with the data
events : ndarray (dtype='int')
N x 2 matrix with start, end samples
Returns
-------
ndarray (dtype='int')
N x 5 matrix with start, trough, - to + zero crossing, peak,
and end samples
"""
new_events = concatenate((
events[:, 0, newaxis],
zeros((events.shape[0], 3), dtype='int64'),
events[:, 1, newaxis]),
axis=1)
selected = ones(events.shape[0], dtype='bool')
for i, ev in enumerate(events):
try:
ev_dat = data[ev[0]:ev[1]]
new_events[i, 1] = ev[0] + ev_dat.argmin() # trough
new_events[i, 2] = ev[0] + where(diff(sign(ev_dat)) > 0)[0][0] # -to+
new_events[i, 3] = ev[0] + ev_dat.argmax() # peak
except IndexError:
selected[i] = False
return new_events[selected, :]