Skip to yearly menu bar Skip to main content

Workshop: Program Transformations for ML

Skye Wanderman-Milne - JAX: accelerated machine-learning research via composable function transformations in Python

Skye Wanderman-Milne


JAX is a system for high-performance machine learning research. It offers the familiarity of Python+NumPy together with hardware acceleration, and it enables the definition and composition of user-wielded function transformations useful for machine learning programs. These transformations include automatic differentiation, automatic batching, end-to-end compilation (via XLA), parallelizing over multiple accelerators, and more. Composing these transformations is the key to JAX's power and simplicity.

Live content is unavailable. Log in and register to view live content