Vision Transformer(ViT)是一种将Transformer架构应用于计算机视觉领域的模型,其核心思想是将图像转换为序列数据,利用自注意力机制捕捉全局特征。以下是其详细介绍:

核心架构与工作流程

  1. 图像切块(Image Patches)
    将输入图像分割为固定大小的块(如16×16像素),每个块视为一个“token”。例如,一张224×224的图像会被切分为14×14=196个patch。

  2. 线性嵌入(Linear Embedding)
    每个patch被展平为一维向量(如16×16×3=768维),通过线性变换映射到固定维度(如768),形成嵌入向量序列。实际中常用卷积层实现(如16×16卷积核,步长16)。

  3. 分类token与位置编码

    • 分类token:在嵌入向量序列前添加一个可训练的特殊token(如[CLS]),用于最终分类任务。
    • 位置编码:通过可训练的参数或正弦函数为每个token添加位置信息,弥补Transformer对空间结构的不敏感。
  4. Transformer编码器

    • 由多个Encoder Block堆叠而成,每个Block包含:

    • 多头自注意力:并行计算多个子空间的注意力,捕捉不同特征间的依赖。

    • 残差连接与层归一化:稳定训练并提升模型深度。
    • 前馈神经网络:增强非线性表达能力。
  5. 分类头(MLP Head)
    提取分类token的输出,通过全连接层映射到类别空间,输出预测结果。

关键特点

  • 全局感知:自注意力机制直接建模patch间的长距离依赖,避免CNN逐层堆叠获取全局信息的局限性。

  • 灵活性:无需修改即可迁移到不同任务(如目标检测、语义分割),只需调整输入输出层。

  • 数据依赖性:需大规模数据(如JFT-300M)预训练,小数据场景下可能不如CNN。

代码实现示例(PyTorch)

import torch
import torch.nn as nn

class ViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, num_classes=1000):
        super().__init__()
        self.patch_size = patch_size
        self.embedding = nn.Conv2d(3, 768, kernel_size=patch_size, stride=patch_size)
        self.pos_embedding = nn.Parameter(torch.zeros(1, 197, 768))  # 196 patches + 1 class token
        self.class_token = nn.Parameter(torch.zeros(1, 1, 768))

        encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=12)
        self.fc = nn.Linear(768, num_classes)

    def forward(self, x):
        b, c, h, w = x.shape
        x = self.embedding(x).flatten(2).transpose(1, 2)  # (b, 196, 768)
        x = torch.cat([self.class_token.expand(b, -1, -1), x], dim=1)  # (b, 197, 768)
        x += self.pos_embedding
        x = self.transformer(x)
        x = self.fc(x[:, 0])  # 取class token输出
        return x

应用与发展

  • 图像分类:在ImageNet上达到88.55%准确率(预训练后微调)。

  • 混合模型:结合CNN与Transformer(如ConViT),平衡局部与全局特征。

  • 扩展任务:通过迁移学习应用于目标检测(DETR)、视频理解(VideoViT)等领域。

总结

ViT通过将图像转化为序列数据,成功将Transformer从NLP迁移到CV领域,凭借全局注意力机制在多项任务中超越传统CNN。其设计思想为后续视觉模型(如Swin Transformer)奠定了基础,推动了CV与NLP技术的融合发展。