Learning to Shard: RL for Co-optimizing the Parallelism Degrees and Per-operator Sharding Dimensions in Distributed LLM Inference
Ruokai Yin · Sattwik Mishra · Xuan Zuo · Hokchhay Tann · Apala Guha
Abstract
Distributed LLM inference requires careful coordination of parallelization strategies across hundreds to thousands of NPUs to meet production SLOs. Current systems like Megatron-LM rely on static heuristics that separately configure parallelism degrees and per-operator sharding dimensions, leaving significant performance on the table as models scale and hardware topologies diversify. We introduce Learn to Shard, to our knowledge, the first RL-based approach to co-optimize both coarse-grained parallelism degrees and fine-grained per-operator sharding dimensions for distributed LLM inference. Our method employs an attention-based policy over an elite history that learns from high-performing strategies to efficiently navigate the vast combinatorial search space. Evaluated on H100 clusters with MoE models up to 1.6T parameters, Learn to Shard achieves up to 3.5$\times$ throughput improvement over metaheuristic baselines and 1.06$\times$ over Megatron heuristics.
Chat is not available.
Successful Page Load