cmps 6720: machine learning neural networksjhamm3/papers/lecture11-deep-generative-mo… ·...

56
1 CMPS 6720: Machine Learning Deep Generative Models Jihun Hamm

Upload: others

Post on 31-Aug-2020

3 views

Category:

Documents


0 download

TRANSCRIPT

Page 1: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

1

CMPS 6720: Machine Learning

Deep Generative Models

Jihun Hamm

Page 2: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Deep Generative Models

Since ~2010, Deep Neural Nets have made a great progress in supervised learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes from ImageNet State-of-the-art results

Since ~2014, Deep Neural Nets is also making a big impact in unsupervised learning problems “Deep generative models”

Generative Adversarial Nets [Goodfellow et al., 2014]

Variational Autoencoder [Kingma & Welling, 2014]

Modeling complex data distributions like faces, scenery, etc State-of-the-art results

2

Page 3: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Cf. Generative model for classification

Generative models can be used for both unsupervised and supervised learning

Classification using discriminative models Examples: Logistic regression, Conditional Random Fields Model 𝑃𝑃(𝑦𝑦|𝑥𝑥). Don’t care about 𝑃𝑃 𝑥𝑥 ⮕ Easy to learn ! Train: Learn params of 𝑃𝑃 𝑦𝑦 𝑥𝑥 from data Test: Infer 𝑦𝑦 = 𝑎𝑎𝑎𝑎𝑎𝑎𝑎𝑎𝑎𝑎𝑥𝑥𝑐𝑐𝑃𝑃 𝑦𝑦 = 𝑐𝑐 𝑥𝑥

Classification using generative models Examples: Bayes Nets, Markov Nets, Mixture models Model 𝑃𝑃(𝑥𝑥|𝑦𝑦) and 𝑃𝑃 𝑦𝑦 ⮕ Can generate new data ! Train: Learn params of 𝑃𝑃 𝑥𝑥 𝑦𝑦 , 𝑃𝑃(𝑦𝑦) from data Test: Infer 𝑦𝑦 = 𝑎𝑎𝑎𝑎𝑎𝑎𝑎𝑎𝑎𝑎𝑥𝑥𝑐𝑐𝑃𝑃 𝑦𝑦 = 𝑐𝑐 𝑥𝑥 = 𝑃𝑃 𝑥𝑥 𝑦𝑦 = 𝑐𝑐 𝑃𝑃(𝑦𝑦 = 𝑐𝑐)/𝑃𝑃(𝑐𝑐)

3

Page 4: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Generative Adversarial Nets

4

𝐷𝐷

𝐺𝐺

𝑧𝑧

𝑥𝑥

Page 5: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

GAN-related work

Variants Deep convolutional GAN (DCGAN) [Radford et al.’15]

Conditional GAN [Mirza et al.’14]

Adversarially learned inference (ALI) [Dumoulin et al.’16]

Adversarial autoencoder (AAE) [Makhzani et al.’15]

Energy-based GAN (EBGAN) [Zhao et al.’16]

Wasserstein GAN (WGAN) [Arjovsky et al.’17]

Boundary equilibrium GAN (BEGAN) [Berthelot et al.’17]

Bayesian GAN [Saatchi et al.’17] …

Applications…

5

Page 6: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

More examples of adversarial ML

GAN Domain adaptation Robust classification: narrow-sense adversarial learning Privacy-preservation / fair learning Attack on Deep NN Training data poisoning Hyperparameter learning Meta-learning …

6

Page 7: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

GAN architecture

𝐷𝐷 𝑥𝑥 :𝑋𝑋 → [0,1] is the discriminator Input: example 𝑥𝑥 Output: probability that 𝑥𝑥 is from 𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑(𝑥𝑥) Usually a conv net, can be any structure

𝐺𝐺 𝑧𝑧 : 𝑅𝑅𝑑𝑑 → 𝑋𝑋 is the generator Input: random noise 𝑧𝑧~𝑃𝑃𝑧𝑧 𝑧𝑧

(typically 𝑃𝑃𝑧𝑧 𝑧𝑧 = 𝑁𝑁(0, 𝐼𝐼)) Output: fake data Usually a conv net, can be any structure Complex distribution 𝑃𝑃 𝑥𝑥 can be modeled using a large-enough 𝐺𝐺

7

𝐷𝐷

𝐺𝐺

𝑧𝑧

𝑥𝑥

Page 8: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

GAN objective

Find 𝐺𝐺 and 𝐷𝐷 which solves min𝐺𝐺

max𝐷𝐷

𝑉𝑉(𝐺𝐺,𝐷𝐷), where 𝑉𝑉 𝐺𝐺,𝐷𝐷 = 𝐸𝐸𝑥𝑥~𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑[log𝐷𝐷 𝑥𝑥 ] + 𝐸𝐸𝑧𝑧~𝑃𝑃𝑧𝑧[log(1 − 𝐷𝐷(𝐺𝐺 𝑧𝑧 ))]

= 𝐸𝐸𝑥𝑥~𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑[log𝐷𝐷 𝑥𝑥 ] + 𝐸𝐸𝑥𝑥~𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓[log(1 − 𝐷𝐷(𝑥𝑥))]

In practice, use random samples 𝑥𝑥1, … , 𝑥𝑥𝑁𝑁 and 𝑧𝑧1, … , 𝑧𝑧𝑀𝑀:

𝑉𝑉 𝐺𝐺,𝐷𝐷 =1𝑁𝑁�𝑖𝑖

log𝐷𝐷 𝑥𝑥𝑖𝑖 +1𝑀𝑀�𝑗𝑗

log(1 − 𝐷𝐷(𝐺𝐺 𝑧𝑧𝑖𝑖) )

It is an instance of minimax problems

8

Page 9: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Minimax optimization

General form: min𝑢𝑢∈𝑈𝑈

max𝑣𝑣∈𝑉𝑉

𝑓𝑓(𝑢𝑢, 𝑣𝑣)

Zero-sum leader-follower games 𝑢𝑢: leader/min player, 𝑣𝑣: follower/max player 𝑓𝑓 𝑢𝑢, 𝑣𝑣 : payoff of follower/max player

9

Page 10: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Minimax optimization

General form: min𝑢𝑢∈𝑈𝑈

max𝑣𝑣∈𝑉𝑉

𝑓𝑓(𝑢𝑢, 𝑣𝑣)

Zero-sum leader-follower games 𝑢𝑢: leader/min player, 𝑣𝑣: follower/max player 𝑓𝑓 𝑢𝑢, 𝑣𝑣 : payoff of follower/max player

10

inner maximization

Page 11: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Minimax optimization

General form: min𝑢𝑢∈𝑈𝑈

max𝑣𝑣∈𝑉𝑉

𝑓𝑓(𝑢𝑢, 𝑣𝑣)

Zero-sum leader-follower games 𝑢𝑢: leader/min player, 𝑣𝑣: follower/max player 𝑓𝑓 𝑢𝑢, 𝑣𝑣 : payoff of follower/max player

11

inner maximization

outer minimization

Page 12: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

GAN is JS-divergence minimization

Proposition 1 [Goodfellow et al.,2014]

Given 𝐺𝐺, the optimal discriminator 𝐷𝐷 is 𝐷𝐷∗ 𝐺𝐺 = 𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 𝑥𝑥𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 𝑥𝑥 +𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓(𝑥𝑥)

