在实践中,当给定相同的查询、键和值的集合时, 希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如短距离依赖和长距离依赖关系)。 因此允许注意力机制组合使用查询、键和值的不同子空间表示(representation subspaces)可能是有益的。
为此与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的 ℎ 组不同的线性投影(linear projections)来变换查询、键和值。 然后,这 ℎ 组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这 ℎ 个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出,这种设计被称为多头注意力(multihead attention)。 对于 ℎ 个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)。 下图展示了使用全连接层来实现可学习的线性变换的多头注意力。
在实现多头注意力之前,用数学语言将这个模型形式化地描述出来。给定查询 q ∈ R d q mathbf{q} in mathbb{R}^{d_q} q∈Rdq、键 k ∈ R d k mathbf{k} in mathbb{R}^{d_k} k∈Rdk和值 v ∈ R d v mathbf{v} in mathbb{R}^{d_v} v∈Rdv,每个注意力头 h i mathbf{h}_i hi( i = 1 , … , h i = 1, ldots, h i=1,…,h)的计算方法为:
h i = f ( W i ( q ) q , W i ( k ) k , W i ( v ) v ) ∈ R p v , mathbf{h}_i = f(mathbf W_i^{(q)}mathbf q, mathbf W_i^{(k)}mathbf k,mathbf W_i^{(v)}mathbf v) in mathbb R^{p_v}, hi=f(Wi(q)q,Wi(k)k,Wi(v)v)∈Rpv,
其中,可学习的参数包括 W i ( q ) ∈ R p q × d q mathbf W_i^{(q)}inmathbb R^{p_qtimes d_q} Wi(q)∈Rpq×dq、 W i ( k ) ∈ R p k × d k mathbf W_i^{(k)}inmathbb R^{p_ktimes d_k} Wi(k)∈Rpk×dk和 W i ( v ) ∈ R p v × d v mathbf W_i^{(v)}inmathbb R^{p_vtimes d_v} Wi(v)∈Rpv×dv,以及代表注意力汇聚的函数 f f f。 f f f可以是加性注意力和缩放点积注意力。多头注意力的输出需要经过另一个线性转换,它对应着 h h h个头连结后的结果,因此其可学习参数是 W o ∈ R p o × h p v mathbf W_oinmathbb R^{p_otimes h p_v} Wo∈Rpo×hpv:
W o [ h 1 ⋮ h h ] ∈ R p o . mathbf W_o [h1⋮hh]
in mathbb{R}^{p_o}. Wo⎣ ⎡h1⋮hh⎦ ⎤∈Rpo.
基于这种设计,每个头都可能会关注输入的不同部分,可以表示比简单加权平均值更复杂的函数。
在实现过程中,选择缩放点积注意力作为每一个注意力头。为了避免计算代价和参数代价的大幅增长,设定 p q = p k = p v = p o / h p_q = p_k = p_v = p_o / h pq=pk=pv=po/h。值得注意的是,如果将查询、键和值的线性变换的输出数量设置为 p q h = p k h = p v h = p o p_q h = p_k h = p_v h = p_o pqh=pkh=pvh=po,则可以并行计算 h h h个头。在下面的实现中, p o p_o po是通过参数num_hiddens指定的。
class MultiHeadAttention(nn.Module): """多头注意力""" def __init__(self,query_size,key_size,value_size,num_hiddens,num_heads,dropout,bias=False): super(MultiHeadAttention,self).__init__() self.num_heads = num_heads self.attention = d2l.torch.DotProductAttention(dropout) self.W_q = nn.Linear(query_size,num_hiddens,bias=bias) self.W_k = nn.Linear(key_size,num_hiddens,bias=bias) self.W_v = nn.Linear(value_size,num_hiddens,bias=bias) self.W_o = nn.Linear(num_hiddens,num_hiddens,bias=bias) def forward(self,queries,keys,values,valid_lens): # queries,keys,values的形状: # (batch_size,查询或者“键-值”对的个数,num_hiddens) # valid_lens 的形状: # (batch_size,)或(batch_size,查询的个数) # 经过变换后,输出的queries,keys,values 的形状: # (batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) queries = transpose_qkv(self.W_q(queries),self.num_heads) keys = transpose_qkv(self.W_k(keys),self.num_heads) values = transpose_qkv(self.W_v(values),self.num_heads) if valid_lens is not None: # 在轴0,将第一项(标量或者矢量)复制num_heads次, # 然后如此复制第二项,然后诸如此类。 valid_lens = torch.repeat_interleave(valid_lens,repeats=self.num_heads,dim=0) # output的形状:(batch_size*num_heads,查询的个数, # num_hiddens/num_heads) output = self.attention(queries,keys,values,valid_lens) # output_concat的形状:(batch_size,查询的个数,num_hiddens) output_concat = transpose_output(output,self.num_heads) return self.W_o(output_concat)
12345678910111213141516171819202122232425262728293031为了能够使多个头并行计算, 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说transpose_output函数反转了transpose_qkv函数的操作。
def transpose_qkv(X,num_heads): """为了多注意力头的并行计算而变换形状""" # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens) # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads, # num_hiddens/num_heads) X = X.reshape(X.shape[0],X.shape[1],num_heads,-1) # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) X = X.permute(0,2,1,3) # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) return X.reshape(-1,X.shape[2],X.shape[3]) def transpose_output(X,num_heads): """逆转transpose_qkv函数的操作""" X = X.reshape(-1,num_heads,X.shape[1],X.shape[2]) X = X.permute(0,2,1,3) return X.reshape(X.shape[0],X.shape[1],-1)
1234567891011121314151617181920下面使用键和值相同的例子来测试编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_size,num_queries,num_hiddens)。
num_hiddens,num_heads = 100,5 multiHeadAttention = MultiHeadAttention(num_hiddens,num_hiddens,num_hiddens,num_hiddens,5,0.5) multiHeadAttention.eval() 123
输出结果如下: MultiHeadAttention( (attention): DotProductAttention( (dropout): Dropout(p=0.5, inplace=False) ) (W_q): Linear(in_features=100, out_features=100, bias=False) (W_k): Linear(in_features=100, out_features=100, bias=False) (W_v): Linear(in_features=100, out_features=100, bias=False) (W_o): Linear(in_features=100, out_features=100, bias=False) ) 12345678910
batch_size,num_queries = 2,4 num_kvpairs,valid_lens = 6,torch.tensor([3,2]) Y = torch.ones(size=(batch_size,num_kvpairs,num_hiddens)) X = torch.ones(size=(batch_size,num_queries,num_hiddens)) multiHeadAttention(X,Y,Y,valid_lens).shape 12345
输出结果如下: torch.Size([2, 4, 100]) 12
import torch import d2l.torch from torch import nn def transpose_qkv(X, num_heads): """为了多注意力头的并行计算而变换形状""" # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens) # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads, # num_hiddens/num_heads) X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) X = X.permute(0, 2, 1, 3) # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) return X.reshape(-1, X.shape[2], X.shape[3]) def transpose_output(X, num_heads): """逆转transpose_qkv函数的操作""" X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) X = X.permute(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1) class MultiHeadAttention(nn.Module): """多头注意力""" def __init__(self, query_size, key_size, value_size, num_hiddens, num_heads, dropout, bias=False): super(MultiHeadAttention, self).__init__() self.num_heads = num_heads self.attention = d2l.torch.DotProductAttention(dropout) self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias) def forward(self, queries, keys, values, valid_lens): # queries,keys,values的形状: # (batch_size,查询或者“键-值”对的个数,num_hiddens) # valid_lens 的形状: # (batch_size,)或(batch_size,查询的个数) # 经过变换后,输出的queries,keys,values 的形状: # (batch_size*num_heads,查询或者“键-值”对的个数, # num_hiddens/num_heads) queries = transpose_qkv(self.W_q(queries), self.num_heads) keys = transpose_qkv(self.W_k(keys), self.num_heads) values = transpose_qkv(self.W_v(values), self.num_heads) if valid_lens is not None: # 在轴0,将第一项(标量或者矢量)复制num_heads次, # 然后如此复制第二项,然后诸如此类。 valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0) # output的形状:(batch_size*num_heads,查询的个数, # num_hiddens/num_heads) output = self.attention(queries, keys, values, valid_lens) # output_concat的形状:(batch_size,查询的个数,num_hiddens) output_concat = transpose_output(output, self.num_heads) return self.W_o(output_concat) num_hiddens, num_heads = 100, 5 multiHeadAttention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, 5, 0.5) multiHeadAttention.eval() batch_size, num_queries = 2, 4 num_kvpairs, valid_lens = 6, torch.tensor([3, 2]) Y = torch.ones(size=(batch_size, num_kvpairs, num_hiddens)) X = torch.ones(size=(batch_size, num_queries, num_hiddens)) multiHeadAttention(X, Y, Y, valid_lens).shape
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071注意力机制第一篇:李沐动手学深度学习V2-注意力机制
注意力机制第二篇:李沐动手学深度学习V2-注意力评分函数
注意力机制第三篇:李沐动手学深度学习V2-基于注意力机制的seq2seq
注意力机制第四篇:李沐动手学深度学习V2-自注意力机制之位置编码
注意力机制第五篇:李沐动手学深度学习V2-自注意力机制
注意力机制第六篇:李沐动手学深度学习V2-多头注意力机制和代码实现
注意力机制第七篇:李沐动手学深度学习V2-transformer和代码实现
相关知识
动手学深度学习笔记(一)
深度学习(花书)+ 动手学深度学习(李沐)资料链接整理
基于深度学习的病虫害智能化识别系统
沐尔囝熙与广州杨森药业达成品牌深度战略合作
深度学习及其应用
深度学习 花卉识别
深度学习应用开发
深度学习下的小样本玉米叶片病害识别研究
∫V2(x
基于深度学习技术的农作物病虫害检测识别系统的研究
网址: 李沐动手学深度学习V2 https://m.huajiangbk.com/newsview1841126.html
上一篇: Lambda表达式以及变量捕获( |
下一篇: 调控黄瓜花器官发育基因作用机制获 |