structured vaes: composing probabilistic graphical models ... · fective second-order optimization...

19
Structured VAEs: Composing Probabilistic Graphical Models and Variational Autoencoders Matthew J. Johnson MATTJJ @SEAS. HARVARD. EDU Harvard University, Cambridge, MA 02138 David Duvenaud DDUVENAUD@SEAS. HARVARD. EDU Harvard University, Cambridge, MA 02138 Alexander B. Wiltschko AWILTSCH@FAS. HARVARD. EDU Harvard University and Twitter, Cambridge, MA 02138 Sandeep R. Datta SRDATTA@HMS. HARVARD. EDU Harvard Medical School, Boston, MA 02115 Ryan P. Adams RPA@SEAS. HARVARD. EDU Harvard University and Twitter, Cambridge, MA 02138 Abstract We develop a new framework for unsupervised learning that composes probabilistic graphical models with deep learning methods and com- bines their respective strengths. Our method uses graphical models to express structured prob- ability distributions and recent advances from deep learning to learn flexible feature models and bottom-up recognition networks. All compo- nents of these models are learned simultaneously using a single objective, and we develop scal- able fitting algorithms that can leverage natural gradient stochastic variational inference, graph- ical model message passing, and backpropaga- tion with the reparameterization trick. We illus- trate this framework with a new structured time series model and an application to mouse behav- ioral phenotyping. 1. Introduction Unsupervised probabilistic modeling often has two goals: first, to learn a model that is flexible enough to represent complex high-dimensional data, such as images or speech recordings, and second, to learn model structure that is in- terpretable, admits meaningful priors, and generalizes to new tasks. That is, it is often not enough just to learn the probability density of the data: we also want to learn a meaningful representation. Probabilistic graphical models (Koller & Friedman, 2009; Murphy, 2012) provide many tools to build such structured representations, but can be limited in their capacity and may require significant feature engineering before being applied to data. Alternatively, advances in deep learning have yielded not only flexible, scalable generative models for complex data like images but also new techniques for automatic feature learning and bottom-up inference (Kingma & Welling, 2014; Rezende et al., 2014). Consider the problem of learning an unsupervised genera- tive model for a depth video of a tracked freely-behaving mouse, illustrated in Figure 1. Learning interpretable rep- resentations for such data, and studying how those rep- resentations change as the animal’s genetics are edited or its brain chemistry altered, can create powerful be- havioral phenotyping tools for neuroscience and for high- throughput drug discovery (Wiltschko et al., 2015). Each frame of the video is a depth image of a mouse in a par- ticular pose, and so even though each image is encoded as 30 × 30 = 900 pixels, the data lie near a low-dimensional nonlinear manifold. A good generative model must not only learn this manifold but also represent many other salient aspects of the data. For example, from one frame to the next the corresponding manifold points should be close to one another, and in fact the trajectory along the mani- fold may follow very structured dynamics. To inform the structure of these dynamics, a natural class of hypotheses used in ethology and neurobiology (Wiltschko et al., 2015) is that the mouse’s behavior is composed of brief, reused actions, such as darts, rears, and grooming bouts. There- fore a natural representation would include discrete states arXiv:1603.06277v1 [stat.ML] 20 Mar 2016

Upload: others

Post on 06-Aug-2020

2 views

Category:

Documents


0 download

TRANSCRIPT

Page 1: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured VAEs: Composing Probabilistic Graphical Modelsand Variational Autoencoders

Matthew J. Johnson [email protected]

Harvard University, Cambridge, MA 02138

David Duvenaud [email protected]

Harvard University, Cambridge, MA 02138

Alexander B. Wiltschko [email protected]

Harvard University and Twitter, Cambridge, MA 02138

Sandeep R. Datta [email protected]

Harvard Medical School, Boston, MA 02115

Ryan P. Adams [email protected]

Harvard University and Twitter, Cambridge, MA 02138

AbstractWe develop a new framework for unsupervisedlearning that composes probabilistic graphicalmodels with deep learning methods and com-bines their respective strengths. Our methoduses graphical models to express structured prob-ability distributions and recent advances fromdeep learning to learn flexible feature modelsand bottom-up recognition networks. All compo-nents of these models are learned simultaneouslyusing a single objective, and we develop scal-able fitting algorithms that can leverage naturalgradient stochastic variational inference, graph-ical model message passing, and backpropaga-tion with the reparameterization trick. We illus-trate this framework with a new structured timeseries model and an application to mouse behav-ioral phenotyping.

1. IntroductionUnsupervised probabilistic modeling often has two goals:first, to learn a model that is flexible enough to representcomplex high-dimensional data, such as images or speechrecordings, and second, to learn model structure that is in-terpretable, admits meaningful priors, and generalizes tonew tasks. That is, it is often not enough just to learn theprobability density of the data: we also want to learn ameaningful representation. Probabilistic graphical models(Koller & Friedman, 2009; Murphy, 2012) provide many

tools to build such structured representations, but can belimited in their capacity and may require significant featureengineering before being applied to data. Alternatively,advances in deep learning have yielded not only flexible,scalable generative models for complex data like imagesbut also new techniques for automatic feature learning andbottom-up inference (Kingma & Welling, 2014; Rezendeet al., 2014).

Consider the problem of learning an unsupervised genera-tive model for a depth video of a tracked freely-behavingmouse, illustrated in Figure 1. Learning interpretable rep-resentations for such data, and studying how those rep-resentations change as the animal’s genetics are editedor its brain chemistry altered, can create powerful be-havioral phenotyping tools for neuroscience and for high-throughput drug discovery (Wiltschko et al., 2015). Eachframe of the video is a depth image of a mouse in a par-ticular pose, and so even though each image is encoded as30 × 30 = 900 pixels, the data lie near a low-dimensionalnonlinear manifold. A good generative model must notonly learn this manifold but also represent many othersalient aspects of the data. For example, from one frame tothe next the corresponding manifold points should be closeto one another, and in fact the trajectory along the mani-fold may follow very structured dynamics. To inform thestructure of these dynamics, a natural class of hypothesesused in ethology and neurobiology (Wiltschko et al., 2015)is that the mouse’s behavior is composed of brief, reusedactions, such as darts, rears, and grooming bouts. There-fore a natural representation would include discrete states

arX

iv:1

603.

0627

7v1

[st

at.M

L]

20

Mar

201

6

Page 2: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

manifoldcoordinates

imagemanifold

depthvideo

Figure 1. Illustration of unsupervised generative modeling ap-plied to depth video of a mouse. Each video frame lies near alow-dimensional image manifold, and we wish to learn a param-eterization of this manifold in which trajectories can be modeledwith switching linear dynamics.

with each state representing the simple dynamics of a par-ticular primitive action, a representation that would be diffi-cult to encode in an unsupervised recurrent neural networkmodel. These two tasks, of learning the image manifoldand learning a structured dynamics model, are complemen-tary: we want to learn the image manifold not just as a setbut in terms of manifold coordinates in which the struc-tured dynamics model fits the data well. A similar model-ing challenge arises in speech (Hinton et al., 2012), wherehigh-dimensional data lie near a low-dimensional manifoldbecause they are generated by a physical system with rel-atively few degrees of freedom (Deng, 1999) but also in-clude the discrete latent dynamical structure of phonemes,words, and grammar (Deng, 2004).

To address these challenges, we propose a new method todesign and learn such models. Our approach uses graph-ical models for representing structured probability dis-tributions, and uses ideas from variational autoencoders(Kingma & Welling, 2014) for learning not only the non-linear feature manifold but also bottom-up recognition net-works to improve inference. Thus our method enables thecombination of flexible deep learning feature models withstructured Bayesian and even Bayesian nonparametric pri-ors. Our approach yields a single variational inference ob-jective in which all components of the model are learnedsimultaneously. Furthermore, we develop a scalable fittingalgorithm that combines several advances in efficient infer-ence, including stochastic variational inference (Hoffmanet al., 2013), graphical model message passing (Koller &Friedman, 2009), and backpropagation with the reparam-eterization trick (Kingma & Welling, 2014). Thus our al-gorithm can leverage conjugate exponential family struc-ture where it exists to efficiently compute natural gradients

