摘要

we present a simple masked separable attention (MSA) for camouflaged object detection. We first separate the multi-head self-attention into three parts, which are responsible for distinguishing the camouflaged objects from the background using different mask strategies. Furthermore, we propose to capture high-resolution semantic representations progressively based on a simple top-down decoder with the proposed MSA to attain precise segmentation results. These structures plus a backbone encoder form a new model, dubbed CamoFormer.(提出一种简单的掩盖可分离注意力算法,将多头注意力分为三部分,分别使用不同的掩模策略将目标区与背景区分开。)

介绍

Our MSA is built upon the multi-head self-attention mechanism but unlike traditional methods that utilize multiple attention heads simply for enhancing the feature representations, we propose to leverage different attention heads to calculate pixel correlations for different regions. To be specific, we split the self-attention heads into three groups. We first use two groups of heads to compute pixel correlations
of the foreground and background regions independently
.Our goal is to use the attention scores built within the predicted foreground generated by a prediction head to index camouflaged objects from the full-value representations and similarly for the background. Besides, we preserve a group of normal attention heads for comput-ing pixel correlations of the full map, which can help distinguish the camouflaged
objects from a global view
. Thus, three groups of heads are complementary.(MSA包含三组注意头,其中两组分别计算前景和背景区域的相关性,另外一组是正常注意力,用与计算全局的像素相关性。)

模型框架

MSA

class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
    hidden_features = int(dim*ffn_expansion_factor)
    self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
    self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
    self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias,mode):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

self.qkv_0 = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.qkv_1 = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.qkv_2 = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

self.qkv1conv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
self.qkv2conv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim,bias=bias)
self.qkv3conv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim,bias=bias)

self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

def forward(self, x,mask=None):
b,c,h,w = x.shape
q=self.qkv1conv(self.qkv_0(x))
k=self.qkv2conv(self.qkv_1(x))
v=self.qkv3conv(self.qkv_2(x))
if mask is not None:
q=q*mask
k=k*mask

q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out

def initialize(self):
weight_init(self)
class MSA_head(nn.Module):
def __init__(self, mode='dilation',dim=128, num_heads=8, ffn_expansion_factor=4, bias=False, LayerNorm_type='WithBias'):
super(MSA_head, self).__init__()
self.norm1 = LayerNorm(dim, LayerNorm_type)
self.attn = Attention(dim, num_heads, bias,mode)
self.norm2 = LayerNorm(dim, LayerNorm_type)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

def forward(self, x,mask=None):
x = x + self.attn(self.norm1(x),mask)
x = x + self.ffn(self.norm2(x))
return x

def initialize(self):
weight_init(self)
class MSA_module(nn.Module):
def __init__(self, dim=128):
super(MSA_module, self).__init__()
self.B_TA = MSA_head()
self.F_TA = MSA_head()
self.TA = MSA_head()
self.Fuse = nn.Conv2d(3*dim,dim,kernel_size=3,padding=1)
self.Fuse2 = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=1), nn.Conv2d(dim, dim, kernel_size=3, padding=1), nn.BatchNorm2d(dim), nn.ReLU(inplace=True))

def forward(self,x,side_x,mask):
N,C,H,W = x.shape
mask = F.interpolate(mask,size=x.size()[2:],mode='bilinear')
mask_d = mask.detach()
mask_d = torch.sigmoid(mask_d)
xf = self.F_TA(x,mask_d)
xb = self.B_TA(x,1-mask_d)
x = self.TA(x)
x = torch.cat((xb,xf,x),1)
x = x.view(N,3*C,H,W)
x = self.Fuse(x)
D = self.Fuse2(side_x+side_x*x)
return D
D4 = self.MSA5(E5, E4, P5)