automated inference with adaptive batches · automated inference with adaptive batches soham de...

84
Automated Inference with Adaptive Batches Soham De Joint work with Abhay Yadav, David Jacobs, Tom Goldstein University of Maryland

Upload: others

Post on 28-Jan-2021

18 views

Category:

Documents


0 download

TRANSCRIPT

  • Automated Inference with Adaptive Batches

    Soham De

    Joint work with Abhay Yadav, David Jacobs, Tom Goldstein

    University of Maryland

  • OVERVIEWMost machine learning models use SGD for training

    BUT…noisy gradients & large variance

    Hard to automate stepsize selection & stopping conditions

  • OVERVIEWMost machine learning models use SGD for training

    BUT…noisy gradients & large variance

    Hard to automate stepsize selection & stopping conditions

    In this work: Big Batch SGD

    Adaptively grow batch size based on amount of noise in the gradients

    Easy automated stepsize selection

  • MOST MODEL FITTING PROBLEMS LOOK LIKE THIS

    min `(x) :=1

    N

    NX

    i=1

    f(x, zi)

  • MOST MODEL FITTING PROBLEMS LOOK LIKE THIS

    SVM

    logistic regression

    neural netsmatrix factorization

    blah blah blah

    We also consider the more general problem:

    min `(x) := Ez⇠p[f(x, z)]

    min `(x) :=1

    N

    NX

    i=1

    f(x, zi)

    Applications

  • SGDcompute gradientselect data update

    xt+1 = xt � ↵tgtgt =1

    N

    NX

    i=1

    rf(xt, zi)

  • SGDcompute gradientselect data update

    xt+1 = xt � ↵tgtgt ⇡ rf(xt, z12)

  • SGD

    far from optimalsolution improves

    compute gradientselect data update

    xt+1 = xt � ↵tgtgt ⇡ rf(xt, z8)

  • SGDcompute gradientselect data update

    gt ⇡ rf(xt, z8) xt+1 = xt � ↵tgt

  • SGDcompute gradientselect data update

    close to optimalsolution gets worse

    gt ⇡ rf(xt, z8) xt+1 = xt � ↵tgt

  • SGD

    Error must decrease as we approach solution

    select data compute gradient update

    gt ⇡ rf(xt, z8) xt+1 = xt � ↵tgt

  • SGD

    Error must decrease as we approach solution

    classical solution

    shrink stepsize

    select data compute gradient update

    gt ⇡ rf(xt, z8) xt+1 = xt � ↵tgt

    limt!1

    ↵t = 0

  • SGD

    Error must decrease as we approach solution

    classical solution

    shrink stepsize hard to pick stepsize schedule

    select data compute gradient update

    gt ⇡ rf(xt, z8) xt+1 = xt � ↵tgt

    limt!1

    ↵t = 0

  • USE BIGGER BATCHES?compute gradientselect data update

    gt ⇡ rf(xt, z8) xt+1 = xt � ↵tgt

    close to optimalsolution gets worse

  • USE BIGGER BATCHES?compute gradientselect data update

    xt+1 = xt � ↵tgt

    } close to optimalsolution gets worse

  • USE BIGGER BATCHES?compute gradientselect data update

    xt+1 = xt � ↵tgtgt ⇡ r`B(xt)

    gradient on batch B} close to optimalsolution gets worse

  • USE BIGGER BATCHES?compute gradientselect data update

    xt+1 = xt � ↵tgtgt ⇡ r`B(xt)

    gradient on batch B} close to optimalsolution gets worse

  • USE BIGGER BATCHES?compute gradientselect data update

    lower variance solution gets better

    xt+1 = xt � ↵tgtgt ⇡ r`B(xt)

    gradient on batch B}

  • GROWING BATCHESregime 1: far from optimal regime 2: close to optimal

  • GROWING BATCHESregime 1: far from optimal regime 2: close to optimal

    Noisy gradients improve solution

    Large batches do unnecessary work

    +Get stuck in local minima

  • GROWING BATCHESregime 1: far from optimal regime 2: close to optimal

    Noisy gradients improve solution

    small batches work well!

    Large batches do unnecessary work

    +Get stuck in local minima

  • GROWING BATCHESregime 1: far from optimal regime 2: close to optimal

    Noisy gradients improve solution

    small batches work well!

    Large batches do unnecessary work

    +Get stuck in local minima

    Noisy gradients with high variance worsen solution

    Small batches require stepsize decay (hard to tune)

  • GROWING BATCHESregime 1: far from optimal regime 2: close to optimal

    Noisy gradients improve solution

    small batches work well!

    Large batches do unnecessary work

    +Get stuck in local minima

    Noisy gradients with high variance worsen solution

    Small batches require stepsize decay (hard to tune)

    large batches work well!

  • GROWING BATCHESregime 1: far from optimal regime 2: close to optimal

    Noisy gradients improve solution

    small batches work well!

    Large batches do unnecessary work

    +Get stuck in local minima

    Noisy gradients with high variance worsen solution

    Small batches require stepsize decay (hard to tune)

    large batches work well!

    Adaptively Grow Batches Over Time!

  • PRELIMINARIES

    kr`B(x)�r`(x)k2 < kr`B(x)k2.

    Standard result in stochastic optimization:

    Lemma A sufficient condition for to be a descent direction is:�r`B(x)

  • PRELIMINARIES

    kr`B(x)�r`(x)k2 < kr`B(x)k2.

    Standard result in stochastic optimization:

    Lemma A sufficient condition for to be a descent direction is:�r`B(x)

    error approximate gradient

  • PRELIMINARIES

    kr`B(x)�r`(x)k2 < kr`B(x)k2.

    Standard result in stochastic optimization:

    Lemma A sufficient condition for to be a descent direction is:�r`B(x)

    If error is small relative to gradient : descent direction

    error approximate gradient

  • PRELIMINARIES

    kr`B(x)�r`(x)k2 < kr`B(x)k2.

    Standard result in stochastic optimization:

    Lemma A sufficient condition for to be a descent direction is:�r`B(x)

    error approximate gradient

    How big is this error?

    How large does the batch need to be to guarantee this?

  • ERROR BOUNDTheorem Assume f has Lz - Lipschitz dependence on data z.Then, expected error is uniformly bounded by:

    EBkr`B(x)�r`(x)k2 = TrVarB(r`B(x))

    1|B| · 4L2z TrVarz(z)

  • ERROR BOUNDTheorem Assume f has Lz - Lipschitz dependence on data z.Then, expected error is uniformly bounded by:

    EBkr`B(x)�r`(x)k2 = TrVarB(r`B(x))expected error batch gradient variance

    1|B| · 4L2z TrVarz(z)

  • ERROR BOUNDTheorem Assume f has Lz - Lipschitz dependence on data z.Then, expected error is uniformly bounded by:

    EBkr`B(x)�r`(x)k2 = TrVarB(r`B(x))expected error batch gradient variance

    1|B| · 4L2z TrVarz(z)

    data variancebatch size

  • ERROR BOUNDTheorem Assume f has Lz - Lipschitz dependence on data z.Then, expected error is uniformly bounded by:

    EBkr`B(x)�r`(x)k2 = TrVarB(r`B(x))expected error batch gradient variance

    1|B| · 4L2z TrVarz(z)

    data variancebatch size

    Expected Error(Gradient Variance)

    < Data Variance

    Higher Batch Size Lower Gradient Variance

  • ERROR BOUND

    From the previous Lemma and Theorem:We expect to be a descent direction reasonably often provided:

    �r`B(x)

    2Ekr`Bt(xt)k2 �1

    |Bt|TrVarzrf(xt, z)

    where ✓ 2 (0, 1)

  • ERROR BOUND

    From the previous Lemma and Theorem:We expect to be a descent direction reasonably often provided:

    �r`B(x)

    2Ekr`Bt(xt)k2 �1

    |Bt|TrVarzrf(xt, z)

    where ✓ 2 (0, 1)

    We use this observation in Big Batch SGD

  • BIG BATCH SGD

    } pick batch B

  • BIG BATCH SGD

    estimate size of error using variance of batch gradients

    } pick batch B

  • BIG BATCH SGD

    if signal > noise (gradient) (variance)

    }

    update

    pick batch B

  • BIG BATCH SGD

    if signal > noise (gradient) (variance)

    }

    update

    pick batch B

  • BIG BATCH SGD

    }

    if signal > noise (gradient) (variance)

    updatepick batch B

  • BIG BATCH SGD

    }

    if signal > noise (gradient) (variance)

    updatepick batch B

  • BIG BATCH SGD

    }

    if signal > noise (gradient) (variance)

    updatepick batch B

  • BIG BATCH SGD

    }

    if signal > noise (gradient) (variance)

    update

    pick batch B

  • BIG BATCH SGD

    }

    if signal > noise (gradient) (variance)

    update

    pick batch B

  • BIG BATCH SGD

    }

    otherwiseincrease batch size

    if signal > noise (gradient) (variance)

    update

    pick batch B

  • BIG BATCH SGD

    }otherwise

    increase batch size

    if signal > noise (gradient) (variance)

    update

    pick batch B

  • BIG BATCH SGD

    }otherwise

    increase batch size

    if signal > noise (gradient) (variance)

    update

    pick batch B

  • BIG BATCH SGD

    }otherwise

    increase batch size

    if signal > noise (gradient) (variance)

    update

    pick batch B

  • BIG BATCH SGD

    }otherwise

    increase batch size

    if signal > noise (gradient) (variance)

    update

    pick batch B

  • BIG BATCH SGDOn each iteration t

  • BIG BATCH SGDOn each iteration t

    • Estimate size of gradient error by computing variance

  • BIG BATCH SGDOn each iteration t

    • Pick batch B large enough such that

    2Ekr`Bt(xt)k2 �1

    |Bt|TrVarzrf(xt, z)

    where ✓ 2 (0, 1)

    • Estimate size of gradient error by computing variance

  • BIG BATCH SGDOn each iteration t

    • Pick batch B large enough such that

    2Ekr`Bt(xt)k2 �1

    |Bt|TrVarzrf(xt, z)

    • Choose stepsize ↵t

    where ✓ 2 (0, 1)

    • Estimate size of gradient error by computing variance

  • BIG BATCH SGDOn each iteration t

    • Pick batch B large enough such that

    2Ekr`Bt(xt)k2 �1

    |Bt|TrVarzrf(xt, z)

    • Choose stepsize ↵t• Update

    xt+1 = xt � ↵tr`Bt(xt)

    where ✓ 2 (0, 1)

    • Estimate size of gradient error by computing variance

  • BIG BATCH SGDOn each iteration t

    • Pick batch B large enough such that

    2Ekr`Bt(xt)k2 �1

    |Bt|TrVarzrf(xt, z)

    • Choose stepsize ↵t• Update

    xt+1 = xt � ↵tr`Bt(xt)

    where ✓ 2 (0, 1)extra step to SGD

    • Estimate size of gradient error by computing variance

  • BIG BATCH SGDOn each iteration t

    • Pick batch B large enough such that

    2Ekr`Bt(xt)k2 �1

    |Bt|TrVarzrf(xt, z)

    • Choose stepsize ↵t• Update

    xt+1 = xt � ↵tr`Bt(xt)

    where ✓ 2 (0, 1)

    Can be estimated using the batch B

    • Estimate size of gradient error by computing variance

  • CONVERGENCEAssumption: has L-Lipschitz gradients`

    `(x) `(y) +r`(y)T (x� y) + L2kx� yk2

    Assumption: satisfies the Polyak-Lojasiewicz (PL) Inequality`kr`(x)k2 � 2µ(`(x)� `(x?))

  • CONVERGENCEAssumption: has L-Lipschitz gradients`

    `(x) `(y) +r`(y)T (x� y) + L2kx� yk2

    Assumption: satisfies the Polyak-Lojasiewicz (PL) Inequality`kr`(x)k2 � 2µ(`(x)� `(x?))

    Theorem Big Batch SGD converges linearly:

    ↵ =1

    �L� =

    ✓2 + (1� ✓)2

    (1� ✓)2

    E[`(xt+1)� `(x?)] �1� µ

    �L

    �· E[`(xt)� `(x?)]

    with optimal step size where

  • CONVERGENCE RESULTSTheorem Big Batch SGD converges linearly:

    ↵ =1

    �L� =

    ✓2 + (1� ✓)2

    (1� ✓)2

    E[`(xt+1)� `(x?)] �1� µ

    �L

    �· E[`(xt)� `(x?)]

    with optimal step size where

    Per-iteration convergence rate

  • CONVERGENCE RESULTSTheorem Big Batch SGD converges linearly:

    ↵ =1

    �L� =

    ✓2 + (1� ✓)2

    (1� ✓)2

    E[`(xt+1)� `(x?)] �1� µ

    �L

    �· E[`(xt)� `(x?)]

    with optimal step size where

    Per-iteration convergence rate

    We show: Error < ✏|B| > O(1/✏)optimal convergence in the infinite data case

  • ADVANTAGES

  • ADVANTAGES

    • Controlling noise enables automated stepsize selection

  • ADVANTAGES

    • Controlling noise enables automated stepsize selection• Backtracking line search works well!

  • ADVANTAGES

    • Controlling noise enables automated stepsize selection• Backtracking line search works well!• Stepsize schemes using curvature estimates also work!


  • ADVANTAGES

    • Controlling noise enables automated stepsize selection• Backtracking line search works well!• Stepsize schemes using curvature estimates also work!


    • More accurate gradients enable automatic stopping conditions 


  • ADVANTAGES

    • Controlling noise enables automated stepsize selection• Backtracking line search works well!• Stepsize schemes using curvature estimates also work!


    • More accurate gradients enable automatic stopping conditions 


    • Bigger batches are better in parallel/distributed settings

  • BACKTRACKING LINE SEARCH

    Measures sufficient decrease condition of objective function

    For regular (deterministic) gradient descent:

    `(xt+1) `(xt) � c↵tkr`(xt)k2

    Also referred to as Armijo Line Search

  • BACKTRACKING LINE SEARCH

    Measures sufficient decrease condition of objective function

    For regular (deterministic) gradient descent:

    `(xt+1) `(xt) � c↵tkr`(xt)k2

    new objective current objective

    sufficient decrease

    Also referred to as Armijo Line Search

  • BACKTRACKING LINE SEARCH

    Measures sufficient decrease condition of objective function

    For regular (deterministic) gradient descent:

    `(xt+1) `(xt) � c↵tkr`(xt)k2

    new objective current objective

    sufficient decrease

    If this fails, decrease stepsize and check again

    Also referred to as Armijo Line Search

  • BACKTRACKING WITH SGDMeasures sufficient decrease condition on individual functions

  • BACKTRACKING WITH SGD

    Moves to the optimum of individual functions, not the global average

    Measures sufficient decrease condition on individual functions

  • BACKTRACKING WITH SGD

    Moves to the optimum of individual functions, not the global average

    Measures sufficient decrease condition on individual functions

    Big Batch SGD gets better estimate of the approximate decrease of original objective

  • BACKTRACKING WORKS WITH BIG BATCH SGD

    `B(xt+1) `B(xt)� c↵tkr`B(xt)k2Decrease stepsize until:

    Measures a condition of sufficient decrease using batch B

  • BACKTRACKING WORKS WITH BIG BATCH SGD

    `B(xt+1) `B(xt)� c↵tkr`B(xt)k2Decrease stepsize until:

    Measures a condition of sufficient decrease using batch B

    Theorem Big Batch SGD with backtracking line search converges linearly:

    with initial step size set large enough s.t.

    E[`(xt+1)� `(x?)] ⇣1� cµ

    �L

    ⌘E[`(xt)� `(x?)]

    ↵0 �1

    2�L

  • BACKTRACKING WORKS WITH BIG BATCH SGD

    `B(xt+1) `B(xt)� c↵tkr`B(xt)k2Decrease stepsize until:

    Measures a condition of sufficient decrease using batch B

    Theorem Big Batch SGD with backtracking line search converges linearly:

    with initial step size set large enough s.t.

    E[`(xt+1)� `(x?)] ⇣1� cµ

    �L

    ⌘E[`(xt)� `(x?)]

    ↵0 �1

    2�L

    We also derive optimal stepsizes for Big Batch SGD using Barzilai-Borwein (BB) curvature estimates, with provable guarantees.

    Check paper for details.

  • CONVEX EXPERIMENTSModel: Ridge Regression & Logistic Regression

    BBS: Big Batch SGDProposed methods: Blue (Fixed stepsize), Red (Backtracking), Green (Optimal stepsizes using curvature estimates) curves

  • CONVEX EXPERIMENTSModel: Ridge Regression & Logistic Regression

    BBS: Big Batch SGDProposed methods: Blue (Fixed stepsize), Red (Backtracking), Green (Optimal stepsizes using curvature estimates) curves

    Big Batch SGD (BBS) with Fixed Step Size

  • CONVEX EXPERIMENTSModel: Ridge Regression & Logistic Regression

    BBS: Big Batch SGDProposed methods: Blue (Fixed stepsize), Red (Backtracking), Green (Optimal stepsizes using curvature estimates) curves

    BBS with Backtracking (Armijo) Line Search

  • CONVEX EXPERIMENTSModel: Ridge Regression & Logistic Regression

    BBS: Big Batch SGDProposed methods: Blue (Fixed stepsize), Red (Backtracking), Green (Optimal stepsizes using curvature estimates) curves

    BBS with Stepsizes using BB curvatures

  • CONVEX EXPERIMENTSModel: Ridge Regression & Logistic Regression

    BBS: Big Batch SGDProposed methods: Blue (Fixed stepsize), Red (Backtracking), Green (Optimal stepsizes using curvature estimates) curves

  • CONVEX EXPERIMENTSBatch Size Increase

    BBS: Big Batch SGDProposed methods: Blue (Fixed stepsize), Red (Backtracking), Green (Optimal stepsizes using curvature estamates) curves

  • CONVEX EXPERIMENTSStepsizes Used

    BBS: Big Batch SGDProposed methods: Blue (Fixed stepsize), Red (Backtracking), Green (Optimal stepsizes using curvature estamates) curves

  • 4-LAYER CNN

    0 10 20 30 40

    Number of epochs

    50

    55

    60

    65

    70

    75

    80

    Accuracy

    Mean class accuracy (test set)

    0 10 20 30 40

    Number of epochs

    83

    84

    85

    86

    87

    88

    89

    90

    Accuracy

    Mean class accuracy (test set)

    CIFAR-10 (left) & SVHN (right)

    Automated Inference with Adaptive Batches

    0 10 20 30 40Number of epochs

    50

    60

    70

    80

    90

    100

    Accuracy

    Mean class accuracy (train set)

    AdadeltaBB+AdadeltaSGD+Mom (Fine Tuned)SGD+Mom (Fixed LR)BBS+Mom (Fixed LR)BBS+Mom+Armijo

    0 10 20 30 40Number of epochs

    86

    88

    90

    92

    94

    96

    98

    100

    Accuracy

    Mean class accuracy (train set)

    AdadeltaBB+AdadeltaSGD+Mom (Fine Tuned)SGD+Mom (Fixed LR)BBS+Mom (Fixed LR)BBS+Mom+Armijo

    0 10 20 30 40Number of epochs

    98.4

    98.6

    98.8

    99

    99.2

    99.4

    99.6

    99.8

    100

    Accuracy

    Mean class accuracy (train set)

    AdadeltaBB+AdadeltaSGD+Mom (Fine Tuned)SGD+Mom (Fixed LR)BBS+Mom (Fixed LR)BBS+Mom+Armijo

    0 10 20 30 40

    Number of epochs

    50

    55

    60

    65

    70

    75

    80

    Accuracy

    Mean class accuracy (test set)

    0 10 20 30 40

    Number of epochs

    83

    84

    85

    86

    87

    88

    89

    90

    Accuracy

    Mean class accuracy (test set)

    0 10 20 30 40

    Number of epochs

    98

    98.5

    99

    Accuracy

    Mean class accuracy (test set)

    Figure 2: Neural Network Experiments. The three columns from left to right correspond to results for CIFAR-10, SVHN,and MNIST, respectively. The top row presents classification accuracies on the training set, while the bottom row presentsclassification accuracies on the test set.

    1998) (ConvNet) to classify three benchmark imagedatasets: CIFAR-10 (Krizhevsky and Hinton, 2009),SVHN (Netzer et al., 2011), and MNIST (LeCun et al.,1998). Our ConvNet is composed of 4 layers, excludingthe input layer, with over 4.3 million weights. To com-pare against fine-tuned SGD, we used a comprehensivegrid search on the stepsize schedule to identify the op-timal schedule. Fixed stepsize methods use the defaultdecay rule of the Torch library: ↵

    t

    = ↵0/(1 + 10�7t),where ↵0 was chosen to be the stepsize used in the finetuned experiments. We also tune the hyper-parameter⇢ in the Adadelta algorithm. Details of the ConvNetand exact hyper-parameters used for training are pre-sented in the supplemental.

    We plot the accuracy on the train and test set vs thenumber of epochs (full passes through the dataset) inFigure 2. We notice that the big batch SGD with back-tracking performs better than both Adadelta and SGD(Fixed LR) in terms of both train and test error. Bigbatch SGD even performs comparably to fine tunedSGD but without the trouble of fine tuning. This is in-teresting because most state-of-the-art deep networks(like AlexNet (Krizhevsky et al., 2012), VGG Net (Si-monyan and Zisserman, 2014), ResNets (He et al.,2016)) were trained by their creators using standardSGD with momentum, and training parameters weretuned over long periods of time (sometimes months).Finally, we note that the big batch AdaDelta performsconsistently better than plain AdaDelta on both largescale problems (SVHN and CIFAR-10), and perfor-mance is nearly identical on the small-scale MNISTproblem.

    7 Conclusion

    We analyzed and studied the behavior of alternativeSGD methods in which the batch size increases overtime. Unlike classical SGD methods, in which stochas-tic gradients quickly become swamped with noise,these “big batch” methods maintain a nearly constantsignal to noise ratio of the approximate gradient. As aresult, big batch methods are able to adaptively adjustbatch sizes without user oversight. The proposed au-tomated methods are shown to be empirically compa-rable or better performing than other standard meth-ods, but without requiring an expert user to chooselearning rates and decay parameters.

    Acknowledgements

    This work was supported by the US O�ce of Naval Re-search (N00014-17-1-2078), and the National ScienceFoundation (CCF-1535902 and IIS-1526234). A. Ya-dav and D. Jacobs were supported by the O�ce of theDirector of National Intelligence (ODNI), IntelligenceAdvanced Research Projects Activity (IARPA), viaIARPA R&D Contract No. 2014-14071600012. Theviews and conclusions contained herein are those ofthe authors and should not be interpreted as necessar-ily representing the o�cial policies or endorsements,either expressed or implied, of the ODNI, IARPA, orthe U.S. Government. The U.S. Government is autho-rized to reproduce and distribute reprints for Govern-mental purposes notwithstanding any copyright anno-tation thereon.

    Proposed Methods

    Big Batch SGD canalso be used with AdaGrad, AdaDelta…

  • TAKEAWAYS

    Adaptively grows batch size over time to maintain a nearly constant signal-to-noise ratio in the gradients

    Better control of the noise makes it easy to automate

    Adaptive stepsize methods work well with this method

    Better for parallel/distributed settings

    We introduce: Big Batch SGD

  • THANKS!Feel free to get in touch!

    Extended version on arXiv: “Big Batch SGD: Automated Inference using Adaptive Batch Sizes”https://arxiv.org/abs/1610.05792

    email: [email protected] website: https://cs.umd.edu/~sohamde/

    Soham De Tom GoldsteinDavid JacobsAbhay Yadav

    https://arxiv.org/abs/1610.05792mailto:[email protected]://cs.umd.edu/~sohamde/