[1910.01075v1] Learning Neural Causal Models from Unknown Interventions
We believe that our approach of treating seemingly observational data as being derived from an environment with agents executing interventions could represent an important change in modelling perspective with deeper implications.

Abstract: Meta-learning over a set of distributions can be interpreted as learning
different types of parameters corresponding to short-term vs long-term aspects
of the mechanisms underlying the generation of data. These are respectively
captured by quickly-changing parameters and slowly-changing meta-parameters. We
present a new framework for meta-learning causal models where the relationship
between each variable and its parents is modeled by a neural network, modulated
by structural meta-parameters which capture the overall topology of a directed
graphical model. Our approach avoids a discrete search over models in favour of
a continuous optimization procedure. We study a setting where interventional
distributions are induced as a result of a random intervention on a single
unknown variable of an unknown ground truth causal model, and the observations
arising after such an intervention constitute one meta-example. To disentangle
the slow-changing aspects of each conditional from the fast-changing
adaptations to each intervention, we parametrize the neural network into fast
parameters and slow meta-parameters. We introduce a meta-learning objective
that favours solutions robust to frequent but sparse interventional
distribution change, and which generalize well to previously unseen
interventions. Optimizing this objective is shown experimentally to recover the
structure of the causal graph.

‹Figure 1: MLP Model Architecture for M = 3, N = 2 (fork3) SCM. The model computes the conditional probabilities of Â, B̂, Ĉ given their parents using a stack of three independent MLPs. The MLP input layer uses an adjacency matrix sampled from Ber(σ(γ)) as an input mask to force the model to make use only of parent nodes to predict their child node. (Causal Induction as an Optimization Problem)Figure 2: Learned edges at three different stages of training. Left: Chain graph with 4 variables. Right: Fully-connected DAG graph with 4 variables. (Synthetic Datasets)

Figure 3: Earthquake: Learned edges at three different stages of training. (Real-World Datasets: BnLearn)Figure 4: Asia: Learned edges at three different stages of training. (Real-World Datasets: BnLearn)Figure 5: Left: Cross entropy (CE) for edge probability between learned and ground-truth graphs for all 3-variable SCMs. Middle: Edge CE loss for the chain graph with 4-7 variables. Right: Edge CE loss for 3-variable graphs with no dropout during pretraining, showing the importance of this dropout. (Real-World Datasets: BnLearn)Figure 6: Ablations study results on all possible 3 variable graphs. Both graphs show the cross-entropy loss on learned vs ground-truth edges over training time. Left: Models that infer the intervention (prediction, bold) vs models that have knowledge of the true intervention (ground truth, long dash) vs models that use no knowledge of the intervention at all (no prediction, short dash). Result suggests inferring the intervention works almost as well as knowing the true intervention. Right: Comparisons of model trained with and without DAG regularizer (LDAG), showing that DAG regularizer helps convergence. (Real-World Datasets: BnLearn)Figure 7: Left: Every possible 3-variable connected DAG. Right: Cross entropy for edge probability between learned and ground-truth SCM for all 3-variable SCMs. (Synthetic data)Figure 8: Left: Ground Truth SCM for Cancer. Middle: Groundtruth SCM for Earthquake. Right: Groundtruth SCM for Asia. (BnLearn data repository)Figure 9: Effect of Sparsity: On 5 variable, 6 variable and 8 variable Nodes (Effect of Sparsity)Figure 10: Cross entropy for edge probability between learned and ground-truth SCM for Cancer at varying temperatures. (Effect of Temperature)Figure 11: Cross entropy for edge probability between learned and ground-truth SCM. Left: The Earthquake dataset with 6 variables. Right: The Asia dataset with 8 variables (Effect of Temperature)›