XLNet 原理解析

论文分析了以 ELMo/GPT 为代表的自回归语言模型与以 BERT 为代表的去噪自编码语言模型的不足,提出了结合 AR 与 AE 各自优点的乱序自回归语言模型,在多项 NLP 任务上达到了 SOTA。
论文:XLNet: Generalized Autoregressive Pretraining for Language Understanding

两种语言模型

自回归语言模型

对于一个给定的序列 $\mathbf{x}=\left[x_{1}, \cdots, x_{T}\right]$,自回归语言模型致力于对该序列的概率分布进行估计。
具体的,利用乘法公式对原始序列的似然函数进行分解:$p(\mathbf{x})=\prod_{t=1}^{T} p\left(x_{t} | \mathbf{x}_{<t}\right)$,可以通过极大似然来进行预训练:

此处的 $h_{\theta}\left(\mathbf{x}_{1: t-1}\right)$ 为神经网络输出的当前词的上文信息,而再通过一层 softmax 的网络即可对当前词进行预测。
缺点:AR 语言模型最大的弊端是非双向,在一些阅读理解等问题的时候,没有真双向信息会使得模型效果降低。而例如 ELMo 的伪双向,也只是一种缓解之计,并不能够学到两个语境间细致的依赖关系。

去噪自编码语言模型

Transformer 被提出之后,其强大的建模双向语境的能力将之前的 CNN/RNN 系列模型完全击败。而以 BERT 为首的去噪自编码语言模型应运而生。其做法是将部分待预测词替换为 MASK,通过建模双向语境的信息,来预测原来的词语:

此处的 $H_{\theta}$ 是 Transformer 的输出向量 $H_{\theta}(\mathbf{x})=\left[H_{\theta}(\mathbf{x})_{1}, H_{\theta}(\mathbf{x})_{2}, \cdots, H_{\theta}(\mathbf{x})_{T}\right]$。根据公式,我们可以发现 BERT 的一些优点和缺点:
优点:上下文相关的模型(context dependency),区别于 AR 模型的伪双向语境,由 Transformer 构建的去噪自编码语言模型建模了真双向语境。
缺点

  1. 独立性假设(independence assumption)。上式使用了 $\approx$ 符号,因为 AE 语言模型不是对原始的序列分布进行建模,没有使用乘法公式。BERT 对 mask 掉的目标单词分别进行预测,使用了独立性假设,即待预测的所有 MASK token 在未 mask 序列的条件下是独立的。
  2. pretrain-finetune 不匹配(discrepancy)。BERT 的输入使用了 MASK 符号,而该符号在 fine-tuning 阶段不会出现,这造成了 pretrain-finetune 不匹配,虽然 BERT 使用了策略缓解了一定的问题,但并未解决。

乱序语言模型

综合以上两种语言模型的弊端与优势,XLNet 最大的贡献即为提出了乱序语言模型。乱序语言模型使用一个序列的所有可能排序方式来构建一个 AR 语言模型,理论上,如果模型的参数在所有的顺序中共享,那么模型就能学到从所有位置收集上下文信息。
假设 $\mathscr{Z}_T$ 表示长度为 $T$ 的序列的所有可能的排序集合,对于一种排序方式 $\mathbf{z}\in\mathscr{Z}_T$,$z_t$ 与 $\mathbf{z}_{<t}$ 分别表示第 $t$ 个位置的元素与前 $t-1$ 个位置的元素。则乱序语言模型可以表示为:

这样一来,理论上,$x_t$ 可以看到任何 $x_{i} \neq x_{t}$ 的信息,因此便俘获了双向的上下文信息。另外,由于使用了 AR 的框架,也自然的避免了独立性假设与 pretrain-finetune 不匹配的问题。
同时,乱序语言模型使用的是原始位置的位置编码,而不是调整了原来句子的顺序,这得益于 Transformer 的 mask 机制来实现。
下图列出了长度为 4 的原始序列在某些排序方式下对 $x_3$ 的建模依赖关系:

双通道自注意力

普通自注意力失效

想要实现乱序的语言模型,如果直接对每个排序方式构建普通的 Transformer,并不能达到效果。原因如下。
假设我们需要预测下一个 token,

此处的 $h_{\theta}\left(\mathbf{x}_{\mathbf{z}<t}\right)$ 是当前时刻之前的序列经过 Transformer 之后的隐表示。但是,$h_{\theta}\left(\mathbf{x}_{\mathbf{z}<t}\right)$ 并没有依赖于待预测的当前位置,即 $z_t$。也就是说,无论要预测的是哪一个位置的词语,产生的分布是一样的,这是不合理的。所以,就有下式:

