数据竞赛/阿里天池 NLP 入门赛 Bert 方案 -3 Bert 预训练与分类
前言
这篇文章用于记录阿里天池 NLP 入门赛,详细讲解了整个数据处理流程,以及如何从零构建一个模型,适合新手入门。
赛题以新闻数据为赛题数据,数据集报名后可见并可下载。赛题数据为新闻文本,并按照字符级别进行匿名处理。整合划分出14个候选分类类别:财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐的文本数据。实质上是一个 14 分类问题。
赛题数据由以下几个部分构成:训练集20w条样本,测试集A包括5w条样本,测试集B包括5w条样本。
比赛地址:https://tianchi.aliyun.com/competition/entrance/531810/introduction
数据可以通过上面的链接下载。
代码地址:https://github.com/zhangxiann/Tianchi-NLP-Beginner
分为 3 篇文章介绍:
在上一篇文章中,我们介绍了 Bert 的源码。
这篇文章,我们来看下如何预训练 Bert,以及使用 Bert 进行分类。
训练 Bert
在前面,我们已经了解完了 Bert 的源码,现在我们我来看如何训练 Bert。
训练 Bert 对应的代码文件是 run_pretraining.py
。
脚本
训练脚本为:run_pretraining.sh
,内容如下:
1 | python run_pretraining.py |
训练过程主要用了estimator调度器。这个调度器支持自定义训练过程,将训练集传入之后自动训练。
对应的代码文件是 run_pretraining.py
。
主要函数是 model_fn_builder()
,get_masked_lm_output()
,get_next_sentence_output()
。
model_fn_builder()
在这个函数里创建 Bert 模型,得到输出,然后分别调用 get_masked_lm_output()
计算预测 mask 词的损失~~,调用 get_next_sentence_output()
计算预测前后句子的 loss*~~(这里不预测句子前后关系,因此不计算 loss)。
1 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, |
get_masked_lm_output()
get_masked_lm_output()
的作用是计算 mask 预测的 loss。
输入参数:
- input_tensor:
BertModel
最后一层的输出,形状是[batch_size, seq_length, hidden_size]
。 - output_weights:形状是
[vocab_size, hidden_size]
。 - positions:表示 mask 的位置,形状是
[vocab_size, hidden_size]
。 - label_ids:表示 mask 对应的真实 token。
- label_weights:每个 mask 的权重。
流程如下:
- 从
input_tensor
中,根据positions
取出 mask 对应的输出。 - 将
input_tensor
经过一个全连接层和layer_norm
层,得到logits
,形状为[batch_size * max_predictions_per_seq, vocab_size]
。 - 将
logits
和output_weights
相乘,得到概率矩阵log_probs
,形状为[batch_size * max_predictions_per_seq, vocab_size]
,再经过 softmax。 - 将
log_probs
和真实标签one_hot_labels
计算加权 loss。
1 | # input_tensor: [batch_size, seq_length, hidden_size] |
训练完成后,会把训练好的模型保存到 output_dit
中。
转换为 PyTorch 模型
由于我们是使用 Tensorflow 来训练模型,而我们的文本分类模型是使用 PyTorch 的,因此需要把 Tensorflow 的模型,转换为 PyTorch 的模型。
这里使用 HuggingFace 提供的 转换代码。
代码文件为 convert_checkpoint.py
,脚本文件为 convert_checkpoint.sh
,脚本如下:
1 | export BERT_BASE_DIR=./bert-mini # 设置模型路径 |
注意,你需要先安装 tensorflow,pytorch,transformers。
微调 Bert 模型
在上一篇文章 阿里天池 NLP 入门赛 TextCNN 方案代码详细注释和流程讲解 中,我们使用 TextCNN 来训练模型,模型结构图如下:
图中的 WordCNNEncoder
就是TextCNN。
我们把 TextCNN 替换为 Bert。
模型结构图如下:
我们只关注如何使用 WordBertEncoder
,模型其他部分的细节与上一篇文章一样,请查看 阿里天池 NLP 入门赛 TextCNN 方案代码详细注释和流程讲解。
WordBertEncoder
代码如下。
首先加载转换好的 PyTorch 模型。
在 forward()
函数中,将 input_ids
和 token_type_ids
输入到 Bert 模型。
得到 sequence_output
(表示最后一个 Encoder 对应的 hidden-states),pooled_output
(表示最后一个 Encoder 的第一个 token 对应的 hidden-states)。
代码中有详细注释。
1 | # build word encoder |
如果你有疑问,欢迎留言。
参考
如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。