dropout distillation
TRANSCRIPT
0
Dropout Distillation Samuel Rota Bulò, Lorenzo Porzi , Peter Kontschieder
ICML2016読み会
紹介者:佐野正太郎
株式会社リクルートコミュニケーションズ
(C)Recruit Communications Co., Ltd.
背景:Dropout学習
• ニューラルネットワークの過学習を抑制する手法
• 学習ステップ毎にランダムに一部のユニットを落とす
• 暗に多数のネットワークのアンサンブルモデルを学習している
• [Srivastava et al., 2014]
1
学習対象のネットワーク 学習ステップ1 学習ステップ2
・・・ 学習時
(C)Recruit Communications Co., Ltd.
背景:Dropoutにおける予測計算
2
Doropout学習時にはネットワーク構造がランダム => 予測時にどの構造を採用するか?
理想:全てのDropoutパターンでの予測計算の期待値をとる
Standard Dropout [Srivastava et al., 2014]
• 予測時にはユニットを落とさない
• 各ユニットの出力を (1 – dropout率) でスケールすることで実用的な精度が得られる
Monte-Carlo Dropout [Gal & Ghahramani, 2015]
• 予測時に複数のDropoutパターンを試して平均をとる
• 予測の計算コストが高い代わりにStandard Dropoutよりも良い精度が得られる
(C)Recruit Communications Co., Ltd.
背景:Distillation
3
Distilling the knowledge in Neural Network [Hinton et al., 2014]
• distill = 蒸留する
• 複数のネットワークや複雑なネットワークを単一の小さなモデルに圧縮する手法
蒸留モデル
アンサンブルモデル
(C)Recruit Communications Co., Ltd.
提案手法:Dropout Distillation
概要
• Dropout学習が暗に獲得しているアンサンブルモデルを圧縮/蒸留(Distillation)する
• Dropout学習後モデルのMonte-Carlo予測を模倣する新しいモデルを学習する
利点
• Standard Dropoutと同じ予測計算コストでStandard Dropoutよりも高い予測精度
• 半教師あり学習への応用可能性:教師信号が欠損したデータをDistillationフェーズで活用できる
• モデル圧縮への応用可能性:Dropuoutで複雑なモデルを学習してDistillationフェーズで圧縮できる
欠点
• Distillationフェーズに余計な時間がかかる
4
(C)Recruit Communications Co., Ltd.
提案手法:Dropout Distillation
5
Dropout 学習済み モデル
生徒モデル
損失関数
Dropout パターン
(C)Recruit Communications Co., Ltd.
提案手法:Dropout Distillation
6
教師モデル (Dropout学習済み)
生徒モデル
Distillationフェーズでは 教師モデルの振る舞いを真似るよう
生徒モデルを学習する
通常のDropout学習で 教師となるモデルを獲得
(C)Recruit Communications Co., Ltd.
提案手法:Dropout Distillation
7
Distillation用 学習データ
(教師信号無し)
教師モデル (Dropout学習済み)
生徒モデル
生徒モデルの出力
出力間の損失を 埋めるように 生徒モデルの パラメタを更新
教師モデルの出力
生徒モデルには ドロップアウトをかけない
教師モデルにドロップアウトを かけながら出力データを生成
(C)Recruit Communications Co., Ltd.
提案手法:Dropout Distillation
8
Distillation用 学習データ
(教師信号無し)
教師モデル (Dropout学習済み)
生徒モデル
生徒モデルの出力
教師モデルの出力
教師モデルと生徒モデルの ネットワーク構造は違っていてもよい
データはDropoutフェーズから流用可 新しいデータを用意するのも可
(C)Recruit Communications Co., Ltd.
理想の予測関数
• 全てのDropoutパターンでの出力期待値
• Dropoutパターンはユニット数に対し指数関数的に増加するので事実上計算できない
問題設定
• 『理想の予測関数』を教師モデルとした生徒モデルを学習したい
どうやって『理想の予測関数』を計算に取り入れるか?
Dropout学習済みモデル
導出
9
理想の予測関数
損失関数
生徒モデル 評価できない
Dropoutパターン
(C)Recruit Communications Co., Ltd.
アプローチ
• 『理想の予測関数』をDropout学習済みモデルで置き換える
• 損失関数がBregmanダイバージェンスのとき以下の最小化問題が等価
Bregmanダイバージェンス
• 二乗損失・Logistic損失・KLダイバージェンスなどを一般化したもの
Dropout 学習済み モデル
導出
10
生徒モデル
微分可能な凸関数
Dropoutパターン この表現を形にしたのが スライド5〜8のアルゴリズム
(C)Recruit Communications Co., Ltd.
実験1:予測計算手法による性能比較
12
MNIST/CIFAR10/CIFAR100データセットで3予測手法のエラー率比較
• Standard Dropout
• Monte-Carlo Dropout(100サンプリング)
• Dropout Distillation
実験手順
1. Dropout学習でベースラインモデルを獲得(300エポック)
2. ベースラインモデルでStandard DropoutとMonte-Carlo Dropoutの性能評価
3. ベースラインモデルを教師としてDropout Distillation(30エポック)
– 生徒モデルのネットワーク構造はベースラインモデルと同様
– ベースラインモデルの学習後パラメタで生徒モデルを初期化
– ベースラインモデルの入力データを流用(pixel毎に確率0.2で値をゼロ化)
4. 生徒モデルでDropout Distillationの性能評価
(C)Recruit Communications Co., Ltd.
実験1:予測計算手法による性能比較
13
• 平均エラー率は Standard > Distillation > Monte-Carlo の順
• Monte-CarloよりDistillationの方がパフォーマンスの分散は低い
(C)Recruit Communications Co., Ltd.
実験2:Distillationに使うデータセットによる性能比較
14
Distillationフェーズの入力データについて3シナリオで性能比較
• [Train] 教師モデルのトレーニングセットをそのまま利用
• [Pert. Train] 教師モデルのトレーニングセットをピクセル毎に確率0.2で値をゼロ化
• [Test] テストデータを利用
どのシナリオが 優れているかは
場合による
(C)Recruit Communications Co., Ltd.
実験3:モデル圧縮への応用可能性
15
CIFAR10/Quickでユニット数を削減した場合のパフォーマンス変化
• [Baseline] Dropout学習のみで削減後モデルを学習
• [Distillation] Dropoutフェーズで削減前モデルを学習してDistillationフェーズで削減後モデルに圧縮
青枠内では『Dropoutフェーズで複雑なモデルを学習 => Distillationフェーズで圧縮』が有効に働いている
FC層からのみユニットを削った場合 全層からフィルタ/ユニットを削った場合