【论文笔记】MT-DNN

想写这篇蛮久的,但由于之前一直忙着搞别的事情(好吧就是懒),一直拖着。刚好最近有用这个方面的需求,就又读了一遍论文和github上的一些实现。

大部分的博客都只是粗略翻译论文,然而光看论文,往往会忽略一些实现细节,所以笔者最近在尝试将论文笔记和源码解析结合起来,就从这篇MT-DNN开始吧。希望大家多多提意见。

Paper:Multi-Task Deep Neural Networks for Natural Language Understanding

其实作者(Xiaodong Liu)早在15年就写过一篇Multi-task相关的论文,只不过当时还没有bert这样优秀的预训练表达层,在bert横扫各大榜单之后,作者将之前多任务的概念和bert相结合,duang~就出了这一篇在GLUE、SNLI和SciTail创下新的SOTA的论文。

Intuition

会滑雪的人,学滑冰要容易的多(笔者试过,反过来不大成立,手动狗头)。我觉得类比成九年义务教育更好,十门功课同步学,数学是你学物理的基础,历史知识能提高你作文的水平,etc。

Motivation

  • 监督学习需要大量监督数据,但正常情况下咱都是没有的。MTL(multi-task learning)可以提高low-resource任务的表现。

  • MTL能起到正则的作用,减轻模型对特定任务的过拟合。

  • bert之类的预训练模型充分利用了无监督数据。MTL作为补充,进一步利用了out-domain的监督数据。

Model

模型很简单,看一下这个图:

底层share了bert的表达层,输出层为每个任务设计了各自的输入形式和loss计算方式。

任务和loss计算

任务分类及数据

GLUE

  • 单句分类(Single-Sentence Classification):
    • CoLA(Corpus of Linguistic Acceptability):判断英语句子是否语法正确
    • SST-2(Stanford Sentiment Treebank):影评情感分类
  • 文本相似度(Text Similarity):
    • STS-B(Semantic Textual Similarity Bench-mark):人类标注的1-5的语义相似度数据集。
  • 对句分类(Pairwise Text Classification):
    • RTE(Recognizing Textual Entailment):entailment or not_entailment
    • MNLI(Multi-Genre Natural Language Inference):entailment,contradiction,neural
    • QQP(Quora Question Pairs):判断两个问题是否问的是同一内容。
    • MRPC(Microsoft Research Paraphrase Corpus):判断是否两个句子是语义相同的。
    • WNLI(Winograd NLI):Wino-grad Schema dataset得到的推理任务。
  • 相关性排序(Relevance Ranking):
    • QNLI(Stanford Question Answering):问答对数据集。

Out-domain

  • SNLI(Stanford Natural Language Inference):Flickr30里人工标注了hypotheses的推理数据集
  • SciTail(Science Question Answering Textual Entailment):科学问题的推理,更难。

loss计算

1. 单句分类

交叉熵:

2. 文本相似度

均方误差:

3. 对句分类

前面比较常见,这个对句分类作者处理的方式比较特殊,用了18年作者自己提出的一种叫 SAN(stochastic answer network)的输出层构建方式,推理过程有点繁琐,给大家贴个图。

注意图中的$m$,$n$ 都是sequence length。

总结起来就是,作者得到query和premise分别的token-wise的表达之后,在他们两个之间做了一个attention,然后开辟了一个新的状态维度做RNN,从而得到多次预测结果,再做平均(类似于人推理时,多次思考才能得到最终的判断)。作者在后面证明了 SAN 结构能带来0.1%~0.5%的提升。

loss也是交叉熵:

4. 相关性排序

作者的这个loss设计还是挺有意思的,不用简单的二分类来做这个任务,而是用learning2rank的范式,对于每个query $Q$ 采样 $N$ 个candidates,其中$A^+$是正确答案,其他的都是是错误答案。

负对数似然:

训练过程

训练过程就是把所有数据合并在一起,每个batch只有单一任务的数据,同时会带有一个task-type的标志,这样模型就知道该走哪一条loss计算的路径。

论文里并没有提及对于单个任务,之后还要不要再单独Fine-tune一下,不过参考github的FAQ,再FT一下,结果会更好。

实验

实现细节

Optimizer:Adamax(这个地方跟bert不太一样)
lr:5e-5
batch size:32
max_num of epochs:5
SAN steps:5
warm-up:0.1
clip gradient norm:1
max seq length:512

GLUE结果

从Table 2 可以看出来,MT-DNN在每一项都超过了bert,而且数据越少的任务,提升越明显,对于QQP和MNLI来说,提升就没那么明显了。

Table 3中的ST-DNN名字很玄乎,其实与bert不同的就是用了文中的复杂了一点的输出模块和loss的设计,比如SAN,learning2rank这些,单独训练各个任务。可见都是有一定程度的提升。所以MT-DNN相对于bert的提升其实来自于 multi task 和 special output module 两个部分。

SNLI 和 SciTail 结果

在得到mult-task训练后的ckpt后,用这个weights去fine tune新的任务,结果和GLUE的保持一致,都有提升,且小数据集任务的提升更明显。

Domain 适应性结果

