2020import matplotlib .pyplot as plt
2121from matplotlib .backends .backend_pdf import PdfPages
2222
23+ from rich import print
24+
2325class FileMLInterface (ABC ):
2426 white_viridis = LinearSegmentedColormap .from_list ('white_viridis' , [
2527 (0 , '#ffffff' ),
@@ -59,7 +61,7 @@ def __init__(self, chain: ChainHandler, prediction_variable: str, fit_name: str)
5961 self ._scaler = StandardScaler ()
6062 # self._pca_matrix = PCA(n_components=0.95)
6163
62- self ._label_scaler = MinMaxScaler ( feature_range = ( 0 , 1 ) )
64+ self ._label_scaler = StandardScaler ( )
6365
6466
6567
@@ -86,19 +88,19 @@ def set_training_test_set(self, test_size: float):
8688 self ._training_data , self ._test_data , self ._training_labels , self ._test_labels = train_test_split (features , labels , test_size = test_size )
8789
8890 # Fit scaling pre-processors. These get applied properly when scale_data is called
89- _ = self ._scaler .fit_transform (self ._training_data )
90- self ._label_scaler .fit_transform (self ._training_labels )
91+ self ._scaler .fit (self ._training_data )
92+ self ._label_scaler .fit (self ._training_labels )
9193
9294 # self._pca_matrix.fit(scaled_training)
9395
9496 def scale_data (self , input_data ):
9597 # Applies transformations to data set
9698 scale_data = self ._scaler .transform (input_data )
97- # scale_data = self._pca_matrix.transform(scale_data)
9899 return scale_data
99100
100101 def scale_labels (self , labels ):
101102 return self ._label_scaler .transform (labels )
103+ # return labels.values.reshape(-1, 1)
102104
103105 def invert_scaling (self , input_data ):
104106 # Inverts transform
@@ -193,7 +195,7 @@ def load_model(self, input_model: str):
193195 :param input_file: Pickled Model
194196 :type input_file: str
195197 """
196- print (f"Attempting to load file from { input_file } " )
198+ print (f"[spring_green1] Attempting to load file from[/spring_green1][bold red3] { input_file } " )
197199 with open (input_model , 'r' ) as f :
198200 self ._model = pickle .load (f )
199201
@@ -216,14 +218,18 @@ def test_model(self):
216218 train_as_numpy = self .scale_labels (self ._training_labels ).T [0 ]
217219 self .evaluate_model (train_prediction , train_as_numpy , "train_qq_plot.pdf" )
218220
219- print ("=====\n \n " )
221+ print ("=====" )
220222 print ("Testing Results!" )
221223
222224 test_prediction = self .model_predict (self ._test_data )
223225 test_as_numpy = self .scale_labels (self ._test_labels ).T [0 ]
224226
225227 self .evaluate_model (test_prediction , test_as_numpy , outfile = f"{ self ._fit_name } " )
226- print ("=====\n \n " )
228+ print ("=====" )
229+
230+
231+ def print_model_summary (self ):
232+ print ("Model Summary" )
227233
228234 def model_predict_single_sample (self , sample ):
229235 sample_shaped = sample .reshape (1 ,- 1 )
@@ -232,7 +238,7 @@ def model_predict_single_sample(self, sample):
232238 def get_maxlikelihood (self )-> OptimizeResult :
233239 init_vals = self .training_data .iloc [[1 ]].to_numpy ()[0 ]
234240
235- print ("Calculating max LLH" )
241+ print ("[bold purple] Calculating max LLH" )
236242 maximal_likelihood = minimize (self .model_predict_single_sample , init_vals , bounds = zip (self ._chain .lower_bounds [:- 1 ], self ._chain .upper_bounds [:- 1 ]), method = "L-BFGS-B" , options = {"disp" : True })
237243 return maximal_likelihood
238244
@@ -245,9 +251,9 @@ def run_likelihood_scan(self, n_divisions: int = 500):
245251
246252 errors = np .sqrt (np .diag (maximal_likelihood .hess_inv (np .identity (self .chain .ndim - 1 ))))
247253
248- print ("Maximal Pars :" )
254+ print ("[bold red3] Maximal Pars :" )
249255 for i in range (self .chain .ndim - 1 ):
250- print (f"Param : { self .chain .plot_branches [i ]} : { maximal_likelihood .x [i ]} ±{ errors [i ]} " )
256+ print (f"[bold red3] Param :[/bold red3] [yellow3] { self .chain .plot_branches [i ]} : { maximal_likelihood .x [i ]} ±{ errors [i ]} " )
251257
252258
253259 with PdfPages ("llh_scan.pdf" ) as pdf :
@@ -285,13 +291,14 @@ def evaluate_model(self, predicted_values: Iterable, true_values: Iterable, outf
285291 :type outfile: str, optional
286292 """
287293
288- print (predicted_values )
289- print (f"Mean Absolute Error : { metrics .mean_absolute_error (predicted_values ,true_values )} " )
290-
294+ print (f"[bold red3]Mean Absolute Error :[/bold red3] [yellow3]{ metrics .mean_absolute_error (predicted_values ,true_values )} " )
291295
296+ outfile_name = outfile .split ("." )[0 ]
297+ outfile = f"{ outfile_name } .pdf"
298+ warnings .filterwarnings ("ignore" , message = "Polyfit may be poorly conditioned" )
292299 lobf = np .poly1d (np .polyfit (predicted_values , true_values , 1 ))
293300
294- print (f"Line of best fit : y={ lobf .c [0 ]} x + { lobf .c [1 ]} " )
301+ print (f"[bold purple] Line of best fit :[/bold purple] [dodger_blue1] y={ lobf .c [0 ]} x + { lobf .c [1 ]} " )
295302
296303 fig = plt .figure ()
297304
@@ -322,11 +329,21 @@ def evaluate_model(self, predicted_values: Iterable, true_values: Iterable, outf
322329 ax .set_ylabel ("True Log Likelihood" )
323330
324331 fig .legend ()
332+
325333 if outfile == "" : outfile = f"evaluated_model_qq_tf.pdf"
326334
327- print (f"Saving QQ to { outfile } " )
335+ print (f"[bold spring_green1] Saving QQ to[/bold spring_green1][dodger_blue1] { outfile } " )
328336
329337 fig .savefig (outfile )
338+
339+ try :
340+ is_notebook = self .is_notebook ()
341+ if is_notebook :
342+ plt .show ()
343+ except Exception :
344+ ...
345+
346+
330347 plt .close ()
331348
332349 # Gonna draw a hist
@@ -335,4 +352,18 @@ def evaluate_model(self, predicted_values: Iterable, true_values: Iterable, outf
335352 plt .hist (difs , bins = 100 , density = True , range = (np .std (difs )* - 5 , np .std (difs )* 5 ))
336353 plt .xlabel ("True - Pred" )
337354 plt .savefig (f"diffs_5sigma_range_{ outfile } " )
338- plt .close ()
355+
356+ plt .close ()
357+
358+ @classmethod
359+ def is_notebook (cls ) -> bool :
360+ try :
361+ shell = get_ipython ().__class__ .__name__
362+ if shell == 'ZMQInteractiveShell' :
363+ return True # Jupyter notebook or qtconsole
364+ elif shell == 'TerminalInteractiveShell' :
365+ return False # Terminal running IPython
366+ else :
367+ return False # Other type (?)
368+ except NameError :
369+ return False # Probably standard Python interpreter
0 commit comments