- [CLS]就是classification的意思,可以理解为用于下游的分类任务。
在BERT(Bidirectional Encoder Representations from Transformers)模型里,[CLS]
是一个特殊的分类标记(Classification Token),下面为你详细介绍它的作用和意义。
1. 位置与形式
在输入文本时,[CLS]
标记会被添加到序列的起始位置。比如,当输入一个句子 “I love natural language processing” 时,实际进入BERT模型的输入序列为 [CLS] I love natural language processing
。
2. 主要作用
文本分类任务
在文本分类任务中,[CLS]
标记起着关键作用。BERT模型会对输入序列进行编码,得到每个标记对应的隐藏状态。而 [CLS]
标记对应的隐藏状态会汇聚整个输入序列的信息,可被视作整个句子的特征表示。后续,该隐藏状态会被传入一个分类器(例如全连接层),以完成文本分类任务。例如,在情感分析任务里,分类器会依据 [CLS]
标记的隐藏状态来判断文本是积极情感还是消极情感。
预训练任务
在BERT的预训练阶段,[CLS]
标记参与了两个重要的预训练任务:掩码语言模型(Masked Language Model,MLM)和下一句预测(Next Sentence Prediction,NSP)。在NSP任务中,模型要判断输入的两个句子在原始文本中是否是连续的句子,[CLS]
标记的隐藏状态会用于这个二分类判断。
3. 代码示例
以下是使用Hugging Face的 transformers
库加载BERT模型,并获取 [CLS]
标记隐藏状态的简单示例:
from transformers import BertTokenizer, BertModel
import torch
# 加载预训练的分词器和模型
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
# 输入文本
text = "I love natural language processing"
# 添加 [CLS] 标记并进行分词
inputs = tokenizer(text, return_tensors='pt')
# 模型推理
with torch.no_grad():
outputs = model(**inputs)
# 获取 [CLS] 标记的隐藏状态
cls_hidden_state = outputs.last_hidden_state[:, 0, :]
print(cls_hidden_state.shape)
在这个示例中,outputs.last_hidden_state[:, 0, :]
就是 [CLS]
标记对应的隐藏状态,其中 [:, 0, :]
表示取每个批次中第一个标记(即 [CLS]
标记)的所有隐藏单元的值。
参考
- https://www.pianshen.com/article/5232700066/