Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13754.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix ``mode`` parameter in :func:`~mne.minimum_norm.get_point_spread` and :func:`~mne.minimum_norm.get_cross_talk` to correctly map public mode names (``'max'``, ``'svd'``) to internal names, expose previously hidden modes ``'sum'`` and ``'maxval'`` in the public API, and raise :class:`ValueError` for invalid mode values, by :newcontrib:`Famous077`.
19 changes: 15 additions & 4 deletions mne/minimum_norm/resolution_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,16 @@ def _get_psf_ctf(
# summarise PSFs/CTFs across vertices if requested
pca_var = None # variances computed only if return_pca_vars=True
if mode is not None:
# mapping public mode names to internal names
_mode_map = {
"max": "maxnorm",
"svd": "pca",
"mean": "mean",
"sum": "sum",
"maxval": "maxval",
}
funcs, pca_var = _summarise_psf_ctf(
funcs, mode, n_comp, return_pca_vars, nn
funcs, _mode_map[mode], n_comp, return_pca_vars, nn
)

if not vector: # if one value per vertex requested
Expand Down Expand Up @@ -193,11 +201,14 @@ def _get_psf_ctf(

def _check_get_psf_ctf_params(mode, n_comp, return_pca_vars):
"""Check input parameters of _get_psf_ctf() for consistency."""
if mode in [None, "sum", "mean"] and n_comp > 1:
valid_modes = (None, "mean", "max", "svd", "sum", "maxval")
if mode not in valid_modes:
raise ValueError(f"mode must be one of {valid_modes}, got {mode!r} instead.")
if mode in [None, "mean", "sum"] and n_comp > 1:
msg = f"n_comp must be 1 for mode={mode}."
raise ValueError(msg)
if mode != "pca" and return_pca_vars:
msg = "SVD variances can only be returned if mode=pca."
if mode != "svd" and return_pca_vars:
msg = "SVD variances can only be returned if mode='svd'."
raise ValueError(msg)


Expand Down
14 changes: 7 additions & 7 deletions mne/minimum_norm/tests/test_resolution_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def test_resolution_matrix_free(src_type, fwd_volume_small):
)
assert_array_almost_equal(rm_mne_free, rm_mne_free.T)
# check various summary and normalisation options
for mode in [None, "sum", "mean", "maxval", "maxnorm", "pca"]:
for mode in [None, "mean", "max", "svd", "sum", "maxval"]:
n_comps = [1, 3]
if mode in [None, "sum", "mean"]:
if mode in [None, "mean", "sum"]:
n_comps = [1]
for n_comp in n_comps:
for norm in [None, "max", "norm", True]:
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_resolution_matrix_free(src_type, fwd_volume_small):
# There is an ambiguity in the sign flip from the PCA here.
# Ideally we would use the normals to fix it, but it's not
# trivial.
if mode == "pca" and n_comp == 3:
if mode == "svd" and n_comp == 3:
stc_psf_free = abs(stc_psf_free)
stc_ctf_free = abs(stc_psf_free)
assert_array_almost_equal(
Expand Down Expand Up @@ -184,9 +184,9 @@ def test_resolution_matrix_fixed():
# Some arbitrary vertex numbers
idx = [1, 100, 400]
# check various summary and normalisation options
for mode in [None, "sum", "mean", "maxval", "maxnorm", "pca"]:
for mode in [None, "mean", "max", "svd", "sum", "maxval"]:
n_comps = [1, 3]
if mode in [None, "sum", "mean"]:
if mode in [None, "mean", "sum"]:
n_comps = [1]
for n_comp in n_comps:
for norm in [None, "max", "norm", True]:
Expand Down Expand Up @@ -217,7 +217,7 @@ def test_resolution_matrix_fixed():
rm_mne,
forward_fxd["src"],
idx,
mode=mode,
mode="svd",
n_comp=n_comp,
norm="norm",
return_pca_vars=True,
Expand All @@ -226,7 +226,7 @@ def test_resolution_matrix_fixed():
rm_mne,
forward_fxd["src"],
idx,
mode=mode,
mode="svd",
n_comp=n_comp,
norm="norm",
return_pca_vars=True,
Expand Down
Loading