with respect to some variational parameters, enabling ef-fective second-order optimization (Martens, 2015), whileusing backpropagation to compute gradients with respectto all other parameters. We refer to our general approachas the structured variational autoencoder (SVAE). In thispaper we illustrate the SVAE using graphical models basedon switching linear dynamical systems (SLDS) (Murphy,2012; Fox et al., 2011).

Code is available at github.com/mattjj/svae.

2. BackgroundHere we briefly review some of the ideas on which thispaper builds. This section also fixes notation that is reusedthroughout the paper.

2.1. Natural gradient stochastic variational inference

Stochastic variational inference (SVI) (Hoffman et al.,2013) applies stochastic gradient ascent to a mean fieldvariational inference objective in a way that exploits ex-ponential family conjugacy to efficiently compute naturalgradients (Amari, 1998; Martens, 2015). Consider a modelcomposed of global latent variables θ, local latent variablesx = {xn}Nn=1, and data y = {yn}Nn=1,

p(θ, x, y) = p(θ)

N∏n=1

p(xn|θ)p(yn|xn, θ), (1)

where p(θ) is the natural exponential family conjugate priorto the exponential family p(xn, yn|θ),

ln p(θ) = 〈ηθ, tθ(θ)〉 − lnZθ(ηθ) (2)ln p(xn, yn|θ) = 〈ηxy(θ), txy(xn, yn)〉 − lnZxy(ηxy(θ))

= 〈tθ(θ), (txy(xn, yn), 1)〉. (3)

Figure 2 shows the graphical model. The mean fieldvariational inference problem is to approximate the pos-terior p(θ, x | y) with a tractable distribution q(θ, x)by finding a local minimum of the KL divergenceKL(q(θ, x)‖p(θ, x | y)) or, equivalently, using the identity

ln p(y) = KL(q(θ, x)‖p(θ, x | y))+Eq(θ,x)[lnp(θ, x, y)

q(θ, x)

],

to choose q(θ, x) to maximize the objective

L[q(θ, x)] , Eq(θ,x)[lnp(θ, x, y)

q(θ, x)

]≤ ln p(y). (4)

Consider the mean field family q(θ)q(x) = q(θ)∏n q(xn).

Because of the conjugate exponential family structure, theoptimal global mean field factor q(θ) is in the same familyas the prior p(θ),

ln q(θ) = 〈η̃θ, tθ(θ)〉 − lnZθ(η̃θ). (5)

Page 3: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

xn yn

N

Figure 2. Prototypical graphical model for SVI.

The mean field objective on the global variational param-eters η̃θ, optimizing out the local variational factors q(x),can then be written

L(η̃θ) , maxq(x)

Eq(θ)q(x)[lnp(θ, x, y)

q(θ)q(x)

]≤ ln p(y) (6)

and the natural gradient of the objective (6) decomposesinto a sum of local expected sufficient statistics (Hoffmanet al., 2013):

∇̃η̃θL(η̃θ) = ηθ +

N∑n=1

Eq∗(xn)(txy(xn, yn), 1)− η̃θ, (7)

where q∗(xn) is a locally optimal local mean field factorgiven η̃θ. Thus we can compute a stochastic natural gradi-ent update for our global mean field objective by samplinga data minibatch yn, optimizing the local mean field factorq(xn), and computing scaled expected sufficient statistics.

2.2. Variational autoencoders

The variational autoencoder (VAE) (Kingma & Welling,2014; Rezende et al., 2014) is a recent model and varia-tional inference method that links neural network autoen-coders (Vincent et al., 2008) with mean field variationalBayes. Given a high-dimensional dataset y = {yn}Nn=1

such as a collection of images, the VAE models each ob-servation yn in terms of a low-dimensional latent variablexn and a nonlinear observation model with parameters ϑ,

xniid∼ N (0, I), n = 1, 2, . . . , N (8)

yn |xn, ϑ ∼ N (µ(xn;ϑ),Σ(xn;ϑ)) (9)

where µ(xn;ϑ) and Σ(xn;ϑ) might depend on xn througha multilayer perceptron (MLP) with L layers,

h`(xn) = f(W`h`−1(xn) + b`), ` = 1, 2, . . . , L, (10)µ(xn;ϑ) = WµhL(xn) + bµ, (11)Σ(xn;ϑ) = diag(exp(Wσ2hL(xn) + bσ2)), (12)

ϑ = ({(W`, b`)}L`=1, (Wµ, bµ), (Wσ2 , bσ2)), (13)

with h0(xn) = xn and f(x) = tanh(x) elementwise. Be-cause we will reuse this particular MLP construction, weintroduce the notation

(µ(xn;ϑ),Σ(xn;ϑ)) = MLP(xn;ϑ). (14)

yn

xn

N

#

(a) VAE generative model.

yn

xn

N

#

(b) VAE variational family.

Figure 3. Graphical models for the variational autoencoder.

To approximate the posterior p(ϑ, x | y), the variational au-toencoder uses the mean field family

q(ϑ)q(x | y) = q(ϑ)

N∏n=1

q(xn | yn). (15)

A key insight of the variational autoencoder is to use a con-ditional variational density q(xn | yn), where the param-eters of the variational distribution on xn depend on thecorresponding data point yn. In particular, we can takethe mean and covariance parameters of q(xn | yn) to beµ(yn;φ) and Σ(yn;φ), respectively, where

(µ(yn;φ),Σ(yn;φ)) = MLP(yn;φ) (16)

and φ denotes a set of MLP parameters. Thus the varia-tional distribution q(xn | yn) acts like a stochastic encoderfrom an observation to a distribution over latent variables,while the forward model p(yn |xn) acts as a stochastic de-coder from a latent variable value to a distribution overobservations. Furthermore, the functions µ(yn;φ) andΣ(yn;φ) act as a recognition or inference network for ob-servation yn, outputting a distribution over the latent vari-able xn using a feed-forward neural network applied to yn.See Figure 3 for graphical models of the VAE.

The resulting mean field objective expresses a variationalBayesian version of an autoencoder. Using η̃ϑ to denotethe variational parameters for the factor q(ϑ), the varia-tional parameters are then the encoder parameters φ andthe decoder parameters η̃ϑ, and the objective is

L(η̃ϑ, φ) , Eq(ϑ)q(x | y)

[lnp(ϑ)

q(ϑ)+

N∑n=1

lnp(xn, yn |ϑ)

q(xn | yn)

].

To optimize this objective efficiently, Kingma & Welling(2014) applies a reparameterization trick. To simplify no-tation and computation, we take q(ϑ) to be a singular factorq(ϑ) = δϑ∗(ϑ) so that the corresponding variational param-eters η̃ϑ can be denoted ϑ∗ and some terms can be dropped.First, we rewrite the objective as

L(ϑ∗, φ) = Eq(x | y) ln p(y |x, ϑ∗)−KL(q(x | y) ‖ p(x)).

The term KL(q(x | y) ‖ p(x)) is the KL divergence be-tween two Gaussians and its gradient with respect to φ

Page 4: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

can be computed in closed form. To compute stochasticgradients of the expectation term, since a random variablexn ∼ q(xn | yn) can be parameterized as

xn = g(φ, ε) , µq(yn;φ) + Σq(yn;φ)12 ε, ε ∼ N (0, I),

the expectation term can be rewritten in terms of g(φ, ε)and its gradient approximated via Monte Carlo over ε,

∇ϑ∗,φEq ln p(y |x, ϑ∗) ≈N∑n=1

∇ϑ∗,φ ln p(yn | g(φ, ε̂n), ϑ∗)

where ε̂niid∼ N (0, I). Because g(φ, ε) is a differentiable

function of φ, these gradients can be computed using stan-dard backpropagation. For scalability, the sum over datapoints is also approximated via Monte Carlo. General non-singular factors q(ϑ) can also be handled by the reparame-terization trick (Kingma & Welling, 2014, Appendix F).

3. Generative model and variational familyIn this section we develop an SVAE generative model andcorresponding variational family. To be concrete we fo-cus on a particular generative model for time series basedon a switching linear dynamical system (SLDS) (Murphy,2012; Fox et al., 2011), which illustrates how the SVAE canincorporate both discrete and continuous latent variableswith rich probabilistic dependence. To simplify notationwe consider modeling only one data sequence of length T ,denoted y = y1:T .

