Building Models that Learn to Discover Structure and Relations

Thomas Kipf, Ethan Fetaya, Jackson Wang, Max Welling, Rich Zemel (ML Research Partners from University of Amsterdam)

Some argue that a key component of human intelligence is our ability to reason about objects and their relations (e.g. [1,2]). This enables us, for example, to build rich compositional models of physics (how objects or particles interact) and intuitive theories of causation (what causes what) [3].

For artificial systems, these tasks remain a challenge. Most sophisticated pattern recognition models, e.g. based on Convolutional Neural Networks (CNNs) or Recurrent Neural Networks (RNNs), lack a certain relational inductive bias [4]; impeding their ability to generalize well on problems with inherent compositional structure.

In our recent ICML (2018) paper: Neural Relational Inference for Interacting Systems, we explore a class of models named Graph Neural Networks (GNNs) that reflect the inherent structure of the problem domain in their model architecture¹. This enables variants of GNNs, for example, to learn to predict physical dynamics of an interacting system (e.g. billiard balls on a table) [5] or to reason about relations between objects in a given image [6].

Physical simulation of particles coupled by invisible springs. The connectivity is given by a hidden interaction graph. Unconnected particles do not interact.

In our work, we investigate whether this class of models is also capable of recognizing the underlying structure and types of relations in the data we observe in a completely unsupervised way (i.e. without ever showing it the ground truth relations or interactions).

To illustrate this task, let us consider a more concrete setting: we are looking at a physical simulation of balls rolling on a 2D surface (see Figure above), and some of the balls are connected by (invisible) springs that create an attractive force. If we knew the position and velocity of every ball, including their connectivity structure, we would be able to predict where they are going to move next.

Can you guess which balls are connected by springs? The solution appears after a few seconds.

Without knowing this latent interaction graph, predicting the system’s dynamics can be quite difficult. Similarly, the task of inferring which ball is connected to which by a spring is challenging in the first place. Have a look at the video to the left (or above, in case you’re viewing this on your phone) and see if you are able to correctly guess the interaction graph.You will notice that once the interaction structure is shown, it is suddenly much easier and highly intuitive to understand the dynamics.

In our work, we give GNNs the task to simultaneously infer this latent interaction structure and to predict the dynamics of the interacting system. After showing the model a wide set of simulations, it can recognize these hidden relations and give accurate predictions in 99.9% of the cases², interestingly without ever using a ground truth interaction graph example in training.

Our Neural Relational Inference (NRI) model can be seen as an auto-encoder where the task of the encoder is to create a hypothesis about how the system interacts and the decoder learns a dynamical model of the interacting system constrained by the encoder’s “interaction hypothesis”. We frame this as a probabilistic model where the latent code corresponds to a distribution over relation types between objects (see Figure below).