adversarial learning

32
Adversarial Learning o What is a GAN? o Some mathematical background o Algorithms for training a GAN o The Wasserstein GAN o Conditional GANs

Upload: others

Post on 13-Jun-2022

12 views

Category:

Documents


0 download

TRANSCRIPT

Page 1: Adversarial Learning

Adversarial LearningoWhat is a GAN?oSome mathematical backgroundoAlgorithms for training a GANoThe Wasserstein GANoConditional GANs

Page 2: Adversarial Learning

How do we sample from a distribution?

๐‘Œ โˆผ ๐‘! ๐‘ฆ!Parametric methods:

โ€“ Some known parametric distribution:

๐‘! ๐‘ฆ =1๐‘ exp โˆ’๐‘ข ๐‘ฆ

โ€“ Monte Carlo Markov Chain (MCMC) methodsโ€“ Hastings-Metropolis samplingโ€“ Requires a known distribution

!Non-parametric methods:โ€“ Provided with samples ๐‘ฆ", โ‹ฏ , ๐‘ฆ#$%โ€“ Infer distribution from samplesโ€“ Generator

๐‘Œ = โ„Ž" ๐‘

1

https://sdv.dev/Copulas/tutorials/03_Multivariate_Distributions.html

Multivariate random vector in โ„œ!

Multivariate density in โ„œ!

Source of randomness; ๐‘ โˆผ ๐‘ 0, ๐ผ

Random vector with desired distribution

Page 3: Adversarial Learning

Training a Generator

!Function of GAN:โ€“ Generates samples, -๐‘ฆ&, with the same distribution as ๐‘ฆ&.โ€“ Can use drop-outs to generate randomness

!Training algorithm:โ€“ Compare the distributions of -๐‘ฆ& and ๐‘ฆ&.โ€“ Feedback parameter corrections

