论文标题
卷积神经网络中自我注意的彩票票证假设
The Lottery Ticket Hypothesis for Self-attention in Convolutional Neural Network
论文作者
论文摘要
储层计算是预测湍流的有力工具,其简单的架构具有处理大型系统的计算效率。然而,其实现通常需要完整的状态向量测量和系统非线性知识。我们使用非线性投影函数将系统测量扩展到高维空间,然后将其输入到储层中以获得预测。我们展示了这种储层计算网络在时空混沌系统上的应用,该系统模拟了湍流的若干特征。我们表明,使用径向基函数作为非线性投影器,即使只有部分观测并且不知道控制方程,也能稳健地捕捉复杂的系统非线性。最后,我们表明,当测量稀疏、不完整且带有噪声,甚至控制方程变得不准确时,我们的网络仍然可以产生相当准确的预测,从而为实际湍流系统的无模型预测铺平了道路。
Recently many plug-and-play self-attention modules (SAMs) are proposed to enhance the model generalization by exploiting the internal information of deep convolutional neural networks (CNNs). In general, previous works ignore where to plug in the SAMs since they connect the SAMs individually with each block of the entire CNN backbone for granted, leading to incremental computational cost and the number of parameters with the growth of network depth. However, we empirically find and verify some counterintuitive phenomena that: (a) Connecting the SAMs to all the blocks may not always bring the largest performance boost, and connecting to partial blocks would be even better; (b) Adding the SAMs to a CNN may not always bring a performance boost, and instead it may even harm the performance of the original CNN backbone. Therefore, we articulate and demonstrate the Lottery Ticket Hypothesis for Self-attention Networks: a full self-attention network contains a subnetwork with sparse self-attention connections that can (1) accelerate inference, (2) reduce extra parameter increment, and (3) maintain accuracy. In addition to the empirical evidence, this hypothesis is also supported by our theoretical evidence. Furthermore, we propose a simple yet effective reinforcement-learning-based method to search the ticket, i.e., the connection scheme that satisfies the three above-mentioned conditions. Extensive experiments on widely-used benchmark datasets and popular self-attention networks show the effectiveness of our method. Besides, our experiments illustrate that our searched ticket has the capacity of transferring to some vision tasks, e.g., crowd counting and segmentation.