diff --git a/CITATION.cff b/CITATION.cff index 1da3c61..a90ecc3 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -55,6 +55,10 @@ authors: family-names: Veillette affiliation: 'Department of Psychology, University of Chicago, Chicago, IL, USA' orcid: 'https://orcid.org/0000-0002-0332-4372' + - given-names: Roy Eric + family-names: Wieske + affiliation: 'Biopsychology and Neuroergonomics, Technische Universität Berlin, Berlin, Germany' + orcid: 'https://orcid.org/0009-0006-2018-1074' type: software repository-code: 'https://github.com/sappelhoff/pyprep' license: MIT diff --git a/docs/authors.rst b/docs/authors.rst index 0ddb8f6..bccf1ed 100644 --- a/docs/authors.rst +++ b/docs/authors.rst @@ -11,3 +11,4 @@ .. _Victor Xiang: https://github.com/Nick3151 .. _Yorguin Mantilla: https://github.com/yjmantilla .. _John Veillette: https://github.com/john-veillette +.. _Roy Eric Wieske: https://github.com/Randomidous diff --git a/docs/changelog.rst b/docs/changelog.rst index 0d7a00f..e72a62a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -26,6 +26,7 @@ Version 0.6.0 (unreleased) Changelog ~~~~~~~~~ +- Added :meth:`~pyprep.NoisyChannels.find_bad_by_PSD` method for detecting channels with abnormally high or low power spectral density. This is a PyPREP-only feature not present in MATLAB PREP, by `Roy Eric Wieske`_ (:gh:`145`) - Users can now determine whether or not to use ``correlation`` as a method for finding bad channels in :meth:`~pyprep.NoisyChannels.find_all_bads` (defaults to True), by `Stefan Appelhoff`_ (:gh:`169`) - Manually marked bad channels are ignored for finding further bads (just like NaN and flat channels) in :meth:`~pyprep.NoisyChannels.find_all_bads`, by `Stefan Appelhoff`_ (:gh:`168`) diff --git a/docs/matlab_differences.rst b/docs/matlab_differences.rst index 9e3b668..10d0f4d 100644 --- a/docs/matlab_differences.rst +++ b/docs/matlab_differences.rst @@ -233,6 +233,34 @@ MATLAB PREP, PyPREP will use a Python reimplementation of ``eeg_interp`` instead when the ``matlab_strict`` parameter is set to ``True``. +PyPREP-Only Features +-------------------- + +The following features are available in PyPREP but are not present in the +original MATLAB PREP implementation. + + +Bad channel detection by PSD +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :meth:`~pyprep.NoisyChannels.find_bad_by_PSD` method detects channels with +abnormally high or low power spectral density (PSD) compared to other channels. +This method is not part of the original MATLAB PREP pipeline, but can be +considered a refinement of the ``bad_by_hfnoise`` detection in MATLAB PREP, +which flags channels based on the ratio of high-frequency power (>50 Hz) to +total power. + +A channel is considered "bad-by-PSD" if its total PSD (computed using Welch's +method over a configurable frequency range, defaulting to 1-45 Hz to exclude +line noise) deviates considerably from the median channel PSD. The deviation +is calculated using robust Z-scoring based on the median absolute deviation +(MAD). + +This method is called by :meth:`~pyprep.NoisyChannels.find_all_bads` by default, +but is skipped when ``matlab_strict=True`` to maintain equivalence with the +original MATLAB PREP pipeline. + + References ---------- diff --git a/matprep_artifacts b/matprep_artifacts index c7e99e5..6c272f3 160000 --- a/matprep_artifacts +++ b/matprep_artifacts @@ -1 +1 @@ -Subproject commit c7e99e5329afc505ba071f856242f47373fe87a7 +Subproject commit 6c272f3da47eb7dc4ef029ef5a87cd3b1dc05157 diff --git a/pyprep/find_noisy_channels.py b/pyprep/find_noisy_channels.py index 92af43a..ce54aee 100644 --- a/pyprep/find_noisy_channels.py +++ b/pyprep/find_noisy_channels.py @@ -106,6 +106,7 @@ def __init__( "bad_by_hf_noise": {}, "bad_by_correlation": {}, "bad_by_dropout": {}, + "bad_by_psd": {}, "bad_by_ransac": {}, } @@ -120,6 +121,7 @@ def __init__( self.bad_by_correlation = [] self.bad_by_SNR = [] self.bad_by_dropout = [] + self.bad_by_psd = [] self.bad_by_ransac = [] # Get original EEG channel names, channel count & samples @@ -197,6 +199,7 @@ def get_bads(self, verbose=False, as_dict=False): "bad_by_correlation": self.bad_by_correlation, "bad_by_SNR": self.bad_by_SNR, "bad_by_dropout": self.bad_by_dropout, + "bad_by_psd": self.bad_by_psd, "bad_by_ransac": self.bad_by_ransac, "bad_by_manual": self.bad_by_manual, } @@ -205,7 +208,12 @@ def get_bads(self, verbose=False, as_dict=False): for bad_chs in bads.values(): all_bads.update(bad_chs) - name_map = {"nan": "NaN", "hf_noise": "HF noise", "ransac": "RANSAC"} + name_map = { + "nan": "NaN", + "hf_noise": "HF noise", + "psd": "PSD", + "ransac": "RANSAC", + } if verbose: out = f"Found {len(all_bads)} uniquely bad channels:\n" for bad_type, bad_chs in bads.items(): @@ -288,6 +296,8 @@ def find_all_bads( if self.correlation: self.find_bad_by_correlation() self.find_bad_by_SNR() + if not self.matlab_strict: + self.find_bad_by_PSD() if self.ransac: self.find_bad_by_ransac( channel_wise=channel_wise, max_chunk_size=max_chunk_size @@ -556,6 +566,150 @@ def find_bad_by_SNR(self): # Flag channels bad by both HF noise and low correlation as bad by low SNR self.bad_by_SNR = list(bad_by_corr.intersection(bad_by_hf)) + def find_bad_by_PSD(self, zscore_threshold=3.0, fmin=1.0, fmax=45.0): + """Detect channels with abnormally high or low power spectral density. + + This is a PyPREP-only method not present in the original MATLAB PREP. + + A channel is considered "bad-by-psd" if: + 1. Its power in any frequency band (low: 1-15 Hz, mid: 15-30 Hz, + high: 30-45 Hz) is abnormally HIGH compared to other channels, OR + 2. Its high-frequency band has more power than its low-frequency band + (violating the typical 1/f spectral profile of EEG). + + Note: Only excess power (positive z-scores) is flagged, as abnormally + low power could reflect normal topographic variation. + + PSD is computed using Welch's method over the specified frequency range. + The default range (1-45 Hz) excludes line noise frequencies (50/60 Hz). + + Parameters + ---------- + zscore_threshold : float, optional + The minimum absolute z-score of a channel for it to be considered + bad-by-psd. Defaults to ``3.0``. + fmin : float, optional + The lower frequency bound (in Hz) for PSD computation. + Defaults to ``1.0``. + fmax : float, optional + The upper frequency bound (in Hz) for PSD computation. The default + of ``45.0`` excludes 50/60 Hz line noise from the analysis. + + """ + MAD_TO_SD = 1.4826 # Scales units of MAD to units of SD, assuming normality + # Reference: https://stat.ethz.ch/R-manual/R-devel/library/stats/html/mad.html + + # Define frequency bands (in Hz) + BAND_LOW = (fmin, 15.0) # ~ delta, theta, alpha + BAND_MID = (15.0, 30.0) # ~ beta + BAND_HIGH = (30.0, fmax) # ~ gamma + + if self.EEGFiltered is None: + self.EEGFiltered = self._get_filtered_data() + + # Create a temporary Raw object from filtered data for PSD computation + info = mne.create_info( + ch_names=self.ch_names_new.tolist(), + sfreq=self.sample_rate, + ch_types="eeg", + ) + raw_filtered = mne.io.RawArray(self.EEGFiltered, info, verbose=False) + + # Compute PSD using Welch method and convert to log scale (dB) + psd = raw_filtered.compute_psd( + method="welch", fmin=fmin, fmax=fmax, verbose=False + ) + psd_data = psd.get_data() + freqs = psd.freqs + log_psd = 10 * np.log10(psd_data) + + # Get frequency indices for each band + idx_low = (freqs >= BAND_LOW[0]) & (freqs < BAND_LOW[1]) + idx_mid = (freqs >= BAND_MID[0]) & (freqs < BAND_MID[1]) + idx_high = (freqs >= BAND_HIGH[0]) & (freqs <= BAND_HIGH[1]) + + # Compute band power (sum of log PSD within each band) for each channel + band_power_low = np.sum(log_psd[:, idx_low], axis=1) + band_power_mid = np.sum(log_psd[:, idx_mid], axis=1) + band_power_high = np.sum(log_psd[:, idx_high], axis=1) + + def robust_zscore(values): + """Compute robust z-scores using MAD.""" + median = np.median(values) + mad = np.median(np.abs(values - median)) + sd = mad * MAD_TO_SD + if sd > 0: + return (values - median) / sd + return np.zeros_like(values) + + # Criterion 1: Outlier with abnormally HIGH power in any band + # Note: Only positive z-scores (excess power) are flagged, as low power + # could reflect normal topographic variation rather than a bad channel + zscore_low = robust_zscore(band_power_low) + zscore_mid = robust_zscore(band_power_mid) + zscore_high = robust_zscore(band_power_high) + + bad_by_band = ( + (zscore_low > zscore_threshold) + | (zscore_mid > zscore_threshold) + | (zscore_high > zscore_threshold) + ) + + # Criterion 2: 1/f violation (high freq band has more power than low freq band) + # This is unusual for normal EEG and suggests muscle artifact or bad contact + bad_by_1f_violation = band_power_high > band_power_low + + # Criterion 3: Abnormal band ratios compared to other channels + # Use small epsilon to avoid division by zero + eps = np.finfo(float).eps + ratio_low_mid = band_power_low / (band_power_mid + eps) + ratio_low_high = band_power_low / (band_power_high + eps) + ratio_mid_high = band_power_mid / (band_power_high + eps) + + zscore_ratio_low_mid = robust_zscore(ratio_low_mid) + zscore_ratio_low_high = robust_zscore(ratio_low_high) + zscore_ratio_mid_high = robust_zscore(ratio_mid_high) + + bad_by_ratio = ( + (np.abs(zscore_ratio_low_mid) > zscore_threshold) + | (np.abs(zscore_ratio_low_high) > zscore_threshold) + | (np.abs(zscore_ratio_mid_high) > zscore_threshold) + ) + + # Combine criteria (bad if ANY criterion is met) + # Note: bad_by_ratio is computed for diagnostics but not used in final + # decision as it tends to be overly sensitive and theoretically debatable + bad_by_psd_usable = bad_by_band | bad_by_1f_violation + + # Map back to original channel indices + psd_channel_mask = np.zeros(self.n_chans_original, dtype=bool) + psd_channel_mask[self.usable_idx] = bad_by_psd_usable + abnormal_psd_channels = self.ch_names_original[psd_channel_mask] + + # Compute combined z-score for reporting (max absolute z-score across bands) + psd_zscore = np.zeros(self.n_chans_original) + max_band_zscore = np.maximum( + np.abs(zscore_low), np.maximum(np.abs(zscore_mid), np.abs(zscore_high)) + ) + psd_zscore[self.usable_idx] = max_band_zscore + + # Update names of bad channels by abnormal PSD & save additional info + self.bad_by_psd = abnormal_psd_channels.tolist() + self._extra_info["bad_by_psd"].update( + { + "psd_zscore": psd_zscore, + "band_power_low": band_power_low, + "band_power_mid": band_power_mid, + "band_power_high": band_power_high, + "zscore_low": zscore_low, + "zscore_mid": zscore_mid, + "zscore_high": zscore_high, + "bad_by_band": bad_by_band, + "bad_by_1f_violation": bad_by_1f_violation, + "bad_by_ratio": bad_by_ratio, + } + ) + def find_bad_by_ransac( self, n_samples=50, diff --git a/pyprep/reference.py b/pyprep/reference.py index 99e7af0..ab8d85d 100644 --- a/pyprep/reference.py +++ b/pyprep/reference.py @@ -238,6 +238,7 @@ def robust_reference(self, max_iterations=4): "bad_by_correlation": [], "bad_by_SNR": [], "bad_by_dropout": [], + "bad_by_psd": [], "bad_by_ransac": [], "bad_by_manual": self.bads_manual, "bad_all": [], diff --git a/tests/test_find_noisy_channels.py b/tests/test_find_noisy_channels.py index ee18675..9dba80b 100644 --- a/tests/test_find_noisy_channels.py +++ b/tests/test_find_noisy_channels.py @@ -217,6 +217,84 @@ def test_bad_by_SNR(raw_tmp): assert nd.bad_by_SNR == [raw_tmp.ch_names[low_snr_idx]] +def test_bad_by_PSD(raw_tmp): + """Test detection of channels with abnormal power spectral density.""" + # set scaling factors for high and low PSD test channels + low_psd_factor = 0.05 + high_psd_factor = 20.0 + + # make the signal for a random channel have very high power (high PSD) + n_chans = raw_tmp.get_data().shape[0] + high_psd_idx = int(rng.integers(0, n_chans, 1)[0]) + raw_tmp._data[high_psd_idx, :] *= high_psd_factor + + # test detection of abnormally high-PSD channels + nd = NoisyChannels(raw_tmp, do_detrend=False) + nd.find_bad_by_PSD() + assert raw_tmp.ch_names[high_psd_idx] in nd.bad_by_psd + + # verify that extra_info is populated correctly with band-based metrics + extra = nd._extra_info["bad_by_psd"] + assert "psd_zscore" in extra + assert len(extra["psd_zscore"]) == n_chans + # Check band power arrays + assert "band_power_low" in extra + assert "band_power_mid" in extra + assert "band_power_high" in extra + # Check per-band z-scores + assert "zscore_low" in extra + assert "zscore_mid" in extra + assert "zscore_high" in extra + # Check detection criteria flags + assert "bad_by_band" in extra + assert "bad_by_1f_violation" in extra + assert "bad_by_ratio" in extra + + # make the signal for a different channel have very low power (low PSD) + low_psd_idx = (high_psd_idx - 1) if high_psd_idx > 0 else 1 + raw_tmp._data[low_psd_idx, :] *= low_psd_factor + + # test detection of both abnormally high and low PSD channels + nd = NoisyChannels(raw_tmp, do_detrend=False) + nd.find_bad_by_PSD() + assert raw_tmp.ch_names[high_psd_idx] in nd.bad_by_psd + assert ( + raw_tmp.ch_names[low_psd_idx] not in nd.bad_by_psd + ) # the low PSD criterion was ommitted + + # verify that bad_by_psd is included in get_bads() output + all_bads = nd.get_bads(as_dict=True) + assert "bad_by_psd" in all_bads + assert raw_tmp.ch_names[high_psd_idx] in all_bads["bad_all"] + + +def test_bad_by_PSD_1f_violation(raw_tmp): + """Test detection of channels violating the 1/f spectral profile.""" + n_chans = raw_tmp.get_data().shape[0] + bad_idx = int(rng.integers(0, n_chans, 1)[0]) + + # Replace channel with high-frequency dominated signal (violates 1/f) + # Normal EEG has more power in low frequencies than high frequencies + # This channel will have more power in 30-45 Hz than in 1-15 Hz + high_freq_signal = _generate_signal(32, 44, raw_tmp.times, fcount=10) + raw_tmp._data[bad_idx, :] = high_freq_signal * 50 # Strong high-freq signal + + nd = NoisyChannels(raw_tmp, do_detrend=False) + nd.find_bad_by_PSD() + + # Channel should be flagged due to 1/f violation + assert raw_tmp.ch_names[bad_idx] in nd.bad_by_psd + + # Verify the 1/f violation was detected + extra = nd._extra_info["bad_by_psd"] + # Find the index in usable channels (convert boolean mask to int indices) + usable_int_idx = np.where(nd.usable_idx)[0] + usable_names = [raw_tmp.ch_names[i] for i in usable_int_idx] + if raw_tmp.ch_names[bad_idx] in usable_names: + usable_pos = usable_names.index(raw_tmp.ch_names[bad_idx]) + assert extra["bad_by_1f_violation"][usable_pos] + + def test_find_bad_by_ransac(raw_tmp): """Test the RANSAC component of NoisyChannels.""" # Set a consistent random seed for all RANSAC runs