First, Section 3.1 describes the generative model, whichillustrates the combination of a graphical model express-ing latent structure with a flexible neural net to generateobservations. Next, Section 3.2 describes the structuredvariational family, which leverages both graph-structuredmean field approximations and flexible recognition net-works. Section 3.3 discusses variations and extensions.

3.1. A switching linear dynamical system withnonlinear observations

A switching linear dynamical system (SLDS) representsdata in terms of continuous latent states that evolve ac-cording to a discrete set of linear dynamics. At each timet ∈ {1, 2, . . . , T} there is a discrete-valued latent statezt ∈ {1, 2, . . . ,K}, which indexes the dynamical mode,and a continuous-valued latent state xt ∈ RM that evolvesaccording to that mode’s linear Gaussian dynamics:

xt+1 = A(zt)xt +B(zt)ut, utiid∼ N (0, I), (17)

where A(k), B(k) ∈ RM×M for k = 1, 2, . . . ,K. The dis-crete latent state zt evolves according to Markov dynamics,

zt+1 | zt, π ∼ π(zt) (18)

where π = {π(k)}Kk=1 denotes the Markov transition ma-trix and π(k) ∈ Rn+ is its kth row. The initial states aregenerated separately:

z1 |πinit ∼ πinit, (19)

x1 | z1, µinit,Σinit ∼ N (µ(z1)init ,Σ

(z1)init ). (20)

Thus inferring the latent variables and parameters of anSLDS identifies a set of reused dynamical modes, each de-scribed as a linear dynamical system on latent states, inaddition to Markov switching between different linear dy-namics. We use θ to denote all the dynamical parameters,θ = (π, πinit, {(A(k), B(k), µ

(k)init ,Σ

(k)init )}Kk=1).

At each time t, the continuous latent state xt gives rise toan observation yt that is conditionally Gaussian with meanµ(xt;ϑ) and covariance Σ(xt;ϑ),

yt |xt, ϑ ∼ N (µ(xt;ϑ),Σ(xt;ϑ)). (21)

In a typical SLDS (Fox et al., 2011), µ(xt;ϑ) is linear inxt and Σ(xt;ϑ) is a constant function, so that we can takeϑ = (C,D) for some matrices C and D and write

yt = Cxt +Dvt, vtiid∼ N (0, I). (22)

However, to enable flexible modeling of images and othercomplex features, we allow the dependence to be a moregeneral nonlinear model. In particular, we considerµ(xt;ϑ) and Σ(xt;ϑ) to be MLPs with parameters ϑ asin Eqs. (10)-(13) of Section 2.2,

(µ(xt;ϑ),Σ(xt;ϑ)) = MLP(xt;ϑ). (23)

By allowing a nonlinear emission model, each low-dimensional state xt can be mapped into a high-dimensional observation yt, and hence the model can rep-resent high-dimensional data in terms of structured dynam-ics on a low-dimensional manifold. The overall generativemodel is summarized in Figure 4a.

Note that by construction the density p(z, x | θ) is an expo-nential family. We can choose the prior p(θ) to be a naturalexponential family conjugate prior, writing

ln p(θ) = 〈ηθ, tθ(θ)〉 − lnZθ(ηθ) (24)ln p(z, x|θ) = 〈ηzx(θ), tzx(z, x)〉 − lnZzx(ηzx(θ))

= 〈tθ(θ), (tzx(z, x), 1)〉. (25)

We can also use a Bayesian nonparametric prior, choosingK =∞ and generating the discrete state sequence z1:T ac-cording to an HDP-HMM (Fox et al., 2011). Though we donot discuss the Bayesian nonparametric case further, the al-gorithms we develop here immediately extend to the HDP-HMM using the methods in Johnson & Willsky (2014).

Page 5: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

N

z(n)1 z

(n)2 z

(n)3

z(n)T

x(n)1 x

(n)2 x

(n)3 x

(n)T

y(n)1 y

(n)2 y

(n)3 y

(n)T

#

(a) SLDS generative model with nonlinear observationmodel parameterized by ϑ.

N

z(n)1 z

(n)2 z

(n)3

z(n)T

x(n)1 x

(n)2 x

(n)3 x

(n)T

y(n)1 y

(n)2 y

(n)3 y

(n)T

#

(b) Structured CRF variational family with node po-tentials {ψ(x(n)

t ; y(n)t , φ)}Tt=1 parameterized by φ.

Figure 4. Graphical models for the SLDS generative model and corresponding structured CRF variational family.

(a) LDS(b) GMM (c) G-HMM

Figure 5. Special cases of the SLDS generative model.

This construction contains the generative model of theVAE, described in Section 2.2, as a special case. Specif-ically, the VAE uses the same class of MLP observationmodels, but each latent value xt is modeled as an indepen-dent and identically distributed Gaussian, xt

iid∼ N (0, I),while the SVAE model proposed here allows the xt tohave a rich joint probabilistic structure. The SLDS gen-erative model also includes as special cases the Gaussianmixture model (GMM), Gaussian-emission discrete-stateHMM (G-HMM), and Gaussian linear dynamical system(LDS), and thus the algorithms we develop here for theSLDS directly specialize to these models. See Figure 5 forgraphical models of these special case models.

While using conditionally linear dynamics within eachstate may seem limited, the flexible nonlinear observationdistribution greatly extends the capacity of such models.Indeed, recent work on neural word embeddings (Mikolovet al., 2013) as well as neural image models (Radford et al.,2015) has demonstrated learned latent spaces in which lin-ear structure corresponds to meaningful semantics. For ex-ample, addition and subtraction of word vectors can corre-spond to semantic relationships between words and trans-lation in an image model’s latent space can correspond toan object’s rotation. Therefore linear models in a learnedlatent space can yield significant expressiveness while en-abling fast probabilistic inference, interpretable priors andparameters, and a host of other tools. In particular, lineardynamics allow us to learn or encode information abouttimescales and frequencies: the eigenvalue spectrum ofeach transition matrix A(k) directly represents its charac-

teristic timescales, and so we can control and interpret thestructure of linear dynamics in ways that nonlinear dynam-ics models do not allow.

3.2. Variational family and CRF recognition networks

Here we describe a structured mean field family with whichwe can perform variational inference in the posterior distri-bution of the generative model from Section 3.1. This meanfield family illustrates how an SVAE can leverage not onlygraphical model and exponential family structure but alsolearn bottom-up inference networks. As we show in Sec-tion 4, these structures allow us to compose several effi-cient inference algorithms including SVI, message passing,backpropagation, and the reparameterization trick.

In mean field variational inference, one constructs atractable variational family by breaking dependencies inthe posterior (Wainwright & Jordan, 2008). To constructa structured mean field family for the generative model de-veloped in Section 3.1, we break the posterior dependen-cies between the dynamics parameters θ, the observationparameters ϑ, the discrete state sequence z = z1:T andthe continuous state sequence x = x1:T , writing the cor-responding a factorized density as

q(θ, ϑ, z1:T , x1:T ) = q(θ)q(ϑ)q(z1:T )q(x1:T ). (26)

Note that this structured mean field family does not breakthe dependencies among the discrete states z1:T or amongthe continuous states x1:T as in a naive mean field modelbecause these random variables are highly correlated in theposterior. By preserving joint dependencies across time,these structured factors provide a much more accurate rep-resentation of the posterior while still allowing tractable in-ference via graphical model message passing (Wainwright& Jordan, 2008).

To leverage bottom-up inference networks, we parameter-ize the factor q(x1:T ) as a conditional random field (CRF)(Murphy, 2012). That is, using the fact that the optimal fac-tor q(x1:T ) is Markov according to a chain graph, we write

Page 6: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

it terms of pairwise potentials and node potentials as

q(x1:T ) ∝(T−1∏t=1

ψ(xt, xt+1)

)(T∏t=1

ψ(xt; yt, φ)

)(27)

where the node potential ψ(xt; yt, φ) is a function of theobservation yt. Specifically, we choose each node poten-tial to be a Gaussian factor in which the precision matrixJ(yt;φ) and potential vector h(yt;φ) depend on the corre-sponding observation yt through an MLP,

