Skip to content

Commit 24a5634

Browse files
authored
[ENH] return input data if score is negative on first iter (#94)
1 parent ac12d71 commit 24a5634

2 files changed

Lines changed: 36 additions & 7 deletions

File tree

meegkit/dss.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,10 @@ def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5,
291291
show: bool
292292
Produce a visual output of each iteration (default=False).
293293
dirname: str
294-
Path to the directory where visual outputs are saved when show is 'True'.
294+
Path to the directory where visual outputs are saved when show is 'True'.
295295
If 'None', does not save the outputs. (default=None)
296296
extension: str
297-
Extension of the images filenames. Must be compatible with plt.savefig()
297+
Extension of the images filenames. Must be compatible with plt.savefig()
298298
function. (default=".png")
299299
n_iter_max : int
300300
Maximum number of iterations (default=100).
@@ -313,11 +313,13 @@ def nan_basic_interp(array):
313313
array[nans] = np.interp(ix(nans), ix(~nans), array[~nans])
314314
return array
315315

316+
data_clean = data.copy()
317+
316318
freq_rn = [fline - win_sz, fline + win_sz]
317319
freq_sp = [fline - spot_sz, fline + spot_sz]
318320
freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0)
319321

320-
freq_rn_ix = np.logical_and(freq >= freq_rn[0],
322+
freq_rn_ix = np.logical_and(freq >= freq_rn[0],
321323
freq <= freq_rn[1])
322324
freq_used = freq[freq_rn_ix]
323325
freq_sp_ix = np.logical_and(freq_used >= freq_sp[0],
@@ -338,8 +340,8 @@ def nan_basic_interp(array):
338340
aggr_resid = []
339341
iterations = 0
340342
while iterations < n_iter_max:
341-
data, _ = dss_line(data, fline, sfreq, nfft=nfft, nremove=1)
342-
freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0)
343+
data_clean, _ = dss_line(data_clean, fline, sfreq, nfft=nfft, nremove=1)
344+
freq, psd = welch(data_clean, fs=sfreq, nfft=nfft, axis=0)
343345
if psd.ndim == 3:
344346
mean_psd = np.mean(psd, axis=(1, 2))[freq_rn_ix]
345347
elif psd.ndim == 2:
@@ -366,7 +368,7 @@ def nan_basic_interp(array):
366368
ax.flat[0].set_xlabel("Frequency (Hz)")
367369
ax.flat[0].set_ylabel("Power")
368370

369-
ax.flat[1].plot(freq_used, mean_psd_tf, c="gray",
371+
ax.flat[1].plot(freq_used, mean_psd_tf, c="gray",
370372
label="Interpolated mean PSD")
371373
ax.flat[1].plot(freq_used, mean_psd, c="blue", label="Mean PSD")
372374
ax.flat[1].plot(freq_used, clean_fit_line, c="red", label="Fitted polynomial")
@@ -396,6 +398,9 @@ def nan_basic_interp(array):
396398
plt.show()
397399

398400
if mean_score <= 0:
401+
# Return original data if score is negative on first iteration
402+
if iterations == 0:
403+
return data, 0
399404
break
400405

401406
iterations += 1
@@ -404,4 +409,4 @@ def nan_basic_interp(array):
404409
raise RuntimeError("Could not converge. Consider increasing the "
405410
"maximum number of iterations")
406411

407-
return data, iterations
412+
return data_clean, iterations

tests/test_dss.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,30 @@ def _plot(before, after):
176176
plt.close("all")
177177

178178

179+
def test_dss_line_iter_no_noise():
180+
"""
181+
Test that dss_line_iter returns original data unchanged when DSS
182+
cannot improve the signal.
183+
"""
184+
sr = 200
185+
fline = 50
186+
n_samples = 9000
187+
n_chans = 10
188+
rng = np.random.RandomState(42)
189+
190+
# create data without line noise at target frequency
191+
x = rng.randn(n_samples, n_chans)
192+
x_original = x.copy()
193+
194+
x_out, n_iters = dss.dss_line_iter(x, fline, sr, n_iter_max=10)
195+
196+
assert n_iters == 0, f"Expected 0 iterations (no improvement), got {n_iters}"
197+
assert np.allclose(x_out, x_original), (
198+
"When DSS cannot improve signal, should return original data unchanged"
199+
)
200+
assert np.allclose(x, x_original), "Input data should never be mutated"
201+
202+
179203
def profile_dss_line(nkeep):
180204
"""Test line noise removal."""
181205
import cProfile

0 commit comments

Comments
 (0)