NLP/Bert 源码解读 3

NLP/Bert 源码解读 3

上一篇文章中,我讲解了在 transformers 中,与 Bert 相关的模型,其中最核心的就是 BertModel。今天,我会详细讲解 BertModel

BertModel

构造方法

首先来看 BertModel 的构造函数。

1
2
3
4
5
6
7
8
9
def __init__(self, config):
super().__init__(config)
self.config = config

self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)

self.init_weights()

可以看到, BertModel 包含了3 个模块,分别是 BertEmbeddingsBertEncoderBertPooler

forward() 方法

forward() 函数的流程如下:

  • 对数据进行预处理。包括:获取 input_shapeattention_masktoken_type_idshead_mask。并且如果是 decoder,那么还需要 encoder_hidden_shapeencoder_extended_attention_mask
  • 调用 self.embeddings 得到 embedding_output,形状是 (batch_size, sequence_length, hidden_size)
  • embedding_output 输入 self.encoder 进行编码(或者解码),得到 encoder_outputs,是一个 tuple。
  • 通过 sequence_output = encoder_outputs[0],得到 sequence_output,形状是 (batch_size, sequence_length, hidden_size)
  • sequence_output 输入 self.pooler 得到输出 pooled_output,形状是 (batch_size, hidden_size)

forward() 方法的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
):
# 1.对数据进行预处理
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

# 获取 input_shape
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

device = input_ids.device if input_ids is not None else inputs_embeds.device

# 获取 attention_mask
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
# 获取 token_type_ids
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
# 把 attention_mask 做广播,得到 attention_mask
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)

# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
# 判断是否为 decoder,获取 encoder_hidden_shape,encoder_extended_attention_mask
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
# 获取 head_mask
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)




# 2. 通过 embedding 得到每个 token 的向量
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)

# 2. 输入encoder 进行编码(或者解码)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# sequence_output: (batch_size, sequence_length, hidden_size)
sequence_output = encoder_outputs[0]
# 输入 pooler,得到结果 pooled_output,形状是:`(batch_size, hidden_size)`
pooled_output = self.pooler(sequence_output)

outputs = (sequence_output, pooled_output,) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)

我们可以看到,数据会依次经过 BertEmbeddingsBertEncoderBertPooler。下面我们来分析这 3 个类。

BertEmbeddings

构造方法

构造方法中,创建了 3 个 Embedding,分别是 word_embeddingsposition_embeddingstoken_type_embeddings

然后创建了 BertLayerNorm 层和 Dropout 层。

代码如下:

1
2
3
4
5
6
7
8
9
10
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

forward() 方法

forward() 方法的流程如下:

  • input_idsposition_idstoken_type_ids 分别经过各自的 embedding 层,得到 inputs_embedsposition_embeddingstoken_type_embeddings
  • 将得到的 3 个 embedding 相加,再经过 BertLayerNorm 层和 Dropout 层,得到输出,形状是 (batch_size, sequence_length, hidden_size)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]

seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings

BertEncoder

BertEncoder 由多层 BertLayer 组成,而 BertLayer 对应于一层 编码器(encoder)如下图所示:


上图的结构,堆叠多层,得到 BertEncoder

构造方法

构造函数如下:

1
2
3
4
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])

BertEncoder 是由多层 BertLayer 组成的。

forward() 方法

forward() 方法的流程如下:

  • 遍历每一层 BertLayer,得到 layer_outputs,取出第一个元素,也就是 hidden_states
  • hidden_states 输入到下一层的 BertLayer。第一层的 hidden_states 就是前面 embedding 层得到的 embedding_output
  • 返回 outputs,是一个 tuple。
    • 第 1 个元素是:最后一层的 hidden state,形状是 (batch_size, sequence_length, hidden_size)
    • output_hidden_states=True 时,outputs 的第 2 个元素是 all hidden states
    • output_attentions=True 时,outputs 的第 3 个元素是 all attentions

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if getattr(self.config, "gradient_checkpointing", False):

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
# 取出每一层的 hidden_states,输入到下一层
hidden_states = layer_outputs[0]

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)

# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)

BertPooler

BertPooler 由线性层和 Tanh 激活函数组成。

构造方法

1
2
3
4
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()

从构造函数可以看出,由线性层和 Tanh 激活函数组成。

forward() 方法

基本思路是:这里取出第一个 token 对应的 hidden_state,经过线性层和 Tanh 激活函数。

得到 pooled_output,形状是 (batch_size, hidden_size)

代码如下:

1
2
3
4
5
6
7
8
9
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
# 这里取出第一个 token 对应的 hidden_state
# first_token_tensor: (batch_size, hidden_size)
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output

小结

总结一下,在 BertModel 中。数据会依次经过BertEmbeddingsBertEncoder,,最终得到输出。而 BertPooler 是由多层 BertLayer 组成的。下面来分析 BertLayer 代码。

BertLayer

