inference in probabilistic graphical -...
TRANSCRIPT
-
Inference in Probabilistic Graphical Models by Graph Neural Networks
Author: KiJung Yoon, Renjie Liao, Yuwen Xiong, Lisa Zhang, Ethan Fetaya, Raquel Urtasun, Richard Zemel, Xaq Pitkow
Presenter: Shihao Niu, Zhe Qu, Siqi Liu, Jules Ahmar
-
TL;DR: Use Graph Neural Networks (GNNs) to learn a message-passing algorithm that solves inference tasks in probabilistic graphical models.
Motivation● Inference is difficult for probabilistic graphical models. ● Message passing algorithms, such as belief propagation, struggles when the
graph contains loops○ Loopy belief propagation: convergence are not guaranteed.
-
Why GNNs● Essentially an extension of recurrent neural networks (RNN) on the graph
inputs. ● Central idea is to update hidden states at each node iteratively, by
aggregating incoming messages. ● Have a similar structure as a message passing algorithm.
-
● Recall that the distribution of a factor graph is○
● Recall the formulas of a belief propagation algorithm○ ○
Factor graph and belief propagation
-
BP to GNNs: mapping the messages
● BP is recursive and graph-based. Naturally, we could map the messages to GNN nodes, and use Neural Networks to describe the nonlinear updates.
-
BP to GNNs: mapping the variable nodes
-
BP to GNNs: mapping the variable nodesMarginal probability of in MRF:
Marginal joint probability of in factor graph:
● All of the messages depend only on one variable node at a time● The nonlinear functions between GNN nodes can account for AFTER
equilibrium is reached.
-
Preliminaries for model● Binary MRF, aka Ising models.● and are specified randomly, and are provided as input for GNN inference. ● ●
-
GNN Recap
Update the state embedding of based on
- the feature of - the feature of the edges of- the state embeddings of the neighbors of - the feature of the neighbor of
Local output function:
-
GNN Recap (Cont.)Scarselli, Franco, et al. "The graph neural network model."
Decompose the state update function to be a sum of per-edge terms
-
Message Passing Neural Networks
Define Message from i to j at time t+1 as:
Step 1: Aggregate all incoming message into a single message at the destination node
Step 2: Update hidden state based on the current hidden state and the aggregated message
An abstraction of several GNN variants
Phase 1Message Passing
-
Message Passing Neural Networks (Cont.)
Phase 2: Readout Phase
The message function, node update function, and readout function could have different settings.
MPNN could generalize several different models.
-
GG-NN (Gated Graph Neural Network)
Source: Zhou, Jie, et al. "Graph neural networks: A review of methods and applications."
Gate Recurrent Units (GRU)
-
GG-NN (Cont.)
Readout Phase:
-
GG-NN (Cont.)
Gate Recurrent Units (GRU)
-
GG-NN (Cont.)
Gate Recurrent Units (GRU)
-
Two mappings between Factor graph and GNN
message-GNN and node-GNN perform similarly, and much better than belief propagation
message-GNN
node-GNN
-
Mapping I: Message-GNN (graphical model) (GNN) Message 𝜇ij between node i and j Node v Message nodes are ij and jk Node v and w connected
Conforms closely to the structure of conventional belief propagation, and reflects how messages depend on each other:Motivation:
-
Mapping I: Message-GNN1. If connected, message from node to :
2. Then update its hidden state by:
3. Readout function to extract marginal or MAP:
a. First aggregates all GNN nodes with same target by summation
b. Then apply a shared readout function
neural network (GRU)
Multi-layer Perceptron with ReLU activation function
another MLP with sigmoid activation function
(nodes in graphical model)
-
Mapping II: Node-GNN
● Mapping: (graphical model) (GNN) Variable nodes Node
1. Message function:
2. Aggregate Messages:
3. Node update function:
4. Readout is generated directly from hidden states:
-
Message-GNN and Node-GNN● Objective: backpropagation to minimize total cross-entropy loss function
--- ground truth, --- estimated result
● Receives external inputs about couplings between edges● Depends on the hidden states of source and destination nodes at the
previous time step.
Message Passing Function (General):
-
Experiments● In each experiment, two types of GNNs are tested:
○ Variable nodes (node-GNN)○ Message nodes (msg-GNN)
● Examine generalization of the model when...○ Testing on unseen graphs of the same structure○ Testing on completely random graphs○ Testing on graphs with the same size○ Testing on graphs with larger size
● Analyze performance in estimating both marginal probabilities and MAP state
-
Training Graphs
-
Larger, Novel Test Graphs
-
Marginal Inference Accuracy
-
Random Graphs
-
Generalization Performance on Random Graphs
-
Convergence of Inference Dynamics
-
MAP Estimation
-
Conclusion● Experiments showed that GNNs provide a flexible learning method for
inference in probabilistic graphical models
● Proved that learned representations and nonlinear transformations on edges generalize to larger graphs with different structures
● Examined two possible representations of graphical models within GNNs: variable nodes and message nodes
● Experimental results support GNNs as a great framework for solving hard inference problems
● Future work: train and test on larger and more diverse graphs, as well as broader classes of graphical models
-
References1. Zhou, Jie, et al. "Graph neural networks: A review of methods and applications." arXiv preprint
arXiv:1812.08434 (2018).
2. Gilmer, Justin, et al. "Neural message passing for quantum chemistry." Proceedings of the 34th
International Conference on Machine Learning-Volume 70. JMLR. org, 2017.
3. Scarselli, Franco, et al. "The graph neural network model." IEEE Transactions on Neural Networks
20.1 (2008): 61-80.
4. Li, Yujia, et al. "Gated graph sequence neural networks." arXiv preprint arXiv:1511.05493 (2015).
5. Wu, Zonghan, et al. "A comprehensive survey on graph neural networks." arXiv preprint
arXiv:1901.00596 (2019).
-
Homework1. Where do GNNs outperform belief propagation? Where does belief
propagation outperform GNNs?2. Given the following factor graph, draw the GNN using Message-GNN
mapping: