Cross attention(交叉注意力)是深度学习里的一个重要概念,在很多模型尤其是Transformer架构及其变体中被广泛应用。下面为你详细介绍:

概念

在注意力机制里,主要有自注意力(self-attention)和交叉注意力。自注意力关注的是序列内元素之间的关系,也就是让序列里的每个元素都和序列内的其他元素进行交互,以此捕获序列的上下文信息。而交叉注意力则是在两个不同的序列之间计算注意力权重,一个序列作为查询(query),另一个序列作为键(key)和值(value),从而让查询序列能够聚焦于键 - 值序列里的相关信息。

计算过程

交叉注意力的计算和自注意力类似,不过它的查询、键和值来自不同的序列。具体步骤如下:

  1. 线性变换:对查询序列 $Q$、键序列 $K$ 和值序列 $V$ 分别进行线性变换,得到 $Q'$、$K'$ 和 $V'$。

  2. 计算注意力分数:计算查询和键之间的相似度,一般使用点积来计算:$scores = Q'K'^T$。

  3. 应用掩码(可选):在某些情形下,需要应用掩码来屏蔽掉一些不需要关注的位置。

  4. 计算注意力权重:对分数应用 softmax 函数,从而得到注意力权重:$weights = softmax(scores / \sqrt{d_k})$,这里的 $d_k$ 是键的维度。

  5. 加权求和:用注意力权重对值进行加权求和,得到输出:$output = weightsV'$。

公式

交叉注意力的输出可以用以下公式表示:

代码示例

以下是使用 PyTorch 实现的简单交叉注意力代码:

import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, input_dim, key_dim, value_dim):
        super(CrossAttention, self).__init__()
        self.query_proj = nn.Linear(input_dim, key_dim)
        self.key_proj = nn.Linear(input_dim, key_dim)
        self.value_proj = nn.Linear(input_dim, value_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value):
        Q = self.query_proj(query)
        K = self.key_proj(key)
        V = self.value_proj(value)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(Q.size(-1), dtype=torch.float32))
        attention_weights = self.softmax(scores)
        output = torch.matmul(attention_weights, V)
        return output

# 示例使用
input_dim = 64
key_dim = 32
value_dim = 32
query = torch.randn(10, 5, input_dim)  # 批次大小为 10,序列长度为 5
key = torch.randn(10, 8, input_dim)    # 批次大小为 10,序列长度为 8
value = torch.randn(10, 8, input_dim)  # 批次大小为 10,序列长度为 8

cross_attn = CrossAttention(input_dim, key_dim, value_dim)
output = cross_attn(query, key, value)
print(output.shape)

应用场景

  • 机器翻译:在编码器 - 解码器架构里,解码器能够利用交叉注意力机制关注编码器输出的相关信息,从而生成准确的翻译结果。

  • 图像生成:在文本到图像生成任务中,文本嵌入可以作为查询,图像特征作为键和值,通过交叉注意力让生成的图像和文本描述相匹配。

  • 多模态任务:在处理多种模态数据(像文本和图像)时,交叉注意力可以帮助不同模态之间进行信息交互。