JAX MD: A Framework for Differentiable Physics

Sam Schoenholz, Dogus Cubuk

Spotlight presentation: Orals & Spotlights Track 23: Graph/Meta Learning/Software
on 2020-12-09T20:00:00-08:00 - 2020-12-09T20:10:00-08:00
Poster Session 5 (more posters)
on 2020-12-09T21:00:00-08:00 - 2020-12-09T23:00:00-08:00
GatherTown: Application ( Town D0 - Spot A3 )
Join GatherTown
Only iff poster is crowded, join Zoom . Authors have to start the Zoom call from their Profile page / Presentation History.
Abstract: We introduce JAX MD, a software package for performing differentiable physics simulations with a focus on molecular dynamics. JAX MD includes a number of statistical physics simulation environments as well as interaction potentials and neural networks that can be integrated into these environments without writing any additional code. Since the simulations themselves are differentiable functions, entire trajectories can be differentiated to perform meta-optimization. These features are built on primitive operations, such as spatial partitioning, that allow simulations to scale to hundreds-of-thousands of particles on a single GPU. These primitives are flexible enough that they can be used to scale up workloads outside of molecular dynamics. We present several examples that highlight the features of JAX MD including: integration of graph neural networks into traditional simulations, meta-optimization through minimization of particle packings, and a multi-agent flocking simulation. JAX MD is available at www.github.com/google/jax-md.

Preview Video and Chat

Chat is not available.