Skip to content
97 changes: 33 additions & 64 deletions src/spikeinterface/metrics/quality/pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,57 +103,27 @@ def _nn_one_unit(args):
return unit_id, nn_hit_rate, nn_miss_rate


def _nearest_neighbor_metric_function(sorting_analyzer, unit_ids, tmp_data, job_kwargs, **metric_params):
def _nearest_neighbor_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params):
nn_result = namedtuple("NearestNeighborResult", ["nn_hit_rate", "nn_miss_rate"])

# Use pre-computed PCA data
pca_data_per_unit = tmp_data["pca_data_per_unit"]

# Extract job parameters
n_jobs = job_kwargs.get("n_jobs", 1)
mp_context = job_kwargs.get("mp_context", None)

nn_hit_rate_dict = {}
nn_miss_rate_dict = {}

if n_jobs == 1:
# Sequential processing
units_loop = unit_ids

for unit_id in units_loop:
pcs_flat = pca_data_per_unit[unit_id]["pcs_flat"]
labels = pca_data_per_unit[unit_id]["labels"]

try:
nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics(pcs_flat, labels, unit_id, **metric_params)
except:
nn_hit_rate = np.nan
nn_miss_rate = np.nan

nn_hit_rate_dict[unit_id] = nn_hit_rate
nn_miss_rate_dict[unit_id] = nn_miss_rate
else:
if mp_context is not None and platform.system() == "Windows":
assert mp_context != "fork", "'fork' mp_context not supported on Windows!"
elif mp_context == "fork" and platform.system() == "Darwin":
warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS')

# Prepare arguments - only pass pickle-able data
args_list = []
for unit_id in unit_ids:
pcs_flat = pca_data_per_unit[unit_id]["pcs_flat"]
labels = pca_data_per_unit[unit_id]["labels"]
args_list.append((unit_id, pcs_flat, labels, metric_params))
for unit_id in unit_ids:
pcs_flat = pca_data_per_unit[unit_id]["pcs_flat"]
labels = pca_data_per_unit[unit_id]["labels"]

with ProcessPoolExecutor(
max_workers=n_jobs,
mp_context=mp.get_context(mp_context) if mp_context else None,
) as executor:
results = executor.map(_nn_one_unit, args_list)
try:
nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics(pcs_flat, labels, unit_id, **metric_params)
except:
nn_hit_rate = np.nan
nn_miss_rate = np.nan

for unit_id, nn_hit_rate, nn_miss_rate in results:
nn_hit_rate_dict[unit_id] = nn_hit_rate
nn_miss_rate_dict[unit_id] = nn_miss_rate
nn_hit_rate_dict[unit_id] = nn_hit_rate
nn_miss_rate_dict[unit_id] = nn_miss_rate

return nn_result(nn_hit_rate=nn_hit_rate_dict, nn_miss_rate=nn_miss_rate_dict)

Expand All @@ -169,7 +139,6 @@ class NearestNeighbor(BaseMetric):
}
depend_on = ["principal_components"]
needs_tmp_data = True
needs_job_kwargs = True


def _nn_advanced_one_unit(args):
Expand Down Expand Up @@ -394,10 +363,10 @@ def mahalanobis_metrics(all_pcs, all_labels, this_unit_id):
import scipy.stats
import scipy.spatial.distance

pcs_for_this_unit = all_pcs[all_labels == this_unit_id, :]
pcs_for_other_units = all_pcs[all_labels != this_unit_id, :]
pcs_for_this_unit = all_pcs[all_labels == this_unit_id]
pcs_for_other_units = all_pcs[all_labels != this_unit_id]

mean_value = np.expand_dims(np.mean(pcs_for_this_unit, 0), 0)
mean_value = np.mean(pcs_for_this_unit, 0, keepdims=True)

try:
VI = np.linalg.inv(np.cov(pcs_for_this_unit.T))
Expand All @@ -407,14 +376,14 @@ def mahalanobis_metrics(all_pcs, all_labels, this_unit_id):

mahalanobis_other = np.sort(scipy.spatial.distance.cdist(mean_value, pcs_for_other_units, "mahalanobis", VI=VI)[0])

mahalanobis_self = np.sort(scipy.spatial.distance.cdist(mean_value, pcs_for_this_unit, "mahalanobis", VI=VI)[0])

# number of spikes
n = np.min([pcs_for_this_unit.shape[0], pcs_for_other_units.shape[0]])
num_spikes_self = pcs_for_this_unit.shape[0]
num_spikes_other = pcs_for_other_units.shape[0]
n = min(num_spikes_self, num_spikes_other)

