TranUNet学习
摘要
由于卷积运算的固有局部性,U-Net 通常在显示建模长期依赖关系方面表现出局限性(卷积运算的固有局部性:卷积运算只能捕捉到输入数据中相邻区域之间的关系,这个区域大小取决于卷积核大小,这种局部性使得卷积神经网络能够很好地处理图像等具有局部相关性的数据,但也限制了它在处理远程依赖关系方面的能力。比如,一个像素的标签可能和它周围很远的像素有关,但卷积神经网络可能难以捕捉到这种远程依赖关系;虽然U-Net通过层次的堆叠来扩大感受野,但U-Net 仍然可能在长期依赖关系方面表现出局部性)
Transformer具有全局的自注意力机制,但由于缺乏低层次细节,可能导致定位能力有限(低层次细节捕捉:可以通过卷积神经网络的下采样来扩大感受野,从而来捕捉到图像中的纹理、边缘和颜色变化等信息;定位能力有限:模型不能捕捉到图像的边缘信息,它很难准确地划分不同的物体)
提出了TransUNet,它兼有transformer和U-Net的优点,作为医学图像分割的强大替代方案。Transformer对来自卷积神经网络(CNN)特征映射的标记化图像补丁进行编码,作为提取全局上下文的输入序列;解码器对编码特征进行采样,然后将其与高分辨率CNN特征图相结合,以实现精确的定位。
Transformer可以作为医学图像分割任务的强编码器,结合U-Net通过恢复局部空间信息来增强更精细的细节。
介绍
经过transformer得到特征直接进行上采样得到的结果令人不满意(原因:transformer 将输入的图像作为一维的序列,所有的阶段都在专注于获取全局上下文信息,缺乏对于详细定位信息的低分辨率特征,直接上采样到全分辨率不能有效恢复这些信息,导致分割结果粗糙)
提出了医学图像分割框架TransUNet,从序列到序列预测的角度建立了自关注机制。为了弥补transformer带来的特征分辨率损失,TransUNet 采用CNN-transformer混合架构,结合CNN特征的详细高分辨率空间信息和Transformer的全局上下文。
受U-Net结构启发,将transformer得到的self-attention 特征上采样后与从编码路径得到的不同高分辨率CNN特征相结合,从而实现精准定位。与基于CNN的self-attention方法相比,基于transformer的架构可以更好地利用self-attention。另外,更密集地结合低级特征通常会导致更好的分割精度。
相关工作
将CNN与自注意机制结合
TransUNet是第一个基于transform的医学图像分割框架,它建立在非常成功的ViT之上
方法
网络结构
代码学习
TransNet 网络设置
def get_transNet(n_classes): |
Resnet50 + ViT-B/16 模型的配置
def get_r50_b16_config(): |
ViT-B/16模型配置
def get_b16_config(): |
VisionTransformer
forward
方法接受一个输入张量 x
,并返回分割结果。如果输入张量的通道数为 1,则将其重复三次以模拟 RGB 图像。然后,将输入张量传递给 Transformer 编码器,得到编码结果、注意力权重和特征图。接着,将编码结果和特征图传递给解码器,得到解码结果。最后,将解码结果传递给分割头,得到分割结果。
load_from
方法接受一个权重字典,并使用这些权重来初始化模型的权重。方法首先复制权重字典中的权重到模型的相应部分。然后,处理位置嵌入的权重。如果位置嵌入的大小与模型中位置嵌入的大小相同,则直接复制权重;否则,根据情况调整位置嵌入的大小并复制权重。接着,初始化编码器的权重。最后,如果模型包含混合模型,则初始化混合模型的权重。
class VisionTransformer(nn.Module): |
Transformer模块
class Transformer(nn.Module): |