摘要

  1. 由于卷积运算的固有局部性,U-Net 通常在显示建模长期依赖关系方面表现出局限性(卷积运算的固有局部性:卷积运算只能捕捉到输入数据中相邻区域之间的关系,这个区域大小取决于卷积核大小,这种局部性使得卷积神经网络能够很好地处理图像等具有局部相关性的数据,但也限制了它在处理远程依赖关系方面的能力。比如,一个像素的标签可能和它周围很远的像素有关,但卷积神经网络可能难以捕捉到这种远程依赖关系;虽然U-Net通过层次的堆叠来扩大感受野,但U-Net 仍然可能在长期依赖关系方面表现出局部性)

  2. Transformer具有全局的自注意力机制,但由于缺乏低层次细节,可能导致定位能力有限(低层次细节捕捉:可以通过卷积神经网络的下采样来扩大感受野,从而来捕捉到图像中的纹理、边缘和颜色变化等信息;定位能力有限:模型不能捕捉到图像的边缘信息,它很难准确地划分不同的物体)

  3. 提出了TransUNet,它兼有transformer和U-Net的优点,作为医学图像分割的强大替代方案。Transformer对来自卷积神经网络(CNN)特征映射的标记化图像补丁进行编码,作为提取全局上下文的输入序列;解码器对编码特征进行采样,然后将其与高分辨率CNN特征图相结合,以实现精确的定位

  4. Transformer可以作为医学图像分割任务的强编码器,结合U-Net通过恢复局部空间信息来增强更精细的细节

介绍

  1. 经过transformer得到特征直接进行上采样得到的结果令人不满意(原因:transformer 将输入的图像作为一维的序列,所有的阶段都在专注于获取全局上下文信息,缺乏对于详细定位信息的低分辨率特征,直接上采样到全分辨率不能有效恢复这些信息,导致分割结果粗糙)

  2. 提出了医学图像分割框架TransUNet,从序列到序列预测的角度建立了自关注机制。为了弥补transformer带来的特征分辨率损失,TransUNet 采用CNN-transformer混合架构,结合CNN特征的详细高分辨率空间信息和Transformer的全局上下文。

  3. 受U-Net结构启发,将transformer得到的self-attention 特征上采样后与从编码路径得到的不同高分辨率CNN特征相结合,从而实现精准定位。与基于CNN的self-attention方法相比,基于transformer的架构可以更好地利用self-attention。另外,更密集地结合低级特征通常会导致更好的分割精度。

相关工作

  1. 将CNN与自注意机制结合

  2. TransUNet是第一个基于transform的医学图像分割框架,它建立在非常成功的ViT之上

方法

网络结构

代码学习

TransNet 网络设置

def get_transNet(n_classes):
img_size = 512 #设置图像大小
vit_patches_size = 16 #设置patch 个数
vit_name = 'R50-ViT-B_16'

config_vit = CONFIGS_ViT_seg[vit_name] #获取配置
config_vit.n_classes = n_classes
config_vit.n_skip = 3
if vit_name.find('R50') != -1: #如果名称中包含R50
#重新设置patch大小
config_vit.patches.grid = (int(img_size / vit_patches_size), int(img_size / vit_patches_size))
#创建ViT_seg网络
net = ViT_seg(config_vit, img_size=img_size, num_classes=n_classes)
return net


Resnet50 + ViT-B/16 模型的配置

def get_r50_b16_config():
"""Returns the Resnet50 + ViT-B/16 configuration."""
config = get_b16_config()#获取Vit模型的基本配置
config.patches.grid = (16, 16)#设置patch大小
config.resnet = ml_collections.ConfigDict()#创建一个configdict
##来存储resnet
config.resnet.num_layers = (3, 4, 9)#设置resnet模型层数
config.resnet.width_factor = 1 #设置resnent的宽度因子

config.classifier = 'seg' #分割
#设置预训练模型路径
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
config.decoder_channels = (256, 128, 64, 16)#设置解码器通道
config.skip_channels = [512, 256, 64, 16]#设置skip通道
config.n_classes = 2 #设置类别数
config.n_skip = 3 #设置skip 数目
config.activation = 'softmax' #设置激活函数

return config #返回配置结果


ViT-B/16模型配置

