摘要

We incorporate a Group multi-axis Hadamard Product Attention module (GHPA)
and a Group Aggregation Bridge module (GAB) in a lightweight manner. The GHPA groups input features and performs Hadamard Product Attention mechanism (HPA) on different axes to extract pathological information from diverse perspectives. The GAB effectively fuses multiscale information by grouping low-level features, high-level features, and a mask generated by the decoder at each stage.this is the first model with a parameter count limited to just 50KB.(以轻量级的方式整合了GHPA模块和一个GAB模块)。GHPA对输入特征进行分组,并在不同的轴上执行HPA,从不同的角度提取病理信息。GAB通过将低级特征、高级特征和解码器在每一级生成的掩码分组来有效地融合多尺度信息。)

代码

GitHub - JCruan519/EGE-UNet: (MICCAI23) This is the official code repository for “EGE-UNet: an Efficient Group Enhanced UNet for skin lesion segmentation”.)

介绍

To be specific, EGE-UNet leverages two key modules: the Group multi-axis
Hadamard Product Attention module (GHPA) and Group Aggregation Bridge
module (GAB)
. HPA employs a learnable weight and performs a hadamard product operation with the input to obtain the output. Subsequently, inspired by the multi-head mode in MHSA, we propose GHPA, which divides the input into different groups
and performs HPA in each group. However, it is worth noting that we perform
HPA on different axes in different groups, which helps to further obtain infor-
mation from diverse perspectives
. On the other hand, for GAB, since the size
and shape of segmentation targets in medical images are inconsistent, it is essential to obtain multi-scale information
. Therefore, GAB integrates high-level
and low-level features with different sizes based on group aggregation, and addi-
tionally introduce mask information to assist feature fusion. Via combining the
above two modules with UNet, we propose EGE-UNet, which achieves excellent
segmentation performance with extremely low parameter and computation. Un-
like previous approaches that focus solely on improving performance, our model
also prioritizes usability in real-world environments.(具体而言,EGE-UNet利用了两个关键模块:集团多轴Hadamard产品关注模块(GHPA)和集团聚合桥模块(GAB)。HPA采用可学习的权重,并对输入执行Hadamard乘积运算,以获得输出。随后,受MHSA中的多头模式的启发,提出了GHPA,它将输入分成不同的组,并在每个组中执行HPA。但值得注意的是,我们在不同的群体中在不同的轴上进行HPA,这有助于从多样化的角度进一步获得信息。另一方面,对于GAB,由于医学图像中分割目标的大小和形状是不一致的,因此获取多尺度信息是必不可少的。为此,GAB基于群体聚合将不同大小的高层特征和底层特征进行融合,并额外引入掩模信息辅助特征融合。通过将上述两个模块与UNet相结合,提出了EGE-UNet算法,该算法以极低的参数和计算量获得了良好的分割性能。)

主要贡献

(1) GHPA and GAB are proposed, with the former efficiently acquiring and integrating multi-perspective information and the latter accepting features at different scales, along with an auxiliary mask for efficient multi-scale feature fusion. (提出了GHPA和GAB两种特征融合算法,前者能有效地获取和融合多视角信息,后者能有效地融合不同尺度的特征,沿着提出了一种辅助掩模,实现了多尺度特征的有效融合)

(2) We propose EGE-UNet, an extremely lightweight model designed for skin lesion segmentation.
(3) We conduct extensive experiments, which demonstrate the effectiveness of
our methods in achieving state-of-the-art performance with significantly lower
resource requirements.

模型框架

GHPA 模块

