Contenu du post
Equilibrium Graph Pooling In graph-level prediction tasks, be it graph classification, graph regression, or something else, we usually do some kind of graph pooling to aggregate representations of nodes in a single vector. It has to be a permutation-invariant function, so we don’t have much choice apart from standard mean / max / sum / min / median. Fabian Fuchs in his new blog post asks: “Have we found the global optimum of how to do global aggregation or are we stuck in a local minimum?” In the new work, they propose Equilibrium Aggregation for global graph pooling. The idea brings together two subfields of deep learning: Learning on Sets (you’ve probably heard about Janossy pooling, Deep Sets and Self-Attention) and Implicit layers (Equilibrium models and Neural ODEs, for example). Equilibrium Aggregation minimizes the energy argmin E(x,y) that is defined as a sum of pairwise potentials F(x,y) and some regularizer term. The potential function is parameterized by a neural net and, for starters, might be implemented as DeepSets MLP. Varying the potential function, you could also recover vanilla sum/max/mean/median pooling. Generally speaking, the idea of using DeepSets for aggregation can be tracked to the very GraphSAGE, but it didn’t have a lot of theoretical justification back then. Experimentally, putting equilibrium aggregation as a global pooling function (particularly with a backbone GCN message passing) leads to significant improvements in MOL-PCBA and several graph-level toy tasks. So far, equilibrium aggregation does not bring much benefit when using it as a message aggregation function inside a GNN layer, and doesn’t support edge features in a global pooling - but those could be cool extensions and your next research project 😉 Check out Fabian’s post for more details!