论文标题
Agro:对抗性易于优化的易误打数组
AGRO: Adversarial Discovery of Error-prone groups for Robust Optimization
论文作者
论文摘要
已知通过经验风险最小化(ERM)训练的模型依赖于标签与任务无关的输入特征之间的虚假相关性,从而导致对分布转移的概括不佳。群体分布在稳健的优化(G-DRO)可以通过最大程度地减少一组预定的组对训练数据的最坏情况的损失来减轻此问题。 G-DRO成功地提高了最差的组的性能,而相关性不存在。但是,G-DRO假设提前知道了伪造的相关性和相关的最坏群体,这使得将其应用于具有潜在多个未知的虚假相关性的新任务是具有挑战性的。我们提出了农业 - 对抗性群体发现,以进行分配强大的优化 - 一种端到端的方法,可以共同识别容易出错的组并提高它们的准确性。 Agro将G-DRO与对抗性切片模型相称,以找到用于训练示例的小组分配,以最大程度地提高发现组的最坏情况损失。在Wilds基准测试中,与G-DRO使用的先前的组发现方法相比,Agro平均在已知的最差群的平均模型性能上提高了8%。 AGRO还改善了SST2,QQP和MS-Coco的分布性能 - 数据集可能尚未表达潜在的虚假相关性。人类对Argo群体的评估表明,它们包含明确但以前未研究的伪造相关性,这些相关性导致模型错误。
Models trained via empirical risk minimization (ERM) are known to rely on spurious correlations between labels and task-independent input features, resulting in poor generalization to distributional shifts. Group distributionally robust optimization (G-DRO) can alleviate this problem by minimizing the worst-case loss over a set of pre-defined groups over training data. G-DRO successfully improves performance of the worst-group, where the correlation does not hold. However, G-DRO assumes that the spurious correlations and associated worst groups are known in advance, making it challenging to apply it to new tasks with potentially multiple unknown spurious correlations. We propose AGRO -- Adversarial Group discovery for Distributionally Robust Optimization -- an end-to-end approach that jointly identifies error-prone groups and improves accuracy on them. AGRO equips G-DRO with an adversarial slicing model to find a group assignment for training examples which maximizes worst-case loss over the discovered groups. On the WILDS benchmark, AGRO results in 8% higher model performance on average on known worst-groups, compared to prior group discovery approaches used with G-DRO. AGRO also improves out-of-distribution performance on SST2, QQP, and MS-COCO -- datasets where potential spurious correlations are as yet uncharacterized. Human evaluation of ARGO groups shows that they contain well-defined, yet previously unstudied spurious correlations that lead to model errors.