ψ(xt; yt, φ) ∝ exp

{−1

2xTt J(yt;φ)xt + h(yt;φ)Txt

},

(h(yt;φ), J(yt;φ)) = MLP(yt;φ), (28)

using the notation from Section 2.2. These local recogni-tion networks allow us to fit a regression from each obser-vation yt to a probabilistic guess at the corresponding la-tent state xt. Using graphical model inference, these localguesses can be synthesized with the dynamics model into acoherent joint factor q(x1:T ) over the entire state sequence.While we could write q(x1:T | y1:T ) to emphasize this de-pendence and match the notation in Section 2.2, we onlywrite q(x1:T ) to simplify notation. The overall variationalfamily is summarized in Figure 4b.

There are many alternative designs for such recognitionnetworks; for example, the node potential on xt need notdepend only on the observation yt. See Section 3.3.

This structured mean field family can be directly comparedto the fully factorized family used in the variational autoen-coder described in Section 2.2. That is, there is no graphstructure among the latent variables of the VAE. The SVAEgeneralizes the VAE by allowing the output of the recogni-tion network to be arbitrary potentials in a graphical model,such as the node potentials considered here. Furthermore,in the SVAE some of the graphical model potentials areinduced by the probabilistic model rather than being theoutput of a recognition network; for example, the optimalpairwise potentials ψ(xt, xt+1) are induced by the varia-tional factors on the dynamical parameters and latent dis-crete states, q(θ) and q(z1:T ), and the forward generativemodel p(z, x | θ) (see Section 4.2.1). Thus the SVAE pro-vides a way to combine bottom-up information from flex-ible inference networks with top-down information fromother latent variables in a structured probabilistic model.

When p(θ) is chosen to be a conjugate prior, as in Eq. (24),the optimal factor q(θ) is in the same exponential family:

q(θ) = exp {〈η̃θ, tθ(θ)〉 − logZθ(η̃θ)} , (29)

where η̃θ denotes the variational natural parameters. Tosimplify notation, as in Section 2.2 we take the variationalfactor on the observation parameters to be a singular distri-bution, q(ϑ) = δϑ∗(ϑ). The mean field objective in terms

(a) (b)

Figure 6. Alternative SLDS recognition models.

(a)(b)

Figure 7. Variations on the SLDS generative model.

of the global variational parameters η̃θ, ϑ∗, and φ is then

L(η̃θ, ϑ∗, φ), max

q(z)q(x)Eq(θ)q(z)q(x)

[lnp(θ, z, x)p(y |x, ϑ∗)

q(θ)q(z)q(x)

](30)

where, as in Eq. (6), the maximization is over the free pa-rameters of the local variational factors q(z) and q(x) givenfixed values of η̃θ, ϑ∗, and φ. In Section 4 we show how tooptimize this variational objective.

3.3. Variations and extensions

The SVAE construction is general: it can admit many latentprobabilistic models as well as many flexible observationmodels and recognition network designs. In this sectionwe outline some of these possibilities.

While the SVAE recognition network described in Sec-tion 3.2 only produces node potentials depending on singledata points, in general SVAE recognition networks can out-put potentials on more than one node or take as input morethan one data point. For example, a recognition networkcould output node potentials of the form φ(xt; yt, yt−1)that depend also on the previous data point, as sketched inFigure 6a, or even depend on many data points through a re-current neural network (RNN). Recognition networks mayalso output factors on discrete latent variables, as sketchedin Figure 6b.

The observation model may also be extended, for exampleto allow data to be generated according to a nonlinear au-toregression, as sketched in Figure 7a. This extension maybe useful when modeling video so that a frame can be gen-erated in part by reference to the preceding frame.

Finally, the SVAE also admits many generative models,

Page 7: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

both structured probabilistic graphical models and moreflexible alternatives. As mentioned in Section 3.2, theSLDS described here includes many common structuredgraphical models as special cases, shown in Figure 5.For example, by using a GMM inside an SVAE, onecan express a warped mixture model (Iwata et al., 2013).The SLDS also lends itself to many extensions, includingexplicit-duration semi-Markov modeling of discrete states(Johnson & Willsky, 2013) or IO-HMMs (Bengio & Fras-coni, 1995), sketched in Figure 7b. More generally, SVAEsprovide a direct way to incorporate hierarchical Bayesianmodeling, which enables many ways to share latent ran-dom variables and generalize across datasets. Becausethe SVAE can draw on the vast literature on probabilisticgraphical models, there is a rich library of potential SVAEs.Finally, though we focus on traditional probabilistic graph-ical models here, the SVAE generative model can also bean RNN expressing flexible nonlinear dynamics (Archeret al., 2015), though such models would not in general ad-mit the efficient inference algorithms and natural gradientsprovided by exponential families.

4. Learning and inferenceIn this section we give an efficient algorithm for comput-ing stochastic gradients of the SVAE objective in Eq. (30).These stochastic gradients can be used in a generic op-timization routine such as stochastic gradient ascent orAdam (Kingma & Ba, 2015).

As we show, the SVAE algorithm is essentially a combi-nation of SVI (Hoffman et al., 2013) and AEVB (Kingma& Welling, 2014), described in Sections 2.1 and 2.2, re-spectively. By drawing on SVI, our algorithm is able toexploit exponential family conjugacy structure, when it isavailable, to efficiently compute natural gradients with re-spect to some variational parameters. Because natural gra-dients are adapted to the geometry of the variational familyand are invariant to model reparameterizations (Amari &Nagaoka, 2007) natural gradient ascent provides an effec-tive second-order optimization method (Martens & Grosse,2015; Martens, 2015). By drawing on AEVB, our algo-rithm can fit both general nonlinear observation models andflexible bottom-up recognition networks.

The algorithm is split into two parts. First, in Section 4.1we describe the general algorithm for computing gradi-ents of the objective in terms of the results from a model-specific inference subroutine. Next, in Section 4.2 we detailthis model inference subroutine for the SLDS.

4.1. SVAE algorithm

Here we show how to compute stochastic gradients of theSVAE mean field objective (30) using the results of a model

inference subroutine. The algorithm is summarized in Al-gorithm 1 and implemented in svae.py.

For scalability, the stochastic gradients used here are com-puted on minibatches of data. To simplify notation, we as-sume the dataset is a collection ofN sequences, {y(n)}Nn=1,each of length T . We sample one sequence y(n) uni-formly at random and compute a stochastic gradient withit. It is also possible to sample subsequences and computecontrollably-biased stochastic gradients (Foti et al., 2014).

The SVAE algorithm computes a natural gradient with re-spect to η̃θ and standard gradients with respect to ϑ∗ and φ.To compute these gradients, as in Section 2.2 we split theobjective L(η̃θ, ϑ

∗, φ) as

Eq(x) ln p(y |x, ϑ∗)−KL(q(θ, z, x) ‖ p(θ, z, x)). (31)

Note that only the second term depends on the variationaldynamics parameter η̃θ. Furthermore, it is a KL diver-gence between two members of the same exponential fam-ily (Eqs. (24) and (29)), and so as in Hoffman et al. (2013)and Section 2.1 we can write the natural gradient of (31)with respect to η̃θ as

∇̃η̃θL = ηθ +

N∑n=1

Eq(z)q(x)(tzx(z(n), x(n)), 1)− η̃θ (32)

where q(z) and q(x) are taken to be locally optimal localmean field factors as in Eq. (7). Therefore by sampling thesequence index n uniformly at random, an unbiased esti-mate of the natural gradient is given by

∇̃η̃θL ≈ ηθ +NEq(z)q(x)(tzx(z(n), x(n)), 1)− η̃θ. (33)

We abbreviate t̄(n)zx , Eq(z)q(x)tzx(z(n), x(n)). These ex-pected sufficient statistics are computed efficiently usingthe model inference subroutine described in Section 4.2.

To compute the gradient of the SVAE objective with re-spect to the observation parameters ϑ∗, note that only thefirst term in (31) depends on ϑ∗. As in Section 2.2, we cancompute a Monte Carlo approximation to this gradient us-ing samples {x̂(n)s (φ)}Ss=1 where x̂(n)s (φ)

