bayesian learning and inference in recurrent switching...

6
Bayesian Learning and Inference in Recurrent Switching Linear Dynamical Systems Scott W. Linderman Matthew J. Johnson Andrew C. Miller Columbia University Harvard and Google Brain Harvard University Ryan P. Adams David M. Blei Liam Paninski Harvard and Google Brain Columbia University Columbia University 1 Stochastic Variational Inference The main paper introduces a Gibbs sampling algo- rithm for the recurrent SLDS and its siblings, but it is straightforward to derive a mean field variational inference algorithm as well. From this, we can imme- diately derive a stochastic variational inference (SVI) [Homan et al., 2013] algorithm for conditionally in- dependent time series. We use a structured mean field approximation on the augmented model, p(z 1:T ,x 1:T , ! 1:T , | y 1:T ) q(z 1:T ) q(x 1:T ) q(! 1:T ) q(; ). The first three factors will not be explicitly parame- terized; rather, as with Gibbs sampling, we leverage standard message passing algorithms to compute the necessary expectations with respect to these factors. Moreover, q(! 1:T ) further factorizes as, q(! 1:T )= T Y t=1 K-1 Y k=1 q(! t,k ). To be concrete, we also expand the parameter factor, q(; )= K Y k=1 q(R k ,r k | rec,k ) q(A k ,b k ,B k ; dyn,k ) q(C k ,d k ,D k ; obs,k ). The algorithm proceeds by alternating between opti- mizing q(x 1:T ), q(z 1:T ), q(! 1:T ), and q(). Updating q(x 1:T ). Fixing the factor on the discrete states q(z 1:T ), the optimal variational factor on the continuous states q(x 1:T ) is determined by, ln q(x 1:T )= (x 1 )+ T -1 X t=1 (x t ,x t+1 ) + T -1 X t=1 (x t ,z t+1 , ! t )+ T X t=1 (x t ; y t )+ c. where (x 1 )= E q()q(z) ln p(x 1 | z 1 , ) (1) (x t ,x t+1 )= E q()q(z) ln p(x t+1 | x t ,z t , ), (2) (x t ,z t+1 )= E q()q(z)q(!) ln p(z t+1 | x t ,z t , ! t , ), (3) Because the densities p(x 1 | z 1 , ) and p(x t+1 | x t ,z t , ) are Gaussian exponential families, the expectations in Eqs. (1)-(2) can be computed eciently, yielding Gaussian potentials with natural parameters that de- pend on both q() and q(z 1:T ). Furthermore, each (x t ; y t ) is itself a Gaussian potential. As in the Gibbs sampler, the only non-Gaussian potential comes from the logistic stick breaking model, but once again, the olya-gamma augmentation scheme comes to the res- cue. After augmentation, the potential as a function of x t is, E q()q(z)q(!) ln p(z t+1 | x t ,z t , ! t , ) = - 1 2 T t+1 t t+1 + T t+1 (z t+1 )+ c. Since t+1 = R zt x t + r zt is linear in x t , this is another Gaussian potential. As with the dynamics and ob- servation potentials, the recurrence weights, (R k ,r k ), also have matrix normal factors, which are conjugate after augmentation. We also need access to E q [! t,k ]; we discuss this computation below. After augmentation, the overall factor q(x 1:T ) is a Gaussian linear dynamical system with natural pa- rameters computed from the variational factor on the

Upload: others

Post on 12-Jun-2020

8 views

Category:

Documents


0 download

TRANSCRIPT

Page 1: Bayesian Learning and Inference in Recurrent Switching ...proceedings.mlr.press/v54/linderman17a/linderman17a-supp.pdf · Bayesian Learning and Inference in Recurrent Switching Linear

Bayesian Learning and Inference in

Recurrent Switching Linear Dynamical Systems

Scott W. Linderman

⇤Matthew J. Johnson

⇤Andrew C. Miller

Columbia University Harvard and Google Brain Harvard University

Ryan P. Adams David M. Blei Liam Paninski

Harvard and Google Brain Columbia University Columbia University

1 Stochastic Variational Inference

The main paper introduces a Gibbs sampling algo-rithm for the recurrent SLDS and its siblings, but itis straightforward to derive a mean field variationalinference algorithm as well. From this, we can imme-diately derive a stochastic variational inference (SVI)[Ho↵man et al., 2013] algorithm for conditionally in-dependent time series.

We use a structured mean field approximation on theaugmented model,

p(z1:T , x1:T ,!1:T , ✓ | y1:T )⇡ q(z1:T ) q(x1:T ) q(!1:T ) q(✓; ⌘).

