@@ -291,10 +291,10 @@ def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5,
291291 show: bool
292292 Produce a visual output of each iteration (default=False).
293293 dirname: str
294- Path to the directory where visual outputs are saved when show is 'True'.
294+ Path to the directory where visual outputs are saved when show is 'True'.
295295 If 'None', does not save the outputs. (default=None)
296296 extension: str
297- Extension of the images filenames. Must be compatible with plt.savefig()
297+ Extension of the images filenames. Must be compatible with plt.savefig()
298298 function. (default=".png")
299299 n_iter_max : int
300300 Maximum number of iterations (default=100).
@@ -313,11 +313,13 @@ def nan_basic_interp(array):
313313 array [nans ] = np .interp (ix (nans ), ix (~ nans ), array [~ nans ])
314314 return array
315315
316+ data_clean = data .copy ()
317+
316318 freq_rn = [fline - win_sz , fline + win_sz ]
317319 freq_sp = [fline - spot_sz , fline + spot_sz ]
318320 freq , psd = welch (data , fs = sfreq , nfft = nfft , axis = 0 )
319321
320- freq_rn_ix = np .logical_and (freq >= freq_rn [0 ],
322+ freq_rn_ix = np .logical_and (freq >= freq_rn [0 ],
321323 freq <= freq_rn [1 ])
322324 freq_used = freq [freq_rn_ix ]
323325 freq_sp_ix = np .logical_and (freq_used >= freq_sp [0 ],
@@ -338,8 +340,8 @@ def nan_basic_interp(array):
338340 aggr_resid = []
339341 iterations = 0
340342 while iterations < n_iter_max :
341- data , _ = dss_line (data , fline , sfreq , nfft = nfft , nremove = 1 )
342- freq , psd = welch (data , fs = sfreq , nfft = nfft , axis = 0 )
343+ data_clean , _ = dss_line (data_clean , fline , sfreq , nfft = nfft , nremove = 1 )
344+ freq , psd = welch (data_clean , fs = sfreq , nfft = nfft , axis = 0 )
343345 if psd .ndim == 3 :
344346 mean_psd = np .mean (psd , axis = (1 , 2 ))[freq_rn_ix ]
345347 elif psd .ndim == 2 :
@@ -366,7 +368,7 @@ def nan_basic_interp(array):
366368 ax .flat [0 ].set_xlabel ("Frequency (Hz)" )
367369 ax .flat [0 ].set_ylabel ("Power" )
368370
369- ax .flat [1 ].plot (freq_used , mean_psd_tf , c = "gray" ,
371+ ax .flat [1 ].plot (freq_used , mean_psd_tf , c = "gray" ,
370372 label = "Interpolated mean PSD" )
371373 ax .flat [1 ].plot (freq_used , mean_psd , c = "blue" , label = "Mean PSD" )
372374 ax .flat [1 ].plot (freq_used , clean_fit_line , c = "red" , label = "Fitted polynomial" )
@@ -396,6 +398,9 @@ def nan_basic_interp(array):
396398 plt .show ()
397399
398400 if mean_score <= 0 :
401+ # Return original data if score is negative on first iteration
402+ if iterations == 0 :
403+ return data , 0
399404 break
400405
401406 iterations += 1
@@ -404,4 +409,4 @@ def nan_basic_interp(array):
404409 raise RuntimeError ("Could not converge. Consider increasing the "
405410 "maximum number of iterations" )
406411
407- return data , iterations
412+ return data_clean , iterations
0 commit comments