PyTorch/[PyTorch 学习笔记] 7.1 模型保存与加载
本章代码:
- https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py
- https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_load.py
- https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/checkpoint_resume.py
- https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/save_checkpoint.py
这篇文章主要介绍了序列化与反序列化,以及 PyTorch 中的模型保存于加载的两种方式,模型的断点续训练。
序列化与反序列化
模型在内存中是以对象的逻辑结构保存的,但是在硬盘中是以二进制流的方式保存的。
序列化是指将内存中的数据以二进制序列的方式保存到硬盘中。PyTorch 的模型保存就是序列化。
反序列化是指将硬盘中的二进制序列加载到内存中,得到模型的对象。PyTorch 的模型加载就是反序列化。
PyTorch 中的模型保存与加载
torch.save
1 | torch.save(obj, f, pickle_module, pickle_protocol=2, _use_new_zipfile_serialization=False) |
主要参数:
- obj:保存的对象,可以是模型。也可以是 dict。因为一般在保存模型时,不仅要保存模型,还需要保存优化器、此时对应的 epoch 等参数。这时就可以用 dict 包装起来。
- f:输出路径
其中模型保存还有两种方式:
保存整个 Module
这种方法比较耗时,保存的文件大
1 | torch.savev(net, path) |
只保存模型的参数
推荐这种方法,运行比较快,保存的文件比较小
1 | state_sict = net.state_dict() |
下面是保存 LeNet 的例子。在网络初始化中,把权值都设置为 2020,然后保存模型。
1 | import torch |
运行完之后,文件夹中生成了`model.pkl
和model_state_dict.pkl
,分别保存了整个网络和网络的参数
torch.load
1 | torch.load(f, map_location=None, pickle_module, **pickle_load_args) |
主要参数:
- f:文件路径
- map_location:指定存在 CPU 或者 GPU。
加载模型也有两种方式
加载整个 Module
如果保存的时候,保存的是整个模型,那么加载时就加载整个模型。这种方法不需要事先创建一个模型对象,也不用知道模型的结构,代码如下:
1 | path_model = "./model.pkl" |
输出如下:
1 | LeNet2( |
只加载模型的参数
如果保存的时候,保存的是模型的参数,那么加载时就参数。这种方法需要事先创建一个模型对象,再使用模型的load_state_dict()
方法把参数加载到模型中,代码如下:
1 | path_state_dict = "./model_state_dict.pkl" |
模型的断点续训练
在训练过程中,可能由于某种意外原因如断点等导致训练终止,这时需要重新开始训练。断点续练是在训练过程中每隔一定次数的 epoch 就保存模型的参数和优化器的参数,这样如果意外终止训练了,下次就可以重新加载最新的模型参数和优化器的参数,在这个基础上继续训练。
下面的代码中,每隔 5 个 epoch 就保存一次,保存的是一个 dict,包括模型参数、优化器的参数、epoch。然后在 epoch 大于 5 时,就break
模拟训练意外终止。关键代码如下:
1 | if (epoch+1) % checkpoint_interval == 0: |
在 epoch 大于 5 时,就break
模拟训练意外终止
1 | if epoch > 5: |
断点续训练的恢复代码如下:
1 | path_checkpoint = "./checkpoint_4_epoch.pkl" |
需要注意的是,还要设置scheduler.last_epoch
参数为保存的 epoch。模型训练的起始 epoch 也要修改为保存的 epoch。
参考资料
如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。
我的文章会首发在公众号上,欢迎扫码关注我的公众号张贤同学。