Source code for wonambi.trans.frequency

"""Module to compute frequency representation.
"""
from copy import deepcopy
from logging import getLogger
from functools import partial
from multiprocessing import Pool

from numpy import (arange, array, asarray, copy, empty, exp, isnan, log, max, mean,
                   median, moveaxis, NaN, pi, real, reshape, sqrt, swapaxes, zeros)
from numpy.linalg import norm
import numpy.fft as np_fft
from scipy import fftpack
from scipy.signal import windows, get_window, fftconvolve
from scipy.signal import detrend as detrend_func

from .extern.dpss import dpss_windows  # this will be in scipy v1.1
from ..datatype import ChanFreq, ChanTimeFreq, ChanTime
from .select import _create_subepochs

lg = getLogger(__name__)


[docs]def frequency(data, output='spectraldensity', scaling='power', sides='one', taper=None, halfbandwidth=3, NW=None, duration=None, overlap=0.5, step=None, detrend='linear', n_fft=None, log_trans=False, centend='mean'): """Compute the power spectral density (PSD, output='spectraldensity', scaling='power'), or energy spectral density (ESD, output='spectraldensity', scaling='energy') or the complex fourier transform (output='complex', sides='two') Parameters ---------- data : instance of ChanTime one of the datatypes detrend : str None (no detrending), 'constant' (remove mean), 'linear' (remove linear trend) output : str 'spectraldensity' or 'csd' or 'complex' 'spectraldensity' meaning the autospectrum or auto-spectral density, a special case of 'csd' (cross-spectral density), where the signal is cross-correlated with itself if 'csd', both channels in data are used as input sides : str 'one' or 'two', where 'two' implies negative frequencies scaling : str 'power' (units: V ** 2 / Hz), 'energy' (units: V ** 2), 'fieldtrip', 'chronux' taper : str Taper to use, commonly used tapers are 'boxcar', 'hann', 'dpss' halfbandwidth : int (only if taper='dpss') Half bandwidth (in Hz), frequency smoothing will be from +halfbandwidth to -halfbandwidth NW : int (only if taper='dpss') Normalized half bandwidth (NW = halfbandwidth * dur). Number of DPSS tapers is 2 * NW - 1. If specified, NW takes precedence over halfbandwidth duration : float, in s If not None, it divides the signal in epochs of this length (in seconds) and then average over the PSD / ESD (not the complex result) overlap : float, between 0 and 1 The amount of overlap between epochs (0.5 = 50%, 0.95 = almost complete overlap). step : float, in s step in seconds between epochs (alternative to overlap) n_fft: int Length of FFT, in samples. If less than input axis, input is cropped. If longer than input axis, input is padded with zeros. If None, FFT length set to axis length. log_trans : bool If True, spectral values will be natural log-transformed. The transformation is applied before averaging (or taking the median). centend : str (only if duration is not None). Central tendency measure to use, either mean (arithmetic) or median. Returns ------- instance of ChanFreq If output='complex', there is an additional dimension ('taper') which is useful for 'dpss' but it's also present for all the other tapers. Raises ------ TypeError If the data does not have a 'time' axis. It might work in the future on other axes, but I cannot imagine how. ValueError If you use duration (to create multiple epochs) and output='complex', because it does not average the complex output of multiple epochs. Notes ----- See extensive notes at wonambi.trans.frequency._fft It uses sampling frequency as specified in s_freq, it does not recompute the sampling frequency based on the time axis. Use of log or median for Welch's method is included based on recommendations from Izhikevich et al., bioRxiv, 2018. """ if output not in ('spectraldensity', 'complex', 'csd'): raise TypeError('output can be "spectraldensity", "complex" or "csd",' ' not "{output}"') if 'time' not in data.list_of_axes: raise TypeError('\'time\' is not in the axis ' + str(data.list_of_axes)) if len(data.list_of_axes) != data.index_of('time') + 1: raise TypeError('\'time\' should be the last axis') # this might be improved if duration is not None and output == 'complex': raise ValueError('cannot average the complex spectrum over multiple epochs') if output == 'csd' and data.number_of('chan') != 2: raise ValueError('CSD can only be computed between two channels') if duration is not None: nperseg = int(duration * data.s_freq) if step is not None: nstep = int(step * data.s_freq) else: nstep = nperseg - int(overlap * nperseg) freq = ChanFreq() freq.attr = deepcopy(data.attr) freq.s_freq = data.s_freq freq.start_time = data.start_time freq.axis['chan'] = copy(data.axis['chan']) freq.axis['freq'] = empty(data.number_of('trial'), dtype='O') if output == 'complex': freq.axis['taper'] = empty(data.number_of('trial'), dtype='O') freq.data = empty(data.number_of('trial'), dtype='O') for i in range(data.number_of('trial')): x = data(trial=i) if duration is not None: x = _create_subepochs(x, nperseg, nstep) f, Sxx = _fft(x, s_freq=data.s_freq, detrend=detrend, taper=taper, output=output, sides=sides, scaling=scaling, halfbandwidth=halfbandwidth, NW=NW, n_fft=n_fft) if log_trans: Sxx = log(Sxx) if duration is not None: if centend == 'mean': Sxx = Sxx.mean(axis=-2) elif centend == 'median': Sxx = median(Sxx, axis=-2) else: raise ValueError('Invalid central tendency measure. ' 'Use mean or median.') freq.axis['freq'][i] = f if output == 'complex': freq.axis['taper'][i] = arange(Sxx.shape[-1]) if output == 'csd': newchan = ' * '.join(freq.axis['chan'][i]) freq.axis['chan'][i] = asarray([newchan], dtype='U') freq.data[i] = Sxx return freq
[docs]def timefrequency(data, method='morlet', **options): """Compute the power spectrum over time. Parameters ---------- data : instance of ChanTime data to analyze method : str the method to compute the time-frequency representation, such as 'morlet' (wavelet using complex morlet window), 'spectrogram' (corresponds to 'spectraldensity' in frequency()), 'stft' (short-time fourier transform, corresponds to 'complex' in frequency()) options : dict options depend on the method used, see below. Returns ------- instance of ChanTimeFreq data in time-frequency representation. The exact output depends on the method. Using 'morlet', you get a complex output at each frequency where the wavelet was computed. Examples -------- The data in ChanTimeFreq are complex and they should stay that way. You can also get the magnitude or power the easy way using Math. >>> from wonambi.trans import math, timefreq >>> tf = timefreq(data, foi=(8, 10)) >>> tf_abs = math(tf, operator_name='abs') >>> tf_abs.data[0][0, 0, 0] 1737.4662329214384) Notes ----- It uses sampling frequency as specified in s_freq, it does not recompute the sampling frequency based on the time axis. For method 'morlet', the following options should be specified: foi : ndarray or list or tuple vector with frequency of interest ratio : float ratio for a wavelet family ( = freq / sigma_f) sigma_f : float standard deviation of the wavelet in frequency domain dur_in_sd : float duration of the wavelet, given as number of the standard deviation in the time domain, in one side. dur_in_s : float total duration of the wavelet, two-sided (i.e. from start to finish) normalization : str 'area' means that energy is normalized to 1, 'peak' means that the peak of the wavelet is set at 1, 'max' is a normalization used by nitime where the max value of the output of the convolution remains the same even if you change the sigma_f. zero_mean : bool make sure that the wavelet has zero mean (only relevant if ratio < 5) For method 'spectrogram' or 'stft', the following options should be specified: duraton : int duration of the window to compute the power spectrum, in s overlap : int amount of overlap (0 -> no overlap, 1 -> full overlap) """ implemented_methods = ('morlet', 'spectrogram', # this is output spectraldensity 'stft') # this is output complex if method not in implemented_methods: raise ValueError('Method ' + method + ' is not implemented yet.\n' 'Currently implemented methods are ' + ', '.join(implemented_methods)) if method == 'morlet': default_options = {'foi': None, 'ratio': 5, 'sigma_f': None, 'dur_in_sd': 4, 'dur_in_s': None, 'normalization': 'area', 'zero_mean': False, } elif method in ('spectrogram', 'stft'): default_options = {'duration': 1, 'overlap': 0.5, 'step': None, 'detrend': 'linear', 'taper': 'hann', 'sides': 'one', 'scaling': 'power', 'halfbandwidth': 2, 'NW': None, } default_options.update(options) options = default_options timefreq = ChanTimeFreq() timefreq.attr = deepcopy(data.attr) timefreq.s_freq = data.s_freq timefreq.start_time = data.start_time timefreq.axis['chan'] = data.axis['chan'] timefreq.axis['time'] = empty(data.number_of('trial'), dtype='O') timefreq.axis['freq'] = empty(data.number_of('trial'), dtype='O') if method == 'stft': timefreq.axis['taper'] = empty(data.number_of('trial'), dtype='O') timefreq.data = empty(data.number_of('trial'), dtype='O') if method == 'morlet': # we assume that the data is ChanTime assert data.index_of('chan') == 0 assert data.index_of('time') == 1 wavelets = _create_morlet(deepcopy(options), data.s_freq) for i in range(data.number_of('trial')): lg.info('Processing trial # {0: 6}'.format(i)) timefreq.axis['freq'][i] = array(options['foi']) timefreq.axis['time'][i] = data.axis['time'][i] timefreq.data[i] = empty((data.number_of('chan')[i], data.number_of('time')[i], len(options['foi'])), dtype='complex') data_i = data(trial=i) args = [] for i_chan in range(data.number_of('chan')[i]): for wavelet in wavelets: args.append((i_chan, wavelet)) with Pool() as p: result = p.starmap(partial(_convolve, dat=data_i), args) tf = reshape(array(result), (data.number_of('chan')[i], len(wavelets), -1)) timefreq.data[i] = moveaxis(tf, 2, 1) elif method in ('spectrogram', 'stft'): # TODO: add timeskip nperseg = int(options['duration'] * data.s_freq) if options['step'] is not None: nstep = int(options['step'] * data.s_freq) else: nstep = nperseg - int(options['overlap'] * nperseg) if method == 'spectrogram': output = 'spectraldensity' elif method == 'stft': output = 'complex' for i in range(data.number_of('trial')): t = _create_subepochs(data.time[i], nperseg, nstep).mean(axis=1) x = _create_subepochs(data(trial=i), nperseg, nstep) f, Sxx = _fft(x, s_freq=data.s_freq, detrend=options['detrend'], taper=options['taper'], output=output, sides=options['sides'], scaling=options['scaling'], halfbandwidth=options['halfbandwidth'], NW=options['NW']) timefreq.axis['time'][i] = t timefreq.axis['freq'][i] = f if method == 'stft': timefreq.axis['taper'][i] = arange(Sxx.shape[-1]) timefreq.data[i] = Sxx return timefreq
[docs]def band_power(data, freq, scaling='power', n_fft=None, detrend=None, array_out=False): """Compute power or energy acoss a frequency band, and its peak frequency. Power is estimated using the mid-point rectangle rule. Input can be ChanTime or ChanFreq. Parameters ---------- data : instance of ChanTime or ChanFreq data to be analyzed, one trial only freq : tuple of float Frequencies for band of interest. Power will be integrated across this band, inclusively, and peak frequency determined within it. If a value is None, the band is unbounded in that direction. input_type : str 'time' or 'spectrum' scaling : str 'power' or 'energy', only used if data is ChanTime n_fft : int length of FFT. if shorter than input signal, signal is truncated; if longer, signal is zero-padded to length array_out : bool if True, will return two arrays instead of two dict. Returns ------- dict of float, or ndarray keys are channels, values are power or energy dict of float, or ndarray keys are channels, values are respective peak frequency """ if not array_out: power = {} peakf = {} else: power = zeros((data.number_of('chan')[0], 1)) peakf = zeros((data.number_of('chan')[0], 1)) if isinstance(data, ChanFreq): Sxx = data elif isinstance(data, ChanTime): Sxx = frequency(data, scaling=scaling, n_fft=n_fft, detrend=detrend) else: raise ValueError('Invalid data type') if detrend is None: if 'power' == scaling: detrend = 'linear' elif 'energy' == scaling: detrend = None sf = Sxx.axis['freq'][0] f_res = sf[1] - sf[0] # frequency resolution if freq[0] is not None: idx_f1 = asarray([abs(x - freq[0]) for x in sf]).argmin() else: idx_f1 = 0 if freq[1] is not None: idx_f2 = min(asarray([abs(x - freq[1]) for x in sf]).argmin() + 1, len(sf) - 1) # inclusive, to follow convention else: idx_f2 = len(sf) - 1 for i, chan in enumerate(Sxx.axis['chan'][0]): s = Sxx(chan=chan)[0] pw = sum(s[idx_f1:idx_f2]) * f_res idx_peak = s[idx_f1:idx_f2].argmax() pf = sf[idx_f1:idx_f2][idx_peak] if array_out: power[i, 0] = pw peakf[i, 0] = pf else: power[chan] = pw peakf[chan] = pf return power, peakf
def _create_morlet(options, s_freq): """Create morlet wavelets, with scipy.signal doing the actual computation. Parameters ---------- foi : ndarray or list or tuple vector with frequency of interest s_freq : int or float sampling frequency of the data options : dict with 'M_in_s' (duration of the wavelet in seconds) and 'w' (Omega0) Returns ------- ndarray nFreq X nSamples matrix containing the complex Morlet wavelets. """ wavelets = [] foi = options.pop('foi') for f in foi: wavelets.append(morlet(f, s_freq, **options)) return wavelets
[docs]def morlet(freq, s_freq, ratio=5, sigma_f=None, dur_in_sd=4, dur_in_s=None, normalization='wonambi', zero_mean=False): """Create a Morlet wavelet. Parameters ---------- freq : float central frequency of the wavelet s_freq : int sampling frequency ratio : float ratio for a wavelet family ( = freq / sigma_f) sigma_f : float standard deviation of the wavelet in frequency domain dur_in_sd : float duration of the wavelet, given as number of the standard deviation in the time domain, in one side. dur_in_s : float total duration of the wavelet, two-sided (i.e. from start to finish) normalization : str 'wonambi' (default) returns an amplitude of 1 in frequency-domain for a sine wave of amplitude 1 in the time-domain; 'juniper' returns amplitude which is dependent on sampling frequency; 'area' normalizes the area of the Gaussian envelope to be 1; 'peak' normalizes the peak of the Gaussian envelope to be 1. Note that the frequency-domain values for 'area' and 'peak' will depend on the carrier frequency of the wavelet (they depend on sigma_f). zero_mean : bool make sure that the wavelet has zero mean (only relevant if ratio < 5) Returns ------- ndarray vector containing the complex Morlet wavelets Notes ----- 'ratio' and 'sigma_f' are mutually exclusive. If you use 'sigma_f', the standard deviation stays the same for all the frequency. It's more common to specify a constant ratio for the wavelet family, so that the frequency resolution changes with the frequency of interest. 'dur_in_sd' and 'dur_in_s' are mutually exclusive. 'dur_in_s' specifies the total duration (from start to finish) of the window. 'dur_in_sd' calculates the total duration as the length in standard deviations in the time domain: dur_in_s = dur_in_sd * 2 * sigma_t, with sigma_t = 1 / (2 * pi * sigma_f) """ if sigma_f is None: sigma_f = freq / ratio else: ratio = freq / sigma_f sigma_t = 1 / sigma_f if ratio < 5 and not zero_mean: lg.info('The wavelet won\'t have zero mean, set zero_mean=True to ' 'correct it') if dur_in_s is None: dur_in_s = sigma_t * dur_in_sd * 2 t = arange(-dur_in_s / 2, dur_in_s / 2, 1 / s_freq) w = exp(1j * 2 * pi * freq * t) if zero_mean: w -= exp(-1 / 2 * ratio ** 2) w *= exp(-t ** 2 / (2 * sigma_t ** 2)) if normalization == 'wonambi': w /= sqrt(pi / 2) * sigma_t * s_freq elif normalization == 'juniper': w /= sqrt(2 * pi) * sigma_t elif normalization == 'area': w /= sqrt(sqrt(pi) * sigma_t * s_freq) elif normalization == 'peak': pass lg.info('At freq {0: 9.3f}Hz, sigma_f={1: 9.3f}Hz, sigma_t={2: 9.3f}s, ' 'total duration={3: 9.3f}s'.format(freq, sigma_f, sigma_t, dur_in_s)) lg.debug(' Real peak={0: 9.3f}, Mean={1: 12.6f}, ' 'Energy={2: 9.3f}'.format(max(real(w)), mean(w), norm(w) ** 2)) return w
def _fft(x, s_freq, detrend='linear', taper=None, output='spectraldensity', sides='one', scaling='power', halfbandwidth=4, NW=None, n_fft=None): """ Core function taking care of computing the power spectrum / power spectral density or the complex representation. Parameters ---------- x : 1d or 2d numpy array input data (fft will be computed on the last dimension) s_freq : int sampling frequency detrend : str None (no detrending), 'constant' (remove mean), 'linear' (remove linear trend) output : str 'spectraldensity' (= 'psd' in scipy) or 'complex' (for complex output) sides : str 'one' or 'two', where 'two' implies negative frequencies scaling : str 'power' (= 'density' in scipy, units: uV ** 2 / Hz), 'energy' (= 'spectrum' in scipy, units: uV ** 2), 'fieldtrip', 'chronux' taper : str Taper to use, commonly used tapers are 'boxcar', 'hann', 'dpss' (see below) halfbandwidth : int (only if taper='dpss') Half bandwidth (in Hz), frequency smoothing will be from +halfbandwidth to -halfbandwidth NW : int (only if taper='dpss') Normalized half bandwidth (NW = halfbandwidth * dur). Number of DPSS tapers is 2 * NW - 1. If specified, NW takes precedence over halfbandwidth n_fft: int Length of FFT, in samples. If less than input axis, input is cropped. If longer than input axis, input is padded with zeros. If None, FFT length set to axis length. Returns ------- freqs : 1d ndarray vector with frequencies at which the PSD / ESD / complex fourier was computed result: ndarray PSD / ESD / complex fourier. It has the same number of dim as the input. Frequency transform is computed on the last dimension. If output='complex', there is one additional dimension with the taper(s). Notes ----- The nomenclature of the frequency-domain analysis is not very consistent across packages / toolboxes. The convention used here is based on `wikipedia`_ So, you can have the spectral density (called sometimes power spectrum) or a complex output. Conceptually quite different but they can both be computed using the fft algorithm, so we do both here. Regarding the spectral density, you can have the power spectral density (PSD) or the energy spectral density (ESD). PSD should be used for stationary signals (gamma activity), while ESD should be used for signals that have a clear beginning and end (spindles). ESD gives the energy over the whole duration of the input window, while PSD is normalized by the window length. Parseval's theorem says that the energy of the signal in the time-domain must be equal to the energy in the frequency domain. All the tapers are correct to comply with this theorem (see tests/test_trans_frequency.py for all the examples). Note that packages such as 'FieldTrip' and 'Chronux' do not usually respect this convention (and use some ad-hoc convention). You can use the scaling of these packages to compare the results from those matlab toolboxes, but note that the results probably don't satisty Parseval's theorem. Note that scipy.signal is not consistent with these names, but the formulas are the same. Also, scipy (v1.1 at least) does not handle dpss. Finally, the complex output has an additional dimension (taper), for each taper (even for the boxcar or hann taper). This is useful for multitaper analysis (DPSS), where it doesn't make sense to average complex results. .. _wikipedia: https://en.wikipedia.org/wiki/Spectral_density TODO ---- Scipy v1.1 can generate dpss tapers. Once scipy v1.1 is available, use that instead of the extern folder. """ if output == 'complex' and sides == 'one': print('complex always returns both sides') sides = 'two' axis = x.ndim - 1 n_smp = x.shape[axis] if n_fft is None: n_fft = n_smp if sides == 'one': freqs = np_fft.rfftfreq(n_fft, 1 / s_freq) elif sides == 'two': freqs = fftpack.fftfreq(n_fft, 1 / s_freq) if taper is None: taper = 'boxcar' if taper == 'dpss': if NW is None: NW = halfbandwidth * n_smp / s_freq tapers, eig = dpss_windows(n_smp, NW, 2 * NW - 1) if scaling == 'chronux': tapers *= sqrt(s_freq) else: if taper == 'hann': tapers = windows.hann(n_smp, sym=False)[None, :] else: # TODO: it'd be nice to use sym=False if possible, but the difference is very small tapers = get_window(taper, n_smp)[None, :] if scaling == 'energy': rms = sqrt(mean(tapers ** 2)) tapers /= rms * sqrt(n_smp) elif scaling != 'chronux': # idk how chronux treats other windows apart from dpss tapers /= norm(tapers) if detrend is not None: has_nan = isnan(x).any(axis=axis) if has_nan.any(): x = x.copy() x[has_nan] = 0 x = detrend_func(x, axis=axis, type=detrend) if has_nan.any(): x[has_nan] = NaN tapered = tapers * x[..., None, :] if sides == 'one': result = np_fft.rfft(tapered, n=n_fft) elif sides == 'two': result = fftpack.fft(tapered, n=n_fft) if scaling == 'chronux': result /= s_freq elif scaling == 'fieldtrip': result *= sqrt(2 / n_smp) if output == 'spectraldensity': result = (result.conj() * result) elif output == 'csd': result = (result[None, 0, ...].conj() * result[None, 1, ...]) if (sides == 'one' and output in ('spectraldensity', 'csd') and scaling != 'chronux'): if n_fft % 2: result[..., 1:] *= 2 else: # Last point is unpaired Nyquist freq point, don't double result[..., 1:-1] *= 2 if scaling == 'power': scale = 1.0 / s_freq elif scaling == 'energy': scale = 1.0 / n_smp else: scale = 1 if output == 'complex' and scaling in ('power', 'energy'): scale = sqrt(scale) result *= scale if scaling == 'fieldtrip' and output in ('spectraldensity', 'csd'): # fieldtrip uses only one side result /= 2 if output in ('spectraldensity', 'csd'): if output == 'spectraldensity': result = result.real result = mean(result, axis=axis) elif output == 'complex': # dpss should be last dimension in complex, no mean result = swapaxes(result, axis, -1) return freqs, result def _convolve(i_chan, wavelet, dat): tf = fftconvolve(dat[i_chan, :], wavelet, 'same') return tf