iid∼ q(x(n)) andwe write x̂(n)s (φ) to note the dependence on φ. These sam-ples are generated efficiently (and as differentiable func-tions of φ) by the model inference subroutine described inSection 4.2. To simplify notation we take S = 1.

Finally, we compute gradients with respect to the recog-nition network parameters φ. Both the first and secondterms of (31) depend on φ, the first term through the sam-ple x̂(n)(φ) and the second term through the KL divergenceKL(φ) , KL(q(θ, z(n), x(n)) ‖ p(θ, z(n), x(n))). Thus wemust differentiate through the procedures that the modelinference subroutine uses to compute these quantities. As

Page 8: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

Algorithm 1 Computing gradients of the SVAE objective

Input: Variational dynamics parameter η̃θ of q(θ), obser-vation model parameter ϑ∗, recognition network param-eters φ, sampled sequence y(n)

function SVAEGRADIENTS(η̃θ, ϑ∗, φ, y(n))ψ ← RECOGNIZE(y(n), φ)

(x̂(n)(φ), t̄(n)zx , KL(φ))← INFERENCE(η̃θ, ψ)

∇̃η̃θL ← ηθ +N(t̄(n)zx , 1)− η̃θ

∇ϑ∗,φL ←∇ϑ∗,φ

[N ln p(y(n)| x̂(n)(φ), ϑ∗)−KL(φ)

]returnnatural gradient ∇̃η̃θL, gradient∇ϑ∗,φL

end function

described in Section 4.2, performing this differentiationefficiently for the SLDS corresponds to backpropagationthrough message passing.

To observe that both x̂(n)(φ) and KL(φ) can be computeddifferentiably by the model inference subroutine of theSLDS, and to relate this process to the reparameterizationtrick of Section 2.2, note that the objective only dependson φ through the optimal factor q(x(n)1:T ), which is a jointGaussian factor over the latent continuous states. That is,we can write the sample x̂(n)(φ) ∼ q(x(n)1:T ) as

x̂(n)(φ) = g(φ, ε) , J−1(φ)h(φ) + J(φ)−12 ε (34)

where ε ∼ N (0, I), h(φ) ∈ RTM , and J(φ) ∈ RTM×TM .The matrix J(φ) is a block tridiagonal matrix correspond-ing to the Gaussian LDS of (27), the block diagonal ofwhich depends on φ. Since g(φ, ε) is differentiable withrespect to φ, this representation shows that x̂(n)(φ) is dif-ferentiable with respect to φ. Given a sample ε̂, to evalu-ate∇φg(φ, ε̂) efficiently we can exploit the block tridiago-nal structure of J(φ), which corresponds exactly to back-propagating through the message passing procedure used tosample q(x(n)1:T ). Similarly, to evaluate the gradient of thesecond term of (31), which is computed analytically, wesimply backpropagate through the message passing routineused to compute it. Both message passing computations arepart of the model inference subroutine described in Sec-tion 4.2. Computing each of these gradients using stan-dard backpropagation tools requires twice the time of themessage passing algorithms used to compute them. Whenthe reparameterization trick is not available, the score func-tion estimator can be used instead (Kleijnen & Rubinstein,1996; Gelman & Meng, 1998; Ranganath et al., 2014)

4.2. Model inference subroutine

Here we describe the model inference subroutine used bythe SVAE algorithm.

Because the VAE corresponds to a particular SVAE withlimited latent probabilistic structure, this inference sub-

Algorithm 2 Model inference subroutine for the SLDS

Input: Variational dynamics parameter η̃θ of q(θ), nodepotentials {ψ(xt; yt)}Tt=1 from recognition networkfunction INFERENCE(η̃θ, {ψ(xt; yt)}Tt=1)

Initialize factor q(x)repeatq(z) ∝ exp{Eq(θ)q(x)ln p(z, x | θ)}q(x) ∝ exp{Eq(θ)q(z)ln p(x | z, θ)}

∏t ψ(xt; yt)

until q(z) and q(x) convergex̂← sample q(x)t̄zx ← Eq(z)q(x)tzx(z, x)KL← KL(q(θ) ‖ p(θ))

+NEq(θ) KL(q(z)q(x) ‖ p(z, x | θ))returnsample x̂, expected stats t̄zx, divergence KL

end function

routine can be viewed as a generalization of two steps inthe AEVB algorithm (Kingma & Welling, 2014). Specif-ically, using the notation of Section 2.2, in the spe-cial case of the VAE the inference subroutine computesKL(q(x | y) ‖ p(x)) in closed form and generates samplesx ∼ q(x | y) used to approximate Eq(x) ln p(y |x). How-ever, the inference subroutine of an SVAE may in gen-eral perform other computations: first, because the SVAEcan include other latent random variables and graph struc-ture, the inference subroutine may optimize local meanfield factors or run message passing. Second, because theSVAE can perform stochastic natural gradient updates onthe global factor q(θ), the inference subroutine may alsocompute expected sufficient statistics.

In this subsection, we detail these computations for theSLDS model. Specifically, the inference subroutine mustfirst optimize the local mean field factors q(z1:T ) andq(x1:T ), then compute and return a sample x̂1:T ∼ q(x1:T ),expected sufficient statistics Eq(z)q(x)tzx(z, x), and a KLdivergence KL(q(θ, z, x) ‖ p(θ, z, x)). To simplify nota-tion, we drop the sequence index n, writing y in place ofy(n). The algorithm is summarized in Algorithm 2 and im-plemented in slds svae.py.

4.2.1. OPTIMIZING LOCAL MEAN FIELD FACTORS

As in the SVI algorithm of Section 2.1, for a given data se-quence y we need to optimize the local mean field factorsq(z) and q(x). That is, for a fixed global parameter factorq(θ) with natural parameter η̃θ and fixed node potentialsψ = {ψ(xt; yt)}Tt=1 output by the recognition network,we optimize the variational objective with respect to boththe local variational factor on discrete latent states q(z1:T )and the local variational factor on continuous latent statesq(x1:T ). This optimization can be performed efficiently byexploiting the SLDS exponential family form and the struc-tured variational family.

Page 9: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

The algorithm proceeds by alternating between optimizingq(x1:T ) and optimizing q(z1:T ). Fixing the factor on thediscrete states q(z1:T ), the optimal variational factor on thecontinuous states q(x1:T ) is proportional to

exp

{lnψ(x1) +

T−1∑t=1

lnψ(xt, xt+1) +

T∑t=1

lnψ(xt; yt)

},

where

ψ(x1) = Eq(θ)q(z) ln p(x1 | z1, θ) (35)ψ(xt, xt+1) = Eq(θ)q(z) ln p(xt+1 |xt, zt, θ), (36)

Because the densities p(x1 | z1, θ) and p(xt+1 |xt, zt, θ)are Gaussian exponential families, the expectations inEqs. (35)-(36) can be computed efficiently, yielding Gaus-sian potentials with natural parameters that depend on bothq(θ) and q(z1:T ). Furthermore, because each ψ(xt; yt) isitself a Gaussian potential, the overall factor q(x1:T ) isa Gaussian linear dynamical system with natural parame-ters computed from the variational factor on the dynami-cal parameters q(θ), the variational parameter on the dis-crete states q(z1:T ) and the recognition model potentials{ψ(xt; yt)}Tt=1.

Because the optimal factor q(x1:T ) is a Gaussian linear dy-namical system, we can use message passing to performefficient inference. In particular, the expected sufficientstatistics of q(x1:T ) needed for updating q(z1:T ) can becomputed efficiently. Message passing can also be usedto draw samples or compute the log partition function effi-ciently, as we describe in Section 4.2.2.

Similarly, fixing q(x1:T ) the optimal factor q(z1:T ) is pro-portional to

exp

{lnψ(z1) +

T−1∑t=1

lnψ(zt, zt+1) +

T∑t=1

lnψ(zt)

},

where

ψ(z1) = Eq(θ) ln p(z1 | θ) + Eq(θ)q(x) ln p(x1 | z1, θ)ψ(zt, zt+1) = Eq(θ) ln p(zt+1 | zt) (37)

