Swin-Unet

摘要

​ 在过去几年中,卷积神经网络(CNN)在医学图像分析领域取得了里程碑式的进展,尤其是基于U型结构和跳跃连接的深度神经网络被广泛应用于各种医学图像任务中。然而,尽管CNN取得了优异的表现,但由于卷积操作的局部性,它无法很好地学习全局和长距离语义信息的交互。在本文中,我们提出了Swin-Unet,这是一个类似于Unet的纯Transformer,用于医学图像分割。 将标记化的图像块输入到基于Transformer的带有跳跃连接的U型编码器-解码器架构中进行局部全局语义特征学习。具体而言,我们使用带有移位窗口的分层Swin Transformer作为编码器来提取上下文特征。并设计了一个带有块扩展层的对称Swin Transformer解码器来执行上采样操作以恢复特征图的空间分辨率。在对输入和输出直接进行 4 倍下采样和上采样的情况下,在多器官和心脏分割任务上的实验表明,纯基于 Transformer 的 U 型编码器-解码器网络优于使用全卷积或 Transformer 与卷积组合的方法。代码和训练好的模型将在 https://github.com/HuCaoFighting/Swin-Unet 上公开发布。

介绍

​ 得益于深度学习的发展,计算机视觉技术在医学图像分析中得到了广泛的应用。图像分割是医学图像分析的重要组成部分。特别是准确、鲁棒的医学图像分割可以在计算机辅助诊断和图像引导临床手术中发挥基石作用。

​ 现有的医学图像分割方法主要依赖于具有 U 型结构的全卷积神经网络 (FCNN,顾名思义,就是神经网络全部由卷积层构成。与经典CNN网络的区别在于,它将CNN网络中的全连接层全部用卷积层替换。) 。典型的 U 型网络 U-Net 由具有跳跃连接的对称编码器-解码器组成。在编码器中,使用一系列卷积层和连续的下采样层来提取具有大感受野的深度特征。然后,解码器将提取的深度特征上采样到输入分辨率进行像素级语义预测,并将来自编码器的不同尺度的高分辨率特征与跳跃连接融合,以减轻下采样造成的空间信息丢失。凭借如此优雅的结构设计,U-Net 在各种医学成像应用中取得了巨大成功。遵循这一技术路线,已经开发出许多算法,例如 3D U-Net 、Res-UNet 、U-Net++ 和 UNet3+ ,用于各种医学成像模态的图像和体积分割。这些基于 FCNN 的方法在心脏分割、器官分割和病变分割中的优异表现证明了 CNN 具有强大的学习判别特征的能力。

​ 目前,虽然基于 CNN 的方法在医学图像分割领域取得了优异的表现,但仍然不能完全满足医学应用对于分割精度的严格要求。图像分割在医学图像分析中仍然是一项具有挑战性的任务。由于卷积运算的内在局部性,基于 CNN 的方法很难学习明确的全局和长距离语义信息交互 。一些研究试图通过使用空洞卷积层 、自注意机制 和图像金字塔 来解决这个问题。然而,这些方法在建模长距离依赖关系方面仍然存在局限性。最近,受到 Transformer 在自然语言处理 (NLP) 领域 巨大成功的启发,研究人员试图将 Transformer 带入视觉领域。在 [17] 中,提出了视觉变换器 (ViT) 来执行图像识别任务。以具有位置嵌入的 2D 图像块作为输入,并在大数据集上进行预训练,ViT 实现了与基于 CNN 的方法相当的性能。此外,[18] 提出了数据高效图像变换器 (DeiT),这表明 Transformer 可以在中等规模的数据集上进行训练,并且将其与蒸馏方法相结合可以获得更强大的 Transformer。[19] 开发了一种分层的 Swin Transformer。以 Swin Transformer 作为视觉主干,[19] 的作者在图像分类、目标检测和语义分割方面取得了最先进的性能。ViT、DeiT 和 Swin Transformer 在图像识别任务上的成功证明了 Transformer 在视觉领域的应用潜力。

全局和长距离语义信息交互是指在图像或其他数据中的不同区域之间进行信息交流和理解的能力。在视觉任务中,尤其是图像分割和目标检测中,不同对象或区域之间可能存在重要的关联性和语义信息。

  1. 全局信息:指整个图像中包含的所有信息,这些信息常常是多个对象、背景以及它们之间的关系的综合表现。在处理复杂场景时,全局信息可以帮助模型更好地理解图像的整体结构和语义。
  2. 长距离依赖:在图像中,某些重要特征或对象可能相隔较远,传统的卷积神经网络由于其局部感受野的特性,可能难以捕捉这些远离的区域之间的依赖关系。长距离依赖主要体现在需要跨越较大空间距离的上下文信息,这对于理解图像中的复杂情况尤为重要。

