Cross attention(交叉注意力)是深度学习里的一个重要概念,在很多模型尤其是Transformer架构及其变体中被广泛应用。下面为你详细介绍:
概念
在注意力机制里,主要有自注意力(self-attention)和交叉注意力。自注意力关注的是序列内元素之间的关系,也就是让序列里的每个元素都和序列内的其他元素进行交互,以此捕获序列的上下文信息。而交叉注意力则是在两个不同的序列之间计算注意力权重,一个序列作为查询(query),另一个序列作为键(key)和值(value),从而让查询序列能够聚焦于键 - 值序列里的相关信息。
计算过程
交叉注意力的计算和自注意力类似,不过它的查询、键和值来自不同的序列。具体步骤如下:
-
线性变换:对查询序列 $Q$、键序列 $K$ 和值序列 $V$ 分别进行线性变换,得到 $Q'$、$K'$ 和 $V'$。
-
计算注意力分数:计算查询和键之间的相似度,一般使用点积来计算:$scores = Q'K'^T$。
-
应用掩码(可选):在某些情形下,需要应用掩码来屏蔽掉一些不需要关注的位置。
-
计算注意力权重:对分数应用 softmax 函数,从而得到注意力权重:$weights = softmax(scores / \sqrt{d_k})$,这里的 $d_k$ 是键的维度。
-
加权求和:用注意力权重对值进行加权求和,得到输出:$output = weightsV'$。
公式
交叉注意力的输出可以用以下公式表示:
代码示例
以下是使用 PyTorch 实现的简单交叉注意力代码:
import torch
import torch.nn as nn
class CrossAttention(nn.Module):
def __init__(self, input_dim, key_dim, value_dim):
super(CrossAttention, self).__init__()
self.query_proj = nn.Linear(input_dim, key_dim)
self.key_proj = nn.Linear(input_dim, key_dim)
self.value_proj = nn.Linear(input_dim, value_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, query, key, value):
Q = self.query_proj(query)
K = self.key_proj(key)
V = self.value_proj(value)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(Q.size(-1), dtype=torch.float32))
attention_weights = self.softmax(scores)
output = torch.matmul(attention_weights, V)
return output
# 示例使用
input_dim = 64
key_dim = 32
value_dim = 32
query = torch.randn(10, 5, input_dim) # 批次大小为 10,序列长度为 5
key = torch.randn(10, 8, input_dim) # 批次大小为 10,序列长度为 8
value = torch.randn(10, 8, input_dim) # 批次大小为 10,序列长度为 8
cross_attn = CrossAttention(input_dim, key_dim, value_dim)
output = cross_attn(query, key, value)
print(output.shape)
应用场景
-
机器翻译:在编码器 - 解码器架构里,解码器能够利用交叉注意力机制关注编码器输出的相关信息,从而生成准确的翻译结果。
-
图像生成:在文本到图像生成任务中,文本嵌入可以作为查询,图像特征作为键和值,通过交叉注意力让生成的图像和文本描述相匹配。
-
多模态任务:在处理多种模态数据(像文本和图像)时,交叉注意力可以帮助不同模态之间进行信息交互。