一、Transformer
Transformer,是由编码块和解码块两部分组成,其中编码块由多个编码器组成,解码块同样也是由多个解码块组成。
编码器:自注意力 + 全连接
解码块:自注意力 + 编码 - 解码自注意力 +全连接
K、V最后一个编码器输出
二、BERT
import torch
class MultiHeadAttention(nn.Module):
def__init__(self,hidden_size,head_num):
super().__init__()
self.head_size = hidden_size / head_num
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
def transpose_dim(self,x):
x_new_shape = x.size()[:-1]+(self.head_num, head_size)
x = x.view(*x_new_shape)
return x.permute(0,2,1,3)
def forward(self,x,attention_mask):
Quary_layer = self.query(x)
Key_layer = self.key(x)
Value_layer = self.value(x)
'''
B = Quary_layer.shape[0]
N = Quary_layer.shape[1]
multi_quary = Quary_layer.view(B,N,self.head_num,self.head_size).transpose(1,2)
'''
multi_quary =self.transpose_dim(Quary_layer)
multi_key =self.transpose_dim(Key_layer)
multi_value =self.transpose_dim(Value_layer)
attention_scores = torch.matmul(multi_quary, multi_key.transpose(-1,-2))
attention_scores = attention_scores / math.sqrt(self.head_size)
attention_probs = nn.Softmax(dim=-1)(attention_scores)
context_layer = torch.matmul(attention_probs,values_layer)
context_layer = context_layer.permute(0,2,1,3).contiguous()
context_layer_shape = context_layer.size()[:-2]+(self.hidden_size)
context_layer = cotext_layer.view(*context_layer_shape
return context_layer