BertLayer 对应于一层编码器(encoder)或者一层解码器(decoder),如下图所示:


其中解码器(decoder)比编码器(encoder)多了 Encoder-Decoder Attention + Add & Normalize 层。

为了避免你接下来会头晕,我把接下来要讲的类,拆分得更加细致,画成了一张图,帮助你从整体把握。


构造方法

在构造函数中:

  • BertAttention 对应于 Encoder-Decoder Attention + Add & Normalize 层。
  • 如果是解码器(decoder),那么再添加一层 BertAttention,也称为 Cross-Attention 层
  • BertIntermediateBertOutput加起来,对应于 Feed Forward 层。
1
2
3
4
5
6
7
8
9
10
11
def __init__(self, config):
super().__init__()
# BertAttention 对应于上图中的Encoder-Decoder Attention + Add & Normalize
self.attention = BertAttention(config)
self.is_decoder = config.is_decoder
# 如果是解码器(decoder),那么再添加一层 BertAttention,也称为 Cross-Attention 层
if self.is_decoder:
self.crossattention = BertAttention(config)
# BertIntermediate 对应于 Feed Forward 层
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)

forward() 方法

forward() 方法的流程如下:

  • 主要的输入是 hidden_states,对应于前面 embedding 层得到的 embedding_output,形状是 (batch_size, sequence_length, hidden_size)
  • hidden_states 输入 BertAttention 层,得到输出 self_attention_outputs,是一个 tuple,包含hidden_stateattention weights
  • 我们通过 attention_output = self_attention_outputs[0] 取出 attention_output,也就是 hidden_state,形状是 (batch_size, sequence_length, hidden_size)
  • 如果是作为解码器(decoder),那么再经过一次 BertAttention。通过 attention_output = self_attention_outputs[0] 取出 attention_output
  • 接着经过 BertIntermediate 层,得到 intermediate_output,形状是 (batch_size, sequence_length, intermediate_size)
  • 最后经过 BertOutput 层,得到 layer_output,形状是 (batch_size, sequence_length, hidden_size)

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
  def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
# 把 hidden_states(embedding_output) 输入 attention
# 得到的输出 self_attention_outputs,形状是 ()
self_attention_outputs = self.attention(
hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
)
# 取出 attention_output,也就是 hidden_state: (batch_size, sequence_length, hidden_size)
attention_output = self_attention_outputs[0]
# 取出 attention weights
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# 如果是 decoder,那么再经过一次 BertAttention
if self.is_decoder and encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
# 取出 attention_output,也就是 hidden_state: (batch_size, sequence_length, hidden_size)
attention_output = cross_attention_outputs[0]
# # 取出 attention weights
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
# 经过 intermediate 层,intermediate_output: (batch_size, sequence_length, intermediate_size)
intermediate_output = self.intermediate(attention_output)
# 经过 output 层,layer_output: (batch_size, sequence_length, hidden_size)
layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + outputs
return outputs

BertAttention

BertAttention 层包括 BertSelfAttentionBertSelfOutput

我们看下 BertAttention 所处的位置


构造方法

构造函数如下:

1
2
3
4
5
def __init__(self, config):
super().__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
self.pruned_heads = set()

定义了 BertSelfAttention 层和 BertSelfOutput 层。

forward() 方法

forward() 方法的流程如下:

  • 输入hidden_states 的形状是 (batch_size, sequence_length, hidden_size)
  • 经过 BertSelfAttention 层,得到 self_outputs,是一个 tuple,其中 self_outputs[0]hidden_state,形状是 (batch_size, sequence_length, hidden_size)

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
# 经过 BertSelfAttention 层
self_outputs = self.self(
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
)
# 经过 BertSelfOutput 层
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs

BertIntermediate

BertIntermediate 层对应于 Feed Forward 层。

我们看下 BertIntermediate 所处的位置。


构造方法

在构造函数中,首先定义了一个线性层,接着定义了激活函数层。

代码如下:

1
2
3
4
5
6
7
8
9
def __init__(self, config):
super().__init__()
# 首先定义一个线性层
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
# 接着,定义激活函数
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act

forward() 方法

输入 hidden_states 依次经过线性层和激活函数。

代码如下:

1
2
3
4
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states

BertOutput

BertOutput 层对应于 Add & Normalize 层。

我们看下 BertOutput 所处的位置。


构造方法

分别定义了线性层,BertLayerNorm 层,以及 Dropout 层。

代码如下:

1
2
3
4
5
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

forward() 方法

输入参数 hidden_states 是经过 Attention 得到的输出,而 input_tensor 是原始输入,它们的形状都是 (batch_size, sequence_length, hidden_size)

hidden_states 经过线性层和 Dropout 层,然后和 input_tensor 相加,输入到 BertLayerNorm 层,得到最终的输出。

代码如下:

