diff --git a/copy_export_config b/copy_export_config new file mode 100644 index 00000000..506e0c21 --- /dev/null +++ b/copy_export_config @@ -0,0 +1,11 @@ +#!/bin/bash + +# If this is in a directory with a directory called "original_data" containing your data and an example export_config.py file called "export_config.py", it will copy export_config.py into each directory in original_data and change the EXPERIMENT_NAME to match the name of the directory. + +for dir in original_data/*/ +do + cp export_config.py $dir + dirname=${dir%*/} + strrep="s/EXPERIMENT_NAME/${dir#*/}" + sed -i $strrep "${dir}${exportfile}" +done diff --git a/pcpostprocess/scripts/run_herg_qc.py b/pcpostprocess/scripts/run_herg_qc.py index 4020d4fb..61547764 100644 --- a/pcpostprocess/scripts/run_herg_qc.py +++ b/pcpostprocess/scripts/run_herg_qc.py @@ -107,6 +107,9 @@ def main(): sys.modules['export_config'] = export_config spec.loader.exec_module(export_config) + data_list = os.listdir(args.data_directory) + export_config.D2S_QC = {x: y for x, y in export_config.D2S_QC.items() if + any([x == '_'.join(z.split('_')[:-1]) for z in data_list])} export_config.savedir = args.output_dir args.saveID = export_config.saveID @@ -257,13 +260,13 @@ def main(): times = sorted(res_dict[protocol]) savename = combined_dict[protocol] - readnames.append(protocol) - if len(times) == 2: + readnames.append(protocol) savenames.append(savename) times_list.append(times) elif len(times) == 4: + readnames.append(protocol) savenames.append(savename) times_list.append(times[::2]) @@ -277,6 +280,7 @@ def main(): wells_to_export = wells if args.export_failed else overall_selection logging.info(f"exporting wells {wells}") + logging.info(f"overall selection {overall_selection}") no_protocols = len(res_dict) @@ -628,26 +632,20 @@ def extract_protocol(readname, savename, time_strs, selected_wells, args): row_dict['Cm'] = qc_vals[1] row_dict['Rseries'] = qc_vals[2] - before_params, before_leak = fit_linear_leak(before_current[sweep, :], - voltages, times, - *ramp_bounds, - output_dir=out_dir, - save_fname=f"{well}_sweep{sweep}.png" - ) + before_params, before_leak = fit_linear_leak( + before_current[sweep, :], voltages, times, *ramp_bounds, + output_dir=out_dir, save_fname=f"{well}_sweep{sweep}.png") before_leak_currents.append(before_leak) - out_dir = os.path.join(savedir, - f"{saveID}-{savename}-leak_fit-after") + out_dir = os.path.join(savedir, f"{saveID}-{savename}-leak_fit-after") # Convert linear regression parameters into conductance and reversal row_dict['gleak_before'] = before_params[1] row_dict['E_leak_before'] = -before_params[0] / before_params[1] - after_params, after_leak = fit_linear_leak(after_current[sweep, :], - voltages, times, - *ramp_bounds, - save_fname=f"{well}_sweep{sweep}.png", - output_dir=out_dir) + after_params, after_leak = fit_linear_leak( + after_current[sweep, :], voltages, times, *ramp_bounds, + save_fname=f"{well}_sweep{sweep}.png", output_dir=out_dir) after_leak_currents.append(after_leak) @@ -660,24 +658,20 @@ def extract_protocol(readname, savename, time_strs, selected_wells, args): after_corrected = after_current[sweep, :] - after_leak before_corrected = before_current[sweep, :] - before_leak - E_rev_before = infer_reversal_potential(before_corrected, times, - desc, voltages, plot=True, - output_path=os.path.join(reversal_plot_dir, - f"{well}_{savename}_sweep{sweep}_before"), - known_Erev=args.Erev) - - E_rev_after = infer_reversal_potential(after_corrected, times, - desc, voltages, - plot=True, - output_path=os.path.join(reversal_plot_dir, - f"{well}_{savename}_sweep{sweep}_after"), - known_Erev=args.Erev) - - E_rev = infer_reversal_potential(subtracted_trace, times, desc, - voltages, plot=True, - output_path=os.path.join(reversal_plot_dir, - f"{well}_{savename}_sweep{sweep}_subtracted"), - known_Erev=args.Erev) + E_rev_before = infer_reversal_potential( + before_corrected, times, desc, voltages, plot=True, + output_path=os.path.join(reversal_plot_dir, f"{well}_{savename}_sweep{sweep}_before.png"), + known_Erev=args.Erev) + + E_rev_after = infer_reversal_potential( + after_corrected, times, desc, voltages, plot=True, + output_path=os.path.join(reversal_plot_dir, f"{well}_{savename}_sweep{sweep}_after.png"), + known_Erev=args.Erev) + + E_rev = infer_reversal_potential( + subtracted_trace, times, desc, voltages, plot=True, + output_path=os.path.join(reversal_plot_dir, f"{well}_{savename}_sweep{sweep}_subtracted.png"), + known_Erev=args.Erev) row_dict['R_leftover'] =\ np.sqrt(np.sum((after_corrected)**2)/(np.sum(before_corrected**2))) @@ -732,8 +726,8 @@ def extract_protocol(readname, savename, time_strs, selected_wells, args): row_dict['QC4'] = all([x for x, _ in qc4]) if args.output_traces: - out_fname = os.path.join(traces_dir, - f"{saveID}-{savename}-{well}-sweep{sweep}-subtracted.csv") + out_fname = os.path.join( + traces_dir, f"{saveID}-{savename}-{well}-sweep{sweep}-subtracted.csv") np.savetxt(out_fname, subtracted_trace.flatten()) rows.append(row_dict) @@ -746,13 +740,11 @@ def extract_protocol(readname, savename, time_strs, selected_wells, args): t_step = times[1] - times[0] row_dict['total before-drug flux'] = np.sum(current) * (1.0 / t_step) res = \ - get_time_constant_of_first_decay(subtracted_trace, times, desc, - args=args, - output_path=os.path.join(args.output_dir, - 'debug', - '-120mV time constant', - f"{savename}-{well}-sweep" - "{sweep}-time-constant-fit.pdf")) + get_time_constant_of_first_decay( + subtracted_trace, times, desc, args=args, + output_path=os.path.join( + args.output_dir, 'debug', '-120mV time constant', + f"{savename}-{well}-sweep{sweep}-time-constant-fit.pdf")) row_dict['-120mV decay time constant 1'] = res[0][0] row_dict['-120mV decay time constant 2'] = res[0][1] @@ -789,8 +781,8 @@ def extract_protocol(readname, savename, time_strs, selected_wells, args): voltages, ramp_bounds, well=well, protocol=savename) - fig.savefig(os.path.join(subtraction_plots_dir, - f"{saveID}-{savename}-{well}-sweep{sweep}-subtraction")) + fig.savefig(os.path.join( + subtraction_plots_dir, f"{saveID}-{savename}-{well}-sweep{sweep}-subtraction")) fig.clf() plt.close(fig) @@ -910,19 +902,15 @@ def run_qc_for_protocol(readname, savename, time_strs, args): before_raw = np.array(raw_before_all[well])[sweep, :] after_raw = np.array(raw_after_all[well])[sweep, :] - before_params1, before_leak = fit_linear_leak(before_raw, - voltage, - times, - *ramp_bounds, - save_fname=f"{well}-sweep{sweep}-before.png", - output_dir=savedir) + before_params1, before_leak = fit_linear_leak( + before_raw, voltage, times, *ramp_bounds, + save_fname=f"{well}-sweep{sweep}-before.png", + output_dir=savedir) - after_params1, after_leak = fit_linear_leak(after_raw, - voltage, - times, - *ramp_bounds, - save_fname=f"{well}-sweep{sweep}-after.png", - output_dir=savedir) + after_params1, after_leak = fit_linear_leak( + after_raw, voltage, times, *ramp_bounds, + save_fname=f"{well}-sweep{sweep}-after.png", + output_dir=savedir) before_currents_corrected[sweep, :] = before_raw - before_leak after_currents_corrected[sweep, :] = after_raw - after_leak @@ -1062,26 +1050,15 @@ def qc3_bookend(readname, savename, time_strs, args): save_fname = f"{well}_{savename}_before0.pdf" #  Plot subtraction - get_leak_corrected(first_before_current, - voltage, times, - *ramp_bounds, - save_fname=save_fname, - output_dir=output_directory) - - before_traces_first[well] = get_leak_corrected(first_before_current, - voltage, times, - *ramp_bounds) - - before_traces_last[well] = get_leak_corrected(last_before_current, - voltage, times, - *ramp_bounds) - - after_traces_first[well] = get_leak_corrected(first_after_current, - voltage, times, - *ramp_bounds) - after_traces_last[well] = get_leak_corrected(last_after_current, - voltage, times, - *ramp_bounds) + before_traces_first[well] = get_leak_corrected( + first_before_current, voltage, times, *ramp_bounds, + save_fname=save_fname, output_dir=output_directory) + before_traces_last[well] = get_leak_corrected( + last_before_current, voltage, times, *ramp_bounds) + after_traces_first[well] = get_leak_corrected( + first_after_current, voltage, times, *ramp_bounds) + after_traces_last[well] = get_leak_corrected( + last_after_current, voltage, times, *ramp_bounds) # Store subtracted traces first_processed[well] = before_traces_first[well] - after_traces_first[well] @@ -1104,21 +1081,13 @@ def qc3_bookend(readname, savename, time_strs, args): ax = fig.subplots() for well in args.wells: trace1 = hergqc.filter_capacitive_spikes( - first_processed[well], times, voltage_steps - ).flatten() - + first_processed[well], times, voltage_steps).flatten() trace2 = hergqc.filter_capacitive_spikes( - last_processed[well], times, voltage_steps - ).flatten() - + last_processed[well], times, voltage_steps).flatten() passed = hergqc.qc3(trace1, trace2)[0] - res_dict[well] = passed - - save_fname = os.path.join(args.output_dir, - 'debug', - f"debug_{well}_{savename}", - 'qc3_bookend') + save_fname = os.path.join( + args.output_dir, 'debug', f"debug_{well}_{savename}", 'qc3_bookend') ax.plot(times, trace1) ax.plot(times, trace2) diff --git a/pcpostprocess/subtraction_plots.py b/pcpostprocess/subtraction_plots.py index 23410760..d97704cb 100644 --- a/pcpostprocess/subtraction_plots.py +++ b/pcpostprocess/subtraction_plots.py @@ -1,5 +1,12 @@ +import os +import string + +import matplotlib.pyplot as plt import numpy as np +import pandas as pd from matplotlib.gridspec import GridSpec +from scipy.stats import pearsonr +from syncropatch_export.trace import Trace from .leak_correct import fit_linear_leak @@ -45,20 +52,23 @@ def do_subtraction_plot(fig, times, sweeps, before_currents, after_currents, axs = setup_subtraction_grid(fig, nsweeps) protocol_axs, before_axs, after_axs, corrected_axs, \ subtracted_ax, long_protocol_ax = axs - + first = True for ax in protocol_axs: ax.plot(times*1e-3, voltages, color='black') - ax.set_xlabel('time (s)') - ax.set_ylabel(r'$V_\mathrm{cmd}$ (mV)') + ax.set_xticklabels([]) + + if first: + ax.set_ylabel(r'$V_\mathrm{cmd}$ (mV)') + first = False all_leak_params_before = [] all_leak_params_after = [] for i in range(len(sweeps)): - before_params, _ = fit_linear_leak(before_currents, voltages, times, + before_params, _ = fit_linear_leak(before_currents[i, :], voltages, times, *ramp_bounds) all_leak_params_before.append(before_params) - after_params, _ = fit_linear_leak(after_currents, voltages, times, + after_params, _ = fit_linear_leak(after_currents[i, :], voltages, times, *ramp_bounds) all_leak_params_after.append(after_params) @@ -71,55 +81,71 @@ def do_subtraction_plot(fig, times, sweeps, before_currents, after_currents, b0, b1 = all_leak_params_before[i] gleak = b1 - Eleak = -b1/b0 + Eleak = -b0/b1 before_leak_currents[i, :] = gleak * (voltages - Eleak) b0, b1 = all_leak_params_after[i] gleak = b1 - Eleak = -b1/b0 + Eleak = -b0/b1 after_leak_currents[i, :] = gleak * (voltages - Eleak) + first = True for i, (sweep, ax) in enumerate(zip(sweeps, before_axs)): - gleak, Eleak = all_leak_params_before[i] + b0, b1 = all_leak_params_before[i] ax.plot(times*1e-3, before_currents[i, :], label=f"pre-drug raw, sweep {sweep}") ax.plot(times*1e-3, before_leak_currents[i, :], - label=r'$I_\mathrm{L}$.' f"g={gleak:1E}, E={Eleak:.1e}") - # ax.legend() - - if ax.get_legend(): - ax.get_legend().remove() - ax.set_xlabel('time (s)') - ax.set_ylabel(r'pre-drug trace') + label=r'$I_\mathrm{L}$.' f"g={b1:1E}, E={-b0/b1:.1e}") + ax.set_xticklabels([]) + + if first: + ax.set_ylabel(r'pre-drug trace') + first = False + else: + ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) # ax.tick_params(axis='y', rotation=90) + first = True for i, (sweep, ax) in enumerate(zip(sweeps, after_axs)): - gleak, Eleak = all_leak_params_before[i] + b0, b1 = all_leak_params_after[i] ax.plot(times*1e-3, after_currents[i, :], label=f"post-drug raw, sweep {sweep}") ax.plot(times*1e-3, after_leak_currents[i, :], - label=r"$I_\mathrm{L}$." f"g={gleak:1E}, E={Eleak:.1e}") - # ax.legend() - if ax.get_legend(): - ax.get_legend().remove() - ax.set_xlabel('$t$ (s)') - ax.set_ylabel(r'post-drug trace') + label=r"$I_\mathrm{L}$." f"g={b1:1E}, E={-b0/b1:.1e}") + ax.set_xticklabels([]) + if first: + ax.set_ylabel(r'post-drug trace') + first = False + else: + ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) # ax.tick_params(axis='y', rotation=90) + first = True for i, (sweep, ax) in enumerate(zip(sweeps, corrected_axs)): corrected_before_currents = before_currents[i, :] - before_leak_currents[i, :] corrected_after_currents = after_currents[i, :] - after_leak_currents[i, :] + corrb, _ = pearsonr(corrected_before_currents, voltages) ax.plot(times*1e-3, corrected_before_currents, - label=f"leak-corrected pre-drug trace, sweep {sweep}") + label=f"leak-corrected pre-drug trace, sweep {sweep}, PC={corrb:.2f}") + corra, _ = pearsonr(corrected_after_currents, voltages) ax.plot(times*1e-3, corrected_after_currents, - label=f"leak-corrected post-drug trace, sweep {sweep}") - ax.set_xlabel(r'$t$ (s)') - ax.set_ylabel(r'leak-corrected traces') + label=f"leak-corrected post-drug trace, sweep {sweep}, PC={corra:.2f}") + ax.set_xlabel('time (s)') + if first: + ax.set_ylabel(r'leak-corrected traces') + first = False + + # sortedy = sorted(corrected_after_currents+corrected_before_currents) + # ax.set_ylim(sortedy[60]*1.1, sortedy[-60]*1.1) + ax.legend(bbox_to_anchor=(1.05, 1-0.5*i), loc='upper left') # ax.tick_params(axis='y', rotation=90) # ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) ax = subtracted_ax + ax.axhline(0, linestyle='--', color='lightgrey') + sweep_list = [] + pcs = [] for i, sweep in enumerate(sweeps): before_trace = before_currents[i, :].flatten() after_trace = after_currents[i, :].flatten() @@ -131,15 +157,167 @@ def do_subtraction_plot(fig, times, sweeps, before_currents, after_currents, subtracted_currents = before_currents[i, :] - before_leak_currents[i, :] - \ (after_currents[i, :] - after_leak_currents[i, :]) ax.plot(times*1e-3, subtracted_currents, label=f"sweep {sweep}", alpha=.5) - + corrs, _ = pearsonr(subtracted_currents, voltages) + sweep_list += [sweep] + pcs += [corrs] #  Cycle to next colour ax.plot([np.nan], [np.nan], label=f"sweep {sweep}", alpha=.5) - + # sortedy = sorted(subtracted_currents) + # ax.set_ylim(sortedy[30]*1.1, sortedy[-30]*1.1) ax.set_ylabel(r'$I_\mathrm{obs} - I_\mathrm{L}$ (mV)') - ax.set_xlabel('$t$ (s)') + ax.legend(bbox_to_anchor=(1.05, 0.8), loc='upper left') + ax.set_xticklabels([]) long_protocol_ax.plot(times*1e-3, voltages, color='black') long_protocol_ax.set_xlabel('time (s)') long_protocol_ax.set_ylabel(r'$V_\mathrm{cmd}$ (mV)') long_protocol_ax.tick_params(axis='y', rotation=90) - + fig.tight_layout() + + corr_dict = {'sweeps': sweeps, 'pcs': pcs} + return corr_dict + + +def linear_reg(V, I_obs): + # number of observations/points + n = np.size(V) + + # mean of V and I vector + m_V = np.mean(V) + m_I = np.mean(I_obs) + + # calculating cross-deviation and deviation about V + SS_VI = np.sum(I_obs*V) - n*m_I*m_V + SS_VV = np.sum(V*V) - n*m_V*m_V + + # calculating regression coefficients + b_1 = SS_VI / SS_VV + b_0 = m_I - b_1*m_V + + # return intercept, gradient + return b_0, b_1 + + +def regenerate_subtraction_plots(data_path='.', save_dir='.', processed_path=None, + protocols_in=None, passed_only=False): + ''' + Generate subtraction plots of all sweeps of all experiments in a directory + ''' + data_dir = os.listdir(data_path) + passed_wells = None + passed = '' + if 'passed_wells.txt' in data_dir: + return None + else: + data_dir = [x for x in data_dir if os.path.isdir(os.path.join(data_path, x))] + fig = plt.figure(figsize=[15, 24], layout='constrained') + exp_list = [] + protocol_list = [] + well_list = [] + sweep_list = [] + corr_list = [] + passed_list = [] + + if protocols_in is None: + protocols_in = ['staircaseramp', 'staircaseramp (2)', 'ProtocolChonStaircaseRamp', + 'staircaseramp_2kHz_fixed_ramp', 'staircaseramp (2)_2kHz', + 'staircase-ramp', 'Staircase_hERG'] + for exp in data_dir: + exp_files = os.listdir(os.path.join(data_path, exp)) + exp_files = [x for x in exp_files if any([y in x for y in protocols_in])] + if not exp_files: + continue + protocols = set(['_'.join(x.split('_')[:-1]) for x in exp_files]) + if processed_path: + with open(processed_path+'/'+exp+'/passed_wells.txt', 'r') as file: + passed_wells = file.read() + passed_wells = [x for x in passed_wells.split('\n') if x] + if passed_only: + wells = passed_wells + else: + wells = [row + str(i).zfill(2) for row in string.ascii_uppercase[:16] for i in range(1, 25)] + else: + wells = [row + str(i).zfill(2) for row in string.ascii_uppercase[:16] for i in range(1, 25)] + for prot in protocols: + time_strs = [x.split('_')[-1] for x in exp_files if prot+'_'+x.split('_')[-1] == x] + time_strs = sorted(time_strs) + if len(time_strs) == 2: + time_strs = [time_strs] + elif len(time_strs) == 4: + time_strs = [[time_strs[0], time_strs[2]], [time_strs[1], time_strs[3]]] + for it, time_str in enumerate(time_strs): + filepath_before = os.path.join(data_path, exp, + f"{prot}_{time_str[0]}") + json_file_before = f"{prot}_{time_str[0]}" + before_trace = Trace(filepath_before, json_file_before) + filepath_after = os.path.join(data_path, exp, + f"{prot}_{time_str[1]}") + json_file_after = f"{prot}_{time_str[1]}" + after_trace = Trace(filepath_after, json_file_after) + # traces = {z:[x for x in os.listdir(data_path+'/'+exp+'/traces') + # if x.endswith('.csv') and all([y in x for y in [z+'-','subtracted']])] + # for z in protocols} + times = before_trace.get_times() + voltages = before_trace.get_voltage() + voltage_protocol = before_trace.get_voltage_protocol() + protocol_desc = voltage_protocol.get_all_sections() + ramp_bounds = detect_ramp_bounds(times, protocol_desc) + before_current_all = before_trace.get_trace_sweeps() + after_current_all = after_trace.get_trace_sweeps() + + # Convert everything to nA... + before_current_all = {key: value * 1e-3 for key, value in before_current_all.items()} + after_current_all = {key: value * 1e-3 for key, value in after_current_all.items()} + for well in wells: + sweeps = before_current_all[well].shape[0] + before_current = before_current_all[well] + after_current = after_current_all[well] + sweep_dict = do_subtraction_plot(fig, times, sweeps, before_current, after_current, + voltages, ramp_bounds, well=None, protocol=None) + exp_list += [exp]*len(sweep_dict['sweeps']) + protocol_list += [prot]*len(sweep_dict['sweeps']) + well_list += [well]*len(sweep_dict['sweeps']) + sweep_list += sweep_dict['sweeps'] + corr_list += sweep_dict['pcs'] + if passed_wells: + if well in passed_wells: + passed = 'passed' + else: + passed = 'failed' + passed_list += [passed]*len(sweep_dict['sweeps']) + # fig.savefig(os.path.join(save_dir, + # f"{exp}-{prot}-{well}-sweep{it}-subtraction-{passed}")) + fig.clf() + if passed_wells: + outdf = pd.DataFrame.from_dict({'exp': exp_list, 'protocol': protocol_list, + 'well': well_list, 'sweep': sweep_list, 'pc': corr_list, + 'passed': passed_list}) + else: + outdf = pd.DataFrame.from_dict({'exp': exp_list, 'protocol': protocol_list, + 'well': well_list, 'sweep': sweep_list, 'pc': corr_list}) + outdf.to_csv(os.path.join(save_dir, 'subtraction_results.csv')) + + +def detect_ramp_bounds(times, voltage_sections, ramp_no=0): + """ + Extract the the times at the start and end of the nth ramp in the protocol. + + @param times: np.array containing the time at which each sample was taken + @param voltage_sections 2d np.array where each row describes a segment of the protocol: (tstart, tend, vstart, end) + @param ramp_no: the index of the ramp to select. Defaults to 0 - the first ramp + + @returns tstart, tend: the start and end times for the ramp_no+1^nth ramp + """ + + ramps = [(tstart, tend, vstart, vend) for tstart, tend, vstart, vend + in voltage_sections if vstart != vend] + try: + ramp = ramps[ramp_no] + except IndexError: + print(f"Requested {ramp_no+1}th ramp (ramp_no={ramp_no})," + " but there are only {len(ramps)} ramps") + + tstart, tend = ramp[:2] + + ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)] + return ramp_bounds