"""Module to select periods of interest, based on number of trials or any of
the axes.
There is some overlap between Select and the Data.__call__(). The main
difference is that Select takes an instance of Data as input and returns
another instance of Data as output, whil Data.__call__() returns the actual
content of the data.
Select should be as flexible as possible. There are quite a few cases, which
will be added as we need them.
"""
from collections.abc import Iterable
from logging import getLogger
from numpy import (arange, array, asarray, diff, empty, hstack, inf,
issubsctype, linspace, nan_to_num, ndarray, ones, ravel,
setdiff1d, floor)
from numpy.lib.stride_tricks import as_strided
from math import isclose
from scipy.signal import resample as sci_resample
try:
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import QProgressDialog
except ImportError:
Qt = None
QProgressDialog = None
from .. import ChanTime
from .montage import montage
from .reject import remove_artf_evts
lg = getLogger(__name__)
[docs]class Segments():
"""Class containing a set of data segments for analysis, with metadata.
Only contains metadata until .read_data is called.
Attributes
----------
dataset : instance of wonambi.dataset
metadata for the associated record
segments : list of dict
chronological list of segment metadata. Each segment dict contains info
about start and end times, stage, cycle, channel and event name, if
applicable. Once read_data is called, the signal data are added to each
segment dictionary under 'data'.
"""
def __init__(self, dataset):
self.dataset = dataset
self.segments = []
def __iter__(self):
for one_event in self.segments:
yield one_event
def __len__(self):
return len(self.segments)
def __getitem__(self, index):
return self.segments[index]
[docs] def read_data(self, chan=[], ref_chan=[], grp_name=None, concat_chan=False,
average_channels=False, max_s_freq=30000, parent=None):
"""Read data for analysis. Adds data as 'data' in each dict.
Parameters
----------
chan : list of str
active channel names as they appear in record, without ref or group
If given an empty list, the channel specified in seg['chan'] will
be read for each segment
ref_chan : list of str
reference channel names as they appear in record, without group
grp_name : str
name of the channel group, required in GUI
concat_chan : bool
if True, data from all channels will be concatenated
average_channels : bool
if True, all channels will be averaged into a single virtual
channel with label 'avg_chan'
max_s_freq: : int
maximum sampling frequency
parent : QWidget
for GUI only. Identifies parent widget for display of progress
dialog.
"""
output = []
# Set up Progress Bar
if parent:
n_subseg = sum([len(x['times']) for x in self.segments])
progress = QProgressDialog('Fetching signal', 'Abort', 0, n_subseg,
parent)
progress.setWindowModality(Qt.ApplicationModal)
counter = 0
# Begin bundle loop; will yield one segment per loop
for i, seg in enumerate(self.segments):
one_segment = ChanTime()
one_segment.axis['chan'] = empty(1, dtype='O')
one_segment.axis['time'] = empty(1, dtype='O')
one_segment.data = empty(1, dtype='O')
subseg = []
# Subsegment loop; subsegments will be concatenated
for t0, t1 in seg['times']:
if parent:
progress.setValue(counter)
counter += 1
# if channel not specified, use segment channel
if chan:
active_chan = chan
elif seg['chan']:
active_chan = [seg['chan'].split(' (')[0]]
else:
raise ValueError('No channel was specified and the '
'segment at {}-{} has no channel.'.format(
t0, t1))
active_chan = chan if chan else [seg['chan'].split(' (')[0]]
if isinstance(active_chan, str):
active_chan = [active_chan]
chan_to_read = active_chan + ref_chan
data = self.dataset.read_data(chan=chan_to_read, begtime=t0,
endtime=t1)
# Downsample if necessary
if data.s_freq > max_s_freq:
q = int(data.s_freq / max_s_freq)
lg.debug('Decimate (no low-pass filter) at ' + str(q))
data.data[0] = data.data[0][:, slice(None, None, q)]
data.axis['time'][0] = data.axis['time'][0][slice(
None, None, q)]
data.s_freq = int(data.s_freq / q)
# read data from disk
subseg.append(_create_data(
data, active_chan, ref_chan=ref_chan, grp_name=grp_name))
one_segment.s_freq = s_freq = data.s_freq
one_segment.axis['chan'][0] = chs = subseg[0].axis['chan'][0]
one_segment.axis['time'][0] = timeline = hstack(
[x.axis['time'][0] for x in subseg])
one_segment.data[0] = empty((len(active_chan), len(timeline)),
dtype='f')
n_stitch = sum(asarray(diff(timeline) > 2/s_freq, dtype=bool))
for i, ch in enumerate(subseg[0].axis['chan'][0]):
one_segment.data[0][i, :] = hstack(
[x(chan=ch)[0] for x in subseg])
if average_channels:
one_segment.data[0] = one_segment.data[0].mean(0,
keepdims=True)
one_segment.axis['chan'][0] = array(['avg_chan'], dtype='<U2')
active_chan = ['avg_chan']
# For channel concatenation
elif concat_chan and len(chs) > 1:
one_segment.data[0] = ravel(one_segment.data[0])
one_segment.axis['chan'][0] = asarray([(', ').join(chs)],
dtype='U')
# axis['time'] should not be used in this case
output.append({'data': one_segment,
'chan': active_chan,
'stage': seg['stage'],
'cycle': seg['cycle'],
'name': seg['name'],
'n_stitch': n_stitch
})
if parent:
if progress.wasCanceled():
parent.parent.statusBar().showMessage('Process canceled by'
' user.')
return
if parent:
progress.setValue(counter)
self.segments = output
return 1 # for GUI
[docs]def select(data, trial=None, invert=False, **axes_to_select):
"""Define the selection of trials, using ranges or actual values.
Parameters
----------
data : instance of Data
data to select from.
trial : list of int or ndarray (dtype='i'), optional
index of trials of interest
**axes_to_select, optional
Values need to be tuple or list. If the values in one axis are string,
then you need to specify all the strings that you want. If the values
are numeric, then you should specify the range. To select only up to one
point, you can use (None, value_of_interest). To select multiple
values, you can pass a numpy array with dtype bool
invert : bool
take the opposite selection
Returns
-------
instance, same class as input
data where selection has been applied.
"""
if trial is not None and not isinstance(trial, Iterable):
raise TypeError('Trial needs to be iterable.')
for axis_to_select, values_to_select in axes_to_select.items():
if (not isinstance(values_to_select, Iterable) or
isinstance(values_to_select, str)):
raise TypeError(axis_to_select + ' needs to be iterable.')
if trial is None:
trial = range(data.number_of('trial'))
else:
trial = trial
if invert:
trial = setdiff1d(range(data.number_of('trial')), trial)
# create empty axis
output = data._copy(axis=False)
for one_axis in output.axis:
output.axis[one_axis] = empty(len(trial), dtype='O')
output.data = empty(len(trial), dtype='O')
to_select = {}
for cnt, i in enumerate(trial):
lg.debug('Selection on trial {0: 6}'.format(i))
for one_axis in output.axis:
values = data.axis[one_axis][i]
if one_axis in axes_to_select.keys():
values_to_select = axes_to_select[one_axis]
if len(values_to_select) == 0:
selected_values = ()
elif isinstance(values_to_select[0], str):
selected_values = asarray(values_to_select, dtype='U')
else:
if isinstance(values_to_select, ndarray) and issubsctype(values_to_select.dtype, bool):
bool_values = values_to_select
elif (values_to_select[0] is None and
values_to_select[1] is None):
bool_values = ones(len(values), dtype=bool)
elif values_to_select[0] is None:
bool_values = values < values_to_select[1]
elif values_to_select[1] is None:
bool_values = values_to_select[0] <= values
else:
bool_values = ((values_to_select[0] <= values) &
(values < values_to_select[1]))
selected_values = values[bool_values]
if invert:
selected_values = setdiff1d(values, selected_values)
lg.debug('In axis {0}, selecting {1: 6} '
'values'.format(one_axis,
len(selected_values)))
to_select[one_axis] = selected_values
else:
lg.debug('In axis ' + one_axis + ', selecting all the '
'values')
selected_values = data.axis[one_axis][i]
output.axis[one_axis][cnt] = selected_values
output.data[cnt] = data(trial=i, **to_select)
return output
[docs]def resample(data, s_freq, axis='time'):
"""Downsample the data after applying a filter.
Parameters
----------
data : instance of Data
data to downsample
s_freq : int or float
desired sampling frequency
axis : str
axis you want to apply downsample on (most likely 'time')
Returns
-------
instance of Data
downsampled data
"""
output = data._copy()
for i in range(data.number_of('trial')):
# check if the ratio between old and new data is not an integer
out_samples = int(floor(data.number_of('time')[i] / data.s_freq * s_freq))
orig_samples = int(out_samples / s_freq * data.s_freq)
# if so, then skip the last data points
interval = (data.axis[axis][i][0], data.axis[axis][i][orig_samples - 1])
x = data(trial=i, axis=interval)
output.data[i] = sci_resample(
data.data[i],
out_samples,
axis=data.index_of(axis))
n_samples = output.data[i].shape[data.index_of(axis)]
output.axis[axis][i] = linspace(data.axis[axis][i][0],
data.axis[axis][i][-1]
+ 1 / data.s_freq,
n_samples)
output.s_freq = s_freq
return output
[docs]def smart_chan(dataset, simple_chan_name, test_chan=None):
"""From a list of simple channel names, attempts to find the corresponding
channel names in the dataset and returns a list (with same order).
Parameters
----------
dataset : instance of Dataset
info about record
simple_chan_name : list of str
simple names for channels, e.g. ['F3', 'Fp2', 'ECG']
Returns
-------
list
corresponding channel labels as they appear in dataset
"""
chan_key = {}
if test_chan is None:
orig_chan_name = dataset.header['chan_name']
else:
orig_chan_name = test_chan
for s in simple_chan_name:
# look for exact matches
candidates = [x for x in orig_chan_name if s == x]
if len(candidates) == 1:
chan_key[s] = candidates[0]
continue
elif len(candidates) > 1:
raise ValueError( f'The record contains {len(candidates)} '
f'duplicates of channel label {s}')
# look for s in first position
candidates = [x for x in orig_chan_name if s == x[:min(len(s),
len(x))]]
if len(candidates) == 1:
chan_key[s] = candidates[0]
continue
elif len(candidates) > 1:
# s appears in first position more than once
raise ValueError(
f'Too many candidates corresponding to {s}: {candidates}')
# look for unique occurrences of s somewhere in chan label
candidates = [x for x in orig_chan_name if s in x]
if len(candidates) == 1:
chan_key[s] = candidates[0]
continue
elif len(candidates) > 1:
# remove from candidates all instances of chan as reference
no_dash = [x for x in candidates if x[x.index(s) - 1] != '-']
if len(no_dash) == 1:
chan_key[s] = no_dash[0]
continue
else:
raise ValueError(
f'Too many candidates corresponding to {s}: {candidates}')
raise ValueError(f'Unable to find channel containing {s}')
return [chan_key[x] for x in simple_chan_name]
[docs]def fetch(dataset, annot, cat=(0, 0, 0, 0), evt_type=None, stage=None,
cycle=None, chan_full=None, epoch=None, epoch_dur=30,
epoch_overlap=0, epoch_step=None, reject_epoch=False,
reject_artf=False, min_dur=0, buffer=0):
"""Create instance of Segments for analysis, complete with info about
stage, cycle, channel, event type. Segments contains only metadata until
.read_data is called.
Parameters
----------
dataset : instance of Dataset
info about record
annot : instance of Annotations
scoring info
cat : tuple of int
Determines where the signal is concatenated.
If cat[0] is 1, cycles will be concatenated.
If cat[1] is 1, different stages will be concatenated.
If cat[2] is 1, discontinuous signal within a same condition
(stage, cycle, event type) will be concatenated.
If cat[3] is 1, events of different types will be concatenated.
0 in any position indicates no concatenation.
evt_type: list of str, optional
Enter a list of event types to get events; otherwise, epochs will
be returned.
stage: list of str, optional
Stage(s) of interest. If None, stage is ignored.
cycle: list of tuple of two float, optional
Cycle(s) of interest, as start and end times in seconds from record
start. If None, cycles are ignored.
chan_full: list of str or None
Channel(s) of interest, only used for events (epochs have no
channel). Channel format is 'chan_name (group_name)'.
If used for epochs, separate segments will be returned for each
channel; this is necessary for channel-specific artefact removal (see
reject_artf below). If None, channel is ignored.
epoch : str, optional
If 'locked', returns epochs locked to staging. If 'unlocked', divides
signal (with specified concatenation) into epochs of duration epoch_dur
starting at first sample of every segment and discarding any remainder.
If None, longest run of signal is returned.
epoch_dur : float
only for epoch='unlocked'. Duration of epochs returned, in seconds.
epoch_overlap : float
only for epoch='unlocked'. Ratio of overlap between two consecutive
segments. Value between 0 and 1. Overriden by step.
epoch_step : float
only for epoch='unlocked'. Time between consecutive epoch starts, in
seconds. Overrides epoch_overlap/
reject_epoch: bool
If True, epochs marked as 'Poor' quality or staged as 'Artefact' will
be rejected (and the signal segmented in consequence). Has no effect on
event selection.
reject_artf : bool or str or list of str
If True, excludes events marked as 'Artefact'. If chan_full is
specified, only artefacts marked on a given channel are removed from
that channel. Signal is segmented in consequence.
If None, Artefact events are ignored.
If str or list of str, will reject the specified event types only.
min_dur : float
Minimum duration of segments returned, in seconds.
buffer : float
adds this many seconds of signal before and after each segment
Returns
-------
instance of Segments
metadata for all analysis segments
"""
bundles = get_times(annot, evt_type=evt_type, stage=stage, cycle=cycle,
chan=chan_full, exclude=reject_epoch, buffer=buffer)
# Remove artefacts
if bundles and reject_artf is not False:
s_freq = dataset.header['s_freq']
two_sample_dur = 2 / s_freq # min length to prevent begsam == endsam
if isinstance(reject_artf, bool):
evt_type_name = None
else:
evt_type_name = reject_artf
for bund in bundles:
bund['times'] = remove_artf_evts(bund['times'], annot,
chan=bund['chan'], name=evt_type_name, min_dur=two_sample_dur)
# Divide bundles into segments to be concatenated
if bundles:
if 'locked' == epoch:
bundles = _divide_bundles(bundles)
elif 'unlocked' == epoch:
if epoch_step is not None:
step = epoch_step
else:
step = epoch_dur - (epoch_dur * epoch_overlap)
bundles = _concat(bundles, cat)
bundles = _find_intervals(bundles, epoch_dur, step)
elif not epoch:
if evt_type:
bundles = _concat(bundles, cat, concat_continuous=False)
else:
bundles = _concat(bundles, cat)
# Minimum duration
bundles = _longer_than(bundles, min_dur)
segments = Segments(dataset)
segments.segments = bundles
return segments
[docs]def get_times(annot, evt_type=None, stage=None, cycle=None, chan=None,
exclude=False, buffer=0):
"""Get start and end times for selected segments of data, bundled
together with info.
Parameters
----------
annot: instance of Annotations
The annotation file containing events and epochs
evt_type: list of str, optional
Enter a list of event types to get events; otherwise, epochs will
be returned.
stage: list of str, optional
Stage(s) of interest. If None, stage is ignored.
cycle: list of tuple of two float, optional
Cycle(s) of interest, as start and end times in seconds from record
start. If None, cycles are ignored.
chan: list of str or tuple of None
Channel(s) of interest. Channel format is 'chan_name (group_name)'.
If None, channel is ignored.
exclude: bool
Exclude epochs by quality. If True, epochs marked as 'Poor' quality
or staged as 'Artefact' will be rejected (and the signal segmented
in consequence). Has no effect on event selection.
buffer : float
adds this many seconds of signal before and after each segment
Returns
-------
list of dict
Each dict has times (the start and end times of each segment, as
list of tuple of float), stage, cycle, chan, name (event type,
if applicable)
Notes
-----
This function returns epoch or event start and end times, bundled
together according to the specified parameters.
Presently, setting exclude to True does not exclude events found in Poor
signal epochs. The rationale is that events would never be marked in Poor
signal epochs. If they were automatically detected, these epochs would
have been left out during detection. If they were manually marked, then
it must have been Good signal. At the moment, in the GUI, the exclude epoch
option is disabled when analyzing events, but we could fix the code if we
find a use case for rejecting events based on the quality of the epoch
signal.
"""
getter = annot.get_epochs
last = annot.last_second
if stage is None:
stage = (None,)
if cycle is None:
cycle = (None,)
if chan is None:
chan = (None,)
if evt_type is None:
evt_type = (None,)
elif isinstance(evt_type[0], str):
getter = annot.get_events
if chan != (None,):
chan.append('') # also retrieve events marked on all channels
else:
lg.error('Event type must be list/tuple of str or None')
qual = None
if exclude:
qual = 'Good'
bundles = []
for et in evt_type:
for ch in chan:
for cyc in cycle:
for ss in stage:
st_input = ss
if ss is not None:
st_input = (ss,)
evochs = getter(name=et, time=cyc, chan=(ch,),
stage=st_input, qual=qual)
if evochs:
times = [(
max(e['start'] - buffer, 0),
min(e['end'] + buffer, last)) for e in evochs]
times = sorted(times, key=lambda x: x[0])
one_bundle = {'times': times,
'stage': ss,
'cycle': cyc,
'chan': ch,
'name': et}
bundles.append(one_bundle)
return bundles
def _longer_than(segments, min_dur):
"""Remove segments longer than min_dur."""
if min_dur <= 0.:
return segments
long_enough = []
for seg in segments:
if sum([t[1] - t[0] for t in seg['times']]) >= min_dur:
long_enough.append(seg)
return long_enough
def _concat(bundles, cat=(0, 0, 0, 0), concat_continuous=True):
"""Prepare event or epoch start and end times for concatenation."""
chan = sorted(set([x['chan'] for x in bundles]))
cycle = sorted(set([x['cycle'] for x in bundles]))
stage = sorted(set([x['stage'] for x in bundles]))
evt_type = sorted(set([x['name'] for x in bundles]))
all_cycle = None
all_stage = None
all_evt_type = None
if cycle[0] is not None:
all_cycle = ', '.join([str(c) for c in cycle])
if stage[0] is not None:
all_stage = ', '.join(stage)
if evt_type[0] is not None:
all_evt_type = ', '.join(evt_type)
if cat[0]:
cycle = [all_cycle]
if cat[1]:
stage = [all_stage]
if cat[3]:
evt_type = [all_evt_type]
to_concat = []
for ch in chan:
for cyc in cycle:
for st in stage:
for et in evt_type:
new_times = []
for bund in bundles:
chan_cond = ch == bund['chan']
cyc_cond = cyc in (bund['cycle'], all_cycle)
st_cond = st in (bund['stage'], all_stage)
et_cond = et in (bund['name'], all_evt_type)
if chan_cond and cyc_cond and st_cond and et_cond:
new_times.extend(bund['times'])
new_times = sorted(new_times, key=lambda x: x[0])
new_bund = {'times': new_times,
'chan': ch,
'cycle': cyc,
'stage': st,
'name': et
}
to_concat.append(new_bund)
if not cat[2]:
to_concat_new = []
for bund in to_concat:
last = None
bund['times'].append((inf,inf))
start = 0
for i, j in enumerate(bund['times']):
if last is not None:
if not isclose(j[0], last, abs_tol=0.01) \
or not concat_continuous:
new_times = bund['times'][start:i]
new_bund = bund.copy()
new_bund['times'] = new_times
to_concat_new.append(new_bund)
start = i
last = j[1]
to_concat = to_concat_new
to_concat = [x for x in to_concat if x['times']]
return to_concat
def _divide_bundles(bundles):
"""Take each subsegment inside a bundle and put it in its own bundle,
copying the bundle metadata."""
divided = []
for bund in bundles:
for t in bund['times']:
new_bund = bund.copy()
new_bund['times'] = [t]
divided.append(new_bund)
return divided
def _find_intervals(bundles, duration, step):
"""Divide bundles into segments of a certain duration and a certain step,
discarding any remainder."""
segments = []
for bund in bundles:
beg, end = bund['times'][0][0], bund['times'][-1][1]
if end - beg >= duration:
new_begs = arange(beg, end - duration, step)
for t in new_begs:
seg = bund.copy()
seg['times'] = [(t, t + duration)]
segments.append(seg)
return segments
def _create_data(data, active_chan, ref_chan=[], grp_name=None):
"""Create data after montage.
Parameters
----------
data : instance of ChanTime
the raw data
active_chan : list of str
the channel(s) of interest, without reference or group
ref_chan : list of str
reference channel(s), without group
grp_name : str
name of channel group, if applicable
Returns
-------
instance of ChanTime
the re-referenced data
"""
output = ChanTime()
output.s_freq = data.s_freq
output.start_time = data.start_time
output.axis['time'] = data.axis['time']
output.axis['chan'] = empty(1, dtype='O')
output.data = empty(1, dtype='O')
output.data[0] = empty((len(active_chan), data.number_of('time')[0]),
dtype='f')
sel_data = _select_channels(data, active_chan + ref_chan)
data1 = montage(sel_data, ref_chan=ref_chan)
data1.data[0] = nan_to_num(data1.data[0])
all_chan_grp_name = []
for i, chan in enumerate(active_chan):
chan_grp_name = chan
if grp_name:
chan_grp_name = chan + ' (' + grp_name + ')'
all_chan_grp_name.append(chan_grp_name)
dat = data1(chan=chan, trial=0)
output.data[0][i, :] = dat
output.axis['chan'][0] = asarray(all_chan_grp_name, dtype='U')
return output
def _create_subepochs(x, nperseg, step):
"""Transform the data into a matrix for easy manipulation
Parameters
----------
x : 1d ndarray
actual data values
nperseg : int
number of samples in each row to create
step : int
distance in samples between rows
Returns
-------
2d ndarray
a view (i.e. doesn't copy data) of the original x, with shape
determined by nperseg and step. You should use the last dimension
"""
axis = x.ndim - 1 # last dim
nsmp = x.shape[axis]
stride = x.strides[axis]
noverlap = nperseg - step
v_shape = *x.shape[:axis], (nsmp - noverlap) // step, nperseg
v_strides = *x.strides[:axis], stride * step, stride
v = as_strided(x, shape=v_shape, strides=v_strides,
writeable=False) # much safer
return v
def _select_channels(data, channels):
"""Select channels.
Parameters
----------
data : instance of ChanTime
data with all the channels
channels : list
channels of interest
Returns
-------
instance of ChanTime
data with only channels of interest
Notes
-----
This function does the same as wonambi.trans.select, but it's much faster.
wonambi.trans.Select needs to flexible for any data type, here we assume
that we have one trial, and that channel is the first dimension.
"""
output = data._copy()
chan_list = list(data.axis['chan'][0])
idx_chan = [chan_list.index(i_chan) for i_chan in channels]
output.data[0] = data.data[0][idx_chan, :]
output.axis['chan'][0] = asarray(channels)
return output