通过引入如 Transformer 的自注意力机制,可以实现全局和长距离信息的交互。自注意力机制允许模型在处理某一位置的特征时,同时考虑到图像中其他所有位置的特征,从而有效地捕获整体语义和远距离的上下文信息。这种能力对于提高图像分析任务的准确性和鲁棒性具有重要意义

数据高效图像变换器 (DeiT) 是一种改进的视觉变换器,它的提出主要是为了解决在较小或中等规模数据集上训练 Transformer 模型时,数据需求量大的问题。DeiT 通过引入蒸馏方法,能够在较少的数据上实现良好的性能,具体来说,其工作原理如下:

  1. 蒸馏方法:DeiT 使用了一种新的训练策略,称为“知识蒸馏”(Knowledge Distillation)。在这个过程中,DeiT 模型通过与一个预训练的强大教师模型(通常是一个较大规模的 CNN 模型)进行学习,从教师模型中获取知识。教师模型的输出可以作为训练的参考,帮助学生模型(DeiT)更好地理解特征。
  2. 数据效率:通过这种蒸馏策略,DeiT 能够在有限的数据集上获得与大型数据集上训练的模型相当的性能。这样,研究人员和开发者在处理有限的数据时,不必担心模型性能下降。
  3. 模型架构:DeiT 的架构基于 ViT,但在设计上进行了调整,以便适应蒸馏过程,提高了模型的表示能力和训练效率。

​ 受 Swin Transformer 成功的启发,我们提出了 Swin-Unet,以利用 Transformer 的强大功能进行 2D 医学图像分割。据我们所知,Swin-Unet 是第一个纯基于 Transformer 的 U 形架构,由编码器、瓶颈、解码器和跳过连接组成。编码器、瓶颈和解码器均基于 Swin Transformer 块 构建。输入的医学图像被拆分成不重叠的图像块。每个patch被视为一个token,输入到基于Transformer的编码器中学习深度特征表示。提取出的上下文特征随后通过带有patch扩展层的解码器进行上采样,并通过跳跃连接与编码器的多尺度特征融合,恢复特征图的空间分辨率并进一步进行分割预测。在多器官和心脏分割数据集上的大量实验表明,该方法具有良好的分割精度和鲁棒的泛化能力。具体来说,我们的贡献可以概括为:(1)基于Swin Transformer模块,我们构建了一个带有跳跃连接的对称编码器-解码器架构。在编码器中,实现了从局部到全局的自注意力;在解码器中,将全局特征上采样到输入分辨率以进行相应的像素级分割预测。(2)开发了patch扩展层,无需使用卷积或插值运算即可实现上采样和特征维数增加。 (3)实验中发现skip connection对于Transformer同样有效,因此最终构建了一个纯基于Transformer的带有skip connection的U型Encoder-Decoder架构,命名为Swin-Unet。

相关工作

基于CNN的方法:早期的医学图像分割方法主要是基于轮廓和基于传统机器学习的算法。随着深度CNN的发展,U-Net在[3]中被提出用于医学图像分割。由于U型结构的简单性和优越的性能,各种Unet-like方法不断涌现,如Res-UNet、Dense-UNet、U-Net++和UNet3+。并且也被引入到3D医学图像分割领域,如3D-Unet和V-Net。 目前,基于CNN的方法凭借其强大的表征能力,在医学图像分割领域取得了巨大的成功。

视觉变换器:Transformer 最早是在 [15] 中为机器翻译任务提出的。在 NLP 领域,基于 Transformer 的方法在各种任务中都取得了最先进的性能 。在 Transformer 成功的推动下,研究人员在 [17] 中引入了一种开创性的视觉变换器 (ViT),它在图像识别任务上实现了令人印象深刻的速度-准确度权衡。与基于 CNN 的方法相比,ViT 的缺点是它需要在自己的大型数据集上进行预训练。为了减轻训练 ViT 的难度,Deit 描述了几种训练策略,使 ViT 能够在 ImageNet 上进行良好的训练。最近,基于 ViT 已经完成了几项出色的工作 。值得一提的是,一种高效的分层视觉变换器,称为 Swin Transformer,在 [19] 中被提出作为视觉主干。基于移位窗口机制,Swin Transformer 在图像分类、目标检测、语义分割等多种视觉任务上取得了最佳性能。在本研究中,我们尝试使用 Swin Transformer 块作为基本单元,构建一个带有跳跃连接的 U 型编码器 - 解码器架构,用于医学图像分割,从而为 Transformer 在医学图像领域的发展提供一个基准比较。

