[pr12] pr-036 learning to remember rare events

40
Learning to Remember Rare Events Paper is appeared in ICLR 2017, https://arxiv.org/abs/1703.03129 Authors: Łukasz Kaiser, Ofir Nachum, Aurko Roy, and Samy Bengio (Google Brain) Reviewed by Taegyun Jeon 1

Upload: taegyun-jeon

Post on 21-Jan-2018

577 views

Category:

Engineering


4 download

TRANSCRIPT

Page 1: [PR12] PR-036 Learning to Remember Rare Events

Learning to Remember Rare EventsPaper is appeared in ICLR 2017, https://arxiv.org/abs/1703.03129

Authors:

Łukasz Kaiser, Ofir Nachum, Aurko Roy, and Samy Bengio(Google Brain)

 Reviewed by Taegyun Jeon 

1

Page 2: [PR12] PR-036 Learning to Remember Rare Events

What we can learn from this paper?1. Memory‑augmented deep neural network

2. Two tasks:One‑shot learning (Omniglot dataset)

Life‑long one‑shot learning (large‑scale machinetranslation)

3. TensorFlow implementation for the one‑shot learningOfficial code from Google Brain using TensorFlow

2

Page 3: [PR12] PR-036 Learning to Remember Rare Events

Problem Definition

rare events v.s. on average

Image from "One Shot Learning" (Jisung Kim @ TensorFlow‑KR 2nd Meetup Lighting Talk) 3

Page 4: [PR12] PR-036 Learning to Remember Rare Events

8 Tactics To Combat Imbalanced Training Data

1. Collect More Data

2. Try Changing Your Performance Metric

3. Try Resampling Your Dataset

4. Try Generate Synthetic Samples

5. Try Different Algorithms

6. Try Penalized Models

7. Try a Different Perspective

8. Try Getting Creative

"8 Tactics to Combat Imbalanced Classes in Your Machine Learning Dataset" @ Machine Learning Mastery 4

Page 5: [PR12] PR-036 Learning to Remember Rare Events

Problem Definition (for rare events)Deep Neural Networks

Extend the training data

Re‑train them to handle such rare or new events

Very SLOW!! (gradient‑based optimization)

Humans (life‑long fashion)Learn from single example

5

Page 6: [PR12] PR-036 Learning to Remember Rare Events

Key ConceptsDeep Neural Networks (+ Memory Module)

6

Page 7: [PR12] PR-036 Learning to Remember Rare Events

Previous WorksMeta‑Learning with Memory Augmented Neural Networks

Idea: Write the pair of "image and label" into the memory

Matching Networks for One Shot LearningIdea: Train fully end‑to‑end nearest neighbor classifier

Note from A. Karpathy

7

Page 8: [PR12] PR-036 Learning to Remember Rare Events

Memory moduleDefine a memory of size  memory-size  as a triple:

M = (K ,V ,A )

m:  memory-size , k:  key-size .

Key: activations of a chosen layer of a neural network.

Value: ground‑truth targets for the given example.

Age: track the ages of the items stored in memory.

m×k m m

8

Page 9: [PR12] PR-036 Learning to Remember Rare Events

Memory module (query)Memory query q is a vector of size  key-size :

q =R , ∣∣q∣∣ = 1

The nearest neighbor(*) of q in M :

NN(q,M) = arg q ⋅K[i].

Given a query q, Memory M will compute k‑NN :

(n , ...,n ) = NN (q,M)

Return the main result. the value V [n ]

k

imax

1 k k

1

(*) Since the keys are normalized, the nearest neighbor w.r.t. cosine similarity. 9

Page 10: [PR12] PR-036 Learning to Remember Rare Events

Memory module (query)Cosine similarity: d = q ⋅K[n ]

Return  softmax (d ⋅ τ , ..., d ⋅ τ)

Inverse of softmax temperature: τ = 40

i i

1 k

10

Page 11: [PR12] PR-036 Learning to Remember Rare Events

[Note] Softmax temperature, τ

The idea is to control randomness of predictions

: Softmax outputs are more close to each other

: Softmax outputs are more and more "hardmax"

For a low temperature (τ → 0 ), the probability of the outputwith the highest expected reward tends to 1.

+

11

Page 12: [PR12] PR-036 Learning to Remember Rare Events

Memory module (episode)

Slide from "Meta‑learning with memory‑augmented neural networks" (Slideshare, H. Kim) 12

Page 13: [PR12] PR-036 Learning to Remember Rare Events

Memory module (train)

Memory loss

Query q and the correct desired (supervised) value v.