if n >= 2:
dof = pcs_for_this_unit.shape[1] # number of features
l_ratio = np.sum(1 - scipy.stats.chi2.cdf(pow(mahalanobis_other, 2), dof)) / mahalanobis_self.shape[0]
l_ratio = np.sum(1 - scipy.stats.chi2.cdf(pow(mahalanobis_other, 2), dof)) / num_spikes_self
isolation_distance = pow(mahalanobis_other[n - 1], 2)
# if math.isnan(l_ratio):
# print("NaN detected", mahalanobis_other, VI)
Expand Down Expand Up @@ -451,18 +420,17 @@ def d_prime_metric(all_pcs, all_labels, this_unit_id) -> float:

X = all_pcs

y = np.zeros((X.shape[0],), dtype="bool")
y[all_labels == this_unit_id] = True
y = all_labels == this_unit_id

lda = LinearDiscriminantAnalysis(n_components=1)

X_flda = lda.fit_transform(X, y)

flda_this_cluster = X_flda[np.where(y)[0]]
flda_other_cluster = X_flda[np.where(np.invert(y))[0]]
flda_this_cluster = X_flda[y]
flda_other_cluster = X_flda[~y]

d_prime = (np.mean(flda_this_cluster) - np.mean(flda_other_cluster)) / np.sqrt(
0.5 * (np.std(flda_this_cluster) ** 2 + np.std(flda_other_cluster) ** 2)
(np.var(flda_this_cluster) + np.var(flda_other_cluster)) / 2
)

return d_prime
Expand Down Expand Up @@ -518,22 +486,23 @@ def nearest_neighbors_metrics(all_pcs, all_labels, this_unit_id, max_spikes, n_n
return 1.0, 0.0

this_unit = all_labels == this_unit_id
this_unit_pcs = all_pcs[this_unit, :]
other_units_pcs = all_pcs[np.invert(this_unit), :]
this_unit_pcs = all_pcs[this_unit]
other_units_pcs = all_pcs[~this_unit]
X = np.concatenate((this_unit_pcs, other_units_pcs), 0)

num_obs_this_unit = np.sum(this_unit)

if ratio < 1:
# Subsample spikes
inds = np.arange(0, X.shape[0] - 1, 1 / ratio).astype("int")
X = X[inds, :]
num_obs_this_unit = int(num_obs_this_unit * ratio)

nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="ball_tree").fit(X)
distances, indices = nbrs.kneighbors(X)
nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(X)
indices = nbrs.kneighbors(return_distance=False) # don't feed X so it won't return itself as neighbor

this_cluster_nearest = indices[:num_obs_this_unit, 1:].flatten()
other_cluster_nearest = indices[num_obs_this_unit:, 1:].flatten()
this_cluster_nearest = indices[:num_obs_this_unit].flatten()
other_cluster_nearest = indices[num_obs_this_unit:].flatten()

hit_rate = np.mean(this_cluster_nearest < num_obs_this_unit)
miss_rate = np.mean(other_cluster_nearest < num_obs_this_unit)
Expand Down Expand Up @@ -970,17 +939,17 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id):
"""
import scipy.spatial.distance

pcs_for_this_unit = all_pcs[all_labels == this_unit_id, :]
centroid_for_this_unit = np.expand_dims(np.mean(pcs_for_this_unit, 0), 0)
pcs_for_this_unit = all_pcs[all_labels == this_unit_id]
centroid_for_this_unit = np.mean(pcs_for_this_unit, 0, keepdims=True)
distances_for_this_unit = scipy.spatial.distance.cdist(centroid_for_this_unit, pcs_for_this_unit)
distance = np.inf

# find centroid of other cluster and measure distances to that rather than pairwise
# if less than current minimum distance update
for label in np.unique(all_labels):
if label != this_unit_id:
pcs_for_other_cluster = all_pcs[all_labels == label, :]
centroid_for_other_cluster = np.expand_dims(np.mean(pcs_for_other_cluster, 0), 0)
pcs_for_other_cluster = all_pcs[all_labels == label]
centroid_for_other_cluster = np.mean(pcs_for_other_cluster, 0, keepdims=True)
distances_for_other_cluster = scipy.spatial.distance.cdist(centroid_for_other_cluster, pcs_for_this_unit)
mean_distance_for_other_cluster = np.mean(distances_for_other_cluster)
if mean_distance_for_other_cluster < distance:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_compute_pc_metrics_multi_processing(small_sorting_analyzer, tmp_path):


if __name__ == "__main__":
from spikeinterface.metrics.tests.conftest import make_small_analyzer
from spikeinterface.metrics.conftest import make_small_analyzer

small_sorting_analyzer = make_small_analyzer()
test_compute_pc_metrics_multi_processing(small_sorting_analyzer)
Loading