def get_b16_config():
"""Returns the ViT-B/16 configuration."""
config = ml_collections.ConfigDict() #创建一个CondigDict()
#对象来存储信息
#设置patch大小
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
#设置隐藏层大小
config.hidden_size = 768
#创建一个新的CondigDict 对象来存储Transformer的设置
config.transformer = ml_collections.ConfigDict()
#设置Transformer 的MLP维度
config.transformer.mlp_dim = 3072
#设置头数
config.transformer.num_heads = 12
#设置层数
config.transformer.num_layers = 12
#设置dropout率
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1

config.classifier = 'seg'
config.representation_size = None
config.resnet_pretrained_path = None
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz'
config.patch_size = 16

config.decoder_channels = (256, 128, 64, 16)
config.n_classes = 2
config.activation = 'softmax'
return config

VisionTransformer

forward 方法接受一个输入张量 x,并返回分割结果。如果输入张量的通道数为 1,则将其重复三次以模拟 RGB 图像。然后,将输入张量传递给 Transformer 编码器,得到编码结果、注意力权重和特征图。接着,将编码结果和特征图传递给解码器,得到解码结果。最后,将解码结果传递给分割头,得到分割结果。

load_from 方法接受一个权重字典,并使用这些权重来初始化模型的权重。方法首先复制权重字典中的权重到模型的相应部分。然后,处理位置嵌入的权重。如果位置嵌入的大小与模型中位置嵌入的大小相同,则直接复制权重;否则,根据情况调整位置嵌入的大小并复制权重。接着,初始化编码器的权重。最后,如果模型包含混合模型,则初始化混合模型的权重。

class VisionTransformer(nn.Module):
def __init__(self, config, img_size=512, num_classes=2, zero_head=False, vis=False):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.zero_head = zero_head #设置分类器权重初始化是否为0
self.classifier = config.classifier
self.transformer = Transformer(config, img_size, vis)
self.decoder = DecoderCup(config)
#print("inchannel:",config['decoder_channels'][-1])
self.segmentation_head = SegmentationHead(
in_channels=config['decoder_channels'][-1],
out_channels=config['n_classes'],
kernel_size=3,
)
self.config = config

def forward(self, x):
if x.size()[1] == 1:#输入张量通道数为1
x = x.repeat(1,3,1,1) #将其重复三次
#将张量传递给Transformer编码器,得到编码结果,注意力权重和特征图
x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
x = self.decoder(x, features)#将编码结果和特征图传给
#解码器,得到解码结果
#将解码结果传递给分割头,得到分割结果
logits = self.segmentation_head(x)
return logits

def load_from(self, weights):
with torch.no_grad():

res_weight = weights
self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))

self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))

posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])

posemb_new = self.transformer.embeddings.position_embeddings
if posemb.size() == posemb_new.size():
self.transformer.embeddings.position_embeddings.copy_(posemb)
elif posemb.size()[1]-1 == posemb_new.size()[1]:
posemb = posemb[:, 1:]
self.transformer.embeddings.position_embeddings.copy_(posemb)
else:
logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
ntok_new = posemb_new.size(1)
if self.classifier == "seg":
_, posemb_grid = posemb[:, :1], posemb[0, 1:]
gs_old = int(np.sqrt(len(posemb_grid)))
gs_new = int(np.sqrt(ntok_new))
print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np
posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
posemb = posemb_grid
self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))

# Encoder whole
for bname, block in self.transformer.encoder.named_children():
for uname, unit in block.named_children():
unit.load_from(weights, n_block=uname)

if self.transformer.embeddings.hybrid:
self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))
gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)

for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
for uname, unit in block.named_children():
unit.load_from(res_weight, n_block=bname, n_unit=uname)

Transformer模块

class Transformer(nn.Module):
def __init__(self, config, img_size, vis):
super(Transformer, self).__init__()
#创建一个嵌入层
self.embeddings = Embeddings(config, img_size=img_size)
#创建一个编码层
self.encoder = Encoder(config, vis)

def forward(self, input_ids):
#将输入张量传递给嵌入层,得到嵌入输出和特征图
embedding_output, features = self.embeddings(input_ids)
#将嵌入层输出传递给编码器,得到编码结果和注意力权重
encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
#返回编码结果、注意力权重和特征图
return encoded, attn_weights, features