WheatField
WheatField

Application of Contrast Learning to Text Representation

November 28, 20221166 words, 6 min read
Authors

表征学习

文本表征学习就是将一段文本映射到低维向量空间,获取句子的语义表示,大致经历过四个阶段:

  • 统计类型,典型的方法是利用 TD-IDF 抽取关键词,用关键词表示表征整个句子。

  • 深度模型阶段,此阶段方式较多,自从 glove、word2vec 等词粒度的表示出现后,在此基础有比较多的延伸工作,从对句子中的词向量简单平均、到有偏平均 SIF[1],后来引入 CNN、LSTM 等模型利用双塔、单塔方式进行学习句子表示,比较典型的几个工作有

    • 微软在排序搜索场景的 DSSM[2],将 word 进行 hash 减少词汇个数,对 word 的表示进行平均得到句子原始表达,经过三层 MLP 获取句子表示。
    • 多伦多大学提出的 Skip-Thought[3],是 word2vec 的 skip-ngram 在句子表达的延伸,输入三个句子,用中间一句话,预测前后两句话。
    • IBM 的 Siam-CNN[4],提出了四种单塔、双塔不同的架构,利用 pairwise loss 作为损失函数。
    • Meta 的 InferSent[5],在双塔的表示基础上,增加了充分的交互。
  • Bert、Ernie 等预训练大模型阶段,在此阶段比较基础典型的工作有:

    • 由于 Bert 通过 SEP 分割,利用 CLS 运用到匹配任务场景存在计算量过大的问题,Sentence-BERT[6] 提出将句子拆开,每个句子单独过 encoder,借鉴 InferSent 的表示交互,来学习句子表达。
  • 20 年在图像领域兴起的对比学习引入到 NLP。

对比学习

对比学习(contrast learning)一般划分到无监督学习(USL)的范畴,典型范式就是:代理任务 + 目标函数,这两项也是对比学习与有监督学习(SL)最大的区别。

SL 中有输入 xx,有对应的 ground truth yy,计算模型输出的 ypy_pyy 通过目标函数计算损失,指导模型训练。对于 USL 来说,是没有 ground truth 的,而这里就是代理任务发挥作用的地方,代理任务的目标是学习到一个好的表征,使得这个表征可以轻松适应到下游任务。Pretext tasks 从未标注数据中采样输入与标签,再结合特定损失函数进行训练。SimCLR 的框架如下图,大体流程:

  1. 数据增强,采样,构造正负样本。
  2. 通过对比损失训练特征提取器 (encoder) ff 及一个映射头 gggg 用来将 ff 的输出映射到一个低维空间。在 SimCLR 中 projection head 是一个两层的 MLP,维度是 128 维。
  3. 在下游任务中,把 projection head 去掉,只保留 ff,用 ff 的输出作为特征,进行下游任务的训练。
SimCLR framework

SimCLR 采用的损失是 InfoNCE loss,这个损失函数的目标是最大化正样本的相似度,最小化负样本的相似度。具体的计算方式如下:

L=logexp(sim(zi,zj)/τ)k=12N1[ki]exp(sim(zi,zk)/τ)\begin{aligned} \mathcal{L} &= -\log \frac{\exp(\text{sim}(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k\neq i]}\exp(\text{sim}(z_i, z_k)/\tau)} \end{aligned}

其中 sim(zi,zj)=ziTzjzizj\text{sim}(z_i, z_j) = \frac{z_i^Tz_j}{\lVert z_i\rVert \lVert z_j\rVert}τ\tau 是一个温度参数,1[ki]\mathbb{1}_{[k\neq i]} 是一个指示函数,当 kik\neq i 时,1[ki]=1\mathbb{1}_{[k\neq i]}=1,否则为 0。NN 是 batch size,对一个 batch 内的 NN 个样本,通过数据增强的方式构造 2N2N 个样本,对于每一个样本 xx,都有一个正样本 x+x^+2N12N-1 个负样本 xx^-

在 NLP 中应用

对比学习的目标是使得相似的东西表示越相似,不相似的东西越不相似。一般训练过程:

  1. 通过数据增强的方式构造训练数据集,对于一条数据,数据集需要包含正例(相似的数据)和负例(不相似的数据)。
    1. 增强方式如,term 替换、随机删除、回译等
  2. 将正例和负例同时输入到 encoder 模型中。
  3. 最小化正例之间的距离,最大化负例之间的距离,进行参数更新。

在语义相似度任务中,一种基于对比学习的方法是 SimCSE

  1. 损失联合方式自监督:将 CL 的 loss 和其他 loss 混合,通过联合优化,使 CL 起到效果:CLEAR,DeCLUTER,SCCL。
  2. 非联合方法自监督:构造增强样本,微调模型:Bert-CT,ConSERT,SimCSE

参考