1- import emcee
1+ import numpy as np
2+ import tensorflow as tf
23import numpy as np
34import tensorflow as tf
45from tensorflow import linalg as tfla
@@ -21,6 +22,17 @@ def __init__(self, interface: TfInterface, n_chains: int = 1024, circular_params
2122 self ._n_chains = n_chains
2223
2324 # Initial states for all chains
25+ initial_state = tf .convert_to_tensor (np .zeros (self ._n_dim ), dtype = tf .float32 )
26+ self ._chain_states = tf .Variable (tf .tile (tf .expand_dims (initial_state , axis = 0 ), [n_chains , 1 ]), dtype = tf .float32 )
27+
28+ # boundaries
29+ self ._upper_bounds = tf .convert_to_tensor (self ._interface .scale_data (self ._interface .chain .upper_bounds [:- 1 ].reshape (1 ,- 1 )), dtype = tf .float32 )
30+ self ._lower_bounds = tf .convert_to_tensor (self ._interface .scale_data (self ._interface .chain .lower_bounds [:- 1 ].reshape (1 ,- 1 )), dtype = tf .float32 )
31+
32+
33+ self ._circular_indices = self ._get_circular_indices (circular_params )
34+ print (self ._circular_indices )
35+
2436 initial_state = tf .convert_to_tensor (np .ones (self ._n_dim ), dtype = tf .float32 )
2537 self ._chain_states = tf .Variable (tf .tile (tf .expand_dims (initial_state , axis = 0 ), [n_chains , 1 ]), dtype = tf .float32 )
2638
@@ -47,6 +59,11 @@ def __init__(self, interface: TfInterface, n_chains: int = 1024, circular_params
4759 shapes = [(self ._n_chains , self ._n_dim )]
4860 )
4961
62+ def _get_circular_indices (self , circular_params : List [str ]):
63+ """Map circular params to indices in self._interface.chain.plot_branches."""
64+ return [self ._interface .chain .plot_branches .index (param ) for param in circular_params ]
65+
66+
5067 def _estimate_batch_size (self ):
5168 """Estimate batch size based on memory available to this process."""
5269 step_size_in_bytes = self ._n_chains * self ._n_dim * tf .float32 .size
@@ -74,6 +91,26 @@ def _calc_likelihood(self, states: tf.Tensor):
7491 def propose_step_gpu (self ):
7592 # Propose new states for all chains
7693 proposed_states = self ._matrix_handler .sample (self ._n_chains ) + self ._chain_states
94+
95+ def apply_circular_bounds (idx ):
96+ # Extract specific bounds for the circular parameter
97+ lower_bound = self ._lower_bounds [0 , idx ]
98+ upper_bound = self ._upper_bounds [0 , idx ]
99+ adjusted_values = lower_bound + tf .math .mod (proposed_states [:, idx ] - upper_bound , upper_bound - lower_bound )
100+ return tf .tensor_scatter_nd_update (
101+ proposed_states ,
102+ indices = [[chain_idx , idx ] for chain_idx in range (self ._n_chains )],
103+ updates = adjusted_values
104+ )
105+
106+ # Apply circular bounds to indices marked as circular
107+ for idx in self ._circular_indices :
108+ proposed_states = apply_circular_bounds (idx )
109+
110+
111+ # Apply boundary conditions
112+ proposed_states = tf .where (proposed_states < self ._lower_bounds , self ._chain_states , proposed_states )
113+ proposed_states = tf .where (proposed_states > self ._upper_bounds , self ._chain_states , proposed_states )
77114
78115 # Calculate log-likelihoods for proposed states
79116 proposed_loglikelihoods = self ._calc_likelihood (proposed_states )
@@ -133,24 +170,33 @@ def _flush_async(self, final_flush=False):
133170 steps_to_write = self ._queue .dequeue_many (self ._batch_size_steps )
134171 end_idx = self ._current_step
135172
136- self ._dataset [:end_idx , :] = steps_to_write
173+ self ._dataset [end_idx - len (steps_to_write ):end_idx , :] = steps_to_write
174+
137175
138176 def save_mcmc_chain_to_pdf (self , filename : str , output_pdf : str ):
139177 # Open the HDF5 file and read the chain
140178 with h5py .File (filename , 'r' ) as f :
141179 chain = f ['chain' ][:]
142180
181+ # Need it to reflect the actual parameters in our fit so let's combine everything!
182+ rescaled_chain = [self ._interface .invert_scaling (chain [1000 :,i ]) for i in range (self ._n_chains )]
183+ combined_rescaled_chain = np .concatenate (rescaled_chain , axis = 0 )
184+
143185 _ , n_params = chain .shape [1 :]
144186
145187 # Create a PdfPages object to save plots
188+ print ("Plotting traces" )
146189 with PdfPages (output_pdf ) as pdf :
190+
191+ # Rescale the chain
192+
147193 for i in tqdm (range (n_params )):
148194 fig , ax = plt .subplots (figsize = (10 , 6 ))
149195
150196 # Plot the chain for the i-th parameter
151- unscaled_data = self ._interface .invert_scaling (chain [:, 0 , i ])
152-
153- ax .plot (unscaled_data , lw = 0.5 , label = f'Chain { i } ' )
197+ # unscaled_data = self._interface.invert_scaling(chain[:, 0, i])
198+ # for n, r in enumerate(rescaled_chain):
199+ ax .plot (rescaled_chain [ 0 ][:, i ], lw = 0.5 , label = f'Chain 0 ' )
154200 ax .set_ylabel (self ._interface .chain .plot_branches [i ])
155201 ax .set_title (f"Parameter { self ._interface .chain .plot_branches [i ]} MCMC Chain" )
156202 ax .set_xlabel ('Step' )
@@ -159,27 +205,67 @@ def save_mcmc_chain_to_pdf(self, filename: str, output_pdf: str):
159205 pdf .savefig (fig )
160206 plt .close (fig ) # Close the figure to save memory
161207
208+
209+ # Create a PdfPages object to save plots
210+ print ("Plotting posteriors" )
211+ with PdfPages (f"posterior_{ output_pdf } " ) as pdf :
212+
213+ # Rescale the chain
214+
215+ for i in tqdm (range (n_params )):
216+ fig , ax = plt .subplots (figsize = (10 , 6 ))
217+
218+ # Plot the chain for the i-th parameter
219+ # unscaled_data = self._interface.invert_scaling(chain[:, 0, i])
220+ l = self ._interface .chain .lower_bounds [i ]
221+ u = self ._interface .chain .upper_bounds [i ]
222+ bins = np .linspace (l , u , 100 )
223+
224+ ax .hist (rescaled_chain [0 ][:, i ], color = 'b' , label = "ML Pred" , alpha = 0.3 , bins = bins , density = True )
225+ ax .hist (self ._interface .test_data .iloc [10000 :,i ].to_numpy (), color = 'r' , label = "Real Result" , alpha = 0.3 , bins = bins , density = True )
226+
227+ ax .set_xlabel (self ._interface .chain .plot_branches [i ])
228+ ax .set_title (f"Parameter { self ._interface .chain .plot_branches [i ]} MCMC Chain" )
229+
230+ ax .legend ()
231+ # Save the current figure to the PDF
232+ pdf .savefig (fig )
233+ plt .close (fig ) # Close the figure to save memory
234+
235+ print ("Plotting AC" )
236+ with PdfPages (f"ac_{ output_pdf } " ) as pdf :
237+ for i in tqdm (range (n_params )):
238+ fig , ax = plt .subplots (figsize = (10 , 6 ))
239+
240+ # Plot the chain for the i-th parameter
241+ # unscaled_data = self._interface.invert_scaling(chain[:, 0, i])
242+ # for n, r in enumerate(rescaled_chain):
243+ ac = sm .tsa .acf (rescaled_chain [0 ][:, i ], nlags = len (rescaled_chain [0 ][:, 1 ]))
244+ ax .plot (ac , lw = 0.5 , label = f'Chain 0' )
245+ ax .set_ylabel (self ._interface .chain .plot_branches [i ])
246+ ax .set_title (f"Parameter { self ._interface .chain .plot_branches [i ]} MCMC Chain" )
247+ ax .set_xlabel ('Autocorrelation' )
248+
249+ # Save the current figure to the PDF
250+ pdf .savefig (fig )
251+ plt .close (fig ) # Close the figure to save memory
252+
162253 print (f"MCMC chain plots saved to { output_pdf } " )
163254
164255 def __call__ (self , n_steps , output_file_name : str ):
165256 print (f"Running MCMC for { n_steps } steps with { self ._n_chains } chains" )
166257
167258 # Open the HDF5 file in append mode
168- with h5py .File (output_file_name , 'a ' ) as f :
259+ with h5py .File (output_file_name , 'w ' ) as f :
169260 # Create or resize the dataset
170- if 'chain' not in f :
171- # If dataset doesn't exist, create it
172- self ._dataset = f .create_dataset ('chain' , (n_steps , self ._n_chains , self ._n_dim ), chunks = True )
173- else :
174- # If dataset exists, resize it
175- self ._dataset = f ['chain' ]
176- self ._dataset .resize ((n_steps , self ._n_chains , self ._n_dim ))
261+ if 'chain' in f :
262+ del f ['chain' ] # Delete if it already exists to avoid appending duplicate data
263+
264+ self ._dataset = f .create_dataset ('chain' , (n_steps , self ._n_chains , self ._n_dim ), chunks = True )
177265
178266 for _ in tqdm (range (n_steps )):
179267 self .propose_step ()
180268
181269 # Ensure remaining steps are flushed to disk
182- # self._flush_async(final_flush=True)
183-
184- # Save the MCMC chain to PDF
270+ self ._flush_async (final_flush = True )
185271 self .save_mcmc_chain_to_pdf (output_file_name , "traces.pdf" )
0 commit comments