1- import logging
2- import pathlib
3- import re
4- from datetime import datetime
51from os import path
6-
7- import numpy as np
2+ from datetime import datetime
3+ import pathlib
84import pandas as pd
9-
5+ import numpy as np
6+ import re
7+ import logging
108from .utils import convert_to_number
119
1210log = logging .getLogger (__name__ )
@@ -117,7 +115,8 @@ def _load(self):
117115
118116 # Read the Cluster Groups
119117 for cluster_pattern , cluster_col_name in zip (
120- ["cluster_group.*" , "cluster_KSLabel.*" ], ["group" , "KSLabel" ]
118+ ["cluster_group.*" , "cluster_KSLabel.*" , "cluster_group.*" ],
119+ ["group" , "KSLabel" , "KSLabel" ],
121120 ):
122121 try :
123122 cluster_file = next (self ._kilosort_dir .glob (cluster_pattern ))
@@ -126,22 +125,26 @@ def _load(self):
126125 else :
127126 cluster_file_suffix = cluster_file .suffix
128127 assert cluster_file_suffix in (".tsv" , ".xlsx" )
129- break
128+
129+ if cluster_file_suffix == ".tsv" :
130+ df = pd .read_csv (cluster_file , sep = "\t " , header = 0 )
131+ elif cluster_file_suffix == ".xlsx" :
132+ df = pd .read_excel (cluster_file , engine = "openpyxl" )
133+ else :
134+ df = pd .read_csv (cluster_file , delimiter = "\t " )
135+
136+ try :
137+ self ._data ["cluster_groups" ] = np .array (df [cluster_col_name ].values )
138+ self ._data ["cluster_ids" ] = np .array (df ["cluster_id" ].values )
139+ except KeyError :
140+ continue
141+ else :
142+ break
130143 else :
131144 raise FileNotFoundError (
132145 'Neither "cluster_groups" nor "cluster_KSLabel" file found!'
133146 )
134147
135- if cluster_file_suffix == ".tsv" :
136- df = pd .read_csv (cluster_file , sep = "\t " , header = 0 )
137- elif cluster_file_suffix == ".xlsx" :
138- df = pd .read_excel (cluster_file , engine = "openpyxl" )
139- else :
140- df = pd .read_csv (cluster_file , delimiter = "\t " )
141-
142- self ._data ["cluster_groups" ] = np .array (df [cluster_col_name ].values )
143- self ._data ["cluster_ids" ] = np .array (df ["cluster_id" ].values )
144-
145148 def get_best_channel (self , unit ):
146149 template_idx = self .data ["spike_templates" ][
147150 np .where (self .data ["spike_clusters" ] == unit )[0 ][0 ]
0 commit comments