The first three factors will not be explicitly parame-terized; rather, as with Gibbs sampling, we leveragestandard message passing algorithms to compute thenecessary expectations with respect to these factors.Moreover, q(!1:T ) further factorizes as,

q(!1:T ) =

TY

t=1

K�1Y

k=1

q(!t,k

).

To be concrete, we also expand the parameter factor,

q(✓; ⌘) =

KY

k=1

q(Rk

, r

k

| ⌘rec,k

) q(Ak

, b

k

, B

k

; ⌘dyn,k

)

⇥ q(Ck

, d

k

, D

k

; ⌘obs,k

).

The algorithm proceeds by alternating between opti-mizing q(x1:T ), q(z1:T ), q(!1:T ), and q(✓).

Updating q(x1:T ). Fixing the factor on the discretestates q(z1:T ), the optimal variational factor on the

continuous states q(x1:T ) is determined by,

ln q(x1:T ) = (x1) +

T�1X

t=1

(xt

, x

t+1)

+

T�1X

t=1

(xt

, z

t+1,!t

) +

TX

t=1

(xt

; yt

) + c.

where

(x1) = Eq(✓)q(z) ln p(x1 | z1, ✓) (1)

(xt

, x

t+1) = Eq(✓)q(z) ln p(xt+1 |xt

, z

t

, ✓), (2)

(xt

, z

t+1) = Eq(✓)q(z)q(!) ln p(zt+1 |xt

, z

t

,!

t

, ✓), (3)

Because the densities p(x1 | z1, ✓) and p(xt+1 |xt

, z

t

, ✓)are Gaussian exponential families, the expectationsin Eqs. (1)-(2) can be computed e�ciently, yieldingGaussian potentials with natural parameters that de-pend on both q(✓) and q(z1:T ). Furthermore, each (x

t

; yt

) is itself a Gaussian potential. As in the Gibbssampler, the only non-Gaussian potential comes fromthe logistic stick breaking model, but once again, thePolya-gamma augmentation scheme comes to the res-cue. After augmentation, the potential as a functionof x

t

is,

Eq(✓)q(z)q(!) ln p(zt+1 |xt

, z

t

,!

t

, ✓)

= �1

2⌫

T

t+1⌦t

t+1 + ⌫

T

t+1(zt+1) + c.

Since ⌫t+1 = R

ztxt

+ r

zt is linear in x

t

, this is anotherGaussian potential. As with the dynamics and ob-servation potentials, the recurrence weights, (R

k

, r

k

),also have matrix normal factors, which are conjugateafter augmentation. We also need access to E

q

[!t,k

];we discuss this computation below.

After augmentation, the overall factor q(x1:T ) is aGaussian linear dynamical system with natural pa-rameters computed from the variational factor on the

Page 2: Bayesian Learning and Inference in Recurrent Switching ...proceedings.mlr.press/v54/linderman17a/linderman17a-supp.pdf · Bayesian Learning and Inference in Recurrent Switching Linear

Supplementary Material

dynamical parameters q(✓), the variational parameteron the discrete states q(z1:T ), the recurrence poten-tials { (x

t

, z

t

, z

t+1)}T�1t=1 , and the observation model

potentials { (xt

; yt

)}Tt=1.

Because the optimal factor q(x1:T ) is a Gaussian lin-ear dynamical system, we can use message passingto perform e�cient inference. In particular, the ex-pected su�cient statistics of q(x1:T ) needed for up-dating q(z1:T ) can be computed e�ciently.

Updating q(!1:T ). We have,

ln q(!t,k

) = Eq

ln p(zt+1 |!t

, x

t

) + c

= �1

2Eq

[⌫2t+1]!t,k

+ Eq(z1:T ) ln pPG(!t,k

| I[zt+1 � k], 0) + c

While the expectation with respect to q(z1:T ) makesthis challenging, we can approximate it with a sam-ple, z1:T ⇠ q(z1:T ). Given a fixed value z1:T we have,

q(!t,k

) = pPG(!t,k

| I[zt+1 � k],E

q

[⌫2t+1]).

The expected value of the distribution is available inclosed form:

Eq

[!t,k

] =I[z

t+1 � k]

2Eq

[⌫2t+1]

tanh�12Eq

[⌫2t+1]

�.

Updating q(z1:T ). Similarly, fixing q(x1:T ) the op-timal factor q(z1:T ) is proportional to

exp

( (z1) +

T�1X

t=1

(zt

, x

t

, z

t+1) +TX

t=1

(zt

)

),

