"""Module contains the different types of data formats.
The main class is Data, all the other classes should depend on it. The other
classes are given only as convenience, but they should not overwride
Data.__call__, which needs to be very general.
"""
from collections import OrderedDict
from collections.abc import Iterable
from copy import deepcopy
from logging import getLogger
from pathlib import Path
from numpy import arange, array, empty, ix_, NaN, squeeze, where
lg = getLogger()
[docs]class Data:
"""General class containing recordings.
Parameters
----------
data : ndarray
one matrix with dimension matching the number of axes. You can pass
only one trial.
s_freq : int
sampling frequency
axes : dict
dictionary where the key is the name of the axis and the values must be
a numpy vector with the actual values.
Attributes
----------
data : ndarray (dtype='O')
the data as trials. Each trial is a ndarray (dtype='d' or 'f')
axis : OrderedDict
dictionary with axiss (standard names are 'chan', 'time', 'freq');
values should be numpy array
s_freq : int
sampling frequency
start_time : instance of datetime.datetime
the start time of the recording
attr : dict
contains additional information about the dataset, with keys:
- surf
- chan
- scores
Notes
-----
Something which is not immediately clear for chan. dtype='U' (meaning
Unicode) actually creates string of type str\_, while if you use dtype='S'
(meaning String) it creates strings of type bytes\_.
"""
def __init__(self, data=None, s_freq=None, **kwargs):
self.s_freq = s_freq
if data is None:
self.data = array([], dtype='O')
else:
self.data = array((1, ), dtype='O')
self.data[0] = data
self.axis = OrderedDict()
if data is not None:
"""temporary solution until PEP0468
kwargs is a dict, so no order. We try to reconstruct order based
on number of values for each value, but not 100% reliable.
"""
count_kwargs = {len(v): k for k, v in kwargs.items()}
if len(set(count_kwargs)) != len(list(count_kwargs)):
lg.warning('Some arguments have the same length, so the order '
'of the axes might be incorrect')
axes = OrderedDict()
for n_dim in data.shape:
try:
axis_with_right_ndim = count_kwargs[n_dim]
axes[axis_with_right_ndim] = kwargs[axis_with_right_ndim]
except KeyError:
raise ValueError('Number of dimensions in axis does not '
'match number of dimensions in data')
for axis, value in axes.items():
self.axis[axis] = array((1,), dtype='O')
self.axis[axis][0] = value
self.start_time = None
self.attr = {'surf': None,
'chan': None,
'scores': None,
}
def __call__(self, trial=None, tolerance=None, **axes):
"""Return the recordings and their time stamps.
Parameters
----------
trial : list of int or ndarray (dtype='i') or int
which trials you want (if it's one int, it returns the actual
matrix).
**axes
Arbitrary axiss to select from. You specify the axis and
the values as list or tuple of the values that you want.
tolerance : float
if one of the axiss is a number, it specifies the tolerance to
consider one value as chosen (take into account floating-precision
errors).
Returns
-------
ndarray
ndarray containing the data with the same number of axiss as
the original data. The length of the axis is equal to the
length of the data, UNLESS you specify an axis with values. In
that case, the length is equal to the values that you want.
If you specify only one trial (as int, not as tuple or list), then
it returns the actual matrix. Otherwise, it returns a ndarray
(dtype='O') of length equal to the trials.
Notes
-----
You cannot specify intervals here, you can do it in Select.
"""
if trial is None:
trial = range(self.number_of('trial'))
squeeze_trial = False
try:
iter(trial)
except TypeError: # 'int' object is not iterable
trial = (trial, )
squeeze_trial = True
output = empty(len(trial), dtype='O')
for cnt, i in enumerate(trial):
output_shape = []
idx_data = []
idx_output = []
squeeze_axis = []
for axis, values in self.axis.items():
if axis in axes.keys():
selected_values = axes[axis]
if (isinstance(selected_values, Iterable) and
not isinstance(selected_values, str)):
n_values = len(selected_values)
else:
n_values = 1
selected_values = array([selected_values])
squeeze_axis.append(self.index_of(axis))
idx = _get_indices(values[i],
selected_values,
tolerance=tolerance)
if len(idx[0]) == 0:
lg.warning('No index was selected for ' + axis)
idx_data.append(idx[0])
idx_output.append(idx[1])
else:
n_values = len(values[i])
idx_data.append(arange(n_values))
idx_output.append(arange(n_values))
output_shape.append(n_values)
output[cnt] = empty(output_shape, dtype=self.data[i].dtype)
output[cnt].fill(NaN)
if all([len(x) > 0 for x in idx_data]):
ix_output = ix_(*idx_output)
ix_data = ix_(*idx_data)
output[cnt][ix_output] = self.data[i][ix_data]
if len(squeeze_axis) > 0:
output[cnt] = squeeze(output[cnt],
axis=tuple(squeeze_axis))
if squeeze_trial:
output = output[0]
return output
@property
def list_of_axes(self):
"""Return the name of all the axes in the data."""
return tuple(self.axis.keys())
[docs] def index_of(self, axis):
"""Return the index of a axis.
Parameters
----------
axis : str
Name of the axis (such as 'trial', 'time', etc)
Returns
-------
int or ndarray (dtype='int')
number of trial (as int) or number of element in the selected
axis (if any of the other axiss) as 1d array.
Raises
------
ValueError
If the requested axis is not in the data.
"""
return list(self.axis.keys()).index(axis)
[docs] def number_of(self, axis):
"""Return the number of in one axis, as generally as possible.
Parameters
----------
axis : str
Name of the axis (such as 'trial', 'time', etc)
Returns
-------
int or ndarray (dtype='int')
number of trial (as int) or number of element in the selected
axis (if any of the other axiss) as 1d array.
Raises
------
KeyError
If the requested axis is not in the data.
Notes
-----
or is it better to catch the exception?
"""
if axis == 'trial':
return len(self.data)
else:
n_trial = self.number_of('trial')
output = empty(n_trial, dtype='int')
for i in range(n_trial):
output[i] = len(self.axis[axis][i])
return output
def __getattr__(self, possible_axis):
"""Return the axis with a shorter syntax.
Parameters
----------
possible_axis : str
one of the axes
Returns
-------
value of the axis of interest
Notes
------
The if-statement "startswith" is necessary to avoid recursionerror
when loading the class.
"""
if possible_axis.startswith('__'):
raise AttributeError(possible_axis)
try:
return self.axis[possible_axis]
except KeyError:
raise AttributeError(possible_axis)
def __iter__(self):
"""Implement generator for each trial.
The generator returns the data for each trial. This is of course really
convenient for map and parallel processing.
Examples
--------
>>> from wonambi.trans import math
>>> for one_trial in iter(data):
>>> one_mean = math(one_trial, operator_name='mean', axis='time')
>>> print(one_mean.data[0])
"""
for trial in range(self.number_of('trial')):
output = self._copy(axis=False)
for one_axis in self.axis:
output.axis[one_axis] = empty(1, dtype='O')
output.data = empty(1, dtype='O')
output.data[0] = self.data[trial]
for one_axis in output.axis:
output.axis[one_axis][0] = self.axis[one_axis][trial]
yield output
def _copy(self, axis=True, attr=True, data=False):
"""Create a new instance of Data, but does not copy the data
necessarily.
Parameters
----------
axis : bool, optional
deep copy the axes (default: True)
attr : bool, optional
deep copy the attributes (default: True)
data : bool, optional
deep copy the data (default: False)
Returns
-------
instance of Data (or ChanTime, ChanFreq, ChanTimeFreq)
copy of the data, but without the actual data
Notes
-----
It's important that we copy all the relevant information here. If you
add new attributes, you should add them here.
Remember that it deep-copies all the information, so if you copy data
the size might become really large.
"""
cdata = type(self)() # create instance of the same class
cdata.s_freq = self.s_freq
cdata.start_time = self.start_time
if axis:
cdata.axis = deepcopy(self.axis)
else:
cdata_axis = OrderedDict()
for axis_name in self.axis:
cdata_axis[axis_name] = array([], dtype='O')
cdata.axis = cdata_axis
if attr:
cdata.attr = deepcopy(self.attr)
if data:
cdata.data = deepcopy(self.data)
else:
# empty data with the correct number of trials
cdata.data = empty(self.number_of('trial'), dtype='O')
return cdata
[docs] def export(self, filename, export_format='FieldTrip', **options):
"""Export data in other formats.
Parameters
----------
filename : path to file
file to write
export_format : str, optional
supported export format is currently FieldTrip, EDF, FIFF, Wonambi,
BrainVision
Notes
-----
'edf' takes an optional argument "physical_max", see write_edf.
'wonambi' takes an optional argument "subj_id", see write_wonambi.
wonambi format creates two files, one .won with the dataset info as json
file and one .dat with the memmap recordings.
'brainvision' takes an additional argument ("markers") which is a list
of dictionaries with fields:
"name" : str (name of the marker),
"start" : float (start time in seconds)
"end" : float (end time in seconds)
'bids' has an optional argument "markers", like in 'brainvision'
"""
filename = Path(filename)
filename.parent.mkdir(parents=True, exist_ok=True)
export_format = export_format.lower()
if export_format == 'edf':
from .ioeeg import write_edf # avoid circular import
write_edf(self, filename, **options)
elif export_format == 'fieldtrip':
from .ioeeg import write_fieldtrip # avoid circular import
write_fieldtrip(self, filename)
elif export_format == 'mnefiff':
from .ioeeg import write_mnefiff
write_mnefiff(self, filename)
elif export_format == 'wonambi':
from .ioeeg import write_wonambi
write_wonambi(self, filename, **options)
elif export_format == 'brainvision':
from .ioeeg import write_brainvision
write_brainvision(self, filename, **options)
elif export_format == 'bids':
from .ioeeg import write_bids
write_bids(self, filename, **options)
else:
raise ValueError('Cannot export to ' + export_format)
[docs]class ChanTime(Data):
"""Specific class for chan-time recordings, with axes:
chan : ndarray (dtype='O')
for each trial, channels in the data (dtype='U')
time : ndarray (dtype='O')
for each trial, 1d matrix with the time stamp (dtype='f')
"""
def __init__(self):
super().__init__()
self.axis['chan'] = array([], dtype='O')
self.axis['time'] = array([], dtype='O')
[docs]class ChanFreq(Data):
"""Specific class for channel-frequency recordings, with axes:
chan : ndarray (dtype='O')
for each trial, channels in the data (dtype='U')
freq : ndarray (dtype='O')
for each trial, 1d matrix with the frequency (dtype='f')
Notes
-----
Conceptually, it is reasonable that each trial has the same frequency band,
so it might be more convenient to use only one array, but it can happen
that different trials have different frequency bands, so we keep the format
more open.
"""
def __init__(self):
super().__init__()
self.axis['chan'] = array([], dtype='O')
self.axis['freq'] = array([], dtype='O')
[docs]class ChanTimeFreq(Data):
"""Specific class for channel-time-frequency representation, with axes:
chan : ndarray (dtype='O')
for each trial, channels in the data (dtype='U')
time : ndarray (dtype='O')
for each trial, 1d matrix with the time stamp (dtype='f')
freq : ndarray (dtype='O')
for each trial, 1d matrix with the frequency (dtype='f')
"""
def __init__(self):
super().__init__()
self.axis['chan'] = array([], dtype='O')
self.axis['time'] = array([], dtype='O')
self.axis['freq'] = array([], dtype='O')
def _get_indices(values, selected, tolerance):
"""Get indices based on user-selected values.
Parameters
----------
values : ndarray (any dtype)
values present in the axis.
selected : ndarray (any dtype) or tuple or list
values selected by the user
tolerance : float
avoid rounding errors.
Returns
-------
idx_data : list of int
indices of row/column to select the data
idx_output : list of int
indices of row/column to copy into output
Notes
-----
This function is probably not very fast, but it's pretty robust. It keeps
the order, which is extremely important.
If you use values in the self.axis, you don't need to specify tolerance.
However, if you specify arbitrary points, floating point errors might
affect the actual values. Of course, using tolerance is much slower.
Maybe tolerance should be part of Select instead of here.
"""
idx_data = []
idx_output = []
for idx_of_selected, one_selected in enumerate(selected):
if tolerance is None or values.dtype.kind == 'U':
idx_of_data = where(values == one_selected)[0]
else:
idx_of_data = where(abs(values - one_selected) <= tolerance)[0] # actual use min
if len(idx_of_data) > 0:
idx_data.append(idx_of_data[0])
idx_output.append(idx_of_selected)
return idx_data, idx_output