Classification: v would be the class label.

Seq2Seq: v would be the output token of the current timestep.

13

Page 14: [PR12] PR-036 Learning to Remember Rare Events

Memory module (train)loss(q, v,M) = [q ⋅K[n ] − q ⋅K[n ] + α]

K[n ]: positive neightbor, V [n ] = v

K[n ]: negative neightbor, V [n ] ≠ v

α: Margin to make loss as zero

b p +

p p

b b

14

Page 15: [PR12] PR-036 Learning to Remember Rare Events

Memory module (Update)Case V [n ] = v:

K[n ] ←

A[n ] ← 0

Case V [n ] ≠ v:if memory has empty space at n ,assign n with n

if not, n = max(A[n ])

K[n ] ← q, V [n ] ← v, and A[n ] ← 0.

p

1 ∣∣q+k[n ]∣∣1

q+k[n ]1

1

b

empty′

empty

′k

′ ′ ′

15

Page 16: [PR12] PR-036 Learning to Remember Rare Events

Memory module (train & update)

16

Page 17: [PR12] PR-036 Learning to Remember Rare Events

Experiments (Evaluation)1. Evaluation on Omniglot dataset

2. Evaluation on synthetic task

3. Evaluation on English‑German translation model

Qualitative side: rarely‑occurring words

Quantitative side: BLEU score

17

Page 18: [PR12] PR-036 Learning to Remember Rare Events

Experiments (Omniglot Dataset)

18

Page 19: [PR12] PR-036 Learning to Remember Rare Events

Experiments (Omniglot Dataset)

Omniglot dataset

This dataset contains  1623  different handwritten charactersfrom  50  different alphabets.

Each of the 1623 characters was drawn online via Amazon'sMechanical Turk by 20 different people.

Each image is paired with stroke data, a sequences of  [x,y,t] coordinates with time  (t)  in milliseconds.

Stroke data is available in MATLAB files only.

Omniglot dataset for one‑shot learning (github): https://github.com/brendenlake/omniglot 19

Page 20: [PR12] PR-036 Learning to Remember Rare Events

Experiments (Omniglot Dataset)CNN Architecture

(Conv, ReLU), (Conv, ReLU), pool,(Conv, ReLU), (Conv, ReLU), pool, FC, FC

Memory module

Output layer (Prediction)

20

Page 21: [PR12] PR-036 Learning to Remember Rare Events

Experiments (Omniglot Dataset)

 way : different alphabets

 shot : different characters

21

Page 22: [PR12] PR-036 Learning to Remember Rare Events

Experiments (GNMT)

Decoder pathKey: result of attention a

Combine value and LSTM output (at decoder time‑step)t

22

Page 23: [PR12] PR-036 Learning to Remember Rare Events

Experiments (GNMT)

23

Page 24: [PR12] PR-036 Learning to Remember Rare Events

Experiments (GNMT)Convolutional Gated Recurrent Unit (CGRU)

For more information: Read the Lunit tech blog

24

Page 25: [PR12] PR-036 Learning to Remember Rare Events

ConclusionsLong‑term memory module

Embedding input with a simple CNN (LeNet)

Returning k‑nn could be used for other layers.

25

Page 26: [PR12] PR-036 Learning to Remember Rare Events

Code Review (Github)1.  data_utils.py : Data loading and other utilities.

2.  train.py : Script for training model.

3.  memory.py : Memory module for storing "nearest neighbors".

4.  model.py : Model using memory component.

26

Page 27: [PR12] PR-036 Learning to Remember Rare Events

Quick Start

1) First download and set‑up Omniglot data by running

python data_utils.py

2) Then run the training script:

python train.py --memory_size=8192 \ --batch_size=16 --validation_length=50 \ --episode_width=5 --episode_length=30

27

Page 28: [PR12] PR-036 Learning to Remember Rare Events

3) The first validation batch may look like this (although it isnoisy):

0-shot: 0.040, 1-shot: 0.404, 2-shot: 0.516, 3-shot: 0.604, 4-shot: 0.656, 5-shot: 0.684

4) At step 500 you may see something like this:

0-shot: 0.036, 1-shot: 0.836, 2-shot: 0.900, 3-shot: 0.940, 4-shot: 0.944, 5-shot: 0.916

5) At step 4000 you may see something like this:

0-shot: 0.044, 1-shot: 0.960, 2-shot: 1.000, 3-shot: 0.988, 4-shot: 0.972, 5-shot: 0.992

28

Page 29: [PR12] PR-036 Learning to Remember Rare Events