1
2
3
4
5
6
7
8
def forward(self, hidden_states, input_tensor):
# 经过线性层
hidden_states = self.dense(hidden_states)
# 经过 Dropout 层
hidden_states = self.dropout(hidden_states)
# 相加,输入 BertLayerNorm 层
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states

BertSelfAttention

这个类是计算量最大的类,实现了 Attention 的并行计算,也是最核心的类。

我们看下 BertSelfAttention 所处的位置。


构造方法

在构造函数里,首先根据多头注意力的数量,计算 hidden_size 实际的输出长度。

其中 config.hidden_size表示我们期望输出的长度,config.num_attention_heads 表示多头注意力的数量,self.all_head_size 表示实际的输出长度。

也就是说,期望输出的长度实际的输出长度不一定相同。具体可分为两种情况。

  • 第 1 种情况:config.hidden_size能够整除 config.num_attention_heads。假设 config.hidden_size=768,而 config.num_attention_heads=8,那么分配到每个 head 的长度就是 \(int(768 \div 8) = 96\),那么实际输出的长度就是 \(96 \times 8 = 768\)。这种情况,实际的输出长度就是期望的输出长度。
  • 第 2 种情况:config.hidden_size不能整除 config.num_attention_heads。假设 config.hidden_size=768,而 config.num_attention_heads=10,那么分配到每个 head 的长度就是 \(int(768 \div 10) = 76\),那么实际输出的长度就是 \(76 \times 10 = 760\)。这种情况,实际的输出长度和期望的输出长度不一样。

最后,在构造函数中,也定义了 querykeyvalue 的权重矩阵。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
  def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)

self.num_attention_heads = config.num_attention_heads
# 用 hidden_size 除以 num_attention_heads,得到每个 head 的 size
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
# 再计算实际的 head_size,因为上面可能出现不能够整除的情况
self.all_head_size = self.num_attention_heads * self.attention_head_size

# 定义 query, key, value 的权重矩阵
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

forward() 方法

forward() 方法的流程是:

  • 首先计算 query 矩阵。
  • 然后判断当前是否为解码器(decoder)
    • 如果是 decoder,那么计算 encoder 传过来的 key 和 value。
    • 否则,就是encoder, 计算前一层传进来的 key 和 value。
  • 此时,query,key,value 矩阵的形状都是 (batch_size, sequence_length, hidden_size)
  • 根据多头注意力的数量,变换 query,key,value 矩阵的形状,从 (batch_size, sequence_length, hidden_size) 变为 (sequence_length, batch_size, self.num_attention_heads, self.attention_head_size)
  • 然后 query 和 key 相乘,进行缩放,经过 mask,softmax,得到 attention_scores,形状是 (sequence_length, batch_size, self.num_attention_heads, self.num_attention_heads)
  • attention_scores 和 value 矩阵相乘,得到多组注意力 context_layer,形状是 (sequence_length, batch_size, self.num_attention_heads, self.attention_head_size)
  • 最后变换形状,相当于拼接多头注意力,context_layer 的形状变为 (batch_size, sequence_length, self.all_head_size)

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
  def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
# 计算 query 矩阵
mixed_query_layer = self.query(hidden_states)

# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
if encoder_hidden_states is not None:
#如果是 decoder,那么计算 encoder 传过来的 key 和 value
mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states)
attention_mask = encoder_attention_mask
else:
# 否则,就是encoder, 计算前一层传进来的 key 和 value
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)

# 变换形状,从 (batch_size, sequence_length, hidden_size) 变为 (sequence_length, batch_size, self.num_attention_heads, self.attention_head_size)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
# attention_scores: (sequence_length, batch_size, self.num_attention_heads, self.num_attention_heads)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
# 进行缩放
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
# 经过 softmax
attention_probs = nn.Softmax(dim=-1)(attention_scores)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# 再经过 dropout
# 得到 attention_scores,形状是 (sequence_length, batch_size, self.num_attention_heads, self.num_attention_heads)
attention_probs = self.dropout(attention_probs)

# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask

# context_layer: (sequence_length, batch_size, self.num_attention_heads, self.attention_head_size)
context_layer = torch.matmul(attention_probs, value_layer)

# context_layer: (batch_size, sequence_length, self.num_attention_heads, self.attention_head_size)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
# 相当于拼接多头注意力的输出
# 变换回原来的形状,context_layer: (batch_size, sequence_length, self.all_head_size)
context_layer = context_layer.view(*new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs

BertSelfOutput

BertSelfOutput 得作用,和 BertOutput 的作用一样,都是对应于 Add & Normalize 层。

而且这两个类的代码完全一样,我认为其中一个类应该是多余的。

总结

今天这篇文章,讲了非常多的类。我把所有类的层次关系,都囊括载下面这张图中了,方便你回顾。


下一篇文章,我会进入实战环节,使用 BertForSequenceClassification,在自己的训练集上训练情感分类模型。


如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。

我的文章会首发在公众号上,欢迎扫码关注我的公众号张贤同学


评论