from logging import getLogger
from numpy import (asarray,
c_,
dot,
nanmean,
nanmedian,
moveaxis,
where,
zeros,
)
from numpy.linalg import norm, lstsq
from ..attr import Channels
lg = getLogger(__name__)
[docs]def montage(data, ref_chan=None, ref_to_avg=False, bipolar=None,
method='average'):
"""Apply linear transformation to the channels.
Parameters
----------
data : instance of DataRaw
the data to filter
ref_chan : list of str
list of channels used as reference
ref_to_avg : bool
if re-reference to average or not
bipolar : float
distance in mm to consider two channels as neighbors and then compute
the bipolar montage between them.
method : str
'average' or 'median' or 'regression'. 'average' / 'median' takes the
mean / median across the channels selected as reference (it can be all) and
subtract it from each channel. 'regression' keeps the residuals after
regressing out the mean across channels.
Returns
-------
filtered_data : instance of DataRaw
filtered data
Notes
-----
If you don't change anything, it returns the same instance of data.
"""
if ref_to_avg and ref_chan is not None:
raise TypeError('You cannot specify reference to the average and '
'the channels to use as reference')
if ref_chan is not None:
if (not isinstance(ref_chan, (list, tuple)) or
not all(isinstance(x, str) for x in ref_chan)):
raise TypeError('chan should be a list of strings')
if ref_chan is None:
ref_chan = [] # TODO: check bool for ref_chan
if bipolar:
if not data.attr['chan']:
raise ValueError('Data should have Chan information in attr')
_assert_equal_channels(data.axis['chan'])
chan_in_data = data.axis['chan'][0]
chan = data.attr['chan']
chan = chan(lambda x: x.label in chan_in_data)
chan, trans = create_bipolar_chan(chan, bipolar)
data.attr['chan'] = chan
if ref_to_avg or ref_chan or bipolar:
mdata = data._copy()
idx_chan = mdata.index_of('chan')
for i in range(mdata.number_of('trial')):
if ref_to_avg or ref_chan:
if ref_to_avg:
ref_chan = data.axis['chan'][i]
ref_data = data(trial=i, chan=ref_chan)
if method == 'average':
mdata.data[i] = (data(trial=i) - nanmean(ref_data, axis=idx_chan))
if method == 'median':
mdata.data[i] = (data(trial=i) - nanmedian(ref_data, axis=idx_chan))
elif method == 'regression':
mdata.data[i] = compute_average_regress(data(trial=i), idx_chan)
elif bipolar:
if not data.index_of('chan') == 0:
raise ValueError('For matrix multiplication to work, '
'the first dimension should be chan')
mdata.data[i] = dot(trans, data(trial=i))
mdata.axis['chan'][i] = asarray(chan.return_label(),
dtype='U')
else:
mdata = data
return mdata
def _assert_equal_channels(axis):
"""check that all the trials have the same channels, in the same order.
Parameters
----------
axis : ndarray of ndarray
one of the data axis
Raises
------
"""
for i0 in axis:
for i1 in axis:
if not all(i0 == i1):
raise ValueError('The channels for all the trials should have '
'the same labels, in the same order.')
[docs]def create_bipolar_chan(chan, max_dist):
chan_dist = zeros((chan.n_chan, chan.n_chan), dtype='bool')
for i0, chan0 in enumerate(chan.chan):
for i1, chan1 in enumerate(chan.chan):
if i0 < i1 and norm(chan0.xyz - chan1.xyz) < max_dist:
chan_dist[i0, i1] = True
x_all, y_all = where(chan_dist)
bipolar_labels = []
bipolar_xyz = []
bipolar_trans = []
for x0, x1 in zip(x_all, y_all):
new_label = chan.chan[x0].label + '-' + chan.chan[x1].label
bipolar_labels.append(new_label)
xyz = nanmean(c_[chan.chan[x0].xyz, chan.chan[x1].xyz], axis=1)
bipolar_xyz.append(xyz)
trans = zeros(chan.n_chan)
trans[x0] = 1
trans[x1] = -1
bipolar_trans.append(trans)
bipolar_xyz = c_[bipolar_xyz]
bipolar_trans = c_[bipolar_trans]
bipolar = Channels(bipolar_labels, bipolar_xyz)
return bipolar, bipolar_trans
[docs]def compute_average_regress(x, idx_chan):
"""Take the mean across channels and regress out the mean from each channel
Parameters
----------
x : ndarray
2d array with channels on one dimension
idx_chan:
which axis contains channels
Returns
-------
ndarray
same as x, but with the mean being regressed out
"""
if x.ndim != 2:
raise ValueError(f'The number of dimensions must be 2, not {x.ndim}')
x = moveaxis(x, idx_chan, 0) # move axis to the front
avg = nanmean(x, axis=0)
x_o = []
for i in range(x.shape[0]):
r = lstsq(avg[:, None], x[i, :][:, None], rcond=0)[0]
x_o.append(
x[i, :] - r[0, 0] * avg
)
return moveaxis(asarray(x_o), 0, idx_chan)
[docs]def create_virtual_channel(data, new_chan_name='virtual', method='average'):
"""Create a virtual channel by averaging several channels.
Parameters
----------
data : instance of DataRaw
the data to filter
new_chan_name : str
label for the virtual channel
method : str
'average'
Returns
-------
mdata : instance of Data
virtual data
"""
mdata = data._copy()
for i in range(mdata.number_of('trial')):
mdata.axis['chan'][i] = [new_chan_name]
mdata.data[i] = nanmean(data(trial=i), axis=0)
return mdata