Skip to yearly menu bar Skip to main content

Workshop: Distribution shifts: connecting methods and applications (DistShift)

Correct-N-Contrast: A Contrastive Approach for Improving Robustness to Spurious Correlations

Michael Zhang · Nimit Sohoni · Hongyang Zhang · Chelsea Finn · Christopher Ré

Abstract: We propose Correct-N-Contrast (CNC), a contrastive learning method to improve robustness to spurious correlations when training group labels are unknown. Our motivating observation is that worst-group performance is related to a representation alignment loss, which measures the distance in feature space between different groups within each class. We prove that the gap between worst-group and average loss for each class is upper bounded by this alignment loss for that class. Thus, CNC aims to improve representation alignment via contrastive learning. First, CNC uses an ERM model to infer the group information. Second, with a careful sampling scheme, CNC trains a contrastive model to encourage similar representations for groups in the same class. We show that CNC significantly improves worst-group accuracy over existing state-of-the-art methods on popular benchmarks, e.g., achieving $7.7\%$ absolute lift in worst-group accuracy on the CelebA dataset, and performs almost as well as methods trained with group labels. CNC also learns better-aligned representations between different groups in each class, reducing the alignment loss substantially compared to prior methods.

Chat is not available.