where

(z1) = Eq(✓) ln p(z1 | ✓) + E

q(✓)q(x) ln p(x1 | z1, ✓) (z

t

, x

t

, z

t+1) = Eq(✓)q(x1:T ) ln p(zt+1 | zt, xt

)

(zt

) = Eq(✓)q(x) ln p(xt+1 |xt

, z

t

, ✓)

The first and third densities are exponential families;these expectations can be computed e�ciently. Thechallenge is the recurrence potential,

(zt

, x

t

, z

t+1) = Eq(✓),q(x) ln⇡SB(⌫t+1).

Since this is not available in closed form, we approxi-mate this expectation with Monte Carlo over x

t

, Rk

,and r

k

. The resulting factor q(z1:T ) is an HMMwith natural parameters that are functions of q(✓) andq(x1:T ), and the expected su�cient statistics requiredfor updating q(x1:T ) can be computed e�ciently bymessage passing in the same manner.

Updating q(✓). To compute the expected su�cientstatistics for the mean field update on ⌘, we can alsouse message passing, this time in both factors q(x1:T )and q(z1:T ) separately. The required expected su�-cient statistics are of the form

Eq(z)I[zt = i, z

t+1 = j], Eq(z)I[zt = i],

Eq(z)I[zt = k]E

q(x)

⇥x

t

x

T

t+1

⇤, (4)

Eq(z)I[zt = k]E

q(x)

⇥x

t

x

T

t

⇤, E

q(z)I[z1 = k]Eq(x)[x1] ,

where I[ · ] denotes an indicator function. Each of thesecan be computed easily from the marginals q(x

t

, x

t+1)and q(z

t

, z

t+1) for t = 1, 2, . . . , T � 1, and thesemarginals can be computed in terms of the respectivegraphical model messages.

Given the conjugacy of the augmented model, the dy-namics and observation factors will be MNIW distri-butions as well. These allow closed form expressionsfor the required expectations,

Eq

[Ak

], Eq

[bk

], Eq

[Ak

B

�1k

], Eq

[bk

B

�1k

], Eq

[B�1k

],

Eq

[Ck

], Eq

[dk

], Eq

[Ck

D

�1k

], Eq

[dk

D

�1k

], Eq

[D�1k

].

Likewise, the conjugate matrix normal prioron (R

k

, r

k

) provides access to

Eq

[Rk

], Eq

[Rk

R

T

k

], Eq

[rk

].

Stochastic Variational Inference. Given multi-ple, conditionally independent observations of time se-

ries, {y(p)1:Tp}Pp=1 (using the same notation as in the bas-

ketball experiment), it is straightforward to derive astochastic variational inference (SVI) algorithm [Ho↵-man et al., 2013]. In each iteration, we sample a ran-dom time series; run message passing to compute the

optimal local factors, q(z(p)1:Tp

), q(x(p)1:Tp

), and q(!(p)1:Tp

);and then use expectations with respect to these lo-cal factors as unbiased estimates of expectations withrespect to the complete dataset when updating theglobal parameter factor, q(✓). Given a single, longtime series, we can still derive e�cient SVI algorithmsthat use subsets of the data, as long as we are willing toaccept minor, controllable bias [Johnson and Willsky,2014, Foti et al., 2014].

2 Initialization

Given the complexity of these models, it is importantto initialize the parameters and latent states with rea-sonable values. We used the following initializationprocedure: (i) use probabilistic PCA or factor analysisto initialize the continuous latent states, x1:T , and theobservation, C, D, and d; (ii) fit an AR-HMM to x1:T

in order to initialize the discrete latent states, z1:T ,

Page 3: Bayesian Learning and Inference in Recurrent Switching ...proceedings.mlr.press/v54/linderman17a/linderman17a-supp.pdf · Bayesian Learning and Inference in Recurrent Switching Linear

Linderman, Johnson, Miller, Adams, Blei, and Paninski

and the dynamics models, {Ak

, Q

k

, b

k

}; and then(iii) greedily fit a decision list with logistic regressionsat each node in order to determine a permutation ofthe latent states most amenable to stick breaking. Inpractice, the last step alleviates the undesirable depen-dence on ordering that arises from the stick breakingformulation.

As mentioned in Section 4, one of the less desirablefeatures of the logistic stick breaking regression modelis its dependence on the ordering of the output dimen-sions; in our case, on the permutation of the discretestates {1, 2, . . . ,K}. To alleviate this issue, we first doa greedy search over permutations by fitting a decisionlist to (x

t

, z

t

), zt+1 pairs. A decision list is an iterative