ψ(zt) = Eq(θ)q(x) ln p(xt+1 |xt, zt, θ)Again, because the densities involved are exponential fam-ilies, these expectations can be computed efficiently. Theresulting factor q(z1:T ) is an HMM with natural parametersthat are functions of q(θ) and q(x1:T ). Both the expectedsufficient statistics required for updating q(x1:T ) as well asthe log partition function required in Section 4.2.2 can becomputed efficiently by message passing.

4.2.2. SAMPLES, EXPECTED STATISTICS, AND KL

After optimizing the local variational factors, the model in-ference subroutine uses the optimized factors to draw sam-ples, compute expected sufficient statistics, and compute

a KL divergence. The results of these inference computa-tions, which we detail here, are then used to compute gra-dients of the SVAE objective as described in Section 4.1.

To draw S samples {x̂(s)}Ss=1 where x̂(s)1:Tiid∼ q(x1:T ), we

can perform message passing in q(x1:T ), which is a Gaus-sian distribution that is Markov with respect to a chaingraph. Given two neighboring nodes i and j we define themessage from i to j as

mi→j(xj) ,∫ψ(xi; yi)ψ(xi, xj)mk→i(xi) dxi (38)

where k is the other neighbor of node i. We can pass mes-sages backward once, running a Kalman filter, and thensample forward S times. That is, after computing the mes-sages mt+1→t(xt) for t = T − 1, . . . , 1, we compute themarginal distribution on x1 as

q(x1) ∝ ψ(x1)ψ(x1; y1)m2→1(x1) (39)

and sample x̂1 ∼ q(x1). Iterating, given a sample of x̂t−1,we sample x̂t from the conditional distributions

q(xt | x̂1:t−1) ∝ ψ(x̂t−1, xt)ψ(xt; yt)mt+1→t(xt),

q(xT | x̂1:T−1) ∝ ψ(x̂T−1, xT )ψ(xT ; yT ), (40)

thus constructing a joint sample x̂1:T ∼ q(x1:T ).

To compute the expected sufficient statistics for the naturalgradient step on η̃θ we can also use message passing, thistime in both factors q(x1:T ) and q(z1:T ) separately. Recallthe exponential family notation from Eqs. (24)-(25). Therequired expected sufficient statistics, which we abbreviateas t̄zx , Eq(x)q(z)tzx(z, x), are of the form

Eq(z)I[zt = i, zt+1 = j], Eq(z)I[zt = i],

Eq(z)I[zt = k]Eq(x)[xtx

Tt+1

], (41)

Eq(z)I[zt = k]Eq(x)[xtx

Tt

], Eq(z)I[z1 = k]Eq(x)[x1] ,

where I[ · ] denotes an indicator function. Each of thesecan be computed easily from the marginals q(xt, xt+1) andq(zt, zt+1) for t = 1, 2, . . . , T − 1, and these marginalscan be computed in terms of the respective graphicalmodel messages. For example, to compute the marginalq(xt, xt+1), we write

q(xt, xt+1) ∝ ψ(xt; yt)ψ(xt, xt+1)ψ(xt+1; yt+1)

·mt−1→t(xt)mt+2→t+1(xt+1). (42)

Because each of the factors is a Gaussian potential, theproduct yields a Gaussian marginal with natural parameterscomputed from the sum of the factors’ natural parameters.The corresponding expected sufficient statistics can then becomputed by converting from natural parameters to meanparameters. The computations for q(zt, zt+1) are similar.

Page 10: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

Finally, to compute the required KL divergence efficiently,we split KL(q(θ, z, x) ‖ p(θ, z, x)) into three terms:

KL(q(θ, z, x) ‖ p(θ, z, x)) = KL(q(θ) ‖ p(θ)) + (43)Eq(θ) KL(q(z)‖p(z | θ)) + Eq(θ)q(z) KL(q(x)‖p(x | z, θ)).

Each term is a divergence between members of the sameexponential family and thus can be computed in terms ofnatural parameters, expected sufficient statistics, and logpartition functions. In particular, the first term is

Eq(θ)[lnq(θ)

p(θ)

]= 〈η̃θ − ηθ, t̄θ〉−(lnZθ(η̃θ)− lnZθ(ηθ)),

t̄θ , Eq(θ)tθ(θ). (44)

Because the factors q(z) and q(x) are locally optimal forq(θ) (Beal, 2003), the second and third terms are simplythe log partition functions of q(z) and q(x), respectively,each of which can be computed efficiently from messagepassing, along with some linear terms. Note that the thirdterm in the KL divergence (43) is a function of the recog-nition network parameters φ through the conditional nodepotentials {ψ(xt; yt, φ)}Tt=1 in q(x).

5. Related workIn addition to the papers already referenced, there are sev-eral recent papers to which this work is related. The twopapers closest to this work are Krishnan et al. (2015) andArcher et al. (2015).

In Krishnan et al. (2015) the authors consider combiningvariational autoencoders with continuous state-space mod-els, emphasizing the relationship to linear dynamical sys-tems (also called Kalman filter models). They primarily fo-cus on nonlinear dynamics and an RNN-based variationalfamily, as well as allowing control inputs (analogous to theIO-HMM extension mentioned in Section 3.3 and in Fig-ure 7b, though with no discrete states and with the covari-ates influencing the continuous states). The approach inKrishnan et al. (2015) covers many nonlinear probabilisticstate-space models but does not extend to general graphi-cal models and in particular discrete latent variables. Thatapproach also does not leverage SVI, natural gradients, orefficient message passing inference in the case of (switch-ing) linear dynamics. However, their nonlinear dynamcismodel is more general than the conditionally linear dynam-ics we explicitly consider here. The authors also emphasizecausal and counterfactual inference.

In Archer et al. (2015) the authors also consider the prob-lem of variational inference in general continuous statespace models but focus on using a structured Gaussian vari-ational family. They show how to extend the AEVB al-gorithm of Kingma & Welling (2014) to leverage block

tridiagonal-structured precision (inverse covariance) matri-ces, and hence develop a method more suitable for timeseries and state space models. Leveraging this tridiago-nal structure in the AEVB algorithm is the same as thebackpropagation through message passing in q(x1:T ) dis-cussed in Section 4.2.2, though our algorithm also back-propagates through message passing in the discrete statefactor q(z1:T ). The authors do not discuss learning dy-namics models; instead, they focus on the case of infer-ence with known nonlinear dynamics. As with Krishnanet al. (2015), their model does not include discrete latentvariables (or any latent variables other than the continuousstates) and does not discuss SVI for learning. The over-all algorithm of Archer et al. (2015) can be most directlycompared to part of the model inference subroutine we dis-cuss in Section 4.2, namely to the steps of optimizing thechain-structured factor q(x1:T ) and performing efficient in-ference in it, with the key difference of emphasizing non-linear dynamics. Indeed, the inference algorithm of Archeret al. (2015) can be used in an SVAE to handle inferencewith conditionally nonlinear dynamics, and so these worksare complementary. Notably, Section 4.1 and specificallyEq. (16) of Krishnan et al. (2015) demonstrate a recogni-tion network that outputs pairwise potentials (i.e. dynamicspotentials).

There are many other recent related works, especially lever-aging new deep learning ideas. In particular, both Gregoret al. (2015) and Chung et al. (2015) extend the variationalautoencoder framework to sequential models, though theyfocus on models based on RNNs rather than probabilisticgraphical models. More classical approaches for nonlin-ear filtering, smoothing, and model fitting are surveyed inKrishnan et al. (2015).

6. ExperimentsIn this section we apply the SVAE to both synthetic andreal data and demonstrate its ability to learn both rich fea-ture representations and simple latent dynamics. First, weapply a linear dynamical system (LDS) SVAE to syntheticdata and illustrate some aspects of its training dynamics.Second, we apply an LDS SVAE and switching linear dy-namical system (SLDS) SVAE to modeling depth videorecordings of mouse behavior.

6.1. Image of a moving dot

As a synthetic data example, consider a sequence of 1Dimages representing a dot bouncing from one side of theimage to the other, as shown in the top panel of Figure 8a.While these images are simple, the task of modeling suchsequences captures many salient aspects of the SVAE: tomake predictions far into the future with coherent uncer-tainty estimates, we can use an LDS SVAE to find a low-