Proof: 𝑉𝑉 𝐺𝐺,𝐷𝐷 = ∫ 𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 𝑥𝑥 log𝐷𝐷 𝑥𝑥 + 𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓 𝑥𝑥 log(1 − 𝐷𝐷 𝑥𝑥 ) 𝑑𝑑𝑥𝑥. For any 𝑎𝑎, 𝑏𝑏 ∈ 𝑅𝑅2 − (0,0), the function 𝑦𝑦 ↦ 𝑎𝑎 log 𝑦𝑦 + 𝑏𝑏 log(1 − 𝑦𝑦)achieves its maximum in [0,1] at 𝑑𝑑

𝑑𝑑+𝑏𝑏. The discriminator does not need to be

defined outside of 𝑆𝑆𝑢𝑢𝑆𝑆𝑆𝑆 𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 ∪ 𝑆𝑆𝑢𝑢𝑆𝑆𝑆𝑆 𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓 , concluding the proof.

12

Page 13: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Cont’d

Def: 𝐽𝐽𝑆𝑆𝐷𝐷(𝑆𝑆| 𝑞𝑞 = 12𝐾𝐾𝐾𝐾(𝑆𝑆|| 𝑝𝑝+𝑞𝑞

2) + 1

2𝐾𝐾𝐾𝐾(𝑞𝑞|| 𝑝𝑝+𝑞𝑞

2)

Theorem 1. max𝐷𝐷

𝑉𝑉(𝐺𝐺,𝐷𝐷) = − log 4 + 2 ⋅ 𝐽𝐽𝑆𝑆𝐷𝐷(𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑||𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓)

Corollary: GAN objective is min𝐺𝐺

𝐽𝐽𝑆𝑆𝐷𝐷(𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑||𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓), that is, find the generator 𝐺𝐺 so that real and synthetic data are indistinguishable

Proof: max𝐷𝐷

𝑉𝑉 𝐺𝐺,𝐷𝐷 = 𝑉𝑉 𝐺𝐺,𝐷𝐷∗

= 𝐸𝐸𝑥𝑥~𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑[log𝐷𝐷∗ 𝑥𝑥 ] + 𝐸𝐸𝑥𝑥~𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓 log 1 − 𝐷𝐷∗ 𝑥𝑥

= 𝐸𝐸𝑥𝑥~𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑[log𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 𝑥𝑥

𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 𝑥𝑥 + 𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓 𝑥𝑥] + 𝐸𝐸𝑥𝑥~𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓[log

𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓 𝑥𝑥𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 𝑥𝑥 + 𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓 𝑥𝑥

]

= − log 4 + 𝐾𝐾𝐾𝐾(𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑||𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 + 𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓

2 ) + 𝐾𝐾𝐾𝐾(𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓||𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 + 𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓

2 )

13

Page 14: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Optimization by Gradient Descent

Simplest way to solve min (or max) approximately Alternating

𝑢𝑢 ← 𝑢𝑢 − 𝜌𝜌𝛻𝛻𝑢𝑢𝑓𝑓, 𝑣𝑣 ← 𝑣𝑣 + 𝜂𝜂𝛻𝛻𝑣𝑣𝑓𝑓 Simultaneous

𝑢𝑢𝑣𝑣 ← 𝑢𝑢

𝑣𝑣 + −𝜌𝜌𝐼𝐼 00 𝜂𝜂𝐼𝐼 𝛻𝛻𝑓𝑓

GAN optimization (original version) Assume parameters 𝐺𝐺(𝑧𝑧;𝑢𝑢) and 𝐷𝐷(𝑥𝑥;𝑣𝑣) Alternate between 𝒖𝒖 ← 𝒖𝒖 − 𝝆𝝆𝜵𝜵𝒖𝒖𝑽𝑽 and 𝒗𝒗 ← 𝒗𝒗 + 𝜼𝜼𝜵𝜵𝒗𝒗𝑽𝑽 Note that it isn’t exactly solving a minimax problem

14

Page 15: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Gradient descent is unstable

Gradient descent does not always converge

→ Modify gradient descent[Mescheder et al.’17; Nagarajan et al.’17; Roth et al.’17]

15

𝑓𝑓 𝑢𝑢, 𝑣𝑣 = 𝑢𝑢𝑣𝑣

[Usman et al.’18]

Page 16: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Gradient descent is unstable

Gradient descent does not always converge

→ Modify gradient descent[Mescheder et al.’17; Nagarajan et al.’17; Roth et al.’17]

GAN training fails frequently→ Change the GAN objective

[Uehara et al.’16; Nowozin et al.’16;Arjovsky et al.’17]

16

𝑓𝑓 𝑢𝑢, 𝑣𝑣 = 𝑢𝑢𝑣𝑣

[Usman et al.’18]

vs

[Mets et al.’17]

Page 17: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Alternative losses 𝑉𝑉(𝐺𝐺,𝐷𝐷)

Minimax (original) GAN min

𝐺𝐺max𝐷𝐷

𝐸𝐸𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑[log𝐷𝐷 𝑥𝑥 ] + 𝐸𝐸𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓[log(1 − 𝐷𝐷(𝑥𝑥))]

Non-saturating GAN max

𝐷𝐷𝐸𝐸𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑[log𝐷𝐷 𝑥𝑥 ] + 𝐸𝐸𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓[log(1 − 𝐷𝐷(𝑥𝑥))]

min𝐺𝐺

−𝐸𝐸𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓[log𝐷𝐷 𝑥𝑥 ]

Wasserstein GAN min

𝐺𝐺max𝐷𝐷

𝐸𝐸𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 𝐷𝐷 𝑥𝑥 − 𝐸𝐸𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓 𝐷𝐷 𝑥𝑥 , 𝐷𝐷 is 1-Lipschitz, real-valued

Least-squares GAN max

𝐷𝐷𝐸𝐸𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 𝐷𝐷 𝑥𝑥 − 1 2 − 𝐸𝐸𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓 𝐷𝐷 𝑥𝑥 2 , 𝐷𝐷 is real-valued

min𝐺𝐺

−𝐸𝐸𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓[ 𝐷𝐷 𝑥𝑥 − 1 2]

Additional regularizers: ∇𝐷𝐷 𝑥𝑥 2 or ∇𝐷𝐷 𝑥𝑥 − 1 2

17

Page 18: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Progress on face generation

18

Goodfellow.,2019

Page 19: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Progress on ImageNet

19

Page 20: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

GAN variants that use labels

20

https://machinelearningmastery.com/how-to-develop-an-auxiliary-classifier-gan-ac-gan-from-scratch-with-keras/

Page 21: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

GAN vs C-GAN

21

GAN C-GAN

Page 22: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Different from C-GAN No label input to 𝐷𝐷 𝐷𝐷 output both real/fake and class Label info has to be encoded in 𝑥𝑥𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓

C-GAN loss max

𝐷𝐷𝐸𝐸𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 log𝐷𝐷 𝑥𝑥 𝑐𝑐 + 𝐸𝐸𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓[log(1 − 𝐷𝐷 𝑥𝑥 𝑐𝑐 )]

min𝐺𝐺

−𝐸𝐸𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓[log𝐷𝐷(𝑥𝑥|𝑐𝑐)]

AC-GAN loss max

𝐷𝐷⋯ + 𝐸𝐸𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 log𝑃𝑃 𝑐𝑐 𝑥𝑥 + 𝐸𝐸𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓[log𝑃𝑃 (𝑐𝑐|𝑥𝑥)]

min𝐺𝐺

⋯ − 𝐸𝐸𝑃𝑃𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓[log𝑃𝑃(𝑐𝑐|𝑥𝑥)]

AC-GAN

22

Page 23: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

InfoGAN

Different from C-GAN Label c of examples is not given, but is random MNIST example, c=[digits (0-9), rotation, thickness]T

Label info has to be encoded in 𝑥𝑥𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓, measured by mutual information 𝐼𝐼(𝑐𝑐, 𝑥𝑥𝑓𝑓𝑑𝑑𝑓𝑓𝑓𝑓 = 𝐺𝐺 𝑧𝑧, 𝑐𝑐 )

InfoGAN loss max

𝐷𝐷𝐸𝐸𝑥𝑥~𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 log𝐷𝐷(𝑥𝑥) + 𝐸𝐸 𝑧𝑧,𝑐𝑐 ~𝑃𝑃𝑧𝑧,𝑐𝑐[log(1 − 𝐷𝐷(𝐺𝐺 𝑧𝑧, 𝑐𝑐 ))]

+ 𝜆𝜆𝐼𝐼(𝑐𝑐,𝐺𝐺 𝑧𝑧, 𝑐𝑐 ) min

𝐺𝐺,𝑄𝑄−𝐸𝐸 𝑧𝑧,𝑐𝑐 ~𝑃𝑃𝑧𝑧,𝑐𝑐[log𝐷𝐷(𝐺𝐺(𝑧𝑧, 𝑐𝑐))] − 𝜆𝜆𝐼𝐼(𝑐𝑐,𝐺𝐺 𝑧𝑧, 𝑐𝑐 )

Lemma: 𝐼𝐼 𝑐𝑐;𝐺𝐺 𝑧𝑧, 𝑐𝑐 ≥ 𝐸𝐸𝑐𝑐~𝑃𝑃 𝑐𝑐 , 𝑥𝑥~𝐺𝐺 𝑧𝑧,𝑐𝑐 log𝑄𝑄 𝑐𝑐 𝑥𝑥for any 𝑄𝑄(𝑐𝑐|𝑥𝑥).

That is, we don’t need to estimate mutual information directly !!!

23

Page 24: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

24

Page 25: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

25

Page 26: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Image-to-image translation

26

Yi et al.,2019

Page 27: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Pix2pix [Isola et al.,17]

𝐺𝐺(𝑥𝑥): takes an image and outputs another image 𝐷𝐷(𝑥𝑥): use both 𝑥𝑥 and 𝑦𝑦 (similar to C-GAN) Loss

min𝐺𝐺

max𝐷𝐷

𝐸𝐸𝑥𝑥,𝑦𝑦~𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑 log𝐷𝐷(𝑥𝑥,𝑦𝑦) + 𝐸𝐸 𝑥𝑥,𝑧𝑧 ~𝑃𝑃𝑥𝑥,𝑧𝑧[log(1 − 𝐷𝐷(𝑥𝑥,𝐺𝐺 𝑥𝑥, 𝑧𝑧 ))]+ 𝜆𝜆𝐸𝐸(𝑥𝑥,𝑦𝑦,𝑧𝑧) 𝑦𝑦 − 𝐺𝐺 𝑧𝑧, 𝑥𝑥 1

27

Page 28: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

28

Page 29: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

CycleGAN [Zhu et al.,’17]

Does not require paired data

29

Page 30: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Two domains: 𝑋𝑋 and 𝑌𝑌 Two generators: 𝐹𝐹 and 𝐺𝐺 Two discriminators: 𝐷𝐷𝑋𝑋 and 𝐷𝐷𝑌𝑌𝑇𝑇 Loss: 𝐾𝐾𝐺𝐺𝐺𝐺𝑁𝑁 𝐺𝐺,𝐹𝐹,𝐷𝐷𝑌𝑌,𝑋𝑋,𝑌𝑌 + 𝐾𝐾𝐺𝐺𝐺𝐺𝑁𝑁 𝐹𝐹,𝐷𝐷𝑋𝑋,𝑌𝑌,𝑋𝑋 + 𝜆𝜆𝐾𝐾𝐶𝐶𝑌𝑌𝐶𝐶(𝐺𝐺,𝐹𝐹) where

𝐾𝐾𝐶𝐶𝑌𝑌𝐶𝐶 𝐺𝐺,𝐹𝐹 = 𝐸𝐸𝑥𝑥~𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑(𝑥𝑥) 𝐹𝐹 𝐺𝐺 𝑥𝑥 − 𝑥𝑥 1 + 𝐸𝐸𝑦𝑦~𝑃𝑃𝑑𝑑𝑑𝑑𝑑𝑑𝑑𝑑(𝑦𝑦) 𝐺𝐺 𝐹𝐹 𝑦𝑦 − 𝑦𝑦 1

Cyclic consistency

30

Page 31: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

31

Page 32: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Medical Imaging [Yi et al.’19]

32

Denoising [Yi et al.’18] Modality transfer [Wolterink et al.’17]

Anomaly detection [Schlegl et al.’17]

Page 33: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Summary

GAN can be extremely useful for data generation Many interesting properties Underlying mechanism not completely understood Optimization still difficult Lacking universal metric of data generation performance Topic of intense ongoing research

33

Page 34: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

References

Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A. and Bengio, Y., 2014. Generative adversarial nets, NIPS

Radford, A., Metz, L. and Chintala, S., 2015. Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434.

Nowozin, S., Cseke, B. and Tomioka, R., 2016. f-gan: Training generative neural samplers using variational divergence minimization, NIPS

Odena, A., 2016. Semi-supervised learning with generative adversarial networks. arXiv preprint arXiv:1606.01583. Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A. and Chen, X., 2016. Improved techniques for training

gans, NIPS Chen, X., Duan, Y., Houthooft, R., Schulman, J., Sutskever, I. and Abbeel, P., 2016. Infogan: Interpretable representation

learning by information maximizing generative adversarial nets, NIPS Arjovsky, M., Chintala, S., and Bottou, L., 2017. Wasserstein gan. arXiv preprint arXiv:1701.07875. Nagarajan, V. and Kolter, J. Z., 2017. Gradient descent gan optimization is locally stable, NIPS Mescheder, L., Nowozin, S., and Geiger, A. 2017. The numerics of gans, NIPS Zhu, J.Y., Park, T., Isola, P. and Efros, A.A., 2017. Unpaired image-to-image translation using cycle-consistent adversarial

networks, CVPR Isola, P., Zhu, J.Y., Zhou, T. and Efros, A.A., 2017. Image-to-image translation with conditional adversarial networks, CVPR Odena, A., Olah, C. and Shlens, J., 2017, August. Conditional image synthesis with auxiliary classifier gans, ICML Gulrajani, I., Ahmed, F., Arjovsky, M., Dumoulin, V. and Courville, A.C., 2017. Improved training of wasserstein gans, NIPS Mao, X., Li, Q., Xie, H., Lau, R.Y., Wang, Z. and Smolley, P., 2017. Least squares generative adversarial networks, CVPR Huang, H., Yu, P.S. and Wang, C., 2018. An introduction to image synthesis with generative adversarial nets. arXiv preprint

arXiv:1803.04469. Yi, X., Walia, E. and Babyn, P., 2019. Generative adversarial network in medical imaging: A review. Medical image analysis

34

Page 35: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Back to generative model

GAN as generative model 𝑥𝑥 is determined by latent variable 𝑧𝑧 by

decoder 𝑥𝑥 = 𝐷𝐷𝐷𝐷𝑐𝑐 𝑧𝑧

𝑆𝑆 𝑥𝑥 = 𝑆𝑆𝑍𝑍(𝐷𝐷𝐷𝐷𝑐𝑐−1 𝑥𝑥 )det 𝑑𝑑𝐷𝐷𝑓𝑓𝑐𝑐𝑑𝑑𝑧𝑧

−1

Difficult to compute

Parametric generative model 𝑥𝑥 is sampled from 𝑆𝑆(𝑥𝑥|𝑧𝑧;𝜃𝜃) by

decoder 𝑥𝑥 ~ 𝐷𝐷𝐷𝐷𝑐𝑐 𝑧𝑧 = 𝑆𝑆 𝑥𝑥 𝑧𝑧; 𝜃𝜃 𝑆𝑆 𝑥𝑥 = ∫ 𝑆𝑆 𝑥𝑥 𝑧𝑧;𝜃𝜃 𝑆𝑆 𝑧𝑧; 𝜃𝜃 𝑑𝑑𝑧𝑧

May be easy or difficult to compute

Caution: These are NOT graphical models !

35

Decoder

𝑆𝑆 𝑧𝑧 = 𝑁𝑁 0, 𝐼𝐼𝑧𝑧

𝑥𝑥 𝑥𝑥 ~ 𝑆𝑆 𝑥𝑥 𝑧𝑧; 𝜃𝜃

Decoder

𝑆𝑆 𝑧𝑧 = 𝑁𝑁 0, 𝐼𝐼𝑧𝑧

𝑥𝑥𝑆𝑆𝑋𝑋 𝑥𝑥 = 𝑆𝑆𝑍𝑍(𝐷𝐷𝐷𝐷𝑐𝑐−1 𝑥𝑥 )

𝑑𝑑𝐷𝐷𝐷𝐷𝑐𝑐𝑑𝑑𝑧𝑧

−1𝑥𝑥 = 𝐷𝐷𝐷𝐷𝑐𝑐(𝑧𝑧)

𝑆𝑆 𝑥𝑥 = ∫ 𝑆𝑆 𝑥𝑥 𝑧𝑧; 𝜃𝜃 𝑆𝑆 𝑧𝑧; 𝜃𝜃 𝑑𝑑𝑧𝑧

Page 36: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Marginal likelihood

Graphic model for generative model Short-hand notations:

𝑆𝑆𝜃𝜃 𝑧𝑧 ≔ 𝑆𝑆 𝑧𝑧;𝜃𝜃 , 𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧 : = 𝑆𝑆(𝑥𝑥|𝑧𝑧;𝜃𝜃) Marginal likelihood evaluation

𝑆𝑆𝜃𝜃 𝑥𝑥 = ∫ 𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧 𝑆𝑆𝜃𝜃 𝑧𝑧 𝑑𝑑𝑧𝑧 = 𝐸𝐸𝑧𝑧[𝑆𝑆𝜃𝜃(𝑥𝑥|𝑧𝑧)] Problem

Only in simple cases, (e.g., Gaussian) 𝑆𝑆(𝑥𝑥) is exactly computable All other complex (and therefore interesting) cases, 𝑆𝑆(𝑥𝑥) is computable

only approximately (e.g., by Monte-Carlo integration)

𝑆𝑆𝜃𝜃 𝑥𝑥 = 𝐸𝐸𝑧𝑧 𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧 ≅ 1𝑁𝑁∑𝑖𝑖 𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧𝑖𝑖 , 𝑧𝑧𝑖𝑖~𝑆𝑆𝜃𝜃 𝑧𝑧 , 𝑖𝑖𝑖𝑖𝑑𝑑

Better: use high-probability samples 𝑆𝑆𝜃𝜃 𝑥𝑥 ≅ 𝐸𝐸𝑧𝑧~𝑞𝑞𝜙𝜙[𝑆𝑆𝜃𝜃(𝑥𝑥|𝑧𝑧)] ≅ 1

𝑁𝑁∑𝑖𝑖 𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧𝑖𝑖 , 𝑧𝑧𝑖𝑖~𝑞𝑞𝜙𝜙 𝑧𝑧|𝑥𝑥 , 𝑖𝑖𝑖𝑖𝑑𝑑,

where 𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 ≅ 𝑆𝑆𝜃𝜃(𝑧𝑧)

36

𝑥𝑥

𝑧𝑧N

𝜃𝜃

Page 37: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Generative model Assume 𝑆𝑆 𝑧𝑧 = 𝑁𝑁 0, 𝐼𝐼 Decoder (with parameter 𝜃𝜃)

takes 𝑧𝑧 as input and outputsmean 𝑢𝑢(𝑧𝑧) and covariance 𝜎𝜎2 𝑧𝑧

𝑥𝑥 is sampled from 𝑁𝑁(𝑢𝑢 𝑧𝑧 ,𝜎𝜎 𝑧𝑧 𝐼𝐼) 𝑆𝑆(𝑥𝑥, 𝑧𝑧) is not jointly Gaussian.

cf. linear-Gaussian

Variational Autoencoder (VAE) [Kingma & Welling’14]

Encoder

Decoder

𝑥𝑥

𝜇𝜇(𝑥𝑥)

𝑆𝑆 𝑧𝑧 = 𝑁𝑁 0, 𝐼𝐼

𝑞𝑞 𝑧𝑧 𝑥𝑥;𝜙𝜙 = 𝑁𝑁 𝜇𝜇 𝑥𝑥 ,𝜎𝜎2 𝑥𝑥 𝐼𝐼𝜎𝜎(𝑥𝑥)

𝑧𝑧

𝜇𝜇(𝑧𝑧) 𝜎𝜎(𝑧𝑧)

𝑥𝑥

𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧 = 𝑁𝑁 𝜇𝜇 𝑧𝑧 ,𝜎𝜎2 𝑧𝑧 𝐼𝐼

Generative model

Recognition model

Page 38: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Recognition/inference model Encoder (with parameter 𝜙𝜙)

takes 𝑥𝑥 as input and outputsmean 𝑢𝑢(𝑥𝑥) and covariance 𝜎𝜎2 𝑥𝑥

𝑧𝑧 is sampled from 𝑁𝑁(𝑢𝑢 𝑥𝑥 ,𝜎𝜎2 𝑥𝑥 𝐼𝐼) Note 𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 is not 𝑆𝑆𝜃𝜃 𝑧𝑧 𝑥𝑥 Will be explained shortly

Caution VAE looks similar to AE, but

it’s a very different model AE: deterministic 𝑥𝑥 → 𝑧𝑧 → 𝑥𝑥 VAE: probabilistic 𝑥𝑥 → 𝑧𝑧 → 𝑥𝑥

Variational Autoencoder (VAE) [Kingma & Welling’14]

Encoder

Decoder

𝑥𝑥

𝜇𝜇(𝑥𝑥)

𝑆𝑆 𝑧𝑧 = 𝑁𝑁 0, 𝐼𝐼

𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 = 𝑁𝑁 𝜇𝜇 𝑥𝑥 ,𝜎𝜎2 𝑥𝑥 𝐼𝐼𝜎𝜎(𝑥𝑥)

𝑧𝑧

𝜇𝜇(𝑧𝑧) 𝜎𝜎(𝑧𝑧)

𝑥𝑥

𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧 = 𝑁𝑁 𝜇𝜇 𝑧𝑧 ,𝜎𝜎2 𝑧𝑧 𝐼𝐼

Generative model

Recognition model

Page 39: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

(Previously) Proof: EM increases log-likelihood

Start with the log likelihood:log𝑃𝑃(𝑋𝑋;𝜃𝜃) = log𝑃𝑃(𝑋𝑋,𝑍𝑍;𝜃𝜃) − log𝑃𝑃 𝑍𝑍 𝑋𝑋;𝜃𝜃

Take expectation over Z w.r.t. q(Z) on both sides: 𝐾𝐾𝐿𝐿𝑆𝑆 = ∑𝑧𝑧 𝑞𝑞 𝑍𝑍 log𝑃𝑃(𝑋𝑋;𝜃𝜃) = log𝑃𝑃(𝑋𝑋;𝜃𝜃) 𝑅𝑅𝐿𝐿𝑆𝑆 = ∑𝑧𝑧 𝑞𝑞(𝑍𝑍) log𝑃𝑃(𝑋𝑋,𝑍𝑍;𝜃𝜃) − ∑𝑧𝑧 𝑞𝑞 𝑍𝑍 log𝑃𝑃 𝑍𝑍 𝑋𝑋;𝜃𝜃

= ⋯ − 𝑞𝑞 𝑍𝑍 log𝑞𝑞 𝑍𝑍 + 𝑞𝑞 𝑍𝑍 log𝑞𝑞 𝑍𝑍= ∑𝑧𝑧 𝑞𝑞 𝑍𝑍 log 𝑃𝑃(𝑋𝑋,𝑍𝑍;𝜃𝜃)

𝑞𝑞(𝑍𝑍)− ∑𝑧𝑧 𝑞𝑞 𝑍𝑍 log 𝑝𝑝(𝑍𝑍|𝑋𝑋;𝜃𝜃)

𝑞𝑞 𝑍𝑍= 𝑙𝑙 𝑞𝑞,𝜃𝜃 + 𝐾𝐾𝐾𝐾(𝑞𝑞||𝑆𝑆)

Therefore 𝐾𝐾𝐿𝐿𝑆𝑆 = log𝑃𝑃(𝑋𝑋;𝜃𝜃) = 𝑅𝑅𝐿𝐿𝑆𝑆 = 𝑙𝑙 𝑞𝑞, 𝜃𝜃 + 𝐾𝐾𝐾𝐾(𝑞𝑞||𝑆𝑆) ≥ 𝑙𝑙 𝑞𝑞,𝜃𝜃 𝑙𝑙 𝑞𝑞, 𝜃𝜃 is called ELBO: a lower bound of the evidence log𝑆𝑆(𝑋𝑋; 𝜃𝜃) Maximizing 𝑙𝑙 𝑞𝑞,𝜃𝜃 also maximizes the evidence This holds for any 𝑞𝑞 Variational method

39

Page 40: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

ELBO for VAE

Start with the log likelihood:log𝑆𝑆𝜃𝜃(𝑥𝑥) = 𝑙𝑙 𝑞𝑞, 𝜃𝜃 + 𝐾𝐾𝐾𝐾(𝑞𝑞||𝑆𝑆) ≥ 𝑙𝑙(𝑞𝑞,𝜃𝜃)

Let’s choose 𝑞𝑞 to be 𝑞𝑞𝜙𝜙(𝑧𝑧|𝑥𝑥𝑖𝑖) Then,

𝑙𝑙 𝜙𝜙,𝜃𝜃 = 𝐸𝐸𝑞𝑞𝜙𝜙 𝑧𝑧|𝑥𝑥 log 𝑝𝑝𝜃𝜃(𝑥𝑥,𝑧𝑧)𝑞𝑞𝜙𝜙(𝑧𝑧|𝑥𝑥)

= 𝐸𝐸𝑞𝑞𝜙𝜙 𝑧𝑧|𝑥𝑥 log 𝑝𝑝𝜃𝜃(𝑥𝑥|𝑧𝑧)𝑝𝑝𝜃𝜃(𝑧𝑧)𝑞𝑞𝜙𝜙(𝑧𝑧|𝑥𝑥)

= 𝐸𝐸𝑞𝑞𝜙𝜙 𝑧𝑧|𝑥𝑥 log𝑆𝑆𝜃𝜃(𝑥𝑥|𝑧𝑧) + 𝐸𝐸𝑞𝑞𝜙𝜙 𝑧𝑧|𝑥𝑥 log𝑆𝑆𝜃𝜃(𝑧𝑧)𝑞𝑞𝜙𝜙 𝑧𝑧|𝑥𝑥

= 𝐸𝐸𝑞𝑞𝜙𝜙 𝑧𝑧|𝑥𝑥 log𝑆𝑆𝜃𝜃(𝑥𝑥|𝑧𝑧) − 𝐾𝐾𝐾𝐾(𝑞𝑞𝜙𝜙 𝑧𝑧|𝑥𝑥 ||𝑆𝑆𝜃𝜃(𝑧𝑧))likelihood / reconstruction error prior

We maximize this lower bound 𝑙𝑙(𝜙𝜙,𝜃𝜃) instead of the intractable marginal likelihood log𝑆𝑆𝜃𝜃(𝑧𝑧)

40

Page 41: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Likelihood/reconstruction error term

𝐸𝐸𝑞𝑞𝜙𝜙 𝑧𝑧|𝑥𝑥 log𝑆𝑆𝜃𝜃(𝑥𝑥|𝑧𝑧) ≅ 1𝐿𝐿∑𝑙𝑙 log𝑆𝑆𝜃𝜃(𝑥𝑥|𝑧𝑧𝑙𝑙) = 1

𝐿𝐿∑𝑙𝑙

− 𝑥𝑥−𝜇𝜇 𝑧𝑧𝑙𝑙2

𝜎𝜎2(𝑧𝑧𝑙𝑙),

L2 distance between two images 𝑥𝑥 and 𝜇𝜇

Prior term 𝐾𝐾𝐾𝐾(𝑞𝑞𝜙𝜙(𝑧𝑧|𝑥𝑥)||𝑆𝑆𝜃𝜃(𝑧𝑧)) computable in closed form (Was a HW problem) 𝑞𝑞 = 𝑁𝑁(𝜇𝜇 𝑥𝑥 ,𝜎𝜎2 𝑥𝑥 𝐼𝐼) and 𝑆𝑆 = 𝑁𝑁 0, 𝐼𝐼

And therefore 𝐾𝐾𝐾𝐾 ≅ 12∑𝑗𝑗 1 + log 𝜎𝜎𝑗𝑗2 − 𝜇𝜇𝑗𝑗2 − 𝜎𝜎𝑗𝑗2

Note that goodness of the bound is determined by 𝐾𝐾𝐾𝐾 𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 ||𝑆𝑆 𝑧𝑧 𝑥𝑥;𝜃𝜃

41

Page 42: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

VAE loss revisited

42

Encoder

Decoder

𝑥𝑥

𝜇𝜇(𝑥𝑥)

𝑆𝑆 𝑧𝑧 = 𝑁𝑁 0, 𝐼𝐼

𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 = 𝑁𝑁 𝜇𝜇 𝑥𝑥 ,𝜎𝜎2 𝑥𝑥 𝐼𝐼𝜎𝜎(𝑥𝑥)

𝑧𝑧

𝜇𝜇(𝑧𝑧) 𝜎𝜎(𝑧𝑧)

𝑥𝑥

𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧 = 𝑁𝑁 𝜇𝜇 𝑧𝑧 ,𝜎𝜎2 𝑧𝑧 𝐼𝐼

Generative model

Recognition model

Likelihood/reconstruction error:𝐸𝐸𝑞𝑞𝜙𝜙 𝑧𝑧|𝑥𝑥 log𝑆𝑆𝜃𝜃(𝑥𝑥|𝑧𝑧)

Prior:− 𝐾𝐾𝐾𝐾(𝑞𝑞𝜙𝜙 𝑧𝑧|𝑥𝑥 ||𝑆𝑆𝜃𝜃(𝑧𝑧))

Page 43: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Training

Reparameterization trick Sampling 𝑧𝑧 from 𝑞𝑞𝜃𝜃(𝑧𝑧|𝑥𝑥) is originally

non-differentiable procedure Trick: 𝑧𝑧 = 𝜇𝜇 𝑥𝑥 + 𝜎𝜎 𝑥𝑥 ∗ 𝜖𝜖, 𝜖𝜖~𝑁𝑁 0, 𝐼𝐼

Use same trick for sampling 𝑥𝑥 from 𝑆𝑆𝜃𝜃(𝑥𝑥|𝑧𝑧)

Can now train the model end-to-end with gradient descent using backpropagation

43

𝜇𝜇(𝑥𝑥)∗ 𝜎𝜎(𝑥𝑥)

𝑧𝑧

𝜖𝜖

+

Encoder

Decoder

𝑥𝑥

𝜇𝜇(𝑥𝑥) 𝜎𝜎(𝑥𝑥)

𝑧𝑧

𝜇𝜇(𝑧𝑧) 𝜎𝜎(𝑧𝑧)

𝑥𝑥

Generative model

Recognition model

Page 44: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Training

Loss𝑙𝑙 𝜙𝜙, 𝜃𝜃 = 𝐸𝐸𝑞𝑞𝜙𝜙 𝑧𝑧|𝑥𝑥 log𝑆𝑆𝜃𝜃(𝑥𝑥|𝑧𝑧) − 𝐾𝐾𝐾𝐾(𝑞𝑞𝜙𝜙 𝑧𝑧|𝑥𝑥 ||𝑆𝑆𝜃𝜃(𝑧𝑧))

likelihood / reconstruction error prior

Repeat Sample M points 𝑥𝑥1,⋯ , 𝑥𝑥𝑀𝑀 Sample M noise 𝜖𝜖1, … , 𝜖𝜖𝑀𝑀 from 𝑁𝑁 0, 𝐼𝐼 Perform forward- and backward-propagation 𝜃𝜃 ← 𝜃𝜃 − ∇𝜃𝜃𝑙𝑙(𝜙𝜙,𝜃𝜃) 𝜙𝜙 ← 𝜙𝜙 − ∇𝜙𝜙𝑙𝑙(𝜙𝜙,𝜃𝜃)

44

Page 45: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Generating data after training is finished

Encoderr

Decoder

𝑥𝑥

𝜇𝜇(𝑥𝑥)

𝑆𝑆 𝑧𝑧 = 𝑁𝑁 0, 𝐼𝐼

𝑞𝑞 𝑧𝑧 𝑥𝑥;𝜙𝜙 = 𝑁𝑁 𝜇𝜇 𝑥𝑥 ,𝜎𝜎2 𝑥𝑥 𝐼𝐼𝜎𝜎(𝑥𝑥)

𝑧𝑧

𝜇𝜇(𝑧𝑧) 𝜎𝜎(𝑧𝑧)

𝑥𝑥

𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧 = 𝑁𝑁 𝜇𝜇 𝑧𝑧 ,𝜎𝜎2 𝑧𝑧 𝐼𝐼

Generative model

Recognition model

Page 46: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Generating data after training is finished

Kingma & Welling, 2014

Data manifold for 2-d z

z1

z2

Encoderr

Decoder

𝑥𝑥

𝜇𝜇(𝑥𝑥)

𝑆𝑆 𝑧𝑧 = 𝑁𝑁 0, 𝐼𝐼

𝑞𝑞 𝑧𝑧 𝑥𝑥;𝜙𝜙 = 𝑁𝑁 𝜇𝜇 𝑥𝑥 ,𝜎𝜎2 𝑥𝑥 𝐼𝐼𝜎𝜎(𝑥𝑥)

𝑧𝑧

𝜇𝜇(𝑧𝑧) 𝜎𝜎(𝑧𝑧)

𝑥𝑥

𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧 = 𝑁𝑁 𝜇𝜇 𝑧𝑧 ,𝜎𝜎2 𝑧𝑧 𝐼𝐼

Generative model

Recognition model

Page 47: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Diagonal prior on z 𝑧𝑧 = 𝑧𝑧1, … , 𝑧𝑧𝑑𝑑 ~ 𝑁𝑁 0, 𝐼𝐼 𝑧𝑧𝑖𝑖’s are independent Each 𝑧𝑧𝑖𝑖 encode interpretable factors of variation

47

z1

z2

z1:degree of

smile

z2 :head pose

Page 48: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

More examples

48

32x32 CIFAR-10 Labeled Faces in the Wild

Page 49: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

VAE variants

Using label information Semi-supervised VAE [Kingma et al.,’14][Siddarth et al.,’17]

Structured-output VAE [Sohn et al.,’15]

Mixture of Gaussian VAE [Dilokthanakul et al.,’16][Gosh et al.,’19]

Disentangling with 𝛽𝛽-VAE [Higgins et al.,’17]

𝐸𝐸𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 log𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧 − 𝛽𝛽 ∗ 𝐾𝐾𝐾𝐾(𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 ||𝑆𝑆𝜃𝜃 𝑧𝑧 ), 𝛽𝛽 > 1

VQ-VAE [van den Oord et al.,’17], TD-VAE [Gregor et al.,’19]

49

Page 50: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Conditional VAE

Conditional VAE (supervised) Label 𝑐𝑐 observed during training/testing Both encoder and decoder use the label 𝑐𝑐 𝑆𝑆𝜃𝜃 𝑧𝑧 𝑐𝑐 = 𝑁𝑁(0, 𝐼𝐼) : same as before Loss:

log 𝑆𝑆𝜃𝜃 𝑥𝑥|𝑐𝑐 ≥ 𝐸𝐸𝑞𝑞𝜙𝜙 𝑧𝑧|𝑥𝑥,𝑐𝑐 log 𝑆𝑆𝜃𝜃(𝑥𝑥|𝑧𝑧, 𝑐𝑐)

− 𝐾𝐾𝐾𝐾(𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥, 𝑐𝑐 ||𝑆𝑆𝜃𝜃 𝑧𝑧 𝑐𝑐 ) Example:

(Cf. GAN vs C-GAN)

“Conditional VAE” [Sohn et al.,’15] Learn a map from image 𝑥𝑥 to image 𝑦𝑦

50

𝑧𝑧

𝑥𝑥

𝑐𝑐

Encoder

Decoder

𝑥𝑥

𝜇𝜇(𝑥𝑥, c)

𝑆𝑆𝜃𝜃 𝑧𝑧|𝑐𝑐 = 𝑁𝑁 0, 𝐼𝐼

𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥, 𝑐𝑐 = 𝑁𝑁 𝜇𝜇 𝑥𝑥, 𝑐𝑐 ,𝜎𝜎2 𝑥𝑥, 𝑐𝑐 𝐼𝐼𝜎𝜎(𝑥𝑥, 𝑐𝑐)

𝜇𝜇(𝑧𝑧, c) 𝜎𝜎(𝑧𝑧, 𝑐𝑐) 𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧, 𝑐𝑐 = 𝑁𝑁 𝜇𝜇 𝑥𝑥, 𝑐𝑐 ,𝜎𝜎2 𝑥𝑥, 𝑐𝑐 𝐼𝐼

Page 51: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Semi-supervised VAE

Semi-supervised VAE (M2) C-VAE cannot be used with

mixed labeled and unlabeled examples Decoder uses true/predicted label

𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧, 𝑐𝑐 = 𝑁𝑁 𝜇𝜇 𝑥𝑥, 𝑐𝑐 ,𝜎𝜎2 𝑥𝑥, 𝑐𝑐 𝐼𝐼 Encoder predicts 𝑐𝑐 and 𝑧𝑧

𝑞𝑞𝜙𝜙 𝑧𝑧, 𝑐𝑐 𝑥𝑥 = 𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 𝑞𝑞𝜙𝜙 𝑐𝑐 𝑥𝑥𝑞𝑞𝜙𝜙 (𝑧𝑧|𝑥𝑥, 𝑐𝑐) = 𝑁𝑁 𝜇𝜇 𝑥𝑥, 𝑐𝑐 ,𝜎𝜎2 𝑥𝑥, 𝑐𝑐 𝐼𝐼𝑞𝑞𝜙𝜙 𝑐𝑐 𝑥𝑥 = 𝐶𝐶𝑎𝑎𝐶𝐶(𝜋𝜋(𝑥𝑥))

Loss for labeled example log𝑆𝑆𝜃𝜃(𝑥𝑥, 𝑐𝑐) ≥ 𝐸𝐸𝑞𝑞𝜙𝜙(𝑧𝑧|𝑥𝑥,𝑐𝑐) [log𝑆𝑆𝜃𝜃 𝑥𝑥 𝑐𝑐, 𝑧𝑧 +

log𝑆𝑆𝜃𝜃 𝑐𝑐 + log𝑆𝑆 𝑧𝑧 − log𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥, 𝑐𝑐 ]

Loss for unlabeled example log𝑆𝑆𝜃𝜃(𝑥𝑥) ≥ 𝐸𝐸𝑞𝑞𝜙𝜙(𝑧𝑧,𝑐𝑐|𝑥𝑥) log𝑆𝑆𝜃𝜃 𝑥𝑥 𝑐𝑐, 𝑧𝑧 + log𝑆𝑆𝜃𝜃 𝑦𝑦 + log𝑆𝑆 𝑧𝑧 − log𝑞𝑞𝜙𝜙(𝑧𝑧, 𝑐𝑐|𝑥𝑥)

Several models possible [Kingma et al.,’14][Siddharth et al.,’17], also see https://pyro.ai/examples/ss-vae.html

51

𝑧𝑧

𝑥𝑥

𝑐𝑐

Encoder

Decoder

𝑥𝑥

𝜇𝜇(𝑥𝑥, 𝑐𝑐)

𝑆𝑆𝜃𝜃 𝑧𝑧 = 𝑁𝑁 0, 𝐼𝐼𝑆𝑆𝜃𝜃 𝑐𝑐 = 𝐶𝐶𝑎𝑎𝐶𝐶(𝜋𝜋)𝑞𝑞𝜙𝜙 𝑧𝑧, 𝑐𝑐 𝑥𝑥 = 𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 𝑞𝑞𝜙𝜙 𝑐𝑐 𝑥𝑥𝑞𝑞𝜙𝜙(𝑧𝑧|𝑥𝑥, 𝑐𝑐) = 𝑁𝑁 𝜇𝜇 𝑥𝑥, 𝑐𝑐 ,𝜎𝜎2 𝑥𝑥, 𝑐𝑐 𝐼𝐼𝑞𝑞𝜙𝜙 𝑐𝑐 𝑥𝑥 = 𝐶𝐶𝑎𝑎𝐶𝐶(𝜋𝜋(𝑥𝑥))

𝜎𝜎(𝑥𝑥, 𝑐𝑐)

𝜇𝜇(𝑧𝑧, c) 𝜎𝜎(𝑧𝑧, 𝑐𝑐) 𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧, 𝑐𝑐 = 𝑁𝑁 𝜇𝜇 𝑥𝑥, 𝑐𝑐 ,𝜎𝜎2 𝑥𝑥, 𝑐𝑐 𝐼𝐼

𝜋𝜋(𝑥𝑥)

𝑐𝑐

Page 52: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Mixture of Gaussian VAE for classification

MoG VAE [Gosh et al.,’19]

Label 𝑐𝑐 observed during training only Decoder predicts 𝑥𝑥 from 𝑧𝑧

𝑆𝑆𝜃𝜃 𝑥𝑥, 𝑐𝑐 𝑧𝑧 = 𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧 𝑆𝑆𝜃𝜃(𝑐𝑐|𝑧𝑧) 𝑆𝑆𝜃𝜃 𝑧𝑧 = ∑𝑐𝑐𝑤𝑤𝑐𝑐 𝑁𝑁(𝜇𝜇𝑐𝑐,𝜎𝜎𝑐𝑐2𝐼𝐼): mixture of Gaussians One Gaussian per class, centered at fixed 𝜇𝜇𝑐𝑐

Losslog 𝑆𝑆𝜃𝜃 𝑥𝑥, 𝑐𝑐 ≥ 𝐸𝐸𝑞𝑞 𝑧𝑧|𝑥𝑥;𝜙𝜙 log 𝑆𝑆𝜃𝜃 𝑥𝑥, 𝑐𝑐 𝑧𝑧

− 𝐾𝐾𝐾𝐾(𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 ||𝑆𝑆𝜃𝜃(𝑧𝑧)) Test time: approximate classification

argmaxc

𝑆𝑆𝜃𝜃(𝑐𝑐|𝑥𝑥) = argmaxc

𝑆𝑆𝜃𝜃(𝑥𝑥, 𝑐𝑐)

= argmaxc

∫ 𝑆𝑆𝜃𝜃 𝑥𝑥, 𝑐𝑐 𝑧𝑧 𝑆𝑆𝜃𝜃(𝑧𝑧)𝑑𝑑𝑧𝑧

= ⋯ = argmaxc

∫ 𝑆𝑆𝜃𝜃(𝑧𝑧 𝑥𝑥 𝑆𝑆𝜃𝜃 𝑐𝑐 𝑧𝑧 𝑑𝑑𝑧𝑧

≅ argmaxc

𝐸𝐸𝑞𝑞𝜙𝜙(𝑧𝑧|𝑥𝑥) 𝑆𝑆𝜃𝜃 𝑐𝑐 𝑧𝑧 ≅ argmaxc

1𝑀𝑀∑𝑖𝑖 𝑆𝑆𝜃𝜃 𝑐𝑐 𝑧𝑧𝑖𝑖 (use Monte-Carlo sampling again)

52

𝑧𝑧

𝑥𝑥

𝑐𝑐

Encoder

Decoder

𝑥𝑥

𝜇𝜇(𝑥𝑥)

𝑆𝑆𝜃𝜃 𝑧𝑧|𝑐𝑐 = 𝑁𝑁 𝜇𝜇𝑐𝑐, 𝐼𝐼 where𝜇𝜇𝑐𝑐 is one-hot encoding

𝑆𝑆𝜃𝜃 𝑐𝑐 𝑧𝑧 = 𝑝𝑝𝜃𝜃 𝑧𝑧 𝑐𝑐 𝑝𝑝𝜃𝜃(𝑐𝑐)∑𝑐𝑐 𝑝𝑝𝜃𝜃 𝑧𝑧 𝑐𝑐 𝑝𝑝𝜃𝜃(𝑐𝑐)

𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 = 𝑁𝑁 𝜇𝜇 𝑥𝑥 ,𝜎𝜎2 𝑥𝑥 𝐼𝐼𝜎𝜎(𝑥𝑥)

𝜇𝜇(𝑧𝑧) 𝜎𝜎(𝑧𝑧) 𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧 = 𝑁𝑁 𝜇𝜇 𝑧𝑧 ,𝜎𝜎2 𝑧𝑧 𝐼𝐼

Page 53: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

Mixture of Gaussian VAE for clustering

MoG VAE [Zhao et al.,’19]

Unsupervised learning Label/cluster number 𝑐𝑐 not observed Decoder predicts 𝑥𝑥 from 𝑧𝑧

𝑆𝑆𝜃𝜃 𝑥𝑥, 𝑧𝑧 𝑐𝑐 = 𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧 𝑆𝑆𝜃𝜃(𝑧𝑧|𝑐𝑐) 𝑆𝑆𝜃𝜃 𝑧𝑧 = ∑𝑐𝑐𝑤𝑤𝑐𝑐 𝑁𝑁(𝜇𝜇𝑐𝑐,𝜎𝜎𝑐𝑐2𝐼𝐼): mixture of Gaussians One Gaussian per class, centered at 𝜇𝜇𝑐𝑐

Encoder predicts 𝑐𝑐 and 𝑧𝑧 from 𝑥𝑥 𝑞𝑞𝜙𝜙 𝑧𝑧, 𝑐𝑐 𝑥𝑥 = 𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 𝑞𝑞𝜙𝜙 𝑐𝑐 𝑥𝑥 𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 = 𝑁𝑁 𝜇𝜇 𝑥𝑥 ,𝜎𝜎2 𝑥𝑥 𝐼𝐼 𝑞𝑞𝜙𝜙 𝑐𝑐 𝑥𝑥 = 𝐶𝐶𝑎𝑎𝐶𝐶(𝜋𝜋(𝑥𝑥))

Loss

log𝑆𝑆𝜃𝜃 𝑥𝑥 ≥ ∑𝑓𝑓 𝑞𝑞𝜙𝜙 𝑐𝑐𝑓𝑓 𝑥𝑥 log 𝜋𝜋𝑓𝑓

𝑞𝑞𝜙𝜙 𝑐𝑐𝑓𝑓 𝑥𝑥 + ∑𝑓𝑓 𝑞𝑞𝜙𝜙 𝑐𝑐𝑓𝑓 𝑥𝑥 𝑙𝑙𝑓𝑓(𝑥𝑥),

where 𝑙𝑙𝑓𝑓 𝑥𝑥 = 𝐸𝐸𝑞𝑞𝜙𝜙(𝑧𝑧|𝑥𝑥) 𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧 − 𝐾𝐾𝐾𝐾(𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 || 𝑆𝑆𝜃𝜃(𝑧𝑧|𝑐𝑐𝑓𝑓))

53

𝑆𝑆𝜃𝜃 𝑧𝑧|𝑐𝑐 = 𝑁𝑁 𝜇𝜇𝑐𝑐, 𝐼𝐼𝑆𝑆𝜃𝜃 𝑐𝑐 = 𝐶𝐶𝑎𝑎𝐶𝐶 𝜋𝜋

𝑆𝑆𝜃𝜃 𝑐𝑐 𝑧𝑧 = 𝑝𝑝 𝑧𝑧 𝑐𝑐 𝑝𝑝(𝑐𝑐)∑𝑐𝑐 𝑝𝑝 𝑧𝑧 𝑐𝑐 𝑝𝑝(𝑐𝑐)

𝑧𝑧

𝑥𝑥

𝑐𝑐

Encoder

Decoder

𝑥𝑥

𝜇𝜇(𝑥𝑥) 𝑞𝑞𝜙𝜙 𝑧𝑧, 𝑦𝑦 𝑥𝑥 = 𝑞𝑞𝜙𝜙 𝑧𝑧 𝑥𝑥 𝑞𝑞𝜙𝜙 𝑐𝑐 𝑥𝑥𝑞𝑞𝜙𝜙(𝑧𝑧|𝑥𝑥) = 𝑁𝑁 𝜇𝜇 𝑥𝑥 ,𝜎𝜎2 𝑥𝑥 𝐼𝐼𝑞𝑞𝜙𝜙 𝑐𝑐 𝑥𝑥 = 𝐶𝐶𝑎𝑎𝐶𝐶(𝜋𝜋(𝑥𝑥))

𝜎𝜎(𝑥𝑥)

𝜇𝜇(𝑧𝑧) 𝜎𝜎(𝑧𝑧) 𝑆𝑆𝜃𝜃 𝑥𝑥 𝑧𝑧 = 𝑁𝑁 𝜇𝜇 𝑥𝑥 ,𝜎𝜎2 𝑥𝑥 𝐼𝐼

𝜋𝜋(𝑥𝑥)

Page 54: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

VAE vs GAN

Pros of VAE Principled approach to generative data Can use all previous knowledge about probabilistic models Allows inference of 𝑞𝑞𝜙𝜙(𝑧𝑧|𝑥𝑥), can be useful feature representation

for other tasks Cons of VAE

Maximizes lower bound of likelihood. Cannot evaluate the likelihood directly

Samples blurrier and lower quality compared to state-of-the-art (GANs)

54

Page 55: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

VAE+GAN

VAEGAN [Larsen et al.,’16]

Pixelwise likelihood term 𝐸𝐸𝑞𝑞 log 𝑆𝑆 𝑥𝑥 𝑧𝑧 causes blurriness Use GAN to measure whole-image likelihood VAE: ELBO = Lprior + Lpixel-likelihood

VAEGAN: ELBO = Lprior + Ldisc-likelihood + LGAN

Adversarial autoencoder [Makehazani et al.,’14] Unsupervised/supervised Semi-supervised Dimensionality-reduction

Energy-based GAN [Zhao et al.,’17 ], …

55

𝑐𝑐

Page 56: CMPS 6720: Machine Learning Neural Networksjhamm3/papers/Lecture11-Deep-Generative-Mo… · learning problems Conv Nets, VGG, Inception, ResNets, U-Nets, … Classification of 1000-classes

References

Kingma, D., Welling, M. (2014) Auto-Encoding Variational Bayes, ICLR Kingma, D., Rezende, D., Mohamed, S., Welling, M. (2014) Semi-Supervised Learning with Deep Generative

Models, NIPS Sohn, K., Lee, H., & Yan, X. (2015) Learning structured output representation using deep conditional generative

models, NIPS Makhzani, A., Shlens, J., Jaitly, N., Goodfellow, I., & Frey, B. (2015) Adversarial autoencoders. arXiv preprint

arXiv:1511.05644 Dilokthanakul, N., Mediano, P. A., Garnelo, M., Lee, M. C., Salimbeni, H., Arulkumaran, K., & Shanahan, M.

(2016). Deep unsupervised clustering with gaussian mixture variational autoencoders. arXiv preprint arXiv:1611.02648.

Larsen, A., Sønderby, A., Larochelle, H., Winther, O. (2016) Autoencoding beyond pixels using a learned similarity metric, ICML

Zhao, J., Mathieu, M., & LeCun, Y. (2016). Energy-based generative adversarial network. arXiv preprint arXiv:1609.03126.

Siddharth, N., Paige, B., Van de Meent, J. W., Desmaison, A., Goodman, N., Kohli, P., ... & Torr, P. (2017). Learning disentangled representations with semi-supervised deep generative models. NIPS

Burgess, C. P., Higgins, I., Pal, A., Matthey, L., Watters, N., Desjardins, G., & Lerchner, A. (2018). Understanding disentangling in beta-VAE. arXiv preprint arXiv:1804.03599

Zhao, Q., Honnorat, N., Adeli, E., Pfefferbaum, A., Sullivan, E. V., & Pohl, K. M. (2019, June). Variational Autoencoder with Truncated Mixture of Gaussians for Functional Connectivity Analysis, IPMI

Ghosh, P., Losalka, A., & Black, M. J. (2019, July). Resisting adversarial attacks using gaussian mixture variational autoencoders. AAAI

56