Skip to yearly menu bar Skip to main content

Workshop: Mathematics of Modern Machine Learning (M3L)

Grokking modular arithmetic can be explained by margin maximization

Mohamad Amin Mohamadi · Zhiyuan Li · Lei Wu · Danica J. Sutherland

Abstract: We present a margin-based generalization theory explaining the “grokking” phenomenon (Power et, al. 2022), where the model generalizes long after overfitting to arithmetic datasets. Specifically, we study two-layer quadratic networks on mod-$p$ arithmetic problems, and show that solutions with maximal margin normalized by $\ell_\infty$ norm generalize with $\tilde O(p^{5/3})$ samples. To the best of our knowledge, this is the first sample complexity bound strictly better than a trivial $O(p^2)$ complexity for modular addition. Empirically, we find that GD on unregularized $\ell-2$ or cross entropy loss tend to maximize the margin. In contrast, we show that kernel-based models, such as networks that are well-approximated by their neural tangent kernel, need $\Omega(p^2)$ samples to achieve non-trivial $\ell_2$ loss. Our theory suggests that grokking might be caused by overfitting in the kernel regime of early training, followed by generalization as gradient descent eventually leaves the kernel regime and maximizes the normalized margin.

Chat is not available.