Page 11: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

dimensional latent state space representation, along with anonlinear image model and a simple model of dynamics inthe latent state space.

We trained the LDS SVAE on 80 random image sequenceseach of length 50, using one sequence per update, and showthe model’s future predictions given a prefix of a longer se-quence. We used MLP image and recognition models eachwith one hidden layer of 50 units. Figures 8a and 8b showpredictions from an LDS SVAE at two stages of training.Taken from an early stage in training, Figure 8a shows thatthe LDS SVAE first learns an autoencoder that can repre-sent the image data accurately. However, while the latentspace allows the images to be encoded and decoded ac-curately, the latent space trajectories are not well modeledby a linear dynamical system and hence future predictionsare highly uncertain. As training proceeds, the SVAE ad-justs the latent space representation so that the trajectoriescan be accurately modeled and predicted using an LDS, asshown in Figure 8b, and hence long-range predictions be-come very accurate. Figure 9 shows the same experimentrepeated with a higher resolution image.

Figure 10 shows predictions from the fit LDS SVAE basedon revealing data prefixes of different lengths. Note thatwhen only a single frame is revealed, the model is uncertainas to whether the dot is moving upwards or downwards inthe image. However, instead of showing a bifurcation in thepredictions corresponding to these two distinct hypotheses,because the latent dynamics are linear the model’s predic-tions encompass all possible trajectories between these twohypotheses. With a nonlinear dynamics model, such as aswitching LDS (SLDS) or a more general recurrent neu-ral network (RNN), a model that explicitly represents thisbimodal hypothesis could be learned, at least in principle.

Finally, we also use this synthetic data problem to illus-trate the significant optimization advantages that can beprovided by the natural gradient updates with respect to thevariational dynamics parameters. In Figure 11 we comparenatural gradient updates with standard gradient updates atthree different learning rates. The natural gradient algo-rithm not only learns much faster but also is less dependenton parameterization details: while the natural gradient up-date used an untuned stepsize of 0.1, the standard gradientdynamics at step sizes of both 0.1 and 0.05 resulted in somematrix parameters that are required to be positive definiteto be updated to indefinite values. This particular indef-initeness error with standard gradients can be avoided byworking in a different parameterization of the positive def-inite parameter matrices, but we used a direct parameter-ization to make the learning rate comparison more direct.While a step size of 0.01 yielded a stable standard gradientupdate, training is significantly slower than with the natu-ral gradient algorithm, and this speed advantage of natural

gradients is often greater in larger models and higher di-mensions (Martens & Grosse, 2015).

0 1000 2000 3000 4000

iteration

−15

−10

−5

0

5

10

−L

Figure 11. Comparison on the dot problem of natural gradient up-dates (blue) and standard gradient updates (orange). The X’s indi-cate that the algorithm was terminated early due to indefiniteness.

6.2. Application to mouse behavioral phenotyping

As a real data experiment, we apply the SVAE to modelingdepth video of freely-behaving mice. The goal of behav-ioral phenotyping is to identify patterns of behavior andstudy how those patterns change when the animal’s envi-ronment, genetics, or brain function are altered. The SVAEframework is well-suited to provide new modeling tools forthis task because it can learn rich nonlinear image featureswhile learning an interpretable representation of the latentdynamics.

We use a version of the dataset from Wiltschko et al.(2015), in which a mouse is recorded from above using adepth camera, as shown in Figure 12. The mouse’s posi-tion and orientation are tracked and the depth image of themouse’s body is extracted, as shown in the insets. We applythe SVAE to the task of modeling these extracted videos.We used a subset of the dataset consisting of 8 recordings,each of a distinct mouse and each 20 minutes in length at30 frames per second, for a total of 288000 video frames.

First, we show that the variational autoencoder (VAE) ar-chitecture allows us to accurately represent the manifold ofdepth images. We fit a standard VAE to the depth videoframes, using two hidden layers of size 200 for the en-coder and decoder along with a 10D latent space; we reusethis same architecture for the subsequent SVAE models.Figure 13 shows both samples of the real video frames(top panel) and images generated from the VAE by sam-pling from a standard Gaussian in the latent space (bottompanel), demonstrating that the VAE image model can gen-erate very accurate synthetic mouse images. Furthermore,

Page 12: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

(a) Predictions after 200 training epochs. (b) Predictions after 1100 training epochs.

Figure 8. Predictions from an LDS SVAE fit to 1D dot image data at two different stages of training. In each subfigure, the top panelshows an example data sequence with time on the horizontal axis. The middle panel shows the noiseless LDS SVAE predictions giventhe data up to the vertical line, while the bottom panel shows the corresponding latent states.

(a) Predictions after 200 training epochs. (b) Predictions after 1100 training epochs.

Figure 9. As in Figure 8 but with a higher-resolution image sequence.

(a) Prediction given one frame of data. (b) Prediction given 5 frames of data.

Figure 10. As in Figure 8 but showing predictions given different data prefixes.

Page 13: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

Figure 12. Mouse depth video acquisition. A freely-behaving mouse is recorded from above with a depth camera. The mouse positionand orientation are tracked and the mouse body depth image is extracted, as shown in each inset. Our models are applied to theseextracted image sequences after some filtering and downsampling.

Page 14: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

to illustrate that the learned manifold naturally representssmooth variation in the mouse’s pose, Figure 14 shows im-ages generated from regular lattices on random 2D sub-spaces of the latent space.

While the VAE can accurately model the marginal distri-bution on each image, to model entire depth videos jointlywe use the SVAE to learn dynamics in the latent space.To show that a Gaussian linear dynamical system (LDS)can accurately represent these latent space dynamics, wefit an LDS SVAE to the depth video sequences. Figure 15shows predictions from this model over 1-second horizons,illustrating that the LDS SVAE retains the flexible imagemodeling capabilities of the VAE while additionally mod-eling smooth dynamics that can generate plausible videosequences.

Finally, because latent linear dynamics paired with thenonlinear image model can accurately represent the depthvideos over short timescales, by extending the latent dy-namics to a switching LDS (SLDS) we can discover dis-tinct dynamical modes to represent the syllables of behav-ior. Figure 16 shows some of the discrete states that arisefrom fitting an SLDS SVAE with 30 discrete states to thedepth video data. The discrete states that emerge show anatural clustering of short-timescale patterns into behav-ioral units. Thus the SLDS SVAE provides an interpretable,structured representation of behavioral patterns.

7. ConclusionStructured variational autoencoders (SVAEs) provide anew set of tools for probabilistic modeling that leverageboth graphical models and deep learning techniques. Bybuilding on variational autoencoders and neural network ar-chitectures, SVAEs can learn features and bottom-up infer-ence networks automatically and represent complex high-dimensional data. By enabling the latent probabilistic mod-els to be expressed in terms of graphical models, exponen-tial families, and hierarchical structures, SVAEs also allowpowerful modeling and algorithmic tools to be applied, in-cluding efficient natural gradient computations, graphicalmessage passing, priors that encourage sparsity or repre-sent domain knowledge, and discrete latent random vari-ables. Furthermore, these structured latent representationscan provide both interpretability and tractability for non-inference tasks, such as planning or control. Finally, be-cause SVAE learning fits in the standard variational opti-mization framework, many recent advances in neural net-work optimization (Martens & Grosse, 2015) and vari-ational inference (Burda et al., 2015) can be applied toSVAEs.

AcknowledgementsWe thank Roger Grosse, Scott Linderman, Dougal Maclau-rin, Andrew Miller, Oren Rippel, Yakir Reshef, MiguelHernandez-Lobato, and the members of the Harvard Intelli-gent Probabilistic Systems (HIPS) Group for many helpfuldiscussions and advice.

M.J.J. is supported by a fellowship from the Harvard/MITJoint Grants program. A.B.W. is supported by an NSFGraduate Research Fellowship and is a Stuart H.Q. & Vic-toria Quan Fellow. S.R.D. is supported by fellowshipsfrom the Burroughs Wellcome Fund, the Vallee Founda-tion, the Khodadad Program, by grants DP2OD007109 andRO11DC011558 from the National Institutes of Health,and by the Global Brain Initiative from the Simons Foun-dation. R.P.A. is supported by NSF IIS-1421780.

