MT-UNet论文学习
摘要
指出U-Net在医学图像分割任务有巨大成功,但模型缺乏远程依赖能力
transformer 固有的self-attention模块能够拥有捕获远程相关性的能力,但transformer 通常依赖于大规模的预训练,并且有很高的计算复杂度,并且,self-attention 只能对单个样本的自注意力建模,忽略了整个数据集的潜在相关性
提出一个新的transformer模块MTM,用于同时进行样本内部和样本之间注意力学习;MTM首先通过精心设计的局部-全局高斯加自注意(LGG-SA)有效地计算注意力,然后通过外部注意EA挖局样本之间的相互联系
在MTM基础上,构建一个U型混合tansformer的MT-UNet模型,用于医学图像的精准分割
介绍
U-Net面临所有CNN都面临的问题:缺乏建模远程相关性的能力,主要是因为卷积运算的固有局部性。许多研究尝试用Transformer来解决这个问题。self-attention 是transformer 的关键组成部分,它可以对输入的tokens之间互相进行计算,所以transformer 可以处理远程依赖关系。虽然有一些工作取得了令人满意的结果,但transformer 通常依赖于大规模的预训练,对于给使用它带来了不便。并且,SA的计算复杂度为二次的,会降低图像的处理速度。同时,SA有忽略样本之间相关性的局限性,这有很大的提升空间。
为了获取更好的局部感知和更低的计算成本,作者重新设计了SA,再与外部注意EA集成,同时管理样本内和样本间的相关性。
作者提出在细粒度的局部上下文中执行局部SA,而全局SA仅在粗粒度的全局上下文中执行,因为在大多数视觉任务中,近区域之间的视觉依赖性比远区域之间的视觉依赖性强。局部上下文中的信息通常更加相关,因此需要更细致地处理。而全局上下文中的信息则相对不那么相关,因此可以使用粗粒度来处理。这样可以提高计算效率,同时保留重要的信息。例如,在一张图片中,图像中心的物体与周围物体之间的关系通常比与边缘物体之间的关系更为重要。(细粒度和粗粒度是指处理信息时所采用的精细程度。细粒度意味着对信息进行更细致、更详细的处理,而粗粒度则意味着对信息进行更粗略、更概括的处理。例如,在图像处理中,细粒度可能意味着对每个像素进行处理,而粗粒度则可能意味着对图像的整个区域进行处理。)
计算全局注意图时,作者使用轴向注意来减少计算量,并进一步引入可学习的高斯矩阵来增强附近标记的权重。
transformer 需要大规模预训练的主要原因在于它对问题的结构没有先验知识。所以作者在设计MT-UNet 时,使用convolution stem 作为浅层的特征提取器,为分割任务设置结构先验。
总结:设计MTM用于同时进行样本间和样本内部亲和力的学习;提出LGG-SA,在细粒部的局部上下文和粗粒度的局部上下文依次执行SA,还引入可学习的高斯矩阵来强调每个查询周围的附近区域;构建了一个用于医学图像分割的混合Transformer UNet 。
方法
整体结构设计
网络基于编码器-解码器结构,在解码时使用跳跃连接来保持低级特征。MTM仅用在空间尺寸较小的深层,用于减少计算成本,上层仍然使用经典的卷积运算(希望关注初始层上的局部关系,它们包含了更多高分辨率的细节,);在模型之前引入一些卷积操作,对于相对较小的医学图像数据集很有帮助; 对于所有的transformer模块,都遵循一个2步卷积/反卷积内核来实现下采样/上采样以及通道扩展/压缩。

Mixed Transformer Module
MTM由LGG-SA和EA组成,LGG-SA用于模拟不同粒度的短期和长期依赖关系,EA用于挖掘样本间相关性。该模块具有更好的视觉任务性能和更低的时间复杂度,可以取代原有的Transformer编码器。

Local-Global Gaussion-Weighted Self-Attention
LGG-SA 完美体现了集中计算的思想。与传统的SA对于所有的token给与同等的关注不同,LGG-SA 通过使用Local-Global 策略和高斯掩码,可以更多地关注附近区域。LGG-SA可以提高模型性能和节省计算资源。

Local-Global Self-Attention
SA的目的在于捕获输入序列中所有实体之间的相互联系。在计算机视觉中,近区域之间的相关性往往比远区域更重要,在计算注意力时,不需要为更远的区域付出同样的代价。提出Local-Global Self-Attention,LocalSA计算每个window内的注意力,然后将每个window 内的token 聚合为一个全局token,表示窗口的主要信息(对于聚合函数,作者尝试了stride convolution、Max Pooling等方法,其中轻量级动态卷积(Lightweight Dynamic convolution, LDConv)表现最好。在对整个特征映射进行下采样后,可以以更少的开销执行全局SA。)输入的特征映射为H*W *C,将窗口大小设置为P,整个过程如下图:

Gaussian-Weighted Aixal Attention
与使用原始SA的LSA不同,对于GSA,使用了GWAA,G-WAA 通过一个可学习的高斯矩阵来增强每个查询对于附件标记的感知,同时使用的轴向注意具有较低的时间复杂度。
External Attention
提出外部注意来解决SA无法利用不同样本之间的关系的问题,与self-attention 使用每个样本自己的线性变换来计算注意力分数不同,EA中,所有的样本共享两个记忆单元MK 和MV,用来描述整个数据集的最基本的信息。 另外,对Q使用额外的线性映射来扩大通道,提高该模块的表示学习能力。
由于EA的时间复杂度为O(n),因此MT-UNet的总体时间复杂度保持为O(n√n)。