(๐‘ฆ"generatedsamples

Generator(๐‘ฆ" = โ„Ž# ๐‘ง"

TrainingAlgorithm

๐‘ง"independent noise source

๐‘ฆ"referencesamples

๐œƒ

Page 4: Adversarial Learning

The Bayes Discriminator

!Use Bayes rule:๐‘ƒ ๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐‘…|๐‘ฆ =

๐‘ƒ ๐‘ฆ|๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐‘… ๐‘ƒ ๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐‘…๐‘ƒ ๐‘ฆ|๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐‘… ๐‘ƒ ๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐‘… + ๐‘ƒ ๐‘ฆ|๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐น ๐‘ƒ ๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐น

โ€ข Assuming ๐‘ƒ ๐‘ฆ|๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐น = ๐‘!$ ๐‘ฆ , ๐‘ƒ ๐‘ฆ|๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐‘… = ๐‘" ๐‘ฆ , ๐‘ƒ ๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐‘… =๐‘ƒ ๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐น = #

$ , we have that

๐‘ƒ ๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐‘…|๐‘ฆ = =๐‘% ๐‘ฆ ยฝ

๐‘% ๐‘ฆ ยฝ + ๐‘#! ๐‘ฆ ยฝ=

11 + ๐‘…#! ๐‘ฆ

โ€“ where ๐‘…#! ๐‘ฆ =&"! '

&# 'is the likelihood ratio defined by

(๐‘ฆ"generatedsamples

Generator(๐‘ฆ" = โ„Ž#! ๐‘ง"

๐‘ง"independentnoise source

๐‘ฆ"referencesamples

โ€œR -Realโ€

โ€œF - Fakeโ€

What is the probability that an observation, ๐‘ฆ, is real?

Page 5: Adversarial Learning

Implementing a Bayes Discriminator

!Bayesian Discriminator:๐‘“'( ๐‘ฆ โ‰ˆ ๐‘ƒ ๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐‘…|๐‘ฆ

= !% "!% " #!&' "

= $$#%&' "

where

๐‘…'! ๐‘ฆ =๐‘'! ๐‘ฆ๐‘! ๐‘ฆ

Likelihood Ratio

(๐‘ฆ"generatedsamples

Generator(๐‘ฆ" = โ„Ž#! ๐‘ง"

Discriminator (๐‘" = ๐‘“#( (๐‘ฆ"

๐‘ง"independentnoise source

๐‘ฆ"referencesamples

Discriminator ๐‘" = ๐‘“#( ๐‘ฆ"

Should be mostly 0s

Should be mostly 1s

Page 6: Adversarial Learning

Bayes Discriminator and the Likelihood Ratio

!Bayesian Discriminator:๐‘“'( ๐‘ฆ โ‰ˆ ๐‘ƒ ๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐‘…|๐‘ฆ

= !% "!% " #!&' "

= $$#%&' "

generated distribution ๐‘#! ๐‘ฆ

reference distribution ๐‘% ๐‘ฆ

๐‘ฆ

likelihood ratio ๐‘…#! ๐‘ฆ

1

Classify as โ€œrealโ€ Classify as โ€œfakeโ€

๐‘ฆ

Page 7: Adversarial Learning

Training a Bayes Discriminator

!Discriminator loss function:7๐‘‘ ๐œƒ), ๐œƒ( =

1๐พ;&*"

#$%

โˆ’ log ๐‘“'( ๐‘ฆ& โˆ’ log 1 โˆ’ ๐‘“'( -๐‘ฆ&

!Optimal discriminator parameter:

๐œƒ(โˆ— = argmin'"

7๐‘‘ ๐œƒ), ๐œƒ(

โ€“ Results in ML estimate of Bayes classifier parameters.

(๐‘ฆ"generatedsamples

Generator(๐‘ฆ" = โ„Ž#! ๐‘ง"

Discriminator(๐‘" = ๐‘“#( (๐‘ฆ"

๐‘ง"

๐‘ฆ"

๐‘ฆ"referencesamples

0 =Generated

Discriminator๐‘" = ๐‘“#( ๐‘ฆ"

1 =Reference+ 9๐‘‘ ๐œƒ) , ๐œƒ(

CrossEntropy( (๐‘!, 0)

CrossEntropy(๐‘!, 1)๐‘"

(๐‘"

Page 8: Adversarial Learning

Training a Generator

(๐‘ฆ"generatedsamples

Generator(๐‘ฆ" = โ„Ž#! ๐‘ง"

Discriminator (๐‘" = ๐‘“#( (๐‘ฆ"

๐‘ง" <๐‘” ๐œƒ) , ๐œƒ(Loss Function

๐ฟ (๐‘!

!Big idea: Maximize the probability that outputs of the generator are classified as being from the reference distribution.

โ€“ /๐‘/ should be largeโ€“ ๐ฟ /๐‘/ should be small when /๐‘/ is largeโ€“ ๐ฟ /๐‘/ should be a decreasing function of /๐‘/

!Generator loss function:

1๐‘” ๐œƒ0, ๐œƒ1 =1๐พ6/23

45#

๐ฟ ๐‘“!1 /๐‘ฆ/

!Optimal generator parameter:

๐œƒ0โˆ— = argmin!$

๐‘” ๐œƒ0, ๐œƒ1

Loss should encourage(๐‘" to be large

Page 9: Adversarial Learning

Generative Adversarial Network (GAN)*

(๐‘ฆ"generatedsamples

Generator(๐‘ฆ" = โ„Ž#! ๐‘ง"

Discriminator (๐‘" = ๐‘“#( (๐‘ฆ"

๐‘ง"

๐‘ฆ"

๐‘ฆ"referencesamples

0 =Generated

Discriminator ๐‘" = ๐‘“#( ๐‘ฆ"

1 =Reference + 9๐‘‘ ๐œƒ) , ๐œƒ(

Loss Function๐ฟ (๐‘!

<๐‘” ๐œƒ) , ๐œƒ(

CrossEntropy( (๐‘!, 0)

CrossEntropy(๐‘!, 1)

!Generator loss function:

D๐‘” ๐œƒ), ๐œƒ( =1๐พ;&*"

#$%

๐ฟ ๐‘“'( -๐‘ฆ&

!Discriminator loss function:7๐‘‘ ๐œƒ), ๐œƒ( =

1๐พ;&*"

#$%

โˆ’ log ๐‘“'( ๐‘ฆ& โˆ’ log 1 โˆ’ ๐‘“'( -๐‘ฆ&

*Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio, โ€œGenerative Adversarial Networksโ€, Proc. of the Intern. Conference on Neural Information Processing Systems (NIPS 2014). pp. 26

Page 10: Adversarial Learning

GAN: Expected Loss Functions

>๐‘Œ#!generatedsamples

Generator>๐‘Œ = โ„Ž#! ๐‘

Discriminator >๐‘ƒ#! = ๐‘“#(

>๐‘Œ#!๐‘

๐‘Œ

๐‘Œreferencesamples

0 =Generated

Discriminator ๐‘ƒ = ๐‘“#( ๐‘Œ

1 =Reference + ๐‘‘ ๐œƒ) , ๐œƒ(

Loss Function๐ฟ 7๐‘ƒ"!

๐‘” ๐œƒ) , ๐œƒ(

CrossEntropy( 7๐‘ƒ, 0)

CrossEntropy(๐‘ƒ, 1)

!Generator loss function:๐‘” ๐œƒ), ๐œƒ( = ๐ธ ๐ฟ ๐‘“'( H๐‘Œ'!

!Discriminator loss function:๐‘‘ ๐œƒ), ๐œƒ( = ๐ธ โˆ’ log ๐‘“'( ๐‘Œ + ๐ธ โˆ’log 1 โˆ’ ๐‘“'( H๐‘Œ'!

โ€“ By the weak and strong law of large numbers lim#โ†’-

D๐‘” = ๐‘” and lim#โ†’-

7๐‘‘ = ๐‘‘

Page 11: Adversarial Learning

Generator Loss Function Choices

!Option 0: Original loss function proposed in [1].โ€“ ๐ฟ ๐‘ = log 1 โˆ’ ๐‘ ๐ฟ 0 = 0; ๐ฟ 1 = โˆ’โˆž

โ€“ ๐‘” ๐œƒ), ๐œƒ( = ๐ธ log 1 โˆ’ ๐‘“'( H๐‘Œ'!โ€“ Presented as key theoretically grounded approach in Goodfellow paper.โ€“ Consistent with zero-sum game Nash equilibrium theoryโ€“ Almost no one uses it.

!Option 1: โ€œNon-saturatingโ€ loss function, i.e., the โ€œ-log D trickโ€โ€“ ๐ฟ ๐‘ = โˆ’ log ๐‘ ๐ฟ 0 = +โˆž; ๐ฟ 1 = 0

โ€“ ๐‘” ๐œƒ), ๐œƒ( = ๐ธ โˆ’ log ๐‘“'( H๐‘Œ'!โ€“ Mentioned as trick in [1] to keep the training loss from โ€œsaturatingโ€.โ€“ This is what you get if you use cross-entropy loss for generator โ€“ This is what is commonly done.

[1] Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio. โ€œGenerative Adversarial Networksโ€, Proc. of the Intern. Conference on Neural Information Processing Systems (NIPS 2014). pp. 26

Page 12: Adversarial Learning

GAN Architecture(Non-Saturating)

>๐‘Œ#!generatedsamples

Generator>๐‘Œ = โ„Ž#! ๐‘

Discriminator >๐‘ƒ#! = ๐‘“#(

>๐‘Œ#!๐‘

๐‘Œ

๐‘Œreferencesamples

0 =Generated

Discriminator ๐‘ƒ = ๐‘“#( ๐‘Œ

1 =Reference + ๐‘‘ ๐œƒ) , ๐œƒ(

CrossEntropy(๐‘ƒ, 1) ๐‘” ๐œƒ) , ๐œƒ(

CrossEntropy( 7๐‘ƒ, 0)

CrossEntropy(๐‘ƒ, 1)

!Generator loss function:๐‘” ๐œƒ), ๐œƒ( = ๐ธ โˆ’ log ๐‘“'( H๐‘Œ'!

!Discriminator loss function:๐‘‘ ๐œƒ), ๐œƒ( = ๐ธ โˆ’ log ๐‘“'( ๐‘Œ + ๐ธ โˆ’log 1 โˆ’ ๐‘“'( H๐‘Œ'!

โ€“ By the weak and strong law of large numbers lim#โ†’-

D๐‘” = ๐‘” and lim#โ†’-

7๐‘‘ = ๐‘‘

Page 13: Adversarial Learning

GAN Equilibrium Conditions (Non-Saturating)

!We would like to find the solution to:

๐œƒ)โˆ— = argmin'!

๐‘” ๐œƒ), ๐œƒ(โˆ—

= argmin'!

๐ธ โˆ’ log ๐‘“'( H๐‘Œ'!

๐œƒ(โˆ— = argmin'"

๐‘‘ ๐œƒ)โˆ—, ๐œƒ(

= argmin'"

๐ธ โˆ’ log ๐‘“'( ๐‘Œ + ๐ธ โˆ’log 1 โˆ’ ๐‘“'( H๐‘Œ'!

โ€“ This is known as a Nash Equilibriumโ€“ We would like it to converge to ๐‘…โˆ— = 1 โ‡’ (generated = reference distributions)

!How do we solve this?

!Will it converge?

Page 14: Adversarial Learning

Nash Equilibrium with Two Agents*

!Agent ๐บ:โ€“ Controls parameter ๐œƒ0โ€“ Goal is to minimize ๐‘” ๐œƒ0, ๐œƒ1โˆ—

!Agent ๐ทโ€“ Controls parameter ๐œƒ1โ€“ Goal is to minimize ๐‘‘ ๐œƒ0โˆ—, ๐œƒ1

๐‘” ๐œƒO, ๐œƒPminimize meter

knob

๐œƒO

๐‘‘ ๐œƒO, ๐œƒPminimize meter

knob

๐œƒP

๐บ Agent ๐ท Agent

๐œƒO

๐œƒP

!Each Agent tries to minimize their meter

*Graphics and art reproduced from โ€œstick figureโ€ Wiki page.

๐‘” ๐œƒ)โˆ—, ๐œƒ(โˆ— = min'!

๐‘” ๐œƒ), ๐œƒ(โˆ—

๐‘‘ ๐œƒ)โˆ—, ๐œƒ(โˆ— = min'"

๐‘‘ ๐œƒ)โˆ—, ๐œƒ(

Page 15: Adversarial Learning

Zero-Sum Game: Special Nash Equilibrium*

!Agent ๐บ:โ€“ Goal is to minimize ๐‘” ๐œƒ0, ๐œƒ1โˆ—

โ€“ Goal is to maximize ๐‘‘ ๐œƒ0, ๐œƒ1โˆ—

!Agent ๐ทโ€“ Goal is to minimize ๐‘‘ ๐œƒ0โˆ—, ๐œƒ1

๐‘‘ ๐œƒO, ๐œƒPmaximize meter

knob

๐œƒO

๐‘‘ ๐œƒO, ๐œƒPminimize meter

knob

๐œƒP

๐บ Agent ๐ท Agent

๐œƒO

๐œƒP

!Special case when ๐‘” ๐œƒ&, ๐œƒ' = โˆ’๐‘‘ ๐œƒ&, ๐œƒ'

*Graphics and art reproduced from โ€œstick figureโ€ Wiki page.

๐‘‘ ๐œƒ)โˆ—, ๐œƒ(โˆ— = max'!

๐‘‘ ๐œƒ), ๐œƒ(โˆ—

๐‘‘ ๐œƒ)โˆ—, ๐œƒ(โˆ— = min'"

๐‘‘ ๐œƒ)โˆ—, ๐œƒ(

Adversarial relationship

Page 16: Adversarial Learning

Computing the GAN Equilibrium

!Reparameterizing equations

!Alternating minimization โ‡’ mode collapse

!Generator loss gradient descent

!Practical convergence issues

Page 17: Adversarial Learning

Reparameterize Loss Functions!Goal: Replace ๐œƒO and ๐œƒ' with ๐‘… and ๐‘“

!Generator parameter ๐‘…:โ€“ Generated samples are

)๐‘Œ โˆผ ๐‘… ๐‘ฆ ๐‘Q ๐‘ฆ = ๐‘R ๐‘ฆ

where ๐‘… ๐‘ฆ =.#! /

.$ /

!Discriminator parameter ๐‘“:โ€“ Discriminator is

๐‘“ ๐‘ฆ = ๐‘ƒ ๐ถ๐‘™๐‘Ž๐‘ ๐‘  = ๐‘…|๐‘ฆ

!Important facts:โ€“ ๐ธ ๐‘…(๐‘Œ) = 1โ€“ ฮฉ) = ๐‘…:โ„œ0 โ†’ 0,โˆž such that ๐ธ ๐‘… ๐‘Œ = 1โ€“ ฮฉ( = ๐‘“:โ„œ0 โ†’ 0,1โ€“ For any function โ„Ž ๐‘ฆ , ๐ธ โ„Ž H๐‘Œ = ๐ธ โ„Ž ๐‘Œ ๐‘…(๐‘Œ)

Page 18: Adversarial Learning

GAN Equilibrium Conditions

!We would like to find the solution to:

๐‘…โˆ— = arg minRโˆˆa@

๐‘” ๐‘…, ๐‘“โˆ—

๐‘“โˆ— = arg minbโˆˆaA

๐‘‘ ๐‘…โˆ—, ๐‘“

โ€“ This is known as a Nash Equilibriumโ€“ We would like it to converge to ๐‘…โˆ— = 1 โ‡’ (generated = reference distributions)

!How do we do this?

!Will it converge?

Page 19: Adversarial Learning

Reparameterized Loss Functions!Generator loss function:

๐‘” ๐‘…, ๐‘“ = ๐ธ โˆ’ log ๐‘“ H๐‘Œ= ๐ธ โˆ’๐‘…(๐‘Œ) log ๐‘“ ๐‘Œ

!Discriminator loss function:๐‘‘ ๐‘…, ๐‘“ = ๐ธ โˆ’ log ๐‘“ ๐‘Œ + ๐ธ โˆ’log 1 โˆ’ ๐‘“ H๐‘Œ

= ๐ธ โˆ’ log ๐‘“ ๐‘Œ + ๐ธ โˆ’๐‘… ๐‘Œ log 1 โˆ’ ๐‘“ ๐‘Œ

!Nash equilibrium:

๐‘…โˆ— = arg min1โˆˆ3!

๐‘” ๐‘…, ๐‘“โˆ—

๐‘“โˆ— = arg min4โˆˆ3"

๐‘‘ ๐‘…โˆ—, ๐‘“

Page 20: Adversarial Learning

Method 1: Alternating Minimization!Algorithm

!Discriminator update

๐‘“โˆ— ๐‘ฆ โ†1

1 + ๐‘… ๐‘ฆ!Generator update

๐‘…โˆ— ๐‘ฆ โ† ๐›ฟ ๐‘ฆ โˆ’ ๐‘ฆ5 where ๐‘ฆ5 = max/๐‘“ ๐‘ฆ

Repeat {๐‘“โˆ— โ† arg min

)โˆˆ+9๐‘‘ ๐‘…โˆ—, ๐‘“

๐‘…โˆ— โ† arg min%โˆˆ+'

๐‘” ๐‘…, ๐‘“โˆ—

} Doesnโ€™t Work!

Problem:โ€ข This is called โ€œMode collapseโ€โ€ข Only generates sample that the discriminator likes bestโ€ข Intuition:

โ€œWe come from France.โ€ โ€œI like cheese steaks.โ€โ€œToo good to be true.โ€โ€œToo creepy to be real.โ€

Page 21: Adversarial Learning

Method 2: Generator Loss Gradient Descent (GLGD)!Algorithm

!Discriminator update

๐‘“โˆ— ๐‘ฆ โ†1

1 + ๐‘… ๐‘ฆ

!Generator updateโ€“ Take a step in the negative direction of the generator loss gradient.โ€“ ๐‘ƒ: - project into the allow parameter space. (This is not an issue in practice.)

๐‘ƒ:๐‘‘ ๐‘ฆ = ๐‘‘ ๐‘ฆ โˆ’๐ธ ๐‘‘(๐‘Œ)๐ธ ๐‘"(๐‘Œ)

๐‘"(๐‘ฆ)

!Questions/Comments:โ€“ Can be applied with a wide variety of generator/discriminator loss functionsโ€“ Does this converge?โ€“ If so, then what (if anything) is being minimized?

*Martin Arjovsky and Leon Bottou, โ€œTowards Principled Methods for Training Generative Adversarial Networksโ€, ICLR 2017.

Repeat {๐‘“โˆ— โ† arg min

bโˆˆaA๐‘‘ ๐‘…, ๐‘“

๐‘… โ† ๐‘… โˆ’ ๐›ผ๐‘ƒfโˆ‡R๐‘” ๐‘…, ๐‘“โˆ—}

My term

gradient descent stepProjection onto valid

parameter space

Page 22: Adversarial Learning

GLGD Convergence for Non-Saturating GAN

*This is an equivalent expression to Theorem 2.5 of [1].[1] Martin Arjovsky and Leon Bottou, โ€œTowards Principled Methods for Training Generative Adversarial Networksโ€, ICLR 2017.

!For non-saturating GAN when ๐‘“1โˆ— ๐‘ฆ = arg min4โˆˆ3"

๐‘‘ ๐‘…, ๐‘“ , then it can be shown that

โˆ‡1 ๐‘” ๐‘…, ๐‘“โˆ— = โˆ‡1๐ถ ๐‘…

where*

๐ถ ๐‘… = ๐ธ 1 + ๐‘… ๐‘Œ log 1 + ๐‘… ๐‘Œ

!Conclusions:โ€“ GLGD is really a gradient descent algorithm for the cost function ๐ถ ๐‘… .โ€“ 1 + ๐‘ฅ log 1 + ๐‘ฅ is a strictly convex function ๐‘ฅ: โ€“ Therefore, we know that ๐ถ ๐‘… has a unique global minimum at ๐‘… ๐‘Œ = 1.โ€“ However, convergence tends to be slow.

Repeat {๐‘“โˆ— โ† arg min

)โˆˆ+9๐‘‘ ๐‘…, ๐‘“

๐‘… โ† ๐‘… โˆ’ ๐›ผ๐‘ƒ,โˆ‡%๐‘” ๐‘…, ๐‘“โˆ—}

Page 23: Adversarial Learning

More Details on GLGD

*Martin Arjovsky and Leon Bottou, โ€œTowards Principled Methods for Training Generative Adversarial Networksโ€, ICLR 2017.

!Arjovsky and Bottou showed that for the non-saturating GAN*๐‘ƒ5 โˆ‡1 ๐‘” ๐‘…, ๐‘“โˆ— = โˆ‡1 ๐พ๐ฟ ๐‘1||๐‘! โˆ’ 2 ๐ฝ๐‘†๐ท ๐‘1||๐‘!

!So from previous identities, we have that:

๐‘ƒ5 โˆ‡1 ๐‘” ๐‘…, ๐‘“โˆ— = โˆ‡1 2 ๐พ๐ฟ ๐‘1 + ๐‘! /2||๐‘!

= ๐‘ƒ5โˆ‡1๐ธ 1 + ๐‘… ๐‘Œ log 1 + ๐‘… ๐‘Œ

!Conclusions:โ€“ GLGD is really a gradient descent algorithm for the cost function

๐ถ ๐‘… = 2 ๐พ๐ฟ ๐‘ƒ; + ๐‘ƒ" /2||๐‘ƒ"โ€“ ๐ถ ๐‘… has a unique global minimum at ๐‘ƒ; = ๐‘ƒ"โ€“ However, convergence tends to be slow.

Repeat {๐‘“โˆ— โ† arg min

)โˆˆ+9๐‘‘ ๐‘…, ๐‘“

๐‘… โ† ๐‘… โˆ’ ๐›ผ๐‘ƒ,โˆ‡%๐‘” ๐‘…, ๐‘“โˆ—}

Page 24: Adversarial Learning

Convergence of GANs

!Generator and discriminator at convergence:โ€“ Reference and generated distribution are the same โ‡’ ๐‘…โˆ— ๐‘Œ = 1

โ€“ Discriminator can not distinguish distributions โ‡’ ๐‘“โˆ— ๐‘ฆ = ##<; =

= #$

!At convergence the generated and reference distributions are identical. โ€“ Therefore, the likelihood ratio is ๐‘… ๐‘ฆ = 1; โ€“ The generated (fake) and reference (real) distributions are identical;โ€“ The discriminator assigns a 50/50 probability to either case because they are

indistinguishable.

โ€“ Then both the generator are discriminator cross-entropy losses are โˆ’log #$ โ‰ˆ 0.693.

!In practice, things donโ€™t usually work out this wellโ€ฆ

Page 25: Adversarial Learning

Method 2: Practical Algorithm!Algorithm

!What you would like to see

Repeat {For ๐‘1 iterations {

๐ต โ† ๐บ๐‘’๐‘ก๐‘…๐‘Ž๐‘›๐‘‘๐‘œ๐‘š๐ต๐‘Ž๐‘ก๐‘โ„Ž๐œƒ1 โ† ๐œƒ1 โˆ’ ๐›ฝโˆ‡!*๐‘‘ ๐œƒ0, ๐œƒ1; ๐ต

}๐ต โ† ๐บ๐‘’๐‘ก๐‘…๐‘Ž๐‘›๐‘‘๐‘œ๐‘š๐ต๐‘Ž๐‘ก๐‘โ„Ž๐œƒ0 โ† ๐œƒ0 โˆ’ ๐›ผโˆ‡!$๐‘” ๐œƒ0, ๐œƒ1; ๐ต}

Iteration #

Loss

Generator Loss

ยฝ Discriminator Loss

โˆ’ log12 โ‰ˆ 0.693

Looks good, butโ€ฆโ€“ Could result from a discriminator with insufficient capacity

Page 26: Adversarial Learning

Failure Mode: Mode Collapse!Algorithm

!Sometimes you get mode collapse

Iteration #

Loss

Generator Loss

Discriminator dominates

ยฝ Discriminator Loss

โ‰ˆ 0.693

Repeat {For ๐‘– = 0 to ๐‘' โˆ’ 1 {

๐œƒ' โ† ๐œƒ' โˆ’ ๐›ฝโˆ‡-9๐‘‘ ๐œƒ&, ๐œƒ'}๐œƒ& โ† ๐œƒ& โˆ’ ๐›ผโˆ‡-'๐‘” ๐œƒ&, ๐œƒ'}

Might be caused by:โ€“ Overfitting by discriminatorโ€“ Insufficient number of discriminator updatesโ€“ Insufficient generator capacity

Page 27: Adversarial Learning

Concept Wasserstein GAN

*Martin Arjovsky, Soumith Chintala and Leon Bottou, โ€œWasserstein Generative Adversarial Networksโ€, ICML 2017.

!Problem with GAN training using โ€œ โˆ’ log๐ท trickโ€โ€“ Slow and sometimes unstable convergenceโ€“ Problems with vanishing gradient

!Conjecture:โ€“ The problem is caused by the discriminator function.โ€“ Bayes classifier is

โ€ข Too sensitive and too nonlinearโ€ข Non-overlapping distributions create vanishing gradients that slow convergence.

!Base discriminator of the Wasserstein distance (i.e., earth mover distance)

*Reproduced from paper

Page 28: Adversarial Learning

Wasserstein Fundamentals

!Based on Kantorovich-Rubinstein duality (Villani, 2009)*

๐‘Š ๐‘Q||๐‘R = supb Bmn

๐ธ ๐‘“ ๐‘Œ โˆ’ ๐ธ ๐‘“ )๐‘ŒR

โ€“ where ๐‘“ 6 is the Lipschitz constant of ๐‘“

โ€“ ๐‘“ 6 โ‰ค 1 is referred to as the 1-Lipschitz condition

*Villani, Cedric. Optimal Transport: Old and New. Grundlehren der mathematischen Wissenschaften. Springer, Berlin, 2009.

Page 29: Adversarial Learning

Wasserstein GAN*!Then the fundamental result of the Arjovsky and Bottou

โ€“ Define the sets ๐‘… โˆˆ ฮฉ) as usual, but define and ๐‘“ โˆˆ ฮฉ) so that

ฮฉ'. = ๐‘“:โ„œ/ โ†’ โˆ’โˆž,โˆž ๐‘ . ๐‘ก. ๐‘“ 0 โ‰ค 1

โ€“ Then define Wasserstein generator and discriminator loss functions as๐‘”1 ๐‘…, ๐‘“ = ๐ธ โˆ’๐‘“ L๐‘Œ%

๐‘‘1 ๐‘…, ๐‘“ = ๐ธ ๐‘“ L๐‘Œ% โˆ’ ๐ธ ๐‘“ ๐‘Œ%

!Key result from Arjovsky paper*:โˆ‡%๐‘”1 ๐‘…, ๐‘“โˆ— = โˆ‡%๐‘Š(๐‘2||๐‘%)

โ€“ where๐‘“โˆ— = arg min

)โˆˆ39>๐‘‘4 ๐‘…, ๐‘“

*Martin Arjovsky, Soumith Chintala, and Leon Bottou, โ€œWasserstein Generative Adversarial Networksโ€, ICML 2017.

Page 30: Adversarial Learning

Method 3: Wasserstein Algorithm!Algorithm

!Discriminator updateโ€“ How do we solve the problem of minimizing the discriminator loss with the Lipschitz

constraint?โ€“ Answer: We clip the discriminator weights during training.โ€“ Observation: Isnโ€™t this just regularization of the discriminator DNN??

!Generator updateโ€“ Take a step in the negative direction of the generator loss gradient descent.

*Martin Arjovsky and Leon Bottou, โ€œTowards Principled Methods for Training Generative Adversarial Networksโ€, ICLR 2017.

Repeat {๐‘“โˆ— = arg min

bโˆˆpAC๐‘‘q ๐‘…, ๐‘“

๐‘… โ† ๐‘… โˆ’ ๐›ผ๐‘ƒfโˆ‡R๐‘”q ๐‘…, ๐‘“โˆ—}

Page 31: Adversarial Learning

Method 3: Wasserstein Practical Algorithm!Algorithm

!Observations:โ€“ Some people seem to feel the Wasserstein GAN has better convergence.โ€“ However, is this because of the Wasserstein metric?โ€“ Or is it because of the other algorithmic improvements?

*Martin Arjovsky and Leon Bottou, โ€œTowards Principled Methods for Training Generative Adversarial Networksโ€, ICLR 2017.

Make sure to get new batches

Iterate discriminator to approximate convergence

Minimize discriminator loss

Clip discriminator weights to approximate Lipschitz constraint

Take gradient step of generator loss

Page 32: Adversarial Learning

Conditional Generative Adversarial Network

!Generates samples from the conditional distribution of ๐‘Œ given ๐‘‹.

!Descriminator takes ๐‘ฆ5, ๐‘ฅ5 input pairs for ๐‘˜ = 0,โ‹ฏ , ๐พ โˆ’ 1

(๐‘ฆ"- GeneratedGenerator

(๐‘ฆ" = โ„Ž#! ๐‘ฅ" , ๐‘ง"

๐‘ฅ" Discriminator (๐‘" = ๐‘“#( (๐‘ฆ" , ๐‘ฅ"

๐‘ง"

๐‘ฆ"๐‘ฆ"- Reference

0 =Generated

๐‘ฅ" , ๐‘ฆ"

Discriminator ๐‘" = ๐‘“#( ๐‘ฆ" , ๐‘ฅ"

1 =Reference + ๐‘‘ ๐œƒ) , ๐œƒ(

1 = Reference

CrossEntropy( (๐‘!, 1) ๐‘” ๐œƒ) , ๐œƒ(

CrossEntropy( (๐‘!, 0)

CrossEntropy(๐‘!, 1)

Reference distribution

๐‘ฅ"

๐‘ฅ"