Timezone: »
Fast Finite Width Neural Tangent Kernel
Roman Novak · Jascha Sohl-Dickstein · Samuel Schoenholz
Event URL: https://openreview.net/forum?id=_hB2LN5i_ae »
The Neural Tangent Kernel (NTK), defined as the outer product of the neural network (NN) Jacobians, $\Theta_\theta(x_1, x_2) = \left[\partial f(\theta, x_1)\big/\partial \theta\right] \left[\partial f(\theta, x_2)\big/\partial \theta\right]^T$, has emerged as a central object of study in deep learning. In the infinite width limit, the NTK can sometimes be computed analytically and is useful for understanding training and generalization of NN architectures. At finite widths, the NTK is also used to better initialize NNs, compare the conditioning across models, perform architecture search, and do meta-learning. Unfortunately, the finite-width NTK is notoriously expensive to compute, which severely limits its practical utility. We perform the first in-depth analysis of the compute and memory requirements for NTK computation in finite width networks. Leveraging the structure of neural networks, we further propose two novel algorithms that change the exponent of the compute and memory requirements of the finite width NTK, dramatically improving efficiency.We open-source (https://github.com/iclr2022anon/fast_finite_width_ntk) our two algorithms as general-purpose JAX function transformations that apply to any differentiable computation (convolutions, attention, recurrence, etc.) and introduce no new hyper-parameters.
The Neural Tangent Kernel (NTK), defined as the outer product of the neural network (NN) Jacobians, $\Theta_\theta(x_1, x_2) = \left[\partial f(\theta, x_1)\big/\partial \theta\right] \left[\partial f(\theta, x_2)\big/\partial \theta\right]^T$, has emerged as a central object of study in deep learning. In the infinite width limit, the NTK can sometimes be computed analytically and is useful for understanding training and generalization of NN architectures. At finite widths, the NTK is also used to better initialize NNs, compare the conditioning across models, perform architecture search, and do meta-learning. Unfortunately, the finite-width NTK is notoriously expensive to compute, which severely limits its practical utility. We perform the first in-depth analysis of the compute and memory requirements for NTK computation in finite width networks. Leveraging the structure of neural networks, we further propose two novel algorithms that change the exponent of the compute and memory requirements of the finite width NTK, dramatically improving efficiency.We open-source (https://github.com/iclr2022anon/fast_finite_width_ntk) our two algorithms as general-purpose JAX function transformations that apply to any differentiable computation (convolutions, attention, recurrence, etc.) and introduce no new hyper-parameters.
Author Information
Roman Novak (Google Brain)
Jascha Sohl-Dickstein (Google)
Samuel Schoenholz (Google Brain)
More from the Same Authors
-
2020 : End-to-End Differentiability and Tensor Processing Unit Computing to Accelerate Materials’ Inverse Design »
HAN LIU · Yuhan Liu · Zhangji Zhao · Samuel Schoenholz · Ekin Dogus Cubuk · Mathieu Bauchy -
2022 Poster: Fast Neural Kernel Embeddings for General Activations »
Insu Han · Amir Zandieh · Jaehoon Lee · Roman Novak · Lechao Xiao · Amin Karbasi -
2021 Poster: Dataset Distillation with Infinitely Wide Convolutional Networks »
Timothy Nguyen · Roman Novak · Lechao Xiao · Jaehoon Lee -
2021 Poster: Reverse engineering learned optimizers reveals known and novel mechanisms »
Niru Maheswaranathan · David Sussillo · Luke Metz · Ruoxi Sun · Jascha Sohl-Dickstein -
2020 Poster: Finite Versus Infinite Neural Networks: an Empirical Study »
Jaehoon Lee · Samuel Schoenholz · Jeffrey Pennington · Ben Adlam · Lechao Xiao · Roman Novak · Jascha Sohl-Dickstein -
2020 Spotlight: Finite Versus Infinite Neural Networks: an Empirical Study »
Jaehoon Lee · Samuel Schoenholz · Jeffrey Pennington · Ben Adlam · Lechao Xiao · Roman Novak · Jascha Sohl-Dickstein -
2020 Poster: JAX MD: A Framework for Differentiable Physics »
Samuel Schoenholz · Ekin Dogus Cubuk -
2020 Spotlight: JAX MD: A Framework for Differentiable Physics »
Samuel Schoenholz · Ekin Dogus Cubuk -
2019 : Afternoon Coffee Break & Poster Session »
Heidi Komkov · Stanislav Fort · Zhaoyou Wang · Rose Yu · Ji Hwan Park · Samuel Schoenholz · Taoli Cheng · Ryan-Rhys Griffiths · Chase Shimmin · Surya Karthik Mukkavili · Philippe Schwaller · Christian Knoll · Yangzesheng Sun · Keiichi Kisamori · Gavin Graham · Gavin Portwood · Hsin-Yuan Huang · Paul Novello · Moritz Munchmeyer · Anna Jungbluth · Daniel Levine · Ibrahim Ayed · Steven Atkinson · Jan Hermann · Peter Grönquist · · Priyabrata Saha · Yannik Glaser · Lingge Li · Yutaro Iiyama · Rushil Anirudh · Maciej Koch-Janusz · Vikram Sundar · Francois Lanusse · Auralee Edelen · Jonas Köhler · Jacky H. T. Yip · jiadong guo · Xiangyang Ju · Adi Hanuka · Adrian Albert · Valentina Salvatelli · Mauro Verzetti · Javier Duarte · Eric Moreno · Emmanuel de Bézenac · Athanasios Vlontzos · Alok Singh · Thomas Klijnsma · Brad Neuberg · Paul Wright · Mustafa Mustafa · David Schmidt · Steven Farrell · Hao Sun -
2019 : Lunch Break and Posters »
Xingyou Song · Elad Hoffer · Wei-Cheng Chang · Jeremy Cohen · Jyoti Islam · Yaniv Blumenfeld · Andreas Madsen · Jonathan Frankle · Sebastian Goldt · Satrajit Chatterjee · Abhishek Panigrahi · Alex Renda · Brian Bartoldson · Israel Birhane · Aristide Baratin · Niladri Chatterji · Roman Novak · Jessica Forde · YiDing Jiang · Yilun Du · Linara Adilova · Michael Kamp · Berry Weinstein · Itay Hubara · Tal Ben-Nun · Torsten Hoefler · Daniel Soudry · Hsiang-Fu Yu · Kai Zhong · Yiming Yang · Inderjit Dhillon · Jaime Carbonell · Yanqing Zhang · Dar Gilboa · Johannes Brandstetter · Alexander R Johansen · Gintare Karolina Dziugaite · Raghav Somani · Ari Morcos · Freddie Kalaitzis · Hanie Sedghi · Lechao Xiao · John Zech · Muqiao Yang · Simran Kaur · Qianli Ma · Yao-Hung Hubert Tsai · Ruslan Salakhutdinov · Sho Yaida · Zachary Lipton · Daniel Roy · Michael Carbin · Florent Krzakala · Lenka Zdeborová · Guy Gur-Ari · Ethan Dyer · Dilip Krishnan · Hossein Mobahi · Samy Bengio · Behnam Neyshabur · Praneeth Netrapalli · Kris Sankaran · Julien Cornebise · Yoshua Bengio · Vincent Michalski · Samira Ebrahimi Kahou · Md Rifat Arefin · Jiri Hron · Jaehoon Lee · Jascha Sohl-Dickstein · Samuel Schoenholz · David Schwab · Dongyu Li · Sang Keun Choe · Henning Petzka · Ashish Verma · Zhichao Lin · Cristian Sminchisescu -
2019 : JAX, M.D.: End-to-End Differentiable, Hardware Accelerated, Molecular Dynamics in Pure Python »
Samuel Schoenholz -
2019 Poster: Wide Neural Networks of Any Depth Evolve as Linear Models Under Gradient Descent »
Jaehoon Lee · Lechao Xiao · Samuel Schoenholz · Yasaman Bahri · Roman Novak · Jascha Sohl-Dickstein · Jeffrey Pennington -
2019 Poster: MetaInit: Initializing learning by learning to initialize »
Yann Dauphin · Samuel Schoenholz -
2018 : Poster Session 1 »
Stefan Gadatsch · Danil Kuzin · Navneet Kumar · Patrick Dallaire · Tom Ryder · Remus-Petru Pop · Nathan Hunt · Adam Kortylewski · Sophie Burkhardt · Mahmoud Elnaggar · Dieterich Lawson · Yifeng Li · Jongha (Jon) Ryu · Juhan Bae · Micha Livne · Tim Pearce · Mariia Vladimirova · Jason Ramapuram · Jiaming Zeng · Xinyu Hu · Jiawei He · Danielle Maddix · Arunesh Mittal · Albert Shaw · Tuan Anh Le · Alexander Sagel · Lisha Chen · Victor Gallego · Mahdi Karami · Zihao Zhang · Tal Kachman · Noah Weber · Matt Benatan · Kumar K Sricharan · Vincent Cartillier · Ivan Ovinnikov · Buu Phan · Mahmoud Hossam · Liu Ziyin · Valerii Kharitonov · Eugene Golikov · Qiang Zhang · Jae Myung Kim · Sebastian Farquhar · Jishnu Mukhoti · Xu Hu · Gregory Gundersen · Lavanya Sita Tekumalla · Paris Perdikaris · Ershad Banijamali · Siddhartha Jain · Ge Liu · Martin Gottwald · Katy Blumer · Sukmin Yun · Ranganath Krishnan · Roman Novak · Yilun Du · Yu Gong · Beliz Gokkaya · Jessica Ai · Daniel Duckworth · Johannes von Oswald · Christian Henning · Louis-Philippe Morency · Ali Ghodsi · Mahesh Subedar · Jean-Pascal Pfister · Rémi Lebret · Chao Ma · Aleksander Wieczorek · Laurence Perreault Levasseur -
2017 Poster: Mean Field Residual Networks: On the Edge of Chaos »
Ge Yang · Samuel Schoenholz -
2017 Poster: Resurrecting the sigmoid in deep learning through dynamical isometry: theory and practice »
Jeffrey Pennington · Samuel Schoenholz · Surya Ganguli -
2015 Workshop: Statistical Methods for Understanding Neural Systems »
Alyson Fletcher · Jakob H Macke · Ryan Adams · Jascha Sohl-Dickstein -
2012 Poster: Training sparse natural image models with a fast Gibbs sampler of an extended state space »
Lucas Theis · Jascha Sohl-Dickstein · Matthias Bethge