end to end learning and optimization on graphs - bryan wilder · 2021. 1. 2. · 10/27/2019 bryan...
Post on 15-Feb-2021
1 Views
Preview:
TRANSCRIPT
-
End to End Learning and Optimization on Graphs
Bryan Wilder1, Eric Ewing2, Bistra Dilkina2, Milind Tambe11Harvard University
2University of Southern California
To appear at NeurIPS 2019
110/27/2019 Bryan Wilder (Harvard)
-
10/27/2019 Bryan Wilder (Harvard) 2
Data Decisions???
-
10/27/2019 Bryan Wilder (Harvard) 3
Data Decisionsarg max𝑥𝑥∈𝑋𝑋 𝑓𝑓 𝑥𝑥,𝜃𝜃
Standard two stage: predict then optimize
Training: maximize accuracy
-
10/27/2019 Bryan Wilder (Harvard) 4
Data Decisionsarg max𝑥𝑥∈𝑋𝑋 𝑓𝑓 𝑥𝑥,𝜃𝜃
Standard two stage: predict then optimize
Challenge: misalignment between “accuracy” and decision quality
Training: maximize accuracy
-
10/27/2019 Bryan Wilder (Harvard) 5
Data Decisions
Pure end to end
Training: maximize decision quality
-
10/27/2019 Bryan Wilder (Harvard) 6
Data Decisions
Pure end to end
Challenge: optimization is hard
Training: maximize decision quality
-
10/27/2019 Bryan Wilder (Harvard) 7
Data Decisions
Decision-focused learning: differential optimization during training
arg max𝑥𝑥∈𝑋𝑋
𝑓𝑓 𝑥𝑥,𝜃𝜃
Training: maximize decision quality
-
10/27/2019 Bryan Wilder (Harvard) 8
Data Decisions
Challenge: how to make optimization differentiable?
arg max𝑥𝑥∈𝑋𝑋
𝑓𝑓 𝑥𝑥,𝜃𝜃
Training: maximize decision quality
Decision-focused learning: differential optimization during training
-
Relax + differentiate
10/27/2019 Bryan Wilder (Harvard) 9
Forward pass: run a solver
Backward pass: sensitivity analysis via KKT conditions
-
Relax + differentiate• Convex QPs [Amos and Kolter 2018, Donti et al 2018]• Linear and submodular programs [Wilder, Dilkina, Tambe 2019]• MAXSAT (via SDP relaxation) [Wang, Donti, Wilder, Kolter 2019]• MIPs [Ferber, Wilder, Dilkina, Tambe 2019]
• Monday @ 11am, Room 612
10/27/2019 Bryan Wilder (Harvard) 10
-
What’s wrong with relaxations?• Some problems don’t have good ones• Slow to solve continuous optimization problem• Slower to backprop through – 𝑂𝑂(𝑛𝑛3)
10/27/2019 Bryan Wilder (Harvard) 11
-
This work• Alternative: include solver for a simpler proxy problem• Learn a representation that maps hard problem to simple one• Instantiate this idea for a class of graph optimization problems
10/27/2019 Bryan Wilder (Harvard) 12
-
Graph learning + optimization
10/27/2019 Bryan Wilder (Harvard) 13
-
Problem classes• Partition the nodes into K disjoint groups
• Community detection, maxcut, …• Select a subset of K nodes
• Facility location, influence maximization, vaccination, …• Methods of choice are often combinatorial/discrete
10/27/2019 Bryan Wilder (Harvard) 14
-
Approach• Observation: grouping nodes into communities is a good heuristic
• Partitioning: correspond to well-connected subgroups• Facility location: put one facility in each community
• Observation: graph learning approaches already embed into 𝑅𝑅𝑛𝑛
10/27/2019 Bryan Wilder (Harvard) 15
-
Approach1. Start with clustering algorithm (in 𝑅𝑅𝑛𝑛)
• Can (approximately) differentiate very quickly 2. Train embeddings (representation) to solve the particular problem
• Automatically learning a good continuous relaxation!
10/27/2019 Bryan Wilder (Harvard) 16
-
ClusterNet Approach
10/27/2019 Bryan Wilder (Harvard) 17
Node embedding (GCN)
K-means clustering Locate 1 facility in
each community
-
Differentiable K-means
10/27/2019 Bryan Wilder (Harvard) 18
Update cluster centers
Softmax update to node assignments
Forward pass
-
Differentiable K-means
10/27/2019 Bryan Wilder (Harvard) 19
Backward pass
• Option 1: differentiate through the fixed-point condition 𝜇𝜇𝑡𝑡 = 𝜇𝜇𝑡𝑡+1
• Prohibitively slow, memory-intensive
-
Differentiable K-means
10/27/2019 Bryan Wilder (Harvard) 20
Backward pass
• Option 1: differentiate through the fixed-point condition 𝜇𝜇𝑡𝑡 = 𝜇𝜇𝑡𝑡+1
• Prohibitively slow, memory-intensive • Option 2: unroll the entire series of updates
• Cost scales with # iterations• Have to stick to differentiable operations
-
Differentiable K-means
10/27/2019 Bryan Wilder (Harvard) 21
Backward pass
• Option 1: differentiate through the fixed-point condition 𝜇𝜇𝑡𝑡 = 𝜇𝜇𝑡𝑡+1
• Prohibitively slow, memory-intensive • Option 2: unroll the entire series of updates
• Cost scales with # iterations• Have to stick to differentiable operations
• Option 3: get the solution, then unroll one update• Do anything to solve the forward pass• Linear time/memory, implemented in vanilla pytorch
-
Differentiable K-meansTheorem [informal]: provided the clusters are sufficiently balanced and well-separated, the Option 3 approximate gradients converge exponentially quickly to the true ones.
Idea: show that this corresponds to approximating a particular term in the analytical fixed-point gradients.
10/27/2019 Bryan Wilder (Harvard) 22
-
ClusterNet Approach
10/27/2019 Bryan Wilder (Harvard) 23
GCN node embeddings
K-means clustering Locate 1 facility in
each community
-
ClusterNet Approach
10/27/2019 Bryan Wilder (Harvard) 24
GCN node embeddings
K-means clustering Locate 1 facility in
each community
Loss: quality of facility assignment
-
ClusterNet Approach
10/27/2019 Bryan Wilder (Harvard) 25
GCN node embeddings
K-means clustering Locate 1 facility in
each community
Loss: quality of facility assignment
Differentiate through K-means
-
ClusterNet Approach
10/27/2019 Bryan Wilder (Harvard) 26
GCN node embeddings
K-means clustering Locate 1 facility in
each community
Loss: quality of facility assignment
Differentiate through K-means
Update GCN params
-
Experiments• Learning problem: link prediction• Optimization: community detection and facility location problems• Train GCNs as predictive component
10/27/2019 Bryan Wilder (Harvard) 27
-
Example: community detection
10/27/2019 Bryan Wilder (Harvard) 28
Observe partial graph Predict unseen edges Find communities
max𝑟𝑟
12𝑚𝑚
�𝑢𝑢,𝑣𝑣∈𝑉𝑉
�𝑘𝑘=1
𝐾𝐾
𝐴𝐴𝑢𝑢,𝑣𝑣 −𝑑𝑑𝑢𝑢𝑑𝑑𝑣𝑣2𝑚𝑚
𝑟𝑟𝑢𝑢𝑘𝑘𝑟𝑟𝑣𝑣𝑘𝑘
𝑟𝑟𝑢𝑢𝑘𝑘 ∈ 0,1 ∀𝑢𝑢 ∈ 𝑉𝑉, 𝑘𝑘 = 1 …𝐾𝐾
�𝑘𝑘=1
𝐾𝐾
𝑟𝑟𝑢𝑢𝑘𝑘 = 1 ∀𝑢𝑢 ∈ 𝑉𝑉
-
Example: community detection
10/27/2019 Bryan Wilder (Harvard) 29
max𝑟𝑟
12𝑚𝑚
�𝑢𝑢,𝑣𝑣∈𝑉𝑉
�𝑘𝑘=1
𝐾𝐾
𝐴𝐴𝑢𝑢,𝑣𝑣 −𝑑𝑑𝑢𝑢𝑑𝑑𝑣𝑣2𝑚𝑚
𝑟𝑟𝑢𝑢𝑘𝑘𝑟𝑟𝑣𝑣𝑘𝑘
𝑟𝑟𝑢𝑢𝑘𝑘 ∈ 0,1 ∀𝑢𝑢 ∈ 𝑉𝑉, 𝑘𝑘 = 1 …𝐾𝐾
�𝑘𝑘=1
𝐾𝐾
𝑟𝑟𝑢𝑢𝑘𝑘 = 1 ∀𝑢𝑢 ∈ 𝑉𝑉
• Useful in scientific discovery (social groups, functional modules in biological networks)• In applications, two-stage approach is common
[Yan & Gegory ’12, Burgess et al ‘16, Berlusconi et al ‘16, Tan et al ‘16, Bahulker et al ’18…]
Observe partial graph Predict unseen edges Find communities
-
Experiments• Learning problem: link prediction• Optimization: community detection and facility location problems• Train GCNs as predictive component• Comparison
• Two stage: GCN + expert-designed algorithm (2Stage)• Pure end to end: Deep GCN to predict optimal solution (e2e)
10/27/2019 Bryan Wilder (Harvard) 30
-
Results: single-graph link prediction
10/27/2019 Bryan Wilder (Harvard) 31
0
0.2
0.4
0.6
Mod
ular
ityCommunity detection
(higher is better)
ClusterNet 2stage e2e5
7
9
11
Max
dist
ance
Facility location (lower is better)
ClusterNet 2stage e2e
Representative example from cora, citeseer, protein interaction, facebook, adolescent health networks
-
Results: generalization across graphs
10/27/2019 Bryan Wilder (Harvard) 32
0
0.2
0.4
0.6
Mod
ular
ityCommunity detection
(higher is better)
ClusterNet 2stage e2e5
7
9
Max
dist
ance
Facility location (lower is better)
ClusterNet 2stage e2e
ClusterNet learns generalizable strategies for optimization!
-
Takeaways• Good decisions require integrating learning and optimization• Pure end-to-end methods miss out on useful structure• Even simple optimization primitives provide good inductive bias
NeurIPS’19 paper, see bryanwilder.github.io Code available at https://github.com/bwilder0/clusternet
10/27/2019 Bryan Wilder (Harvard) 33
https://github.com/bwilder0/clusternet
End to End Learning and Optimization on GraphsSlide Number 2Slide Number 3Slide Number 4Slide Number 5Slide Number 6Slide Number 7Slide Number 8Relax + differentiateRelax + differentiateWhat’s wrong with relaxations?This workGraph learning + optimizationProblem classesApproachApproachClusterNet ApproachDifferentiable K-meansDifferentiable K-meansDifferentiable K-meansDifferentiable K-meansDifferentiable K-meansClusterNet ApproachClusterNet ApproachClusterNet ApproachClusterNet ApproachExperimentsExample: community detectionExample: community detectionExperimentsResults: single-graph link predictionResults: generalization across graphsTakeaways
top related