Vision Transformer 的思想
将图像拆分为块(patch),并将这些图像块的线性嵌入序列作为Transformer的输入,自注意力即要求每个像素去关注所有其它像素
分层思想
采取分层思想,随着特征层的加深,特征图的高和宽逐渐变小
移动窗口
解决窗口间信息无法交互的缺点
参考博文及代码:Swin-Transformer网络结构详解_swin transformer_太阳花的小绿豆的博客-CSDN博客
人脑中的注意力机制 注意力是一个用来分配有限的信息处理能力的选择机制。 以人眼为例,眼睛首先扫描整个场景元素,然后寻找感兴趣的影像区域,聚焦感兴趣的区域,仔细观察获得信息。人脑对于整个场景的关注不是均衡的,有一定的权重区分,感兴趣的区域会被人脑分配更多的权重。
自注意力机制的实现 自注意机制就是通过权重矩阵自发地找到元素与元素之间的关系
注意力公式以及其中所包含的意义
Q:查询向量(query),用来查询其它元素
K:关键字(key),用来被其它元素查询 V:内容(value)
每一个元素都有自己的Q、K、V,元素使用自己的Q与其它元素的K相乘得到该元素与其它元素的相似度,这个相似度可以理解为关注度得分,关注度得分越高,这两个元素之间的联系越密切。
公式中的Q、K、V表示的是矩阵,是所有元素Q、K、V的集合。Q和K的转置点乘(QK越相似,点乘值越大),得到的是元素之间的相互关系,经过softmax后得到的就是映射到(0,1)的各个元素的相互关联程度,与V相乘后得到的是内部元素有关联的集合。为什么还要再除上一个缩放因子呢,dK所表示的是k的维度,如果维度dk太大,那么点积的值也会变大,由softmax函数的特性会知道这会导致经过softmax函数后被推到一个梯度极小的区域,加入缩放因子可以抵消这种影响,保证可以顺利进行反向传播。
经过这个公式后,得到的结果就可以突出需要注意的地方(即关联程度更高的地方)
Q、K、V的得来 Q、K、V是由上一层的输出乘上它们各自的权重矩阵,权重矩阵Wq、Wk、Wv是可以学习的随机初始化矩阵,然后通过网络反向传播来不断更新
多头注意力机制 可以类比CNN中同时使用多个卷积核的作用,设置多个Q、K、V权重矩阵,形成多个子空间,可以让模型去关注到不同方面的信息,最后再将各个方面信息综合起来,这样做有利于网络捕捉到更丰富的特征。
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom typing import Optional import numpy as npimport torch.utils.checkpoint as checkpoint
patch&&Embedding Patch Partition 将图片输入Patch Partition 模块中进行分块,每4×4作为一个patch,输入的是RGB三通道图片,则shape由[H,W,3]变成[H/4,W/4,48]
Linear Embedding 对每个像素的channel数据做线性变化,shape变为[H/4,W/4,C]
在源码中Patch Partition 和Linear Embeding 是直接通过一个卷积层实现
class PatchEmbed (nn.Module): """ 2D Image to Patch Embedding """ def __init__ (self, patch_size=4 , in_c=3 , embed_dim=96 , norm_layer=None ): super ().__init__() patch_size = (patch_size, patch_size) self.patch_size = patch_size self.in_chans = in_c self.embed_dim = embed_dim self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward (self, x ): _, _, H, W = x.shape pad_input = (H % self.patch_size[0 ] != 0 ) or (W % self.patch_size[1 ] != 0 ) if pad_input: x = F.pad(x, (0 , self.patch_size[1 ] - W % self.patch_size[1 ], 0 , self.patch_size[0 ] - H % self.patch_size[0 ], 0 , 0 )) x = self.proj(x) _, _, H, W = x.shape x = x.flatten(2 ).transpose(1 , 2 ) x = self.norm(x) return x, H, W
model=PatchEmbed() y,H,W=model(x) print (y.shape)
torch.Size([1, 3136, 96])
Lay normalization 模块的作用 Internal Covariate Shift(协变量偏移) 梯度下降使得每一层的参数都在不断发生变化, 进而使得每一层的线性与非线性计算结果分布产生变化。 后层网络就要不停地去适应这种分布变化,这个时候就会使得整个网络的学习速率过慢
因此可以通过固定每一层网络输入值的分布来对减缓这个问题
如果把x∈N×C×H×W类比为一摞书,这摞书总共有 N 本,每本有 C 页,每页有 H 行,每行 W 个数字。LN 求均值时,相当于把每本书里所有页,每页里的所有数字都相加起来,然后再除以该本书总的数字个数CW H,相当于得到某本书总体的数字平均值,每本书的均值相同,求标准差时也是同理。总结为“对应本”。 引用:深度学习笔记-11.Normalization[规范化]方法总结_业余狙击手19的博客-CSDN博客
VIT
Swin Transformer 在VIT 上做出改进,引入W-MSA模块是为了减少计算量。
MSA和W-MSA的计算量 Q计算量是hw×C×C,同理得K,V的计算量也为hw×C×C,一共为3hwC2,再计算Q×K^T,计算量为hw×C×hw,忽略softmax,再×V,计算量hw×hw×C,所以总计算量3hwC^2+2(hw)^2C
划分窗口后,单个窗口的计算量为4M2C2+2M4C,窗口数量是wh/M2,二者相乘总的计算量为4hwC2+2M2hwC
SW-MSA W-MSA,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,引入了SW-MSA
W-MSA和SW-MSA是成对使用的,SW-MSA的窗口从左上角分别向右侧和下方各偏移了[M/2]
class SwinTransformerBlock (nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__ (self, dim, num_heads, window_size=7 , shift_size=0 , mlp_ratio=4. , qkv_bias=True , drop=0. , attn_drop=0. , drop_path=0. , act_layer=nn.GELU, norm_layer=nn.LayerNorm ): super ().__init__() self.dim = dim self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int (dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward (self, x, attn_mask ): H, W = self.H, self.W B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) pad_l = pad_t = 0 pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size x = F.pad(x, (0 , 0 , pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape if self.shift_size > 0 : shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1 , 2 )) else : shifted_x = x attn_mask = None x_windows = window_partition(shifted_x, self.window_size) x_windows = x_windows.view(-1 , self.window_size * self.window_size, C) attn_windows = self.attn(x_windows, mask=attn_mask) attn_windows = attn_windows.view(-1 , self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) if self.shift_size > 0 : x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1 , 2 )) else : x = shifted_x if pad_r > 0 or pad_b > 0 : x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x
windowAttention class WindowAttention (nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__ (self, dim, window_size, num_heads, qkv_bias=True , attn_drop=0. , proj_drop=0. ): super ().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0 ] - 1 ) * (2 * window_size[1 ] - 1 ), num_heads)) coords_h = torch.arange(self.window_size[0 ]) coords_w = torch.arange(self.window_size[1 ]) coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij" )) coords_flatten = torch.flatten(coords, 1 ) relative_coords = coords_flatten[:, :, None ] - coords_flatten[:, None , :] relative_coords = relative_coords.permute(1 , 2 , 0 ).contiguous() relative_coords[:, :, 0 ] += self.window_size[0 ] - 1 relative_coords[:, :, 1 ] += self.window_size[1 ] - 1 relative_coords[:, :, 0 ] *= 2 * self.window_size[1 ] - 1 relative_position_index = relative_coords.sum (-1 ) self.register_buffer("relative_position_index" , relative_position_index) self.qkv = nn.Linear(dim, dim * 3 , bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) nn.init.trunc_normal_(self.relative_position_bias_table, std=.02 ) self.softmax = nn.Softmax(dim=-1 ) def forward (self, x, mask: Optional [torch.Tensor] = None ): """ Args: x: input features with shape of (num_windows*B, Mh*Mw, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3 , self.num_heads, C // self.num_heads).permute(2 , 0 , 3 , 1 , 4 ) q, k, v = qkv.unbind(0 ) q = q * self.scale attn = (q @ k.transpose(-2 , -1 )) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1 )].view( self.window_size[0 ] * self.window_size[1 ], self.window_size[0 ] * self.window_size[1 ], -1 ) relative_position_bias = relative_position_bias.permute(2 , 0 , 1 ).contiguous() attn = attn + relative_position_bias.unsqueeze(0 ) if mask is not None : nW = mask.shape[0 ] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1 ).unsqueeze(0 ) attn = attn.view(-1 , self.num_heads, N, N) attn = self.softmax(attn) else : attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1 , 2 ).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x
掩码操作
相对位置编码
MLP
作用:增强网络的表达能力
Dropout层的作用 一般来说,当相对较大的模型用在较小的数据集时,通过Dropout的方法可以防止过拟合,并提高泛化性
本质是通过随机删除部分神经元(特征)及其对应连接,实现对网络特征提取的随机修正,这种过程被称作随机正则化
class Mlp (nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__ (self, in_features, hidden_features=None , out_features=None , act_layer=nn.GELU, drop=0. ): super ().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.drop1 = nn.Dropout(drop) self.fc2 = nn.Linear(hidden_features, out_features) self.drop2 = nn.Dropout(drop) def forward (self, x ): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x
window_partition def window_partition (x, window_size: int ): """ 将feature map按照window_size划分成一个个没有重叠的window Args: x: (B, H, W, C) window_size (int): window size(M) Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0 , 1 , 3 , 2 , 4 , 5 ).contiguous().view(-1 , window_size, window_size, C) return windows
window_reverse def window_reverse (windows, window_size: int , H: int , W: int ): """ 将一个个window还原成一个feature map Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size(M) H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int (windows.shape[0 ] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1 ) x = x.permute(0 , 1 , 3 , 2 , 4 , 5 ).contiguous().view(B, H, W, -1 ) return x
DropPath class DropPath (nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__ (self, drop_prob=None ): super (DropPath, self).__init__() self.drop_prob = drop_prob def forward (self, x ): return drop_path_f(x, self.drop_prob, self.training)
def drop_path_f (x, drop_prob: float = 0. , training: bool = False ): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0 ],) + (1 ,) * (x.ndim - 1 ) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() output = x.div(keep_prob) * random_tensor return output
patch merging 作用: 进行下采样
假设输入Patch Merging的是一个4x4大小的单通道特征图(feature map),Patch Merging会将每个2x2的相邻像素划分为一个patch,然后将每个patch中相同位置(同一颜色)像素给拼在一起就得到了4个feature map。接着将这四个feature map在深度方向进行concat拼接,然后再通过一个LayerNorm层。最后通过一个全连接层在feature map的深度方向做线性变化,将feature map的深度由C变成C/2。通过这个例子可以看出,通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍。
class PatchMerging (nn.Module): r""" Patch Merging Layer. Args: dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__ (self, dim, norm_layer=nn.LayerNorm ): super ().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False ) self.norm = norm_layer(4 * dim) def forward (self, x, H, W ): """ x: B, H*W, C """ B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) pad_input = (H % 2 == 1 ) or (W % 2 == 1 ) if pad_input: x = F.pad(x, (0 , 0 , 0 , W % 2 , 0 , H % 2 )) x0 = x[:, 0 ::2 , 0 ::2 , :] x1 = x[:, 1 ::2 , 0 ::2 , :] x2 = x[:, 0 ::2 , 1 ::2 , :] x3 = x[:, 1 ::2 , 1 ::2 , :] x = torch.cat([x0, x1, x2, x3], -1 ) x = x.view(B, -1 , 4 * C) x = self.norm(x) x = self.reduction(x) return x
x=torch.rand([1 ,56 *56 ,96 ]) model1=PatchMerging(96 ) x=model1(x,H,W) print (x.shape)
torch.Size([1, 784, 192])
BasicLayer class BasicLayer (nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__ (self, dim, depth, num_heads, window_size, mlp_ratio=4. , qkv_bias=True , drop=0. , attn_drop=0. , drop_path=0. , norm_layer=nn.LayerNorm, downsample=None , use_checkpoint=False ): super ().__init__() self.dim = dim self.depth = depth self.window_size = window_size self.use_checkpoint = use_checkpoint self.shift_size = window_size // 2 self.blocks = nn.ModuleList([ SwinTransformerBlock( dim=dim, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0 ) else self.shift_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance (drop_path, list ) else drop_path, norm_layer=norm_layer) for i in range (depth)]) if downsample is not None : self.downsample = downsample(dim=dim, norm_layer=norm_layer) else : self.downsample = None def create_mask (self, x, H, W ): Hp = int (np.ceil(H / self.window_size)) * self.window_size Wp = int (np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1 , Hp, Wp, 1 ), device=x.device) h_slices = (slice (0 , -self.window_size), slice (-self.window_size, -self.shift_size), slice (-self.shift_size, None )) w_slices = (slice (0 , -self.window_size), slice (-self.window_size, -self.shift_size), slice (-self.shift_size, None )) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) mask_windows = mask_windows.view(-1 , self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1 ) - mask_windows.unsqueeze(2 ) attn_mask = attn_mask.masked_fill(attn_mask != 0 , float (-100.0 )).masked_fill(attn_mask == 0 , float (0.0 )) return attn_mask def forward (self, x, H, W ): attn_mask = self.create_mask(x, H, W) for blk in self.blocks: blk.H, blk.W = H, W if not torch.jit.is_scripting() and self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else : x = blk(x, attn_mask) if self.downsample is not None : x = self.downsample(x, H, W) H, W = (H + 1 ) // 2 , (W + 1 ) // 2 return x, H, W
Mask class SwinTransformer (nn.Module): r""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. patch_norm (bool): If True, add normalization after patch embedding. Default: True use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ def __init__ (self, patch_size=4 , in_chans=3 , num_classes=1000 , embed_dim=96 , depths=(2 , 2 , 6 , 2 ), num_heads=(3 , 6 , 12 , 24 ), window_size=7 , mlp_ratio=4. , qkv_bias=True , drop_rate=0. , attn_drop_rate=0. , drop_path_rate=0.1 , norm_layer=nn.LayerNorm, patch_norm=True , use_checkpoint=False , **kwargs ): super ().__init__() self.num_classes = num_classes self.num_layers = len (depths) self.embed_dim = embed_dim self.patch_norm = patch_norm self.num_features = int (embed_dim * 2 ** (self.num_layers - 1 )) self.mlp_ratio = mlp_ratio self.patch_embed = PatchEmbed( patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None ) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0 , drop_path_rate, sum (depths))] self.layers = nn.ModuleList() for i_layer in range (self.num_layers): layers = BasicLayer(dim=int (embed_dim * 2 ** i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum (depths[:i_layer]):sum (depths[:i_layer + 1 ])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1 ) else None , use_checkpoint=use_checkpoint) self.layers.append(layers) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1 ) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) def _init_weights (self, m ): if isinstance (m, nn.Linear): nn.init.trunc_normal_(m.weight, std=.02 ) if isinstance (m, nn.Linear) and m.bias is not None : nn.init.constant_(m.bias, 0 ) elif isinstance (m, nn.LayerNorm): nn.init.constant_(m.bias, 0 ) nn.init.constant_(m.weight, 1.0 ) def forward (self, x ): x, H, W = self.patch_embed(x) x = self.pos_drop(x) for layer in self.layers: x, H, W = layer(x, H, W) x = self.norm(x) x = self.avgpool(x.transpose(1 , 2 )) x = torch.flatten(x, 1 ) x = self.head(x) return x