learning to learn by gradient descent by gradient descent
TRANSCRIPT
Learning to learn by gradient descent by gradient descent
@NIPS2016読み会
発表者 福田 宏幸
自己紹介
福田 宏幸
2000年4月 (株)電通 入社 コピーライター
2016年7月 (株)電通デジタル 出向
2016年9月 東京大学新領域創成科学研究科
博士課程1年
専門:バイオインフォマティクス
イントロダクション
本論文の要旨
Deep Learningが
学習法の学習も手に入れた。
学習法の学習?
• Deep Learningの成功=特徴量の自動学習
• しかし、 学習アルゴリズム(SGD等の最適化アルゴリズム)の設計は、未だ人間。
• 本論文では、学習アルゴリズム自体をディープラーニングで学習する方法を提案。
Alpha GoのDeepMind社の論文
著者のNIPSでの講演
https://www.youtube.com/watch?v=tPWGGwmgwG0
Deep Learningにおける学習
• コスト関数の設定 二乗誤差 クロスエントロピー
• 高次元、非凸関数。
• 勾配降下法(Gradient Descent)等により コスト関数を最小にするパラメーターを探索
勾配降下法(Gradient Descent)
• 傾きに沿ってパラメーターθを更新していく
• α:学習率 αが大きすぎる:振動する。 αが小さすぎる:更新が遅い。
http://prog3.com/sbdm/blog/google19890102/article/details/50276775
発展形
• パラメーターの更新の仕方がそれぞれ違う
• Adagrad(2011)
• RMSprop(2012)
• ADAM(2015)
http://postd.cc/optimizing-gradient-descent/#fnref:3
http://postd.cc/optimizing-gradient-descent/#fnref:3
学習法を学習する
何を学習すればよいのか?
• 更新量をディープラーニングが学習
• Optimizee: f(θ) 学習したい問題の誤差関数 Optimizer: g(φ) 更新量を出力するNN
Gradient Descent
Learning to Learn
誤差関数をどう定義するか?
• 「良い最適化」を定義したい
• 確率分布に従って生成された誤差関数を平均的に最適化する
θ*(f,φ) 最適化されたパラメータ
誤差関数をどう定義するか?
• 勾配の時系列を扱えるように定式化
• m:LSTM :時刻tの勾配 :時刻tの状態
:ウエイト。論文では1を使用。
ネットワーク
• LSTMに、時刻tでの勾配を入力していく。
※LSTMの方のパラメーター最適化は、Adam。
• 計算量を減らすために、それぞれのパラメータについて独立に更新していく。 (LSTMの重みのみ共通)
Coordinatewise LSTM
実験結果
実験
• Quadratic
• MNIST
• CIFAR10
• DEEP ART
• Optimizeeが、二次関数。
• Wとyの値をサンプリングで生成し、学習。
• 他のアルゴリズムに比べて早く収束。
Quadratic Function
MNIST
• 同じく早く収束。
• テスト時の活性化関数を ReLUに変えると収束しない。
MNIST
• 最終的な誤差の値も少さい。
CIFAR10
• CNNでも実験。 (3layers, 32hidden units)
• 畳込み層と全結合層では重みを共有しない。
• CIFAR5、CIFAR2への転移学習も上手くいく。
DEEP ART
• 1800の画像と、1つのスタイルを学習。
DEEP ART
• 高解像度画像への転移学習。 64×64pixels→128×128pixels
DEEP ART
• ディープラーニングが問題を 上手く一般化できている。
まとめ
まとめ
• ディープラーニングにより、学習アルゴリズム自体を学習する方法を提案
• 既存の最適化手法を上回る精度を得た。
• 問題の構造を上手く一般化することができた。
個人的感想
• 結局チューニングがありそう。
• 大きなデータセット、ネットワー?
• データマイニングのコンペティションで試してみたい。
実装も公開されています
• 本家 https://github.com/deepmind/learning-to-learn
• シンプルな実装 http://runopti.github.io/blog/2016/10/17/learningtolearn-1/