这个结果比较有趣,笔者认为也是比较重要的点,MT-DNN得到的weights相对于bert的weights能在很少的数据下达到不错的效果,且数据越少,相对bert的提升就越大。(甚至23个训练样本就能达到82.1%的准确率,amazing啊。)

Conclusion

打个总结:

  • MT-DNN的优点:
    • 数据要求少
    • 泛化能力强,不容易过拟合
  • MT-DNN的缺点:
    • 实用性:实际应用中也许并不能找到特别合适的,且高质量的多任务
    • 训练慢啊,MT-DNN作者用了4张v100,普通业务要不起这个条件,所以MT-DNN的定位其实类似于bert,训练好了就别乱动了,当pretrain-model用。
  • 作者认为的Further work
    • 更深度的share weights
    • 更有效的训练方法
    • 用更可控的方式融入文本的语言结构(这点个人感觉不适用于现在大刀阔斧搞预训练模型的情况)

Code

源码地址:https://github.com/namisan/mt-dnn

阅读源码前,我习惯思考一下如果我自己写,会怎么写:

  1. 首先咱肯定得分类一下数据,每一类任务对应一个数据流,不能每个任务写一个数据流,太累了。
  2. 新建模型的时候得知道有哪些任务,每个任务num_labels是多少,自动生成输出层集合和与id的映射,训练和推理的时候根据任务id选择输出层。
  3. 怎么保证一个epoch结束,所有任务数据集都用完了呢?
  4. max seq lengthlearning ratebatch size这些超参需要根据任务变化吗?不同任务的loss如何scale呢?

基本上,想到这,心里都有点数了,带着问题看源码实现。

官方的源码是用PyTorch实现的,包括了MT-DNN的训练,和一些下游任务的finetune,同时也提供了训练好的MT-DNN的模型。核心思想和步骤如下:

  1. prepro.py:预处理数据,将GLUE的多个任务分成四类,统一处理成 {'uid': ids, 'label': label, 'token_id': input_ids, 'type_id': type_ids}的形式,方便后面操作。
  2. mt-dnn/batcher.py:自定义的data iterator,将读到的数据处理成Tensor。
  3. mt-dnn/matcher.py:模型,之所以叫matcher,是因为模型有一个ModuleList,存放了不同任务对应的输出层,根据当前batch的任务类型match对应的输出层。
  4. mt-dnn/model.py:这里命名有点混淆,实际的模型是上面的matcher,这里做了一些模型前后的处理工作(ema,predict,save模型之类的)。

重点讲一下mt-dnn/batcher.pymt-dnn/matcher.py这两个部分。

batcher.py

忽略作者对于batch这个单词疯狂的拼写错误,相比于常规单任务的data_iterator,这个类除了iter数据,还要返回关于这个任务的必要信息,比如这个任务的id,任务的类型。make_baches 实现把数据打包成batch,reset用来在每个epoch之后重新shuffle并打包成batch。

matcher.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class SANBertNetwork(nn.Module):
def __init__(self, opt, bert_config=None):
super(SANBertNetwork, self).__init__()
self.dropout_list = []
self.bert_config = BertConfig.from_dict(opt)
self.bert = BertModel(self.bert_config)
if opt['update_bert_opt'] > 0:
for p in self.bert.parameters():
p.requires_grad = False
mem_size = self.bert_config.hidden_size
self.decoder_opt = opt['answer_opt']
# 构建ModuleList,index为task_id
self.scoring_list = nn.ModuleList()
labels = [int(ls) for ls in opt['label_size'].split(',')]
task_dropout_p = opt['tasks_dropout_p']
self.bert_pooler = None

for task, lab in enumerate(labels):
decoder_opt = self.decoder_opt[task]
# 不同任务dropout也不一样
dropout = DropoutWrapper(task_dropout_p[task], opt['vb_dropout'])
self.dropout_list.append(dropout)
if decoder_opt == 1:
out_proj = SANClassifier(mem_size, mem_size, lab, opt, prefix='answer', dropout=dropout)
self.scoring_list.append(out_proj)
else:
out_proj = nn.Linear(self.bert_config.hidden_size, lab)
self.scoring_list.append(out_proj)

self.opt = opt
self._my_init()
self.set_embed(opt)

def forward(self, input_ids, token_type_ids, attention_mask, premise_mask=None, hyp_mask=None, task_id=0):
all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
sequence_output = all_encoder_layers[-1]
if self.bert_pooler is not None:
pooled_output = self.bert_pooler(sequence_output)
decoder_opt = self.decoder_opt[task_id]
if decoder_opt == 1:
max_query = hyp_mask.size(1)
assert max_query > 0
assert premise_mask is not None
assert hyp_mask is not None
hyp_mem = sequence_output[:,:max_query,:]
# 通过任务id,索引到对应的输出层,搞定。
logits = self.scoring_list[task_id](sequence_output, hyp_mem, premise_mask, hyp_mask)
else:
pooled_output = self.dropout_list[task_id](pooled_output)
logits = self.scoring_list[task_id](pooled_output)
return logits

这里笔者删除了其他函数,只保留了__init__forward,正如我们看源码之前猜想的,作者就是通过构建一个ModuleList,根据各个任务的类型、label数等信息append输出层,index即任务id。