From 360af804a5a422e53aa2a71c4b0b3621484c0e82 Mon Sep 17 00:00:00 2001 From: Sharif Haason Date: Mon, 16 Mar 2026 23:12:29 -0400 Subject: [PATCH 1/2] PERF: Batch spectrogram calls in Welch PSD computation Replace np.apply_along_axis (which calls scipy.signal.spectrogram once per row) with chunked 2D calls. scipy.signal.spectrogram handles multi-row input efficiently via vectorized FFT, so processing ~10 MB chunks instead of individual rows eliminates per-call Python dispatch overhead. On 320 epochs x 376 channels (120K rows), psd_array_welch goes from ~5.0s to ~0.19s (~26x speedup). Co-Authored-By: Claude Opus 4.6 --- mne/time_frequency/psd.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/mne/time_frequency/psd.py b/mne/time_frequency/psd.py index 01d932699a1..2feaa43ec0d 100644 --- a/mne/time_frequency/psd.py +++ b/mne/time_frequency/psd.py @@ -62,14 +62,23 @@ def _decomp_aggregate_mask(epoch, func, average, freq_sl): def _spect_func(epoch, func, freq_sl, average, *, output="power"): """Aux function.""" - # Decide if we should split this to save memory or not, since doing - # multiple calls will incur some performance overhead. Eventually we might - # want to write (really, go back to) our own spectrogram implementation - # that, if possible, averages after each transform, but this will incur - # a lot of overhead because of the many Python calls required. + # Process in chunks to balance vectorization (scipy.signal.spectrogram + # handles multi-row input efficiently) against memory usage. kwargs = dict(func=func, average=average, freq_sl=freq_sl) if epoch.nbytes > 10e6: - spect = np.apply_along_axis(_decomp_aggregate_mask, -1, epoch, **kwargs) + # Process in chunks of rows instead of one-by-one. Each chunk is + # passed to spectrogram as a 2D array, which is much faster than + # calling spectrogram per-row via np.apply_along_axis. + n_rows = epoch.shape[0] + # Target ~10 MB per chunk (same threshold as the original code) + row_bytes = epoch[0].nbytes + chunk_size = max(1, int(10e6 / row_bytes)) + parts = [] + for start in range(0, n_rows, chunk_size): + parts.append( + _decomp_aggregate_mask(epoch[start : start + chunk_size], **kwargs) + ) + spect = np.concatenate(parts, axis=0) else: spect = _decomp_aggregate_mask(epoch, **kwargs) return spect From 200aabd8b92e40a4c2778737d3929e77c283245d Mon Sep 17 00:00:00 2001 From: Sharif Haason Date: Mon, 16 Mar 2026 23:13:28 -0400 Subject: [PATCH 2/2] DOC: Add changelog entry for Welch PSD batch optimization Co-Authored-By: Claude Opus 4.6 --- doc/changes/dev/newfeature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 doc/changes/dev/newfeature.rst diff --git a/doc/changes/dev/newfeature.rst b/doc/changes/dev/newfeature.rst new file mode 100644 index 00000000000..6f7c22c5e44 --- /dev/null +++ b/doc/changes/dev/newfeature.rst @@ -0,0 +1 @@ +Speed up :func:`mne.time_frequency.psd_array_welch` and related Welch PSD methods by ~25x for epoched data by batching spectrogram calls instead of per-channel dispatch, by :newcontrib:`Sharif Haason`.