共计 3119 个字符,预计需要花费 8 分钟才能阅读完成。
什么是checkpoint?
检查点checkpoint中存储着模型model所使用的的所有的 tf.Variable 对象以及模型结构的定义
checkpoint的一般格式如下:
(1)meta文件
.meta文件:一个协议缓冲,保存tensorflow中完整的graph、variables、operation、collection;这是我们恢复模型结构的参照;
meta文件保存的是图结构,通俗地讲就是神经网络的网络结构。我们可以使用下面的代码只在第一次保存meta文件。
saver.save(sess, 'my_model.ckpt', global_step=step, write_meta_graph=False)
在后面恢复整个graph的结构的时候,并且还可以使用
tf.train.import_meta_graph(‘xxxxxx.meta’)
能够导入图结构。
(2)data文件
keypoint_model.ckpt-9.data-00000-of-00001:数据文件,保存的是网络的权值,偏置,操作等等。
(3)index文件
keypoint_model.ckpt-9.index 是一个不可变得字符串字典,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据,所谓的元数据就是描述这个Variable 的一些信息的数据。 “数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等。
Note: 以前的版本中tensorflow的model只保存一个文件中。
(4)checkpoint文件——文本文件
checkpoint是一个文本文件,记录了训练过程中在所有中间节点上保存的模型的名称,首行记录的是最后(最近)一次保存的模型名称。checkpoint是检查点文件,文件保存了一个目录下所有的模型文件列表;比如我上面的模型保存了最后的5份checkpoint,这里打开checkpoint查看得到如下内容:
model_checkpoint_path: "keypoint_model.ckpt-9" # 最新的那一份
all_model_checkpoint_paths: "keypoint_model.ckpt-5"
all_model_checkpoint_paths: "keypoint_model.ckpt-6"
all_model_checkpoint_paths: "keypoint_model.ckpt-7"
all_model_checkpoint_paths: "keypoint_model.ckpt-8"
all_model_checkpoint_paths: "keypoint_model.ckpt-9
实例
import tensorflow as tf
import numpy as np
# 1.准备数据:
x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4
# 2.构造一个线性模型
w = tf.Variable(tf.random_normal([1], -1, 1)) #创建新对象,当检测到命名冲突时,系统会自己处理
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b
# 3.求解模型
# 设置损失函数:误差的均方差
loss = tf.reduce_mean(tf.square(y - y_predict))
# 选择梯度下降的方法
optimizer = tf.train.GradientDescentOptimizer(0.5)
# 迭代的目标:最小化损失函数
train = optimizer.minimize(loss)
#参数定义声明
isTrain = True
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = 'F:\\\\vivocode\\\\tftestmodel\\\\'
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
############################################################
# 以下是用 tf 来解决上面的任务
# 1.初始化变量:tf 的必备步骤,主要声明了变量,就必须初始化才能用
# init = tf.global_variables_initializer()
# 设置tensorflow对GPU的使用按需分配
#config = tf.ConfigProto()
#config.gpu_options.allow_growth = True
# 2.启动图 (graph)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
#判断当前工作状态
if isTrain: #isTrain:True表示训练;False:表示测试
# 3.迭代,反复执行上面的最小化损失函数这一操作(train op),拟合平面
for i in range(train_steps): #train_steps表示训练的次数,例子中使用1006666666
sess.run(train, feed_dict={x: x_data})
if (i + 1) % checkpoint_steps == 0: #表示训练多少次保存一下checkpoints,例子中使用50
print ('step: {} train_acc: {} loss: {}'.format(i, sess.run(w), sess.run(b)))
saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) #表示checkpoints文件的保存路径,例子中使用当前路径
else: #如果isTrain=False,则进行测试
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path) #恢复变量
else:
pass
print(sess.run(w),sess.run(b))
生成的checkpoint 如下所示:
load 数据
如果你导入meta数据是不需要在定义图就可以做些运算
import tensorflow as tf
import numpy as np
checkpoint_dir = 'F:\\\\vivocode\\\\tftestmodel\\\\'
# 2.启动图 (graph)
# saver = tf.train.Saver()
with tf.Session() as sess:
model_path = tf.train.latest_checkpoint(checkpoint_dir) # 获取最新的模型,注意这里的是文件夹哦
saver=tf.train.import_meta_graph(model_path+'.meta')
saver.restore(sess,model_path)
graph = tf.get_default_graph()
X = graph.get_tensor_by_name('Variable:0')
print(sess.run(X))
[3.985119]