自注意力/Transformer对CNN的补充:近年来,研究人员试图将自注意力机制引入CNN,以提高网络性能。在[12]中,带有加性注意门的跳跃连接被集成在U形架构中,以执行医学图像分割。然而,这仍然是基于CNN的方法。目前,人们正在努力将CNN和Transformer结合起来,以打破CNN在医学图像分割中的主导地位[2,27,1]。在[2]中,作者将Transformer与CNN结合起来,构成了一个用于二维医学图像分割的强编码器。与[2]类似,[27]和[28]利用Transformer和CNN的互补性来提高模型的分割能力。目前,Transformer与CNN的各种组合被应用于多模态脑肿瘤分割[29]和三维医学图像分割[1,30]。与上述方法不同,我们尝试探索纯Transformer在医学图像分割中的应用潜力

方法

架构概述

​ 所提出的 Swin-Unet 的整体架构如图 1 所示。 Swin-Unet 由编码器、瓶颈、解码器和跳过连接组成。Swin-Unet 的基本单元是 Swin Transformer 块 [19]。对于编码器,为了将输入转换为序列嵌入,医学图像被分割成不重叠的块,块大小为 4×4。通过这种划分方法,每个块的特征维度变为 4×4×3 = 48。此外,应用线性嵌入层将特征维度投影到任意维度(表示为 C)。转换后的patch tokens经过多个 Swin Transformer 块和Patch Merging以生成分层特征表示。具体而言,Patch Merging负责下采样和增加维度,Swin Transformer 块负责特征表示学习。受 U-Net 的启发,我们设计了一个基于对称变压器的解码器。解码器由 Swin Transformer 模块和 patch expanding层组成。提取的上下文特征通过跳跃连接与来自编码器的多尺度特征融合,以补充因下采样造成的空间信息丢失。与块合并层相比,专门设计了一个 patch expanding层来执行上采样。 patch expanding将相邻维度的特征图重塑为具有 2 倍上采样分辨率的大特征图。最后,使用最后一个 patch expanding层执行 4 倍上采样,将特征图的分辨率恢复为输入分辨率(W×H),然后在这些上采样特征上应用线性投影层以输出像素级分割预测。

image-20241013180328363

“Patch tokens” 是深度学习中的一个概念,尤其在视觉模型(如 Vision Transformer,ViT)中被广泛应用。这一概念主要涉及如何将图像转换为一种适合于 Transformer 处理的形式。以下是对“patch tokens”的详细解释:

定义

  • Patch Tokens:在处理图像时,首先将图像划分为多个小块(patch),每个小块被称为一个“patch”。这些小块被展平并转换为一维向量,随后形成的向量序列便称为“patch tokens”。每个 patch token 表示相应小块的特征。

如何生成 Patch Tokens

  • 划分图像:例如,将一幅图像划分为若干个 大小的块(如 16x16 像素)。
  • 展平处理:将每个小块展平为一个向量。例如,一个 RGB 图像的小块 16x16 的图像会被展平为一个大小为16x16x3 的向量。
  • 嵌入:通过线性嵌入将这些展平向量转换为一个新的维度 ,从而生成一序列的 patch tokens。

在 Vision Transformer 中的作用

  • 输入准备:将图像转换为 patch tokens 使得 Transformer 模型能够处理这些输入,因为传统的 Transformer 处理的是序列数据。
  • 特征表示:每个 patch token 作为不同区域图像的特征表达,通过这序列输入到 Transformer 中进行进一步的特征学习和上下文理解。

优势

  • 捕捉局部特征:通过将图像分割为块,模型能够更好地捕捉局部特征和图像中的结构信息。
  • 减少计算复杂度:相比于处理整个图像,处理较小的 patch 可以显著降低计算消耗,使得 Transformer 适合处理高分辨率图像。

在其他任务中的应用

  • 除了 Vision Transformer,patch tokens 的概念也可以扩展到其他任务中,特别是在需要处理高维数据时,如时序数据中的滑动窗口或局部区域特征表示。

Swin Transformer 模块

