[1912.02803] Neural Tangents: Fast and Easy Infinite Neural Networks in Python
By democratizing this previously challenging model family, we hope that researchers will begin to use infinite neural networks, in addition to their finite counterparts, when faced with a new problem domain (especially in cases that are data-limited)

NEURAL TANGENTS is a library designed to enable research into infinite-width neural networks. It provides a high-level API for specifying complex and hierarchical neural network architectures. These networks can then be trained and evaluated either at finite-width as usual or in their infinite-width limit. Infinite-width networks can be trained analytically using exact Bayesian inference or using gradient descent via the Neural Tangent Kernel. Additionally, NEURAL TANGENTS provides tools to study gradient descent training dynamics of wide but finite networks in either function space or weight space. The entire library runs out-of-the-box on CPU, GPU, or TPU. All computations can be automatically distributed over multiple accelerators with near-linear scaling in the number of devices. NEURAL TANGENTS is available at www.github.com/google/neural-tangents We also provide an accompanying interactive Colab notebook [colab.sandbox.google.com/github/google/neuraltangents/blob/master/notebooks/neural_tangents_cookbook.ipynb].
‹Figure 4: An example of the translation of a convolutional neural network into a sequence of kernel operations. We demonstrate how the compositional nature of a typical NN computation on its inputs induces a corresponding compositional computation on the NNGP and NT kernels. Presented is a 2-hidden-layer 1D CNN with nonlinearity φ, performing regression on the 10-dimensional outputs z2 for each of the 4 (RGB]117,66,1471, RGB]152,196,852, RGB]83,167,2203, RGB]239,178,864) inputs x from the dataset X. To declutter notation, unit weight and zero bias variances are assumed in all layers. Top: recursive output (z2 ) computation in the CNN (top) induces a respective recursive NNGP kernel (K̃2 ⊗I10) computation (NTK computation being similar, not shown). Bottom: explicit listing of tensor and corresponding kernel ops in each layer. See Table ?? for operation definitions. Illustration and description adapted from Figure 3 in ?. (Implementation: Transforming Tensor Ops to Kernel Ops)Figure 5: Performance scaling with batch size (left) and number of GPUs (right). Shows time per entry needed to compute the analytic NNGP and NTK covariance matrices (using ) in a 21-layer ReLU network with global average pooling. Left: Increasing the batch size when computing the covariance matrix in blocks allows for a significant performance increase until a certain threshold when all cores in a single GPU are saturated. Simpler models are expected to have better scaling with batch size. Right: Time-per-sample scales linearly with the number of GPUs, demonstrating near-perfect hardware utilization. (Performance)

h == Dense == (1 /(sqrt(n))) sigma_omega * W * y + sigma_b * beta

Figure 6: Training a neural network and its various approximations using . Presented is a 5-layer Erf-neural network of width 512 trained on MNIST using SGD with momentum, along with its constant (0th order), linear (1st order), and quadratic (2nd order) Taylor expansions about the initial parameters. As training progresses (left to right), lower-order expansions deviate from the original function faster than higher-order ones. (Infinite networks of any architecture through sampling)Figure 7: Predictive negative log-likelihoods and condition numbers. Top. Test negative loglikelihoods for NNGP posterior and Gaussian predictive distribution for NTK at infinite training time for CIFAR-10 (test set of 2000 points). Fully Connected (FC, Listing ??) and Convolutional network without pooling (CONV, Listing ??) models are selected based on train marginal negative log-likelihoods in Figure ??. Bottom. Condition numbers for covariance matrices corresponding to NTK/NNGP as well as respective predictive covaraince on the test set. Ill-conditioning of Wide Residual Network kernels due to pooling layers (?) could be the cause of numerical issues when evaluating predictive NLL for this kernels. (Implemented and coming soon functionality)›