此处的 $g_{\theta}\left(\mathbf{x}_{\mathbf{z}<t}, z_{t}\right)$ 额外添加了待预测的目标位置 $z_{t}$ 作为输入。

双通道自注意力

需要建模的两种 token 为:

  1. $x_{z_t}$。此时只能使用位置信息 $z_t$ 而不能使用内容信息 $x_{z_t}$,否则就暴露标签了;
  2. $x_{x_j},\quad j>t$。对于后面的 token,需要使用完全的上下文内容信息 $x_{z_t}$。

因此,干脆直接分离开来,使用两种 hidden representations,一个包含了内容信息,一个只包含位置信息。

  1. query representation。$g_{\theta}\left(\mathbf{x}_{\mathbf{z}<t}, z_{t}\right)$,简写为 $g_{z_{t}}$。仅使用了上下文信息 $\mathbf{x}_{\mathbf{z}_{<t}}$ 与位置信息 $z_{t}$,没有使用当前的内容信息 $x_{z_{t}}$;
  2. content representation。$h_{\theta}\left(\mathbf{x}_{\leq t}\right)$,简写为 $h_{z_{t}}$。与标准的 Transformer 相同,同时使用了上下文信息与 $x_{z_{t}}$ 自己。

第一层初始化方法:query representation 用一个可训练向量 $g_{i}^{(0)}=w$,content representation 用单词本身的 embedding 来初始化 $h_{i}^{(0)}=e\left(x_{i}\right)$。
对于 self-attention 的后续层 $m=1, \ldots, M$,两个表示分别按照下式进行更新:

content representation 图示:

query representation 图示:

可以看到,content representation 通道与标准的 Transformer 是相同的,因此在 fine-tuning 阶段,可以简单的直接将 query stream 去掉。

局部预测

由于使用了乱序的语言模型,部分出现在排列前部分的 token 的上下文信息很弱,导致增加了模型优化的困难度。因此在一个排列中,我们选择只预测后部分的一些 token。
对于一个排列 $\mathbf{z}$,我们将其分为两部分:非目标子序列部分 $\mathbf{z}_{\leq c}$ 和目标子序列部分 $\mathbf{z}_{>c}$。将对目标子序列进行预测,因为其有当前排列下的最长上下文信息。

使用超参数 $K$ 来决定目标子序列的比例,即 $|\mathbf{z}| /(|\mathbf{z}|-c) \approx K$。对于非目标子序列,query representatio 可以不用计算。

Incorporating Ideas from Transformer-XL

模型还整合了 Transformer-XL 的两个重要技术。

relative positional encoding scheme

[TODO]

segment recurrence mechanism

[TODO]

relative segment encodings

在 fine-tuning 阶段,有很多任务是需要输入多 segments 的,如双句分类、阅读理解、智能问答等。传统的 BERT 直接使用了绝对编码,即将 $e_A$ 与 $e_B$ 直接赋值给句子内的每个 token。
参考 Transformers 的 relative encodings,我们对于 segments 进行相对编码。当两个位置 $i$ 与 $j$ 处于同一个 segment 中时,$\mathbf{s}_{ij}=\mathbf{s}_+$,否则 $\mathbf{s}_{ij}=\mathbf{s}_-$。即我们只关心两个位置是不是属于同一个 segment。当位置 $i$ 去关注 $j$ 时,可以计算出一个 attention weight $a_{ij}=(\mathbf{q}_i+\mathbf{b})\top\mathbf{s}_{ij}$,最后直接将该值加到普通的 attention weight 上去。
这样做带来了两个好处:

  1. 提升了泛化能力;
  2. 可以实现 fine-tuning 的时候支持多个 segments,而不是绝对编码时固定的两个。

其他

  1. 作者分析了 XLNet 与 BERT 的不同点,证明了这种框架下 XLNet 能比 BERT 学到更多的上下文信息;
  2. XLNet 使用了更多的语料;
  3. 采用了 span-based prediction。即先随机采样一个长度 $L\in[1,\dots,5]$,然后随机选择连续的 $L$ 个 tokens 作为预测对象。

参考

官方 GitHub:https://github.com/zihangdai/xlnet
https://www.jiqizhixin.com/articles/2019-06-29-3