[pr12] pr-036 learning to remember rare events
TRANSCRIPT
Learning to Remember Rare EventsPaper is appeared in ICLR 2017, https://arxiv.org/abs/1703.03129
Authors:
Łukasz Kaiser, Ofir Nachum, Aurko Roy, and Samy Bengio(Google Brain)
Reviewed by Taegyun Jeon
1
What we can learn from this paper?1. Memory‑augmented deep neural network
2. Two tasks:One‑shot learning (Omniglot dataset)
Life‑long one‑shot learning (large‑scale machinetranslation)
3. TensorFlow implementation for the one‑shot learningOfficial code from Google Brain using TensorFlow
2
Problem Definition
rare events v.s. on average
Image from "One Shot Learning" (Jisung Kim @ TensorFlow‑KR 2nd Meetup Lighting Talk) 3
8 Tactics To Combat Imbalanced Training Data
1. Collect More Data
2. Try Changing Your Performance Metric
3. Try Resampling Your Dataset
4. Try Generate Synthetic Samples
5. Try Different Algorithms
6. Try Penalized Models
7. Try a Different Perspective
8. Try Getting Creative
"8 Tactics to Combat Imbalanced Classes in Your Machine Learning Dataset" @ Machine Learning Mastery 4
Problem Definition (for rare events)Deep Neural Networks
Extend the training data
Re‑train them to handle such rare or new events
Very SLOW!! (gradient‑based optimization)
Humans (life‑long fashion)Learn from single example
5
Key ConceptsDeep Neural Networks (+ Memory Module)
6
Previous WorksMeta‑Learning with Memory Augmented Neural Networks
Idea: Write the pair of "image and label" into the memory
Matching Networks for One Shot LearningIdea: Train fully end‑to‑end nearest neighbor classifier
Note from A. Karpathy
7
Memory moduleDefine a memory of size memory-size as a triple:
M = (K ,V ,A )
m: memory-size , k: key-size .
Key: activations of a chosen layer of a neural network.
Value: ground‑truth targets for the given example.
Age: track the ages of the items stored in memory.
m×k m m
8
Memory module (query)Memory query q is a vector of size key-size :
q =R , ∣∣q∣∣ = 1
The nearest neighbor(*) of q in M :
NN(q,M) = arg q ⋅K[i].
Given a query q, Memory M will compute k‑NN :
(n , ...,n ) = NN (q,M)
Return the main result. the value V [n ]
k
imax
1 k k
1
(*) Since the keys are normalized, the nearest neighbor w.r.t. cosine similarity. 9
Memory module (query)Cosine similarity: d = q ⋅K[n ]
Return softmax (d ⋅ τ , ..., d ⋅ τ)
Inverse of softmax temperature: τ = 40
i i
1 k
10
[Note] Softmax temperature, τ
The idea is to control randomness of predictions
: Softmax outputs are more close to each other
: Softmax outputs are more and more "hardmax"
For a low temperature (τ → 0 ), the probability of the outputwith the highest expected reward tends to 1.
+
11
Memory module (episode)
Slide from "Meta‑learning with memory‑augmented neural networks" (Slideshare, H. Kim) 12
Memory module (train)
Memory loss
Query q and the correct desired (supervised) value v.
Classification: v would be the class label.
Seq2Seq: v would be the output token of the current timestep.
13
Memory module (train)loss(q, v,M) = [q ⋅K[n ] − q ⋅K[n ] + α]
K[n ]: positive neightbor, V [n ] = v
K[n ]: negative neightbor, V [n ] ≠ v
α: Margin to make loss as zero
b p +
p p
b b
14
Memory module (Update)Case V [n ] = v:
K[n ] ←
A[n ] ← 0
Case V [n ] ≠ v:if memory has empty space at n ,assign n with n
if not, n = max(A[n ])
K[n ] ← q, V [n ] ← v, and A[n ] ← 0.
p
1 ∣∣q+k[n ]∣∣1
q+k[n ]1
1
b
empty′
empty
′k
′ ′ ′
15
Memory module (train & update)
16
Experiments (Evaluation)1. Evaluation on Omniglot dataset
2. Evaluation on synthetic task
3. Evaluation on English‑German translation model
Qualitative side: rarely‑occurring words
Quantitative side: BLEU score
17
Experiments (Omniglot Dataset)
18
Experiments (Omniglot Dataset)
Omniglot dataset
This dataset contains 1623 different handwritten charactersfrom 50 different alphabets.
Each of the 1623 characters was drawn online via Amazon'sMechanical Turk by 20 different people.
Each image is paired with stroke data, a sequences of [x,y,t] coordinates with time (t) in milliseconds.
Stroke data is available in MATLAB files only.
Omniglot dataset for one‑shot learning (github): https://github.com/brendenlake/omniglot 19
Experiments (Omniglot Dataset)CNN Architecture
(Conv, ReLU), (Conv, ReLU), pool,(Conv, ReLU), (Conv, ReLU), pool, FC, FC
Memory module
Output layer (Prediction)
20
Experiments (Omniglot Dataset)
way : different alphabets
shot : different characters
21
Experiments (GNMT)
Decoder pathKey: result of attention a
Combine value and LSTM output (at decoder time‑step)t
22
Experiments (GNMT)
23
Experiments (GNMT)Convolutional Gated Recurrent Unit (CGRU)
For more information: Read the Lunit tech blog
24
ConclusionsLong‑term memory module
Embedding input with a simple CNN (LeNet)
Returning k‑nn could be used for other layers.
25
Code Review (Github)1. data_utils.py : Data loading and other utilities.
2. train.py : Script for training model.
3. memory.py : Memory module for storing "nearest neighbors".
4. model.py : Model using memory component.
26
Quick Start
1) First download and set‑up Omniglot data by running
python data_utils.py
2) Then run the training script:
python train.py --memory_size=8192 \ --batch_size=16 --validation_length=50 \ --episode_width=5 --episode_length=30
27
3) The first validation batch may look like this (although it isnoisy):
0-shot: 0.040, 1-shot: 0.404, 2-shot: 0.516, 3-shot: 0.604, 4-shot: 0.656, 5-shot: 0.684
4) At step 500 you may see something like this:
0-shot: 0.036, 1-shot: 0.836, 2-shot: 0.900, 3-shot: 0.940, 4-shot: 0.944, 5-shot: 0.916
5) At step 4000 you may see something like this:
0-shot: 0.044, 1-shot: 0.960, 2-shot: 1.000, 3-shot: 0.988, 4-shot: 0.972, 5-shot: 0.992
28
0) Basic parametersrep_dim: 128, dimension of keys to use in memory
episode_length: 100, length of episode
episode_width: 5, number of distinct labels in a single episode
memory_size: None, number of slots in memory.
batch_size: 16, batch size
num_episodes: 100000, number of training episodes
validation_frequency: 20, every so many training episodesassess validation accuracy
validation_length: 10, number of episodes to use to computevalidation accuracy
seed: 888, random seed for training sampling
save_dir: '', directory to save model to
use_lsh: False, use locality‑sensitive hashing (NOTE: not fullytested) 29
1) data_utils.py
def preprocess_omniglot():# Download and prepare raw Omniglot data.
def maybe_download_data():# Download Omniglot repo if it does not exist.
def write_datafiles():# Load and preprocess images from a directory and # write them to a file.
def crawl_directory():# Crawls data directory and returns stuff.
def resize_images():# Resize images to new dimensions.
30
1) Tips from data_utils.py logging 으로 메세지를 관리한다. level 조절 가능.
pickle 로 dump해서 사용한다. (TFrecord, queue는..?)
간단한 외부 명령은 subprocess 로 실행한다.
train dataset만 augment (rotation) 수행 (0, 90, 180, 270도)
resizing 수행 (기존: 105, 변환: 28)
OUTPUT: train_omni.pkl (733M), test_omni.pkl (126M)
31
2) train.py
def data_utils.get_data():# Get data in form suitable for episodic training.# Returns: Train and test data as dictionaries mapping# label to list of examples.class Trainer(): def run(): self.sample_episode_batch() outputs = self.model.episode_step()
32
2) Tips from train.py기본적인 파라미터는 tf.flags 로 전달
학습과 관련된 내용들은 logging 으로 메세지 전달
assert 활용: episode 길이 오류 확인
train / validation 동시 수행 (20 : 1 비율)
33
3) model.py
class LeNet(object):# Standard CNN architecture
class Model(object):# Model for coordinating between CNN embedder and # Memory module.
34
3) model.pyLine 152‑158, core_builder() :
embeddings = self.embedder.core_builder(x)
if keep_prob < 1.0: embeddings = tf.nn.dropout(embeddings, keep_prob)memory_val, _, teacher_loss = self.memory.query( embeddings, y, use_recent_idx=use_recent_idx)loss, y_pred = self.classifier.core_builder( memory_val, x, y)
return loss + teacher_loss, y_pred
35
3) Tips from model.py core_builder() : 기존 네트워크에 memory 추가
입력 Ý상 x 에 �해 LeNet 을 이용해 embedding vector 생성
weight , bias 는 tf.get_variable 로 미리 생성
model의 각 기능을 최�한 세분화
36
4) memory.py
class Memory(object): def get_hint_pool_idxs(...): # Get small set of idxs to compute nearest neighbor # queries on. def query(...): # Queries memory for nearest neighbor.
class LSHMemory(Memory):# Memory employing locality sensitive hashing.# Note: Not fully tested.
37
4) Tips from memory.py Memory 와 LSHMemory 중 선택 가능, memory 사용 권고.
논문의 memory 동작을 직관적으로 구현
memory_size 와 key_size 만 변경하면 거의 �부분의 네트워크에접목 가능
38
Appendix (Reviews)1. Lunit Tech Blog (by Hyo‑Eun Kim) (Link)
2. OpenReview (ICLR2017) (Link)
3. BAIR Blog: "Learning to Learn" (by Chelsea Finn) (Link)
4. Learning to remember rare events (by Hongbae Kim)(Slideshare)
5. One Shot Learning (by Jisung Kim) (Slideshare)
39
Appendix (Implementations)1. TensorFlow/models (GoogleBrain) (Github)
40