TensorFlow/TensorFlow 保存 pb 模型
我们在保存 TensorFlow 的模型时,一般都是把模型保存为 ckpt 文件。
但是如果把模型迁移到手机上,那么就要把模型保存为 pb 文件。模型保存为 pb 文件时候,模型的变量都会变成固定的,从而会大大减小模型的大小,适合在手机端运行。
这里来讲下如何把模型保存为 pb 文件,以及如何从 pb 文件中加载模型。
pb 的保存
下面是一个例子。
1 | import tensorflow as tf |
下面的代码中,计算了 \(result = v1 + v2\)。
但我只想保存 result
变量,关键的代码是
1 | output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add']) |
其中 ['add']
是一个 list,里面存放的是要保存的计算节点的名称,因为我们使用的加法,所以计算节点的名称就是 add
。
同理,如果你想保存乘法的计算节点,名称是 multiply
。
你还可以给每个计算节点起名称:
例如上面的
1 | result = v1 + v2 |
默认的名称是 add
,可以使用 TensorFlow 提供的 tf.add()
来替代,并自定义名称:
1
result = tf.add(v1,v2,name='result')
那么在保存的时候,add
就改为 result
。
1 | output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['result']) |
如果你想查看所有计算节点的名称,可以使用下面的语句:
1 | var_list = tf.global_variables() |
var_list
就保存了一部分张量的名称。
或者直接查看某个变量的 name
属性,例如查看 result
的张量名称:
1 | result = tf.multiply(v1,v2,name='result') |
输出是 result:0
。(注意,这里是 result:0
,而不是 result
)
这里需要注意,张量的名称和计算节点的名称不是一个概念。它们是从属关系。具体来说,一个计算节点(也被称为一个Operation)可以包括多个张量。具体到名称上,张量的名称 = 计算节点的名称 + ":"+编号
因此,上面打印了 result.name
的值是 result:0
,我们要取出前面的 result
,作为计算节点的名称。
1 | output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['result']) |
pb 的加载
代码如下:
1 | from tensorflow.python.platform import gfile |
注意这里读取的是张量的名称 "result:0"
。
下面使用一个实际的案例来讲解保存和加载 pb 模型。
保存 pb 模型
在这个例子中,我们定义了一个卷积神经网络。
一般来说,我们只需要保存输入的 placeholder
,以及最终输出的结果、loss 等数据。
1 | num_classes=20 |
在上面的代码中,我们保存的计算节点包括:['softmax_cross_entropy_loss/value','input', 'label', "probability"]
,分别对应 loss、输入、标签以及最终输出的概率,而不用关心中间的那些网络层。
加载 pb 模型
下面是加载 pb 模型的代码。
1 | sess = tf.Session() |
其中关键的代码是:
1 | input_image = sess.graph.get_tensor_by_name("input:0") |
分别把 loss、输入、标签以及最终输出的概率取出来,注意这里使用的是张量名称。然后就可以进行计算和推理了。
你学会了吗?
如果你还有不明白的地方,欢迎给我留言。
如果你觉得这篇文章对你有帮助,不妨给我点个赞,鼓励我写出更多好文章。
参考资料