《Continual Learning of a Mixed Sequence of Similar and Dissimilar Tasks》论文总结
一、研究背景
在许多应用中,系统需要逐步或持续学习一系列任务,这被称为持续学习(CL)或终身学习。现有CL模型大多专注于处理灾难性遗忘,但无法同时处理遗忘和进行知识转移。本文旨在解决这一问题,提出一种能够学习混合相似和不相似任务序列的技术。
二、相关工作
- 大多数现有CL模型专注于处理神经网络中的灾难性遗忘,如使用知识蒸馏、量化网络权重重要性、记忆训练示例、生成数据等方法。
- 一些工作致力于知识转移,但现有方法在处理遗忘和知识转移以改进新任务学习方面存在不足。
- 早期终身学习工作主要关注正向知识转移,部分使用神经网络的工作处理的任务非常相似,几乎没有遗忘问题,且没有处理混合任务序列的遗忘问题。
三、提出的CAT模型
- 模型结构:CAT模型由输入数据、任务ID、知识基(KB)、任务掩码(TM)、知识转移注意力(KTA)等组成。KB用于存储从所有任务中学到的知识并由所有任务共享,任务ID用于生成不同的任务ID嵌入,TM用于防止不相似任务的遗忘,KTA用于从相似任务中进行知识转移。
-
防止不相似任务的遗忘:任务掩码
- 原理:通过识别KB中被任务使用的单元,并阻止梯度流通过这些单元来克服遗忘。
- 实现:为每个任务在KB的每层训练一个任务掩码(二进制掩码),根据任务ID嵌入生成掩码,并使用掩码保护任务知识。在学习新任务时,使用之前不相似任务的掩码设置梯度,使重要参数不受影响。
- 训练技巧:采用退火策略使学习任务ID嵌入更容易,任务掩码与分类器一起通过最小化交叉熵进行训练。
-
从相似任务中进行知识转移:知识转移注意力
- 原理:由于每个任务可能有其特定领域知识,知识转移必须具有选择性,通过注意力机制给不同的先前任务赋予不同的重要性,以选择性地将其知识转移到新任务。
- 实现:计算先前相似任务在KB中的掩码输出,学习另一个任务ID嵌入并堆叠输出,通过自注意力计算注意力权重,加权求和得到相似任务的输出,然后输入到分类器中学习。
- 损失函数:训练分类器使用交叉熵损失函数。
-
任务相似性检测
- 定义:通过确定是否存在从先前任务到当前任务的正向知识转移来定义任务相似性。
- 实现:使用转移模型和参考模型进行验证,根据验证集判断任务是否相似,设置任务相似性向量(TSV)。转移模型在KB上训练一个小的读出函数,参考模型是单独为任务构建的模型。
四、实验
-
实验数据集
- 相似任务数据集:采用来自联邦学习的两个数据集,F - EMNIST(每个任务包含一个作者的手写数字/字符)和F - CelebA(每个任务包含一个名人的图像,根据是否微笑进行标注)。
- 不相似任务数据集:使用两个基准图像分类数据集EMNIST和CIFAR100,考虑两种分割场景,每种数据集准备两组任务,每组任务包含不同数量的类别。
- 混合序列数据集:由上述相似和不相似任务数据集构建四个混合序列数据集,用于实验学习混合任务序列。
-
对比基线:考虑十个任务持续学习(TCL)基线,包括EWC、HAT、UCL、HYP、HYP - R、PRO、PathNet、RPSNet、NCL和ONE。
-
网络和训练细节:使用2层全连接网络或基于CNN的AlexNet - like架构,设置任务ID嵌入和KB的维度,使用softmax输出的全连接层作为分类头,采用交叉熵损失函数,设置训练参数如学习率、早停策略、批大小等。
-
结果和分析
- 总体性能:CAT在所有任务的总体准确性方面优于所有基线,其他基线在避免遗忘或鼓励知识转移方面存在不足。
- 不相似任务性能:CAT在不相似任务上的表现优于大多数基线,与HAT相当,表明CAT能较好地处理遗忘。
- 相似任务性能:CAT在相似任务上明显优于HAT和其他基线,因为CAT能利用相似任务之间的共享知识。
- 知识转移有效性:向前知识转移非常有效,向后知识转移对F - MNIST略有改进,对F - CelebA显著改进。
-
消融实验:通过去除KTA或不检测任务相似性的实验表明,完整的CAT系统总是给出最好的总体精度,每个组件都对模型有贡献。
五、结论
本文提出的CAT模型能够实现持续学习系统的四个目标:不遗忘、正向知识转移、反向知识转移和学习相似和不相似任务的混合序列,实验结果表明CAT优于强大的基线。未来工作将专注于提高学习相似任务的准确性和效率,并计划探索在训练中使用更少标记数据的方法。