approximate inference in gms samplinglsong/teaching/8803ml/lecture7.pdf · 2012. 2. 2. ·...
Post on 22-Sep-2020
6 Views
Preview:
TRANSCRIPT
Approximate Inference in GMs Sampling
Machine Learning II: Advanced Topics CSE 8803ML, Spring 2012
Le Song
What we have learned so far
Variable elimination and junction tree
Exact inference
Exponential in tree-width
Mean field, loopy belief propagation
Approximate inference for marginals/conditionals
Fast, but can get inaccurate estimates
Sample-based inference
Approximate inference for marginals/conditionals
With “enough” samples, will converge to the right answer
2
Variational Inference
What is the approximating structure?
How to measure the goodness of the approximation of 𝑄 𝑋1, … , 𝑋𝑛 to the original 𝑃 𝑋1, … , 𝑋𝑛 ?
Reverse KL-divergence 𝐾𝐿(𝑄| 𝑃
How to compute the new parameters?
Optimization Q∗ = argminQ 𝐾𝐿(𝑄| 𝑃
3
𝑃
? 𝑄
? 𝑄
? 𝑄
mean field
≈ 𝑁𝑒𝑤 𝑃𝑎𝑟𝑎𝑚𝑒𝑡𝑒𝑟𝑠
Deriving mean field fixed point condition I
Approximate 𝑃 𝑋 ∝ Ψ 𝐷𝑗𝑗 with 𝑄 𝑋 = 𝑄𝑖 𝑋𝑖𝑖
𝑠. 𝑡. ∀𝑋𝑖 , 𝑄𝑖 𝑋𝑖 > 0, ∫ 𝑄𝑖 𝑋𝑖 𝑑𝑋𝑖 = 1
𝐷(𝑄| 𝑃 = −∫ 𝑄 𝑋 ln𝑃 𝑋 𝑑𝑋 + ∫ 𝑄 𝑋 ln𝑄 𝑋 𝑑𝑋 + 𝑐𝑜𝑛𝑠𝑡.
= −∫ 𝑄𝑖 𝑋𝑖 ( lnΨ 𝐷𝑗 )𝑑𝑋 + 𝑗𝑖 ∫ 𝑄𝑖 𝑋𝑖 ( ln𝑄𝑖 𝑋𝑖 )𝑑𝑋𝑖𝑖
= − ∫ 𝑄𝑖 𝑋𝑖 lnΨ 𝐷𝑗 𝑑𝑋𝐷𝑗 𝑖∈𝐷𝑗𝑗 + ∫ 𝑄𝑖 𝑋𝑖 ln 𝑄𝑖 𝑋𝑖 𝑑𝑋𝑖𝑖
4
Ψ(𝑋1, 𝑋2)
Ψ(𝑋1, 𝑋3)
Ψ 𝑋1, 𝑋4
Ψ(𝑋1, 𝑋5) 𝑄1 𝑋1
𝑋1 𝑋2
𝑋3
𝑋4
𝑋5 𝑋6
𝑋7
E.g., ∫ 𝑄𝑖 𝑋𝑖 lnΨ 𝑋1, 𝑋5 𝑖 𝑑𝑋 =
∫ 𝑄1 𝑋1 𝑄5 𝑋5 lnΨ 𝑋1, 𝑋5 𝑑𝑋1𝑑𝑋5
Deriving mean field fixed point condition II
Let 𝐿 be the Lagrangian function (𝑛𝑜 𝑄𝑖 𝑋𝑖 > 0?)
𝐿 =
− ∫ 𝑄𝑖 𝑋𝑖 lnΨ 𝐷𝑗 𝑑𝑋𝐷𝑗 𝑖∈𝐷𝑗𝑗 +
∫ 𝑄𝑖 𝑋𝑖 ln 𝑄𝑖 𝑋𝑖 𝑑𝑋𝑖𝑖 − 𝛽𝑖(1 − ∫ 𝑄𝑖 𝑋𝑖 )𝑑𝑋𝑖𝑖
For KL divergence, the nonnegativity constraint comes for free
If 𝑄 is a maximum for the original constrained problem, then there exists 𝜆𝑖 , 𝛽𝑖 such that (𝑄, 𝜆𝑖 , 𝛽𝑖) is a stationary point for the Lagrange function (derivative equal 0).
5
Deriving mean field fixed point condition III
𝐿 =
− ∫ 𝑄𝑖 𝑋𝑖 lnΨ 𝐷𝑗 𝑑𝑋𝐷𝑗 𝑖∈𝐷𝑗𝑗 + ∫ 𝑄𝑖 𝑋𝑖 ln 𝑄𝑖 𝑋𝑖 𝑑𝑋𝑖𝑖 −
𝛽𝑖(1 − ∫ 𝑄𝑖 𝑋𝑖 )𝑑𝑋𝑖𝑖
𝜕𝐿
𝜕𝑄𝑖(𝑋𝑖)= − ( 𝑄𝑘 𝑋𝑘𝑘∈𝐷𝑗\i
lnΨ(𝐷𝑗)) + ln 𝑄𝑖 𝑋𝑖𝑗:𝑖∈𝐷𝑗−
𝛽𝑖 = 0
𝑄𝑖 𝑋𝑖 = exp ( 𝑄𝑘 𝑋𝑘𝑘∈𝐷𝑗\ilnΨ(𝐷𝑗)) +𝑗:𝑖∈𝐷𝑗
𝛽𝑖 =1
𝑍𝑖exp ( 𝑄𝑘 𝑋𝑘𝑘∈𝐷𝑗\i
lnΨ(𝐷𝑗)) 𝑗:𝑖∈𝐷𝑗
𝑄𝑖 𝑋𝑖 =1
𝑍𝑖exp 𝐸𝑄 lnΨ 𝐷𝑗𝐷𝑗:𝑋𝑖∈𝐷𝑗
6
Mean Field Algorithm
Initialize 𝑄 𝑋1, … , 𝑋𝑛 = 𝑄 𝑋𝑖𝑖 (eg., randomly or smartly)
Set all variables to unprocessed
Pick an unprocessed variable 𝑋𝑖
Update 𝑄𝑖:
𝑄𝑖 𝑋𝑖 =1
𝑍𝑖exp 𝐸𝑄 lnΨ 𝐷𝑗
𝐷𝑗:𝑋𝑖∈𝐷𝑗
Set variable 𝑋𝑖as processed
If 𝑄𝑖 changed
Set neighbors of 𝑋𝑖 to unprocessed
Guaranteed to converge
7
Understanding mean field algorithm
8
𝑄 𝑋2
𝑥2
ln Ψ 𝑋1, 𝑋2
𝑄 𝑋3
𝑥3
ln Ψ 𝑋1, 𝑋3
𝑄 𝑋4
𝑥4
ln Ψ 𝑋1, 𝑋4
𝑄(𝑋5)
𝑥5
ln Ψ 𝑋1, 𝑋5 𝑄1 𝑋1
𝑄1 𝑋1 =1
𝑍1exp 𝐸𝑄(𝑋2) lnΨ 𝑋1, 𝑋2 + 𝐸𝑄(𝑋3) lnΨ 𝑋1, 𝑋3 + 𝐸𝑄(𝑋4) lnΨ 𝑋1, 𝑋4 + 𝐸𝑄(𝑋5) lnΨ 𝑋1, 𝑋5
𝑄2 𝑋2 =1
𝑍1exp 𝐸𝑄(𝑋1) lnΨ 𝑋1, 𝑋2 + 𝐸𝑄(𝑋6) lnΨ 𝑋2, 𝑋6 + 𝐸𝑄(𝑋7) lnΨ 𝑋2, 𝑋7
𝑋1 𝑋2
𝑋3
𝑋4
𝑋5 𝑋6
𝑋7
𝑄 𝑋1
𝑥1
ln Ψ 𝑋1, 𝑋2 𝑄 𝑋7
𝑥7
ln Ψ 𝑋2, 𝑋7
𝑄(𝑋6)
𝑥6
ln Ψ 𝑋1, 𝑋6 𝑄2 𝑋2
𝑋1 𝑋2
𝑋3
𝑋4
𝑋5 𝑋6
𝑋7
Example of Mean Field
Mean field equation
𝑄𝑖 𝑋𝑖 =1
𝑍𝑖exp 𝐸𝑄 lnΨ 𝐷𝑗𝐷𝑗:𝑋𝑖∈𝐷𝑗
Pairwise Markov random field
𝑃 𝑋1, … , 𝑋𝑛 ∝ exp 𝜃𝑖𝑗𝑋𝑖𝑋𝑗 + 𝜃𝑖𝑋𝑖𝑖𝑖𝑗
= exp (𝜃𝑖𝑗𝑋𝑖𝑋𝑗)𝑖𝑗 exp (𝜃𝑖𝑋𝑖)𝑖
The mean field equation has simple form:
𝑄𝑖 𝑋𝑖 =1
𝑍𝑖exp 𝜃𝑖𝑗𝑋𝑖𝑋𝑗𝑄𝑗 𝑋𝑗𝑥𝑗𝑗∈𝑁 𝑖 + 𝜃𝑖𝑋𝑖
=1
𝑍𝑖exp 𝜃𝑖𝑗𝑋𝑖 < 𝑋𝑗 >𝑄𝑗𝑗∈𝑁 𝑖 + 𝜃𝑖𝑋𝑖
9
Ψ 𝑋𝑖 , 𝑋𝑗
lnΨ 𝑋𝑖 , 𝑋𝑗 = 𝜃𝑖𝑗𝑋𝑖𝑋𝑗
Ψ 𝑋𝑖 lnΨ 𝑋𝑖 = 𝜃𝑖𝑋𝑖
𝑄𝑖 𝑋𝑖
< 𝑋𝑗 >𝑄𝑗
< 𝑋𝑘 >𝑄𝑘
< 𝑋𝑙 >𝑄𝑙
< 𝑋𝑠 >𝑄𝑠
Example of generalized mean field
𝑄 𝐶1 ∝exp( 𝜃𝑖𝑋𝑖 +𝑖∈𝐶1
𝜃𝑖𝑗𝑖𝑗 ∈𝐸,𝑖∈𝐶1,𝑗∈𝐶1𝑋𝑖𝑋𝑗 + 𝜃𝑖𝑗𝑋𝑖 < 𝑋𝑗 >𝑄 𝐶𝑘𝑖∈𝐶1,𝑗∈𝐶𝑘,𝑗∈𝑀𝐵 𝐶1
)𝑘
10
𝐶1
𝐶2 𝐶3
𝐶4
𝑀𝐵(𝐶1)
Node potential within 𝐶1
Edge potential within 𝐶1
Mean of variables in Markov blanket
Edge potential across 𝐶1 and 𝐶𝑘
Why Sampling
Previous inference tasks focus on obtaining the entire posterior distribution 𝑃 𝑋𝑖 𝑒
Often we want to take expectations
Mean 𝜇𝑋𝑖|𝑒 = 𝐸 𝑋𝑖 𝑒 = ∫ 𝑋𝑖𝑃 𝑋𝑖 𝑒 𝑑𝑋𝑖
Variance 𝜎𝑋𝑖|𝑒2 = 𝐸 (𝑋𝑖−𝜇𝑋𝑖|𝑒)
2 𝑒 = ∫ (𝑋𝑖−𝜇𝑋𝑖|𝑒)2𝑃 𝑋𝑖 𝑒 𝑑𝑋𝑖
More general 𝐸 𝑓 = ∫ 𝑓 𝑋 𝑃 𝑋|𝑒 𝑑𝑋, can be difficult to do it analytically
Key idea: approximate expectation by sample average
𝐸 𝑓 ≈1
𝑁 𝑓 𝑥𝑖
𝑁
𝑖=1
where 𝑥1, … , 𝑥𝑁 ∼ 𝑃 𝑋|𝑒 independently and identically
11
Sampling
Samples: points from the domain of a distribution 𝑃 𝑋
The higher the 𝑃 𝑥 , the more likely we see 𝑥 in the sample
Mean 𝜇 =1
𝑁 𝑥𝑖𝑖
Variance 𝜎 2 =1
𝑁 (𝑥𝑖−𝜇 )2𝑖
12
𝑃(𝑋)
𝑋
𝑥1 𝑥2 𝑥3 𝑥4 𝑥5 𝑥6
For 𝑛 independent Gaussian variable with 0 mean and unit variance , use 𝑥 ∼ 𝑟𝑎𝑛𝑑𝑛 𝑛, 1 in Matlab
For multivariate Gaussian 𝒩(𝜇, Σ)
𝑦 ∼ 𝑟𝑎𝑛𝑑𝑛 𝑛, 1
Transform the sample: 𝑥 = 𝜇 + Σ1/2𝑦
𝑋 takes finite number of values, e.g., 1, 2, 3 , 𝑃 𝑋 = 1 = 0.7, 𝑃 𝑋 = 2 = 0.2, 𝑃 𝑋 = 3 = 0.1
Draw a sample from uniform distribution U[0,1], 𝑢 ∼ 𝑟𝑎𝑛𝑑
If 𝑢 falls into interval [0, 0.7), output sample 𝑥 = 1
If 𝑢 falls into interval [0.7, 0.9), output sample 𝑥 = 2
If 𝑢 falls into interval [0.9, 1], output sample 𝑥 = 3
Sampling Example
13
Value 3
0.1 0.7
Value 1
0.2
Value 2 0 1
1 2
3
4 5
Generate Samples from Bayesian Networks
BN describe a generative process for observations
First, sort the nodes in topological order
Then, generate sample using this order according to the CPTs
Generate a set of sample for (A, F, S, N, H):
Sample 𝑎𝑖 ∼ 𝑃 𝐴
Sample 𝑓𝑖 ∼ 𝑃 𝐹
Sample 𝑠𝑖 ∼ 𝑃 𝑆 𝑎𝑖 , 𝑓𝑖
Sample 𝑛𝑖 ∼ 𝑃 𝑁 𝑠𝑖
Sample ℎ𝑖 ∼ 𝑃 𝐻 𝑠𝑖
14
𝐴𝑙𝑙𝑒𝑟𝑔𝑦
𝑁𝑜𝑠𝑒
𝑆𝑖𝑛𝑢𝑠
𝐻𝑒𝑎𝑑𝑎𝑐ℎ𝑒
𝐹𝑙𝑢
Challenge in sampling
How to draw sample from a given distribution?
Not all distributions can be trivially sampled
How to make better use of samples (not all samples are equally useful)?
How do know we’ve sampled enough?
15
Sampling Methods
Direct Sampling
Simple
Works only for easy distributions
Rejection Sampling
Create samples like direct sampling
Only count samples consistent with given evidence
Importance Sampling
Create samples like direct sampling
Assign weights to samples
Gibbs Sampling
Often used for high-dimensional problem
Use variables and its Markov blanket for sampling
16
Rejection sampling I
Want to sample from distribution 𝑃 𝑋 = 1
𝑍𝑓(𝑋)
𝑃(𝑋) is difficult to sample, but f(X) is easy to evaluate
Key idea: sample from a simpler distribution Q(X), and then select samples
Rejection sampling (choose 𝑀 s.t. 𝑓 𝑋 ≤ 𝑀 𝑄 𝑋 )
𝑖 = 1
Repeat until 𝑖 = 𝑁
𝑥 ∼ 𝑄 𝑋 , 𝑢 ∼ 𝑈 0,1
if 𝑢 ≤𝑓 𝑋
𝑀 𝑄 𝑋, accept 𝑥 and 𝑖 = 𝑖 + 1; otherwise, reject 𝑥
Reject sample with probability 1 −𝑓 𝑋
𝑀 𝑄 𝑋
17
Rejection sampling II
Sample 𝑥 ∼ 𝑄(𝑋) and reject with probability 1 −𝑓 𝑋
𝑀 𝑄 𝑋
The likelihood of seeing a particular 𝑥
𝑃 𝑥 =
𝑓 𝑥
𝑀𝑄 𝑥𝑄(𝑥)
∫𝑓 𝑋
𝑀𝑄 𝑋𝑄 𝑋 𝑑𝑋
=1
𝑍𝑓(𝑥)
18
𝑓(𝑋)
𝑋
𝑥1 ∼ 𝑄(𝑋)
𝑀 𝑄(𝑋) Between red and blue curves is rejection region
𝑀 𝑄(𝑥1)
𝑓(𝑥1)
𝑢1 ∼ 𝑈[0,1]
Drawback of rejection sampling
Can have very small acceptance rate
Adaptive rejection sampling: use envelop function for Q
19
𝑓(𝑋)
𝑋
𝑥1 ∼ 𝑄(𝑋)
𝑀 𝑄(𝑋)
𝑀 𝑄(𝑥1)
𝑢1 ∼ 𝑈[0,1]
Big rejection region
𝑋
Importance Sampling I
Support sampling from 𝑃 𝑋 is hard
We can sample from a simpler proposal distribution 𝑄 𝑋
If 𝑄 dominates P (i.e., 𝑄 𝑋 > 0 whenever 𝑃 𝑋 > 0), we can sample from 𝑄 and reweight
𝐸𝑃 𝑓 𝑋 = ∫ 𝑓 𝑋 𝑃 𝑋 𝑑𝑋) = ∫ 𝑓(𝑋)𝑃 𝑋
𝑄 𝑋𝑄(𝑋)𝑑𝑋
Sample 𝑥𝑖 ∼ 𝑄 𝑋
Reweight sample 𝑥𝑖 with 𝑤𝑖 =𝑃 𝑥𝑖
𝑄 𝑥𝑖
approximate expectation with 1
𝑁 𝑓 𝑥𝑖 𝑤𝑖𝑖
20
Importance Sampling II
Instead of reject sample, reweight sample instead
The efficiency of importance sampling depends on how close the proposal Q is to the target P
21
𝑃(𝑋)
𝑋
𝑥1 ∼ 𝑄(𝑋)
𝑄(𝑋)
𝑄(𝑥1)
𝑃(𝑥1)
𝑤1 ∼ 𝑃 𝑥1 /𝑄 𝑥1
𝑥2 ∼ 𝑄(𝑋)
𝑄(𝑥2)
𝑃(𝑥2)
𝑤2 ∼ 𝑃 𝑥2 /𝑄 𝑥2
Unnormalized importance sampling
Suppose we only evaluate 𝑃 𝑋 = 1
𝑍𝑓(𝑋) (e.g., grid or MRF, Z
is difficult to compute), and don’t know Z
We can get around the normalization Z by normalize the importance weight
𝑟 𝑋 =𝑓 𝑋
𝑄 𝑋⇒ 𝐸𝑄 𝑟 𝑋 = ∫
𝑓 𝑋
𝑄 𝑋 𝑄 𝑋 𝑑𝑋 = ∫ 𝑓 𝑋 𝑑𝑋 = 𝑍
𝐸𝑃 𝑔 𝑋 = ∫ 𝑔 𝑋 𝑃 𝑋 𝑑𝑋 =1
𝑍∫ 𝑔 𝑋
𝑓 𝑋
𝑄 𝑋𝑄 𝑋 𝑑𝑋 =
∫ 𝑔 𝑋 𝑟 𝑋 𝑄 𝑋 𝑑𝑋
∫ 𝑟 𝑋 𝑄 𝑋 𝑑𝑋≈
𝑔 𝑥𝑖𝑖 𝑟 𝑥𝑖
𝑟 𝑥𝑖𝑖
22
𝑑𝑒𝑓𝑖𝑛𝑒 𝑤𝑖 =𝑟 𝑥𝑖
𝑟 𝑥𝑖𝑖
𝑎𝑠 𝑡ℎ𝑒
𝑛𝑒𝑤 𝑖𝑚𝑝𝑜𝑟𝑡𝑎𝑛𝑐𝑒 𝑤𝑒𝑖𝑔ℎ𝑡
Difference between normalized and unnormalized
Unnormalized importance sampling is unbiased
Normalized importance sampling is biased, eg. for N=1
𝐸𝑄𝑓 𝑥1 𝑤1
𝑤1= 𝐸𝑄 𝑓 𝑥1 ≠ 𝐸𝑃 𝑓 𝑥1
However, the normalized importance sampler usually has lower variance
But most importantly, unnormalized importance sampler work for unnormalized distributions
For unnormalized sampler, one can also do resampling based on 𝑤𝑖 (multinomial distribution with probability 𝑤𝑖 for 𝑥𝑖) 23
Gibbs Sampling
Both rejection sampling and importance sampling do not scale well to high dimensions
Markov Chain Monte Carlo (MCMC) is an alternative
Key idea: Construct a Markov chain whose stationary distribution is the target distribution 𝑃 𝑋
Sampling process: random walk in the Markov chain
Gibbs sampling is a very special and simple MCMC method.
24
Markov Chain Monte Carlo
Wan to sample from 𝑃 𝑋 , start with a random initial vector X
𝑋𝑡: 𝑋 at time step 𝑡
𝑋𝑡 transition to 𝑋𝑡+1 with probability
𝑄(𝑋𝑡+1|𝑋𝑡, … , 𝑋1) = 𝑇 (𝑋𝑡+1|𝑋𝑡)
The stationary distribution of 𝑇 (𝑋𝑡+1|𝑋𝑡) is our 𝑃 𝑋
Run for an intial 𝑀 samples (burn-in time) until the chain converges/mixes/reaches the stationary ditribution
Then collect 𝑁 (correlated) sample as 𝑥𝑖
Key issues: Designing the transition kernel, and diagnose convergence
25
Gibbs Sampling
A very special transition kernel, works nicely with Markov blanket in GMs.
The procedure
We have variables set 𝑋 = 𝑋1, … , 𝑋𝐾 variables in a GM.
At each step, one variable 𝑋𝑖 is selected (at random or some fixed sequence), denote the remaining variables as 𝑋−𝑖, and its
current value as 𝑥−𝑖𝑡
Compute the conditional distribution 𝑃(𝑋𝑖| 𝑥−𝑖𝑡 )
A value 𝑥𝑖𝑡 is sampled from this distribution
This sample 𝑥𝑖𝑡 replaces the previous sampled value of 𝑋𝑖 in 𝑋
26
Gibbs Sampling in formula
Gibbs sampling
𝑋 = 𝑥0
For t = 1 to N
𝑥1𝑡 = 𝑃(𝑋1|𝑥2
𝑡−1, … , 𝑥𝐾𝑡−1)
𝑥2𝑡 = 𝑃(𝑋2|𝑥1
𝑡, … , 𝑥𝐾𝑡−1)
…
𝑥𝐾𝑡 = 𝑃(𝑋2|𝑥1
𝑡 , … , 𝑥𝐾−1𝑡 )
Variants:
Randomly pick variable to sample
sample block by block
27
𝑋1
𝑋2 𝑋3
𝑋4 𝑋5
𝐹𝑜𝑟 𝑔𝑟𝑎𝑝ℎ𝑖𝑐𝑎𝑙 𝑚𝑜𝑑𝑒𝑙𝑠 Only need to condition on the Variables in the Markov blanket
Gibbs Sampling: Examples
Pairwise Markov random field
𝑃 𝑋1, … , 𝑋𝑘 ∝ exp 𝜃𝑖𝑗𝑋𝑖𝑋𝑗 + 𝜃𝑖𝑋𝑖𝑖𝑖𝑗
In the Gibbs sampling step, we need conditional 𝑃(𝑋1|𝑋2, …𝑋𝑘)
𝑃(𝑋1|𝑋2, …𝑋𝑘) ∝exp 𝜃12𝑋1𝑋2 + 𝜃13𝑋1𝑋3 + 𝜃14𝑋1𝑋4 + 𝜃15𝑋1𝑋5 + 𝜃1𝑋1
28
𝑋1 𝑋2
𝑋3
𝑋4
𝑋5 𝑋6
𝑋7
Diagnose convergence
Good chain
29
Iteration number
Sampled Value
Diagnose convergence
Bad chain
30
Iteration number
Sampled Value
Sampling Methods
Direct Sampling
Simple
Works only for easy distributions
Rejection Sampling
Create samples like direct sampling
Only count samples consistent with given evidence
Importance Sampling
Create samples like direct sampling
Assign weights to samples
Gibbs Sampling
Often used for high-dimensional problem
Use variables and its Markov blanket for sampling
31
top related