-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathChapter 7
More file actions
44 lines (29 loc) · 1.72 KB
/
Chapter 7
File metadata and controls
44 lines (29 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
Chapter 7: Markov Chain Monte Carlo (MCMC) – Part I
For most interesting problems, the posterior is analytically intractable. MCMC methods are algorithms that draw samples from a posterior distribution, allowing us to approximate it.
7.1 The Metropolis-Hastings Algorithm
Metropolis-Hastings is a general MCMC algorithm. It "walks" through the parameter space, and at each step:
It proposes a jump to a new parameter value.
It evaluates how "good" the new spot is by comparing its posterior probability to the current spot.
It decides whether to jump or stay put. It always accepts jumps to better spots, and sometimes accepts jumps to worse spots to ensure it explores the whole distribution.
7.2 Gibbs Sampling
Gibbs sampling is a special case of MCMC for multivariate problems. It works by iteratively sampling each parameter from its full conditional distribution (its distribution given the data and all other parameters).
7.3 Code Example: A Non-Conjugate Model
Let's model data from a Normal distribution where both the mean μ and standard deviation σ are unknown. This requires MCMC.
Python
import pymc as pm
import numpy as np
import arviz as az
# Generate some synthetic data
np.random.seed(42)
data = np.random.normal(loc=10, scale=2, size=100)
with pm.Model() as non_conjugate_model:
# Priors for unknown mean and standard deviation
mu = pm.Normal('mu', mu=0, sigma=10)
sigma = pm.HalfNormal('sigma', sigma=5)
# Likelihood
y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=data)
# PyMC will automatically choose an appropriate MCMC sampler (NUTS)
trace_mcmc = pm.sample(2000, tune=1000)
# We can now analyze the trace object to get posteriors for mu and sigma
az.plot_posterior(trace_mcmc)
plt.show()