self attention 使用的一小段经历

3,372次阅读
没有评论

共计 1559 个字符,预计需要花费 4 分钟才能阅读完成。

引子

大名鼎鼎的NLP论文《all in attention》诠释了attention的厉害。attention这种注意力机制的确在一些自然语言任务如机器翻译等取得了非凡的成就。Bert 的出现横扫业界主流评测数据集上的得分,当然这也跟transformer架构有关,这里面也是存在attention。

去年组内同事花了两个晚上大约4小时给我们讲解 bert  从算法到工程,当时以为听懂了,现在实际使用还是有点懵。真的是不用不知道为啥这么做。时间久了也会忘。

这篇文章其实注意点是在 mask 上面,其实对于muti-head 这些都是self attention中的会讲到,网上也有很多的资料。

 mask 小试牛刀

我是在一个信息流的任务上使用self  attention,对比机器翻译的数据一句话,我是使用用户点击的新闻列表,类似其他的列表特征都是可以的,只要与业务能够对的上。

用户的点击行为都是不确定的,你也不知道用户会点多少,所以和句子一样有长有短,但是你会发现在处理的时候还是需要保证相同的长度,这个时候处理的手段就是截断与补全。

如果是正常的长度截断的情况下,你可以正常使用,没有任何的问题,信息都是完整的。

但是当出现需要补全的时候,补全的信息其实误用的信息,而你不希望模型学到这些信息。

所以说到这,mask机制就要浮出水面了。mask本身就是掩膜遮盖的意思,不难理解就是要把这些无用的信息给遮盖掉,这样模型看不到也学不到这些信息。

而且在这讨论的是 padding mask  ,针对填充的数据mask处理。

def Mask(inputs, seq_len, mode='mul'):
        if seq_len == None:
            return inputs
        else:
            mask = tf.sequence_mask(seq_len[:, 0], tf.shape(inputs)[1])
            for _ in range(len(inputs.shape) - 2):
                mask = tf.expand_dims(mask, 2)
            if mode == 'mul':
                return inputs * mask
            if mode == 'add':
                return inputs - (1 - mask) * 1e12

上面函数中seq_len是指列表的实际长度,假设列表设定的长度是50,实际的长度是30,那么需要对后面20补充且需要进行mask处理。下面这张图是是attention的一个计算过程,这是普通的attention计算过程,self attention会沿用这个计算的逻辑。

self attention 使用的一小段经历

第一步  mask 是在softmax之前,所以再回头看下上面给出的mask函数,我们最后一个传参会用add的方式,这样会保证padding部分的数据都是一个负数极小值,这样经过softmax基本上就是为 0 了。是不是感受到了 mask 的灵魂了。

 

还有一步 mask 使用是在 query mask 上面,其实对于self attention上面的Q K V都是一样的,就是列表中的每一个元素都要计算与其他元素之间的联系。为什么原文要进行 query mask ?

在transform项目下看到这个讨论,作者给出的回答。链接在此

The encoder leaves artifacts, which are non-zeros, for paddings. 
It doesn't make sense, a query attends to those, so before applying softmax for getting score, 
I overwrote them with very very small numbers. As a results they will have score 0.
 Likewise, queries for paddings should not have any values, so they are masked with zeros.

 

正文完
请博主喝杯咖啡吧!
post-qrcode
 2
admin
版权声明:本站原创文章,由 admin 2020-04-04发表,共计1559字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)
验证码