diff --git a/mne/viz/eyetracking/__init__.py b/mne/viz/eyetracking/__init__.py index 04bbb1c7216..c3547d0dcf8 100644 --- a/mne/viz/eyetracking/__init__.py +++ b/mne/viz/eyetracking/__init__.py @@ -4,4 +4,4 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from .heatmap import plot_gaze +from .heatmap import plot_gaze, aoi_dwell_time diff --git a/mne/viz/eyetracking/heatmap.py b/mne/viz/eyetracking/heatmap.py index 46aa34b7d31..790d026bda8 100644 --- a/mne/viz/eyetracking/heatmap.py +++ b/mne/viz/eyetracking/heatmap.py @@ -11,6 +11,68 @@ @fill_doc +def aoi_dwell_time( + epochs, + *, + xrange=None, + yrange=None, + avg_samples=False, +): + """ + Compute total dwell time within a rectangular area of interest (AOI). + + Parameters + ---------- + epochs : mne.Epochs + Epochs object containing gaze data and sampling frequency. + xrange : tuple of int | None + Minimum and maximum x-coordinates defining the AOI (xmin, xmax). + yrange : tuple of int | None + Minimum and maximum y-coordinates defining the AOI (ymin, ymax). + avg_samples : bool + If True, return the mean dwell time across trials. If False, return + dwell time per trial. + + Returns + ------- + dwell_times : ndarray, shape (n_trials,) | float + Total dwell time in seconds. Returns an array of per-trial dwell times + if `avg_samples=False`, or a single float (mean dwell time across trials) + if `avg_samples=True`. + """ + from mne._fiff.pick import _picks_to_idx + + # Get the gaze data + pos_picks = _picks_to_idx(epochs.info, "eyegaze") + gaze_data = epochs.get_data(picks=pos_picks) + gaze_ch_loc = np.array([epochs.info["chs"][idx]["loc"] for idx in pos_picks]) + x_data = gaze_data[:, np.where(gaze_ch_loc[:, 4] == -1)[0], :] + y_data = gaze_data[:, np.where(gaze_ch_loc[:, 4] == 1)[0], :] + + if x_data.shape[1] > 1: # binocular recording. Average across eyes + logger.info("Detected binocular recording. Averaging positions across eyes.") + x_data = np.nanmean(x_data, axis=1) # shape (n_epochs, n_samples) + y_data = np.nanmean(y_data, axis=1) + + # check if outside the aoi range + aoi_hitboxes = ( + (x_data >= xrange[0]) + & (x_data <= xrange[1]) + & (y_data >= yrange[0]) + & (y_data <= yrange[1]) + ).astype(int) + + if not avg_samples: + # aoi total dwell time per sample + dwell_times = np.sum(np.squeeze(aoi_hitboxes), axis=1) / epochs.info["sfreq"] + else: + dwell_times = ( + np.mean(np.sum(np.squeeze(aoi_hitboxes), axis=1)) / epochs.info["sfreq"] + ) + + return dwell_times + + def plot_gaze( epochs, *,