Group multi-axis Hadamard Product Attention module. To overcome
the quadratic complexity issue posed by MHSA, we propose HPA with linear
complexity. Given an input x and a randomly initialized learnable tensor p,
bilinear interpolation is first utilized to resize p to match the size of x. Then,
we employ depth-wise separable convolution (DW) on p, followed by a
hadamard product operation between x and p to obtain the output. However,
utilizing simple HPA alone is insufficient to extract information from multiple
perspectives, resulting in unsatisfactory results. Motivated by the multi-head
mode in MHSA, we introduce GHPA based on HPA, We divide the input into four groups equally along the channel dimension and perform HPA on the height-width, channel-height, and channel-width axes for the first three groups, respectively. For the last group, we only use DW on the feature map. Finally, we concatenate the four groups along the channel dimension and apply another DW to integrate the information from different perspectives. Note that all kernel size employed in DW are 3.(为了克服MHSA的二次复杂度问题,提出了线性复杂度的HPA。给定输入x和随机初始化的可学习张量p,首先利用双线性插值来调整p的大小以匹配x的大小。然后,在p上使用深度可分离卷(DW),然后在x和p之间进行hadamard乘积运算以获得输出。然而单独使用简单的HPA不足以从多个视角提取信息,导致结果不令人满意。受MHSA中多头模式的启发,引入了基于HPA的GHPA。将输入沿通道维度平均沿着分为四组,并分别对前三组的高度-宽度、通道-高度和通道-宽度轴执行HPA。对于最后一组,只在特征图上使用DW。最后,沿着沿着渠道维度将这四个组连接起来,并应用另一个数据仓库来整合来自不同角度的信息。注意DW中使用的所有内核大小都是3。)

class Grouped_multi_axis_Hadamard_Product_Attention(nn.Module):
def __init__(self, dim_in, dim_out, x=8, y=8):
super().__init__()

c_dim_in = dim_in//4
k_size=3
pad=(k_size-1) // 2

self.params_xy = nn.Parameter(torch.Tensor(1, c_dim_in, x, y), requires_grad=True)
nn.init.ones_(self.params_xy)
self.conv_xy = nn.Sequential(nn.Conv2d(c_dim_in, c_dim_in, kernel_size=k_size, padding=pad, groups=c_dim_in), nn.GELU(), nn.Conv2d(c_dim_in, c_dim_in, 1))

self.params_zx = nn.Parameter(torch.Tensor(1, 1, c_dim_in, x), requires_grad=True)
nn.init.ones_(self.params_zx)
self.conv_zx = nn.Sequential(nn.Conv1d(c_dim_in, c_dim_in, kernel_size=k_size, padding=pad, groups=c_dim_in), nn.GELU(), nn.Conv1d(c_dim_in, c_dim_in, 1))

self.params_zy = nn.Parameter(torch.Tensor(1, 1, c_dim_in, y), requires_grad=True)
nn.init.ones_(self.params_zy)
self.conv_zy = nn.Sequential(nn.Conv1d(c_dim_in, c_dim_in, kernel_size=k_size, padding=pad, groups=c_dim_in), nn.GELU(), nn.Conv1d(c_dim_in, c_dim_in, 1))

self.dw = nn.Sequential(
nn.Conv2d(c_dim_in, c_dim_in, 1),
nn.GELU(),
nn.Conv2d(c_dim_in, c_dim_in, kernel_size=3, padding=1, groups=c_dim_in)
)

self.norm1 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')
self.norm2 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')

self.ldw = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size=3, padding=1, groups=dim_in),
nn.GELU(),
nn.Conv2d(dim_in, dim_out, 1),
)

def forward(self, x):
x = self.norm1(x)
x1, x2, x3, x4 = torch.chunk(x, 4, dim=1)#沿通道维度分为四部分
B, C, H, W = x1.size()
#----------xy----------#
params_xy = self.params_xy #可学习张量
x1 = x1 * self.conv_xy(F.interpolate(params_xy, size=x1.shape[2:4],mode='bilinear', align_corners=True))
#----------zx----------#
x2 = x2.permute(0, 3, 1, 2)
params_zx = self.params_zx
x2 = x2 * self.conv_zx(F.interpolate(params_zx, size=x2.shape[2:4],mode='bilinear', align_corners=True).squeeze(0)).unsqueeze(0)
x2 = x2.permute(0, 2, 3, 1)
#----------zy----------#
x3 = x3.permute(0, 2, 1, 3)
params_zy = self.params_zy
x3 = x3 * self.conv_zy(F.interpolate(params_zy, size=x3.shape[2:4],mode='bilinear', align_corners=True).squeeze(0)).unsqueeze(0)
x3 = x3.permute(0, 2, 1, 3)
#----------dw----------#
x4 = self.dw(x4)
#----------concat----------#
x = torch.cat([x1,x2,x3,x4],dim=1)
#----------ldw----------#
x = self.norm2(x)
x = self.ldw(x)
return x

GAB 模块