0) Basic parametersrep_dim: 128, dimension of keys to use in memory

episode_length: 100, length of episode

episode_width: 5, number of distinct labels in a single episode

memory_size: None, number of slots in memory.

batch_size: 16, batch size

num_episodes: 100000, number of training episodes

validation_frequency: 20, every so many training episodesassess validation accuracy

validation_length: 10, number of episodes to use to computevalidation accuracy

seed: 888, random seed for training sampling

save_dir: '', directory to save model to

use_lsh: False, use locality‑sensitive hashing (NOTE: not fullytested) 29

Page 30: [PR12] PR-036 Learning to Remember Rare Events

1) data_utils.py

def preprocess_omniglot():# Download and prepare raw Omniglot data.

def maybe_download_data():# Download Omniglot repo if it does not exist.

def write_datafiles():# Load and preprocess images from a directory and # write them to a file.

def crawl_directory():# Crawls data directory and returns stuff.

def resize_images():# Resize images to new dimensions.

30

Page 31: [PR12] PR-036 Learning to Remember Rare Events

1) Tips from data_utils.py logging 으로 메세지를 관리한다.  level  조절 가능.

 pickle 로 dump해서 사용한다. (TFrecord, queue는..?)

간단한 외부 명령은  subprocess 로 실행한다.

train dataset만  augment (rotation)  수행 (0, 90, 180, 270도)

 resizing  수행 (기존: 105, 변환: 28)

OUTPUT: train_omni.pkl (733M), test_omni.pkl (126M)

31

Page 32: [PR12] PR-036 Learning to Remember Rare Events

2) train.py

def data_utils.get_data():# Get data in form suitable for episodic training.# Returns: Train and test data as dictionaries mapping# label to list of examples.class Trainer(): def run(): self.sample_episode_batch() outputs = self.model.episode_step()

32

Page 33: [PR12] PR-036 Learning to Remember Rare Events

2) Tips from train.py기본적인 파라미터는  tf.flags 로 전달

학습과 관련된 내용들은  logging 으로 메세지 전달

 assert  활용: episode 길이 오류 확인

train / validation 동시 수행 (20 : 1 비율)

33

Page 34: [PR12] PR-036 Learning to Remember Rare Events

3) model.py

class LeNet(object):# Standard CNN architecture

class Model(object):# Model for coordinating between CNN embedder and # Memory module.

34

Page 35: [PR12] PR-036 Learning to Remember Rare Events

3) model.pyLine 152‑158,  core_builder() :

embeddings = self.embedder.core_builder(x)

if keep_prob < 1.0: embeddings = tf.nn.dropout(embeddings, keep_prob)memory_val, _, teacher_loss = self.memory.query( embeddings, y, use_recent_idx=use_recent_idx)loss, y_pred = self.classifier.core_builder( memory_val, x, y)

return loss + teacher_loss, y_pred

35

Page 36: [PR12] PR-036 Learning to Remember Rare Events

3) Tips from model.py core_builder() : 기존 네트워크에 memory 추가

입력 Ý상  x 에 �해  LeNet 을 이용해 embedding vector 생성

 weight ,  bias  는  tf.get_variable 로 미리 생성

model의 각 기능을 최�한 세분화

36

Page 37: [PR12] PR-036 Learning to Remember Rare Events

4) memory.py

class Memory(object): def get_hint_pool_idxs(...): # Get small set of idxs to compute nearest neighbor # queries on. def query(...): # Queries memory for nearest neighbor.

class LSHMemory(Memory):# Memory employing locality sensitive hashing.# Note: Not fully tested.

37

Page 38: [PR12] PR-036 Learning to Remember Rare Events

4) Tips from memory.py Memory 와  LSHMemory  중 선택 가능,  memory  사용 권고.

논문의 memory 동작을 직관적으로 구현

 memory_size 와  key_size 만 변경하면 거의 �부분의 네트워크에접목 가능

38

Page 39: [PR12] PR-036 Learning to Remember Rare Events

Appendix (Reviews)1. Lunit Tech Blog (by Hyo‑Eun Kim) (Link)

2. OpenReview (ICLR2017) (Link)

3. BAIR Blog: "Learning to Learn" (by Chelsea Finn) (Link)

4. Learning to remember rare events (by Hongbae Kim)(Slideshare)

5. One Shot Learning (by Jisung Kim) (Slideshare)

39

Page 40: [PR12] PR-036 Learning to Remember Rare Events

Appendix (Implementations)1. TensorFlow/models (GoogleBrain) (Github)

40