Skip to content

Commit 4067c9a

Browse files
committed
git checkout from upstream: Fix data loading bug related to cluster_groups and KSLabel df key
1 parent 2ea7e5f commit 4067c9a

1 file changed

Lines changed: 22 additions & 19 deletions

File tree

element_array_ephys/readers/kilosort.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
import logging
2-
import pathlib
3-
import re
4-
from datetime import datetime
51
from os import path
6-
7-
import numpy as np
2+
from datetime import datetime
3+
import pathlib
84
import pandas as pd
9-
5+
import numpy as np
6+
import re
7+
import logging
108
from .utils import convert_to_number
119

1210
log = logging.getLogger(__name__)
@@ -117,7 +115,8 @@ def _load(self):
117115

118116
# Read the Cluster Groups
119117
for cluster_pattern, cluster_col_name in zip(
120-
["cluster_group.*", "cluster_KSLabel.*"], ["group", "KSLabel"]
118+
["cluster_group.*", "cluster_KSLabel.*", "cluster_group.*"],
119+
["group", "KSLabel", "KSLabel"],
121120
):
122121
try:
123122
cluster_file = next(self._kilosort_dir.glob(cluster_pattern))
@@ -126,22 +125,26 @@ def _load(self):
126125
else:
127126
cluster_file_suffix = cluster_file.suffix
128127
assert cluster_file_suffix in (".tsv", ".xlsx")
129-
break
128+
129+
if cluster_file_suffix == ".tsv":
130+
df = pd.read_csv(cluster_file, sep="\t", header=0)
131+
elif cluster_file_suffix == ".xlsx":
132+
df = pd.read_excel(cluster_file, engine="openpyxl")
133+
else:
134+
df = pd.read_csv(cluster_file, delimiter="\t")
135+
136+
try:
137+
self._data["cluster_groups"] = np.array(df[cluster_col_name].values)
138+
self._data["cluster_ids"] = np.array(df["cluster_id"].values)
139+
except KeyError:
140+
continue
141+
else:
142+
break
130143
else:
131144
raise FileNotFoundError(
132145
'Neither "cluster_groups" nor "cluster_KSLabel" file found!'
133146
)
134147

135-
if cluster_file_suffix == ".tsv":
136-
df = pd.read_csv(cluster_file, sep="\t", header=0)
137-
elif cluster_file_suffix == ".xlsx":
138-
df = pd.read_excel(cluster_file, engine="openpyxl")
139-
else:
140-
df = pd.read_csv(cluster_file, delimiter="\t")
141-
142-
self._data["cluster_groups"] = np.array(df[cluster_col_name].values)
143-
self._data["cluster_ids"] = np.array(df["cluster_id"].values)
144-
145148
def get_best_channel(self, unit):
146149
template_idx = self.data["spike_templates"][
147150
np.where(self.data["spike_clusters"] == unit)[0][0]

0 commit comments

Comments
 (0)