Group Aggregation Bridge module. The acquisition of multi-scale informa-
tion is deemed pivotal for dense prediction tasks, such as medical image segmen-
tation. Hence, we introduce GAB, which takes three inputs: low-level features, high-level features, and a mask. Firstly, depthwise separable convolution (DW) and bilinear interpolation are employed to adjust the size of high-level features, so as to match the size of low-level features. Secondly, we partition both feature maps into four groups along the channel dimension, and concatenate one group from the low-level features with one from the high-level features to obtain four groups of fused features. For each group of fused features, the mask is concatenated. Next, dilated convolutions with kernel size of 3 and different dilated rates of {1, 2, 5, 7} are applied to the different groups, in order to extract information at different scales. Finally, the four groups are concatenated along the channel dimension, followed by the application of a plain
convolution with the kernel size of 1 to enable interaction among features at
different scales.(多尺度信息的获取被认为是密集预测任务(如医学图像分割)的关键。因此,引入了GAB,它有三个输入:低级特征、高级特征和掩码。首先,使用深度可分离卷积(DW)和双线性插值来调整高层特征的大小,以匹配低层特征的大小。其次,将两个特征图沿通道维度划分为沿着四组,并将一组来自低级特征的特征与一组来自高级特征的特征连接起来,以获得四组融合特征。对于每组融合特征,掩码被连接。接下来,核大小为3的扩张卷积并且将不同的扩张率{1,2,5,7}应用于不同组,以便提取不同尺度的信息。最后,这四个组沿通道维度沿着连接,然后应用核大小为1的普通卷积,以实现不同尺度的特征之间的交互。)

class group_aggregation_bridge(nn.Module):
def __init__(self, dim_xh, dim_xl, k_size=3, d_list=[1,2,5,7]):
super().__init__()
self.pre_project = nn.Conv2d(dim_xh, dim_xl, 1)
group_size = dim_xl // 2
self.g0 = nn.Sequential(
LayerNorm(normalized_shape=group_size+1, data_format='channels_first'),
nn.Conv2d(group_size + 1, group_size + 1, kernel_size=3, stride=1,
padding=(k_size+(k_size-1)*(d_list[0]-1))//2,
dilation=d_list[0], groups=group_size + 1)
)
self.g1 = nn.Sequential(
LayerNorm(normalized_shape=group_size+1, data_format='channels_first'),
nn.Conv2d(group_size + 1, group_size + 1, kernel_size=3, stride=1,
padding=(k_size+(k_size-1)*(d_list[1]-1))//2,
dilation=d_list[1], groups=group_size + 1)
)
self.g2 = nn.Sequential(
LayerNorm(normalized_shape=group_size+1, data_format='channels_first'),
nn.Conv2d(group_size + 1, group_size + 1, kernel_size=3, stride=1,
padding=(k_size+(k_size-1)*(d_list[2]-1))//2,
dilation=d_list[2], groups=group_size + 1)
)
self.g3 = nn.Sequential(
LayerNorm(normalized_shape=group_size+1, data_format='channels_first'),
nn.Conv2d(group_size + 1, group_size + 1, kernel_size=3, stride=1,
padding=(k_size+(k_size-1)*(d_list[3]-1))//2,
dilation=d_list[3], groups=group_size + 1)
)
self.tail_conv = nn.Sequential(
LayerNorm(normalized_shape=dim_xl * 2 + 4, data_format='channels_first'),
nn.Conv2d(dim_xl * 2 + 4, dim_xl, 1)
)
def forward(self, xh, xl, mask):
xh = self.pre_project(xh)
xh = F.interpolate(xh, size=[xl.size(2), xl.size(3)], mode ='bilinear', align_corners=True)
xh = torch.chunk(xh, 4, dim=1)
xl = torch.chunk(xl, 4, dim=1)
x0 = self.g0(torch.cat((xh[0], xl[0], mask), dim=1))
x1 = self.g1(torch.cat((xh[1], xl[1], mask), dim=1))
x2 = self.g2(torch.cat((xh[2], xl[2], mask), dim=1))
x3 = self.g3(torch.cat((xh[3], xl[3], mask), dim=1))
x = torch.cat((x0,x1,x2,x3), dim=1)
x = self.tail_conv(x)
return x