ReferencesAmari, Shun-Ichi. Natural gradient works efficiently in

learning. Neural computation, 10(2):251–276, 1998.

Amari, Shun-ichi and Nagaoka, Hiroshi. Methods of In-formation Geometry. American Mathematical Society,2007.

Archer, Evan, Park, Il Memming, Buesing, Lars, Cun-ningham, John, and Paninski, Liam. Black box varia-tional inference for state space models. arXiv preprintarXiv:1511.07367, 2015.

Beal, Matthew James. Variational algorithms for approxi-mate Bayesian inference. University of London, 2003.

Bengio, Yoshua and Frasconi, Paolo. An input output hmmarchitecture. In Advances in Neural Information Pro-cessing Systems, pp. 427–434, 1995.

Burda, Yuri, Grosse, Roger, and Salakhutdinov, Rus-lan. Importance weighted autoencoders. arXiv preprintarXiv:1509.00519, 2015.

Chung, Junyoung, Kastner, Kyle, Dinh, Laurent, Goel,Kratarth, Courville, Aaron C, and Bengio, Yoshua. Arecurrent latent variable model for sequential data. InAdvances in Neural information processing systems, pp.2962–2970, 2015.

Deng, Li. Computational models for speech production.In Computational Models of Speech Pattern Processing,pp. 199–213. Springer, 1999.

Deng, Li. Switching dynamic system models for speecharticulation and acoustics. In Mathematical Founda-tions of Speech and Language Processing, pp. 115–133.Springer, 2004.

Page 15: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

(a) Real frames from mouse depth video.

(b) Generated random samples from a VAE fit to mouse depth video.

Figure 13. Real data frames and generated images from a variational autoencoder (VAE) fit to mouse depth video data. In the secondpanel, images were generated by sampling latent coordinates from a standard 10D Gaussian.

Page 16: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

Figure 14. Generated images from variational autoencoder (VAE) fit to mouse depth video data. Each image grid was generated bychoosing a random 2D subspace of the 10D latent space and sampling a regular 2D lattice of points on the subspace.

Page 17: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

Figure 15. Examples of predictions from an LDS SVAE fit to depth video. In each panel, the top row is a sampled prediction from theLDS-SVAE and the bottom row is real data. To the left of the line, the model is conditioned on the corresponding data frames and hencegenerates denoised versions of the same images. To the right of the line, the model is not conditioned on the data, thus illustrating themodel’s predictions. The frame sequences are temporally subsampled to reduce their length, showing one of every four video frames.

Page 18: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

(a) Entering a rear

(b) Grooming

(c) Extension into running

(d) Fall from rear

Figure 16. Examples of behavior states inferred from depth video. For each state, four example frame sequences are shown, includingframes during which the given state was most probable according to the variational distribution on the hidden state sequence. Each framesequence is padded on both sides, with a square in the lower-right of a frame depicting that the state was active in that frame. The framesequences are temporally subsampled to reduce their length, showing one of every four video frames. Examples were chosen to havedurations close to the median duration for that state.

Page 19: Structured VAEs: Composing Probabilistic Graphical Models ... · fective second-order optimization (Martens,2015), while using backpropagation to compute gradients with respect to

Structured variational autoencoders

Foti, Nicholas, Xu, Jason, Laird, Dillon, and Fox, Emily.Stochastic variational inference for hidden markov mod-els. In Advances in Neural Information Processing Sys-tems, pp. 3599–3607, 2014.

Fox, E.B., Sudderth, E.B., Jordan, M.I., and Willsky, A.S.Bayesian nonparametric inference of switching dynamiclinear models. IEEE Transactions on Signal Processing,59(4), 2011.

Gelman, Andrew and Meng, Xiao-Li. Simulating normal-izing constants: From importance sampling to bridgesampling to path sampling. Statistical science, pp. 163–185, 1998.

Gregor, Karol, Danihelka, Ivo, Graves, Alex, and Wierstra,Daan. DRAW: A recurrent neural network for imagegeneration. arXiv preprint arXiv:1502.04623, 2015.

Hinton, Geoffrey, Deng, Li, Yu, Dong, Dahl, George E,Mohamed, Abdel-rahman, Jaitly, Navdeep, Senior, An-drew, Vanhoucke, Vincent, Nguyen, Patrick, Sainath,Tara N, et al. Deep neural networks for acoustic mod-eling in speech recognition: The shared views of fourresearch groups. Signal Processing Magazine, IEEE, 29(6):82–97, 2012.

Hoffman, Matthew D., Blei, David M., Wang, Chong, andPaisley, John. Stochastic variational inference. Journalof Machine Learning Research, 2013.

Iwata, Tomoharu, Duvenaud, David, and Ghahramani,Zoubin. Warped mixtures for nonparametric clustershapes. In 29th Conference on Uncertainty in ArtificialIntelligence, pp. 311–319, 2013.

Johnson, Matthew J. and Willsky, Alan S. Bayesian non-parametric hidden semi-markov models. Journal of Ma-chine Learning Research, 14:673–701, February 2013.

Johnson, Matthew J. and Willsky, Alan S. Stochastic vari-ational inference for Bayesian time series models. InInternational Conference on Machine Learning, 2014.

Kingma, Diederik and Ba, Jimmy. Adam: A method forstochastic optimization. In International Conference onLearning Representations, 2015.

Kingma, Diederik P. and Welling, Max. Auto-encodingvariational Bayes. International Conference on LearningRepresentations, 2014.

Kleijnen, Jack PC and Rubinstein, Reuven Y. Optimizationand sensitivity analysis of computer simulation modelsby the score function method. European Journal of Op-erational Research, 88(3):413–427, 1996.

Koller, Daphne and Friedman, Nir. Probabilistic graphicalmodels: principles and techniques. MIT Press, 2009.

Krishnan, Rahul G, Shalit, Uri, and Sontag, David. DeepKalman filters. arXiv preprint arXiv:1511.05121, 2015.

Martens, James. New insights and perspectives on the nat-ural gradient method. arXiv preprint arXiv:1412.1193,2015.

Martens, James and Grosse, Roger. Optimizing neural net-works with Kronecker-factored approximate curvature.In Proceedings of the 32nd International Conference onMachine Learning, 2015.

Mikolov, Tomas, Sutskever, Ilya, Chen, Kai, Corrado,Greg S, and Dean, Jeff. Distributed representations ofwords and phrases and their compositionality. In Ad-vances in Neural Information Processing Systems, pp.3111–3119, 2013.

Murphy, Kevin P. Machine Learning: a Probabilistic Per-spective. MIT Press, 2012.

Radford, Alec, Metz, Luke, and Chintala, Soumith. Un-supervised representation learning with deep convolu-tional generative adversarial networks. arXiv preprintarXiv:1511.06434, 2015.

Ranganath, Rajesh, Gerrish, Sean, and Blei, David M.Black box variational inference. In International Con-ference on Artificial Intelligence and Statistics, pp. 814–822, 2014.

Rezende, Danilo J, Mohamed, Shakir, and Wierstra, Daan.Stochastic backpropagation and approximate inferencein deep generative models. In Proceedings of the 31st In-ternational Conference on Machine Learning, pp. 1278–1286, 2014.

Vincent, Pascal, Larochelle, Hugo, Bengio, Yoshua, andManzagol, Pierre-Antoine. Extracting and composingrobust features with denoising autoencoders. In Proceed-ings of the 25th international conference on Machinelearning, pp. 1096–1103. ACM, 2008.

Wainwright, Martin J. and Jordan, Michael I. Graphicalmodels, exponential families, and variational inference.Foundations and Trends in Machine Learning, 2008.

Wiltschko, Alexander B., Johnson, Matthew J., Iurilli, Giu-liano, Peterson, Ralph E., Katon, Jesse M., Pashkovski,Stan L., Abraira, Victoria E., Adams, Ryan P., andDatta, Sandeep Robert. Mapping sub-second structurein mouse behavior. Neuron, 88(6):1121–1135, 2015.