​ 与传统的多头自注意力 (MSA) 模块不同,swin Transformer 模块是基于移位窗口构建的。图 2 中展示了两个连续的 swin Transformer 模块。每个 swin Transformer 模块由 LayerNorm (LN) 层、多头自注意力模块、残差连接和具有 GELU 非线性的 2 层 MLP 组成。窗口多头自注意力(W-MSA)模块和偏移窗口多头自注意力(SW-MSA)模块分别应用于两个连续的Transformer块中。论文详情:[2103.14030] Swin Transformer: Hierarchical Vision Transformer using Shifted Windows (arxiv.org)

image-20241014155459690

Encoder

​ 在编码器中,分辨率 $\frac H4 \times \frac W4 $的c维标记化输入被输入到两个连续的 Swin Transformer 块中进行表征学习,其中特征维度和分辨率保持不变。 同时, patch merging层将减少 token数量(2 倍下采样)并将特征维度增加到原始维度的 2 倍。此过程将在编码器中重复三次

Patch merging layer: 输入的图块被分成 4 个部分,并通过 patch merging layer连接在一起。通过这样的处理,特征分辨率将下采样 2 倍。而且,由于连接操作导致特征维度增加 4 倍,因此在连接的特征上应用线性层,将特征维度统一为原始维度的 2 倍。

Bottleneck

​ 由于Transformer太深而难以收敛,因此仅使用两个连续的Swin Transformer块来构建瓶颈来学习深度特征表示。在瓶颈中,特征维度和分辨率保持不变。

Decoder

​ 与编码器相对应,基于 Swin Transformer 模块构建了对称解码器。为此,与编码器中使用的块合并层不同,我们在解码器中使用patch expanding层对提取的深度特征进行上采样。patch expanding层将相邻维度的特征图重塑为更高分辨率的特征图(2 倍上采样),并相应地将特征维度降低为原始维度的一半。

Patch expanding layer:以第一个patch expanding层为例,在上采样之前,对输入特征($ \frac W{32} \times \frac H {32} \times 8C$ )应用一个线性层,将特征维度增加到原始维度的 2 倍($ \frac W{32} \times \frac H {32} \times 16C$ )。 然后,我们使用重排操作将输入特征的分辨率扩展为输入分辨率的 2 倍,并将特征维度降低为输入维度的四分之一($ \frac W{32} \times \frac H {32} \times 16C \longrightarrow \frac W{16} \times \frac H {16} \times 4C$)。

Skip connection

​ 与 U-Net 类似,Skip connection用于将来自编码器的多尺度特征与上采样特征融合。将浅层特征和深层特征连接在一起,以减少下采样造成的空间信息丢失。紧接着一个线性层,连接特征的维度与上采样特征的维度保持不变。

实验

数据集中只有背景+目标两个类别

生成npz文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import glob
import cv2
import numpy as np
import torch
def npz():
#原图像路径
path = r'D:\BaiduNetdiskDownload\Pytorch-UNet-master\data\imgs\*.png'
#项目中存放训练所用的npz文件路径
path2 = r'D:\program\GitRepository\Swin-Unet\data\Synapse\train_npz\\'
for i,img_path in enumerate(glob.glob(path)):
#读入图像
image = cv2.imread(img_path)
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
image = cv2.resize(image,(512,512))

#读入标签
label_path = img_path.replace('imgs','masks')
label = cv2.imread(label_path,flags=0)
label = cv2.resize(label,(512,512))
#将非目标像素设置为0
label[label!=255]=0
#将目标像素设置为1
label[label==255]=1
#保存npz
np.savez(path2+str(i),image=image,label=label)
print('------------',i)
print('ok')

#生成npz文件对应的txt文件
def write_name():
# npz文件路径
files = glob.glob(r'D:\program\GitRepository\Swin-Unet\data\Synapse\train_npz\*.npz')
# txt文件路径
f = open(r'D:\program\GitRepository\Swin-Unet\lists\lists_Synapse\train.txt', 'w')
for i in files:
name = i.split('\\')[-1]
name = name[:-4] + '\n'
f.write(name)

if __name__ == '__main__':
npz()

修改train.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import logging
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from networks.vision_transformer import SwinUnet as ViT_seg
from trainer import trainer_synapse
from config import get_config

parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='./data/Synapse', help='root dir for data')
parser.add_argument('--dataset', type=str,
default='Synapse', help='experiment_name')
parser.add_argument('--list_dir', type=str,
default='./lists/lists_Synapse', help='list dir')
parser.add_argument('--num_classes', type=int,
default=2, help='output channel of network')
parser.add_argument('--output_dir',default='./output', type=str, help='output dir')
parser.add_argument('--max_iterations', type=int,
default=30000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int,
default=100, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
default=4, help='batch_size per gpu')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int, default=1,
help='whether use deterministic training')
parser.add_argument('--base_lr', type=float, default=0.01,
help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
default=512, help='input patch size of network input')
parser.add_argument('--seed', type=int,
default=1234, help='random seed')
parser.add_argument('--cfg', type=str,default=r'./configs/swin_tiny_patch4_window7_224_lite.yaml',
metavar="FILE", help='path to config file', )
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
help='no: no cache, '
'full: cache all data, '
'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
parser.add_argument('--use-checkpoint', action='store_true',
help="whether to use gradient checkpointing to save memory")
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
help='mixed precision opt level, if O0, no amp is used')
parser.add_argument('--tag', help='tag of experiment')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')

args = parser.parse_args()
if args.dataset == "Synapse":
args.root_path = os.path.join(args.root_path, "train_npz")
config = get_config(args)


if __name__ == "__main__":
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

dataset_name = args.dataset
dataset_config = {
'Synapse': {
'root_path': args.root_path,
'list_dir': './lists/lists_Synapse',
'num_classes': 2,
},
}

if args.batch_size != 24 and args.batch_size % 6 == 0:
args.base_lr *= args.batch_size / 24
args.num_classes = dataset_config[dataset_name]['num_classes']
args.root_path = dataset_config[dataset_name]['root_path']
args.list_dir = dataset_config[dataset_name]['list_dir']

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
net = ViT_seg(config, img_size=args.img_size, num_classes=args.num_classes).cuda()
net.load_from(config)

trainer = {'Synapse': trainer_synapse,}
trainer[dataset_name](args, net, args.output_dir)

数据集大小为512$ \times$512,但是预训练的swin-transformer为224$\times$224。

修改dataset_synapse.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import random
import h5py
import numpy as np
import torch
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset


def random_rot_flip(image, label):
k = np.random.randint(0, 4)
image = np.rot90(image, k)
label = np.rot90(label, k)
axis = np.random.randint(0, 2)
image = np.flip(image, axis=axis).copy()
label = np.flip(label, axis=axis).copy()
return image, label


def random_rotate(image, label):
angle = np.random.randint(-20, 20)
image = ndimage.rotate(image, angle, order=0, reshape=False)
label = ndimage.rotate(label, angle, order=0, reshape=False)
return image, label


class RandomGenerator(object):
def __init__(self, output_size):
self.output_size = output_size

def __call__(self, sample):
image, label = sample['image'], sample['label']

if random.random() > 0.5:
image, label = random_rot_flip(image, label)
elif random.random() > 0.5:
image, label = random_rotate(image, label)
# 改
x, y,_ = image.shape
if x != self.output_size[0] or y != self.output_size[1]:
# 改
image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y,1), order=3) # why not 3?
label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)

#改
image = torch.from_numpy(image.astype(np.float32))
# 改
image = image.permute(2, 0, 1)
label = torch.from_numpy(label.astype(np.float32))
sample = {'image': image, 'label': label.long()}
return sample


class Synapse_dataset(Dataset):
def __init__(self, base_dir, list_dir, split, transform=None):
self.transform = transform # using transform in torch!
self.split = split
self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines()
self.data_dir = base_dir

def __len__(self):
return len(self.sample_list)

def __getitem__(self, idx):
if self.split == "train":
slice_name = self.sample_list[idx].strip('\n')
data_path = os.path.join(self.data_dir, slice_name+'.npz')
data = np.load(data_path)
image, label = data['image'], data['label']
else:
slice_name = self.sample_list[idx].strip('\n')
data_path = os.path.join(self.data_dir, slice_name + '.npz')
data = np.load(data_path)
image, label = data['image'], data['label']
# 改,numpy转tensor
image = torch.from_numpy(image.astype(np.float32))
image = image.permute(2, 0, 1)
label = torch.from_numpy(label.astype(np.float32))

sample = {'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
sample['case_name'] = self.sample_list[idx].strip('\n')
return sample

详细实现步骤见:SwinUnet官方代码训练自己数据集(单通道灰度图像的分割)_swinunet代码-CSDN博客

实验结果

​ 在自己数据集上,模型效果并不优越,loss在0.12左右震荡(主要是代码中的Diceloss一直较高)。并且,分割图在224$\times$224时,由于图片与原图分辨率差距过大,实验结果图边缘明显齿状,如马赛克一样。