classifier of the form,

z

t+1 =

8>>>>>><

>>>>>>:

o1 if I[p1]o2 if I[¬p1 ^ p2]

o3 if I[¬p1 ^ ¬p2 ^ p3]...

o

K

o.w.,

where (o1, . . . , oK) is a permutation of (1, . . . ,K),and p1, . . . , pk are predicates that depend on (x

t

, z

t

)and evaluate to true or false. In our case, these predi-cates are given by logistic functions,

p

j

= �(rTj

x

t

) > 0.

We fit the decision list using a greedy approach:to determine o1 and r1, we use maximum a pos-terior estimation to fit logistic regressions for eachof the K possible output values. For the k-th lo-gistic regression, the inputs are x1:T and the out-puts are y

t

= I[zt+1 = k]. We choose the best lo-

gistic regression (measured by log likelihood) as thefirst output. Then we remove those time points forwhich z

t+1 = o1 from the dataset and repeat, fit-ting K � 1 logistic regressions in order to determinethe second output, o2, and so on.

After iterating through all K outputs, we have a per-mutation of the discrete states. Moreover, the predi-cates {r

k

}K�1k=1 serve as an initialization for the recur-

rence weights, R, in our model.

3 Bernoulli-Lorenz Details

The Polya-gamma augmentation makes it easy to han-dle discrete observations in the rSLDS, as illustrated inthe Bernoulli-Lorenz experiment. Since the Bernoulli

likelihood is given by,

p(yt

| zt

, ✓) =

NY

n=1

Bern(�(cTn

x

t

+ d

n

)

=

NY

n=1

(ecT

nxt+dn)yt,n

1 + e

c

T

nxt+dn,

we see that it matches the form of (7) with,

t,n

= c

T

n

x

t

+ d

n

, b(yt,n

) = 1, (yt,n

) = y

t,n

� 1

2.

Thus, we introduce an additional set of Polya-gamma auxiliary variables,

t,n

⇠ PG(1, 0),

to render the model conjugate. Given these auxiliaryvariables, the observation potential is proportional toa Gaussian distribution on x

t

,

(xt

, y

t

) / N (Cx

t

+ d |⌅�1t

(yt

),⌅�1t

),

with

⌅t

= diag ([⇠t,1, . . . , ⇠t,N ]) ,

(yt

) = [(yt,1), . . . ,(yt,N )].

Again, this admits e�cient message passing infer-ence for x1:T . In order to update the auxiliaryvariables, we sample from their conditional distribu-tion, ⇠

t,n

⇠ PG(1, ⌫t,n

).

This augmentation scheme also works for binomial,negative binomial, and multinomial obsevations as well[Polson et al., 2013].

4 Basketball Details

For completeness, Figures 1 and 2 show all K = 30inferred states of the rAR-HMM (ro) for the basketballdata.

References

Nicholas Foti, Jason Xu, Dillon Laird, and Emily Fox.Stochastic variational inference for hidden Markovmodels. In Advances in Neural Information Process-

ing Systems, pages 3599–3607, 2014.

Matthew D Ho↵man, David M Blei, Chong Wang, andJohn William Paisley. Stochastic variational infer-ence. Journal of Machine Learning Research, 14(1):1303–1347, 2013.

Matthew J. Johnson and Alan S. Willsky. Stochasticvariational inference for Bayesian time series mod-els. In Proceedings of the 31st International Confer-

ence on Machine Learning (ICML-14), pages 1854–1862, 2014.

Page 4: Bayesian Learning and Inference in Recurrent Switching ...proceedings.mlr.press/v54/linderman17a/linderman17a-supp.pdf · Bayesian Learning and Inference in Recurrent Switching Linear

Supplementary Material

Nicholas G Polson, James G Scott, and Jesse Windle.Bayesian inference for logistic models using Polya-gamma latent variables. Journal of the American

Statistical Association, 108(504):1339–1349, 2013.

Page 5: Bayesian Learning and Inference in Recurrent Switching ...proceedings.mlr.press/v54/linderman17a/linderman17a-supp.pdf · Bayesian Learning and Inference in Recurrent Switching Linear

Linderman, Johnson, Miller, Adams, Blei, and Paninski

Figure 1: All of the inferred basketball states

Page 6: Bayesian Learning and Inference in Recurrent Switching ...proceedings.mlr.press/v54/linderman17a/linderman17a-supp.pdf · Bayesian Learning and Inference in Recurrent Switching Linear

Supplementary Material

Figure 2: All of the inferred basketball states