class FeedForward(nn.Module): def __init__(self, dim, ffn_expansion_factor, bias): super(FeedForward, self).__init__() hidden_features = int(dim*ffn_expansion_factor) self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) def forward(self, x): x = self.project_in(x) x1, x2 = self.dwconv(x).chunk(2, dim=1) x = F.gelu(x1) * x2 x = self.project_out(x) return x class Attention(nn.Module): def __init__(self, dim, num_heads, bias,mode): super(Attention, self).__init__() self.num_heads = num_heads self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv_0 = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) self.qkv_1 = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) self.qkv_2 = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.qkv1conv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) self.qkv2conv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim,bias=bias) self.qkv3conv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim,bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x,mask=None): b,c,h,w = x.shape q=self.qkv1conv(self.qkv_0(x)) k=self.qkv2conv(self.qkv_1(x)) v=self.qkv3conv(self.qkv_2(x)) if mask is not None: q=q*mask k=k*mask
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) attn = (q @ k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) out = (attn @ v) out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) out = self.project_out(out) return out
def initialize(self): weight_init(self) class MSA_head(nn.Module): def __init__(self, mode='dilation',dim=128, num_heads=8, ffn_expansion_factor=4, bias=False, LayerNorm_type='WithBias'): super(MSA_head, self).__init__() self.norm1 = LayerNorm(dim, LayerNorm_type) self.attn = Attention(dim, num_heads, bias,mode) self.norm2 = LayerNorm(dim, LayerNorm_type) self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x,mask=None): x = x + self.attn(self.norm1(x),mask) x = x + self.ffn(self.norm2(x)) return x
def initialize(self): weight_init(self) class MSA_module(nn.Module): def __init__(self, dim=128): super(MSA_module, self).__init__() self.B_TA = MSA_head() self.F_TA = MSA_head() self.TA = MSA_head() self.Fuse = nn.Conv2d(3*dim,dim,kernel_size=3,padding=1) self.Fuse2 = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=1), nn.Conv2d(dim, dim, kernel_size=3, padding=1), nn.BatchNorm2d(dim), nn.ReLU(inplace=True))
def forward(self,x,side_x,mask): N,C,H,W = x.shape mask = F.interpolate(mask,size=x.size()[2:],mode='bilinear') mask_d = mask.detach() mask_d = torch.sigmoid(mask_d) xf = self.F_TA(x,mask_d) xb = self.B_TA(x,1-mask_d) x = self.TA(x) x = torch.cat((xb,xf,x),1) x = x.view(N,3*C,H,W) x = self.Fuse(x) D = self.Fuse2(side_x+side_x*x) return D D4 = self.MSA5(E5, E4, P5)
|