adversarial learning
TRANSCRIPT
Adversarial LearningoWhat is a GAN?oSome mathematical backgroundoAlgorithms for training a GANoThe Wasserstein GANoConditional GANs
How do we sample from a distribution?
๐ โผ ๐! ๐ฆ!Parametric methods:
โ Some known parametric distribution:
๐! ๐ฆ =1๐ exp โ๐ข ๐ฆ
โ Monte Carlo Markov Chain (MCMC) methodsโ Hastings-Metropolis samplingโ Requires a known distribution
!Non-parametric methods:โ Provided with samples ๐ฆ", โฏ , ๐ฆ#$%โ Infer distribution from samplesโ Generator
๐ = โ" ๐
1
https://sdv.dev/Copulas/tutorials/03_Multivariate_Distributions.html
Multivariate random vector in โ!
Multivariate density in โ!
Source of randomness; ๐ โผ ๐ 0, ๐ผ
Random vector with desired distribution
Training a Generator
!Function of GAN:โ Generates samples, -๐ฆ&, with the same distribution as ๐ฆ&.โ Can use drop-outs to generate randomness
!Training algorithm:โ Compare the distributions of -๐ฆ& and ๐ฆ&.โ Feedback parameter corrections
(๐ฆ"generatedsamples
Generator(๐ฆ" = โ# ๐ง"
TrainingAlgorithm
๐ง"independent noise source
๐ฆ"referencesamples
๐
The Bayes Discriminator
!Use Bayes rule:๐ ๐ถ๐๐๐ ๐ = ๐ |๐ฆ =
๐ ๐ฆ|๐ถ๐๐๐ ๐ = ๐ ๐ ๐ถ๐๐๐ ๐ = ๐ ๐ ๐ฆ|๐ถ๐๐๐ ๐ = ๐ ๐ ๐ถ๐๐๐ ๐ = ๐ + ๐ ๐ฆ|๐ถ๐๐๐ ๐ = ๐น ๐ ๐ถ๐๐๐ ๐ = ๐น
โข Assuming ๐ ๐ฆ|๐ถ๐๐๐ ๐ = ๐น = ๐!$ ๐ฆ , ๐ ๐ฆ|๐ถ๐๐๐ ๐ = ๐ = ๐" ๐ฆ , ๐ ๐ถ๐๐๐ ๐ = ๐ =๐ ๐ถ๐๐๐ ๐ = ๐น = #
$ , we have that
๐ ๐ถ๐๐๐ ๐ = ๐ |๐ฆ = =๐% ๐ฆ ยฝ
๐% ๐ฆ ยฝ + ๐#! ๐ฆ ยฝ=
11 + ๐ #! ๐ฆ
โ where ๐ #! ๐ฆ =&"! '
&# 'is the likelihood ratio defined by
(๐ฆ"generatedsamples
Generator(๐ฆ" = โ#! ๐ง"
๐ง"independentnoise source
๐ฆ"referencesamples
โR -Realโ
โF - Fakeโ
What is the probability that an observation, ๐ฆ, is real?
Implementing a Bayes Discriminator
!Bayesian Discriminator:๐'( ๐ฆ โ ๐ ๐ถ๐๐๐ ๐ = ๐ |๐ฆ
= !% "!% " #!&' "
= $$#%&' "
where
๐ '! ๐ฆ =๐'! ๐ฆ๐! ๐ฆ
Likelihood Ratio
(๐ฆ"generatedsamples
Generator(๐ฆ" = โ#! ๐ง"
Discriminator (๐" = ๐#( (๐ฆ"
๐ง"independentnoise source
๐ฆ"referencesamples
Discriminator ๐" = ๐#( ๐ฆ"
Should be mostly 0s
Should be mostly 1s
Bayes Discriminator and the Likelihood Ratio
!Bayesian Discriminator:๐'( ๐ฆ โ ๐ ๐ถ๐๐๐ ๐ = ๐ |๐ฆ
= !% "!% " #!&' "
= $$#%&' "
generated distribution ๐#! ๐ฆ
reference distribution ๐% ๐ฆ
๐ฆ
likelihood ratio ๐ #! ๐ฆ
1
Classify as โrealโ Classify as โfakeโ
๐ฆ
Training a Bayes Discriminator
!Discriminator loss function:7๐ ๐), ๐( =
1๐พ;&*"
#$%
โ log ๐'( ๐ฆ& โ log 1 โ ๐'( -๐ฆ&
!Optimal discriminator parameter:
๐(โ = argmin'"
7๐ ๐), ๐(
โ Results in ML estimate of Bayes classifier parameters.
(๐ฆ"generatedsamples
Generator(๐ฆ" = โ#! ๐ง"
Discriminator(๐" = ๐#( (๐ฆ"
๐ง"
๐ฆ"
๐ฆ"referencesamples
0 =Generated
Discriminator๐" = ๐#( ๐ฆ"
1 =Reference+ 9๐ ๐) , ๐(
CrossEntropy( (๐!, 0)
CrossEntropy(๐!, 1)๐"
(๐"
Training a Generator
(๐ฆ"generatedsamples
Generator(๐ฆ" = โ#! ๐ง"
Discriminator (๐" = ๐#( (๐ฆ"
๐ง" <๐ ๐) , ๐(Loss Function
๐ฟ (๐!
!Big idea: Maximize the probability that outputs of the generator are classified as being from the reference distribution.
โ /๐/ should be largeโ ๐ฟ /๐/ should be small when /๐/ is largeโ ๐ฟ /๐/ should be a decreasing function of /๐/
!Generator loss function:
1๐ ๐0, ๐1 =1๐พ6/23
45#
๐ฟ ๐!1 /๐ฆ/
!Optimal generator parameter:
๐0โ = argmin!$
๐ ๐0, ๐1
Loss should encourage(๐" to be large
Generative Adversarial Network (GAN)*
(๐ฆ"generatedsamples
Generator(๐ฆ" = โ#! ๐ง"
Discriminator (๐" = ๐#( (๐ฆ"
๐ง"
๐ฆ"
๐ฆ"referencesamples
0 =Generated
Discriminator ๐" = ๐#( ๐ฆ"
1 =Reference + 9๐ ๐) , ๐(
Loss Function๐ฟ (๐!
<๐ ๐) , ๐(
CrossEntropy( (๐!, 0)
CrossEntropy(๐!, 1)
!Generator loss function:
D๐ ๐), ๐( =1๐พ;&*"
#$%
๐ฟ ๐'( -๐ฆ&
!Discriminator loss function:7๐ ๐), ๐( =
1๐พ;&*"
#$%
โ log ๐'( ๐ฆ& โ log 1 โ ๐'( -๐ฆ&
*Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio, โGenerative Adversarial Networksโ, Proc. of the Intern. Conference on Neural Information Processing Systems (NIPS 2014). pp. 26
GAN: Expected Loss Functions
>๐#!generatedsamples
Generator>๐ = โ#! ๐
Discriminator >๐#! = ๐#(
>๐#!๐
๐
๐referencesamples
0 =Generated
Discriminator ๐ = ๐#( ๐
1 =Reference + ๐ ๐) , ๐(
Loss Function๐ฟ 7๐"!
๐ ๐) , ๐(
CrossEntropy( 7๐, 0)
CrossEntropy(๐, 1)
!Generator loss function:๐ ๐), ๐( = ๐ธ ๐ฟ ๐'( H๐'!
!Discriminator loss function:๐ ๐), ๐( = ๐ธ โ log ๐'( ๐ + ๐ธ โlog 1 โ ๐'( H๐'!
โ By the weak and strong law of large numbers lim#โ-
D๐ = ๐ and lim#โ-
7๐ = ๐
Generator Loss Function Choices
!Option 0: Original loss function proposed in [1].โ ๐ฟ ๐ = log 1 โ ๐ ๐ฟ 0 = 0; ๐ฟ 1 = โโ
โ ๐ ๐), ๐( = ๐ธ log 1 โ ๐'( H๐'!โ Presented as key theoretically grounded approach in Goodfellow paper.โ Consistent with zero-sum game Nash equilibrium theoryโ Almost no one uses it.
!Option 1: โNon-saturatingโ loss function, i.e., the โ-log D trickโโ ๐ฟ ๐ = โ log ๐ ๐ฟ 0 = +โ; ๐ฟ 1 = 0
โ ๐ ๐), ๐( = ๐ธ โ log ๐'( H๐'!โ Mentioned as trick in [1] to keep the training loss from โsaturatingโ.โ This is what you get if you use cross-entropy loss for generator โ This is what is commonly done.
[1] Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio. โGenerative Adversarial Networksโ, Proc. of the Intern. Conference on Neural Information Processing Systems (NIPS 2014). pp. 26
GAN Architecture(Non-Saturating)
>๐#!generatedsamples
Generator>๐ = โ#! ๐
Discriminator >๐#! = ๐#(
>๐#!๐
๐
๐referencesamples
0 =Generated
Discriminator ๐ = ๐#( ๐
1 =Reference + ๐ ๐) , ๐(
CrossEntropy(๐, 1) ๐ ๐) , ๐(
CrossEntropy( 7๐, 0)
CrossEntropy(๐, 1)
!Generator loss function:๐ ๐), ๐( = ๐ธ โ log ๐'( H๐'!
!Discriminator loss function:๐ ๐), ๐( = ๐ธ โ log ๐'( ๐ + ๐ธ โlog 1 โ ๐'( H๐'!
โ By the weak and strong law of large numbers lim#โ-
D๐ = ๐ and lim#โ-
7๐ = ๐
GAN Equilibrium Conditions (Non-Saturating)
!We would like to find the solution to:
๐)โ = argmin'!
๐ ๐), ๐(โ
= argmin'!
๐ธ โ log ๐'( H๐'!
๐(โ = argmin'"
๐ ๐)โ, ๐(
= argmin'"
๐ธ โ log ๐'( ๐ + ๐ธ โlog 1 โ ๐'( H๐'!
โ This is known as a Nash Equilibriumโ We would like it to converge to ๐ โ = 1 โ (generated = reference distributions)
!How do we solve this?
!Will it converge?
Nash Equilibrium with Two Agents*
!Agent ๐บ:โ Controls parameter ๐0โ Goal is to minimize ๐ ๐0, ๐1โ
!Agent ๐ทโ Controls parameter ๐1โ Goal is to minimize ๐ ๐0โ, ๐1
๐ ๐O, ๐Pminimize meter
knob
๐O
๐ ๐O, ๐Pminimize meter
knob
๐P
๐บ Agent ๐ท Agent
๐O
๐P
!Each Agent tries to minimize their meter
*Graphics and art reproduced from โstick figureโ Wiki page.
๐ ๐)โ, ๐(โ = min'!
๐ ๐), ๐(โ
๐ ๐)โ, ๐(โ = min'"
๐ ๐)โ, ๐(
Zero-Sum Game: Special Nash Equilibrium*
!Agent ๐บ:โ Goal is to minimize ๐ ๐0, ๐1โ
โ Goal is to maximize ๐ ๐0, ๐1โ
!Agent ๐ทโ Goal is to minimize ๐ ๐0โ, ๐1
๐ ๐O, ๐Pmaximize meter
knob
๐O
๐ ๐O, ๐Pminimize meter
knob
๐P
๐บ Agent ๐ท Agent
๐O
๐P
!Special case when ๐ ๐&, ๐' = โ๐ ๐&, ๐'
*Graphics and art reproduced from โstick figureโ Wiki page.
๐ ๐)โ, ๐(โ = max'!
๐ ๐), ๐(โ
๐ ๐)โ, ๐(โ = min'"
๐ ๐)โ, ๐(
Adversarial relationship
Computing the GAN Equilibrium
!Reparameterizing equations
!Alternating minimization โ mode collapse
!Generator loss gradient descent
!Practical convergence issues
Reparameterize Loss Functions!Goal: Replace ๐O and ๐' with ๐ and ๐
!Generator parameter ๐ :โ Generated samples are
)๐ โผ ๐ ๐ฆ ๐Q ๐ฆ = ๐R ๐ฆ
where ๐ ๐ฆ =.#! /
.$ /
!Discriminator parameter ๐:โ Discriminator is
๐ ๐ฆ = ๐ ๐ถ๐๐๐ ๐ = ๐ |๐ฆ
!Important facts:โ ๐ธ ๐ (๐) = 1โ ฮฉ) = ๐ :โ0 โ 0,โ such that ๐ธ ๐ ๐ = 1โ ฮฉ( = ๐:โ0 โ 0,1โ For any function โ ๐ฆ , ๐ธ โ H๐ = ๐ธ โ ๐ ๐ (๐)
GAN Equilibrium Conditions
!We would like to find the solution to:
๐ โ = arg minRโa@
๐ ๐ , ๐โ
๐โ = arg minbโaA
๐ ๐ โ, ๐
โ This is known as a Nash Equilibriumโ We would like it to converge to ๐ โ = 1 โ (generated = reference distributions)
!How do we do this?
!Will it converge?
Reparameterized Loss Functions!Generator loss function:
๐ ๐ , ๐ = ๐ธ โ log ๐ H๐= ๐ธ โ๐ (๐) log ๐ ๐
!Discriminator loss function:๐ ๐ , ๐ = ๐ธ โ log ๐ ๐ + ๐ธ โlog 1 โ ๐ H๐
= ๐ธ โ log ๐ ๐ + ๐ธ โ๐ ๐ log 1 โ ๐ ๐
!Nash equilibrium:
๐ โ = arg min1โ3!
๐ ๐ , ๐โ
๐โ = arg min4โ3"
๐ ๐ โ, ๐
Method 1: Alternating Minimization!Algorithm
!Discriminator update
๐โ ๐ฆ โ1
1 + ๐ ๐ฆ!Generator update
๐ โ ๐ฆ โ ๐ฟ ๐ฆ โ ๐ฆ5 where ๐ฆ5 = max/๐ ๐ฆ
Repeat {๐โ โ arg min
)โ+9๐ ๐ โ, ๐
๐ โ โ arg min%โ+'
๐ ๐ , ๐โ
} Doesnโt Work!
Problem:โข This is called โMode collapseโโข Only generates sample that the discriminator likes bestโข Intuition:
โWe come from France.โ โI like cheese steaks.โโToo good to be true.โโToo creepy to be real.โ
Method 2: Generator Loss Gradient Descent (GLGD)!Algorithm
!Discriminator update
๐โ ๐ฆ โ1
1 + ๐ ๐ฆ
!Generator updateโ Take a step in the negative direction of the generator loss gradient.โ ๐: - project into the allow parameter space. (This is not an issue in practice.)
๐:๐ ๐ฆ = ๐ ๐ฆ โ๐ธ ๐(๐)๐ธ ๐"(๐)
๐"(๐ฆ)
!Questions/Comments:โ Can be applied with a wide variety of generator/discriminator loss functionsโ Does this converge?โ If so, then what (if anything) is being minimized?
*Martin Arjovsky and Leon Bottou, โTowards Principled Methods for Training Generative Adversarial Networksโ, ICLR 2017.
Repeat {๐โ โ arg min
bโaA๐ ๐ , ๐
๐ โ ๐ โ ๐ผ๐fโR๐ ๐ , ๐โ}
My term
gradient descent stepProjection onto valid
parameter space
GLGD Convergence for Non-Saturating GAN
*This is an equivalent expression to Theorem 2.5 of [1].[1] Martin Arjovsky and Leon Bottou, โTowards Principled Methods for Training Generative Adversarial Networksโ, ICLR 2017.
!For non-saturating GAN when ๐1โ ๐ฆ = arg min4โ3"
๐ ๐ , ๐ , then it can be shown that
โ1 ๐ ๐ , ๐โ = โ1๐ถ ๐
where*
๐ถ ๐ = ๐ธ 1 + ๐ ๐ log 1 + ๐ ๐
!Conclusions:โ GLGD is really a gradient descent algorithm for the cost function ๐ถ ๐ .โ 1 + ๐ฅ log 1 + ๐ฅ is a strictly convex function ๐ฅ: โ Therefore, we know that ๐ถ ๐ has a unique global minimum at ๐ ๐ = 1.โ However, convergence tends to be slow.
Repeat {๐โ โ arg min
)โ+9๐ ๐ , ๐
๐ โ ๐ โ ๐ผ๐,โ%๐ ๐ , ๐โ}
More Details on GLGD
*Martin Arjovsky and Leon Bottou, โTowards Principled Methods for Training Generative Adversarial Networksโ, ICLR 2017.
!Arjovsky and Bottou showed that for the non-saturating GAN*๐5 โ1 ๐ ๐ , ๐โ = โ1 ๐พ๐ฟ ๐1||๐! โ 2 ๐ฝ๐๐ท ๐1||๐!
!So from previous identities, we have that:
๐5 โ1 ๐ ๐ , ๐โ = โ1 2 ๐พ๐ฟ ๐1 + ๐! /2||๐!
= ๐5โ1๐ธ 1 + ๐ ๐ log 1 + ๐ ๐
!Conclusions:โ GLGD is really a gradient descent algorithm for the cost function
๐ถ ๐ = 2 ๐พ๐ฟ ๐; + ๐" /2||๐"โ ๐ถ ๐ has a unique global minimum at ๐; = ๐"โ However, convergence tends to be slow.
Repeat {๐โ โ arg min
)โ+9๐ ๐ , ๐
๐ โ ๐ โ ๐ผ๐,โ%๐ ๐ , ๐โ}
Convergence of GANs
!Generator and discriminator at convergence:โ Reference and generated distribution are the same โ ๐ โ ๐ = 1
โ Discriminator can not distinguish distributions โ ๐โ ๐ฆ = ##<; =
= #$
!At convergence the generated and reference distributions are identical. โ Therefore, the likelihood ratio is ๐ ๐ฆ = 1; โ The generated (fake) and reference (real) distributions are identical;โ The discriminator assigns a 50/50 probability to either case because they are
indistinguishable.
โ Then both the generator are discriminator cross-entropy losses are โlog #$ โ 0.693.
!In practice, things donโt usually work out this wellโฆ
Method 2: Practical Algorithm!Algorithm
!What you would like to see
Repeat {For ๐1 iterations {
๐ต โ ๐บ๐๐ก๐ ๐๐๐๐๐๐ต๐๐ก๐โ๐1 โ ๐1 โ ๐ฝโ!*๐ ๐0, ๐1; ๐ต
}๐ต โ ๐บ๐๐ก๐ ๐๐๐๐๐๐ต๐๐ก๐โ๐0 โ ๐0 โ ๐ผโ!$๐ ๐0, ๐1; ๐ต}
Iteration #
Loss
Generator Loss
ยฝ Discriminator Loss
โ log12 โ 0.693
Looks good, butโฆโ Could result from a discriminator with insufficient capacity
Failure Mode: Mode Collapse!Algorithm
!Sometimes you get mode collapse
Iteration #
Loss
Generator Loss
Discriminator dominates
ยฝ Discriminator Loss
โ 0.693
Repeat {For ๐ = 0 to ๐' โ 1 {
๐' โ ๐' โ ๐ฝโ-9๐ ๐&, ๐'}๐& โ ๐& โ ๐ผโ-'๐ ๐&, ๐'}
Might be caused by:โ Overfitting by discriminatorโ Insufficient number of discriminator updatesโ Insufficient generator capacity
Concept Wasserstein GAN
*Martin Arjovsky, Soumith Chintala and Leon Bottou, โWasserstein Generative Adversarial Networksโ, ICML 2017.
!Problem with GAN training using โ โ log๐ท trickโโ Slow and sometimes unstable convergenceโ Problems with vanishing gradient
!Conjecture:โ The problem is caused by the discriminator function.โ Bayes classifier is
โข Too sensitive and too nonlinearโข Non-overlapping distributions create vanishing gradients that slow convergence.
!Base discriminator of the Wasserstein distance (i.e., earth mover distance)
*Reproduced from paper
Wasserstein Fundamentals
!Based on Kantorovich-Rubinstein duality (Villani, 2009)*
๐ ๐Q||๐R = supb Bmn
๐ธ ๐ ๐ โ ๐ธ ๐ )๐R
โ where ๐ 6 is the Lipschitz constant of ๐
โ ๐ 6 โค 1 is referred to as the 1-Lipschitz condition
*Villani, Cedric. Optimal Transport: Old and New. Grundlehren der mathematischen Wissenschaften. Springer, Berlin, 2009.
Wasserstein GAN*!Then the fundamental result of the Arjovsky and Bottou
โ Define the sets ๐ โ ฮฉ) as usual, but define and ๐ โ ฮฉ) so that
ฮฉ'. = ๐:โ/ โ โโ,โ ๐ . ๐ก. ๐ 0 โค 1
โ Then define Wasserstein generator and discriminator loss functions as๐1 ๐ , ๐ = ๐ธ โ๐ L๐%
๐1 ๐ , ๐ = ๐ธ ๐ L๐% โ ๐ธ ๐ ๐%
!Key result from Arjovsky paper*:โ%๐1 ๐ , ๐โ = โ%๐(๐2||๐%)
โ where๐โ = arg min
)โ39>๐4 ๐ , ๐
*Martin Arjovsky, Soumith Chintala, and Leon Bottou, โWasserstein Generative Adversarial Networksโ, ICML 2017.
Method 3: Wasserstein Algorithm!Algorithm
!Discriminator updateโ How do we solve the problem of minimizing the discriminator loss with the Lipschitz
constraint?โ Answer: We clip the discriminator weights during training.โ Observation: Isnโt this just regularization of the discriminator DNN??
!Generator updateโ Take a step in the negative direction of the generator loss gradient descent.
*Martin Arjovsky and Leon Bottou, โTowards Principled Methods for Training Generative Adversarial Networksโ, ICLR 2017.
Repeat {๐โ = arg min
bโpAC๐q ๐ , ๐
๐ โ ๐ โ ๐ผ๐fโR๐q ๐ , ๐โ}
Method 3: Wasserstein Practical Algorithm!Algorithm
!Observations:โ Some people seem to feel the Wasserstein GAN has better convergence.โ However, is this because of the Wasserstein metric?โ Or is it because of the other algorithmic improvements?
*Martin Arjovsky and Leon Bottou, โTowards Principled Methods for Training Generative Adversarial Networksโ, ICLR 2017.
Make sure to get new batches
Iterate discriminator to approximate convergence
Minimize discriminator loss
Clip discriminator weights to approximate Lipschitz constraint
Take gradient step of generator loss
Conditional Generative Adversarial Network
!Generates samples from the conditional distribution of ๐ given ๐.
!Descriminator takes ๐ฆ5, ๐ฅ5 input pairs for ๐ = 0,โฏ , ๐พ โ 1
(๐ฆ"- GeneratedGenerator
(๐ฆ" = โ#! ๐ฅ" , ๐ง"
๐ฅ" Discriminator (๐" = ๐#( (๐ฆ" , ๐ฅ"
๐ง"
๐ฆ"๐ฆ"- Reference
0 =Generated
๐ฅ" , ๐ฆ"
Discriminator ๐" = ๐#( ๐ฆ" , ๐ฅ"
1 =Reference + ๐ ๐) , ๐(
1 = Reference
CrossEntropy( (๐!, 1) ๐ ๐) , ๐(
CrossEntropy( (๐!, 0)
CrossEntropy(๐!, 1)
Reference distribution
๐ฅ"
๐ฅ"