Checkpoint

3,073次阅读
没有评论

共计 3119 个字符,预计需要花费 8 分钟才能阅读完成。

什么是checkpoint?

检查点checkpoint中存储着模型model所使用的的所有的 tf.Variable 对象以及模型结构的定义

checkpoint的一般格式如下:

(1)meta文件

.meta文件:一个协议缓冲,保存tensorflow中完整的graph、variables、operation、collection;这是我们恢复模型结构的参照;

meta文件保存的是图结构,通俗地讲就是神经网络的网络结构。我们可以使用下面的代码只在第一次保存meta文件。

saver.save(sess, 'my_model.ckpt', global_step=step, write_meta_graph=False)

在后面恢复整个graph的结构的时候,并且还可以使用

tf.train.import_meta_graph(‘xxxxxx.meta’)

能够导入图结构。

(2)data文件

keypoint_model.ckpt-9.data-00000-of-00001:数据文件,保存的是网络的权值,偏置,操作等等。

(3)index文件

keypoint_model.ckpt-9.index  是一个不可变得字符串字典,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据,所谓的元数据就是描述这个Variable 的一些信息的数据。 “数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等。

Note: 以前的版本中tensorflow的model只保存一个文件中。

(4)checkpoint文件——文本文件

checkpoint是一个文本文件,记录了训练过程中在所有中间节点上保存的模型的名称,首行记录的是最后(最近)一次保存的模型名称。checkpoint是检查点文件,文件保存了一个目录下所有的模型文件列表;比如我上面的模型保存了最后的5份checkpoint,这里打开checkpoint查看得到如下内容:

model_checkpoint_path: "keypoint_model.ckpt-9"      # 最新的那一份
all_model_checkpoint_paths: "keypoint_model.ckpt-5"
all_model_checkpoint_paths: "keypoint_model.ckpt-6"
all_model_checkpoint_paths: "keypoint_model.ckpt-7"
all_model_checkpoint_paths: "keypoint_model.ckpt-8"
all_model_checkpoint_paths: "keypoint_model.ckpt-9

实例

import tensorflow as tf
import numpy as np
 
# 1.准备数据: 
x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4
 
# 2.构造一个线性模型 
w = tf.Variable(tf.random_normal([1], -1, 1)) #创建新对象,当检测到命名冲突时,系统会自己处理
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b
 
# 3.求解模型
# 设置损失函数:误差的均方差 
loss = tf.reduce_mean(tf.square(y - y_predict))
# 选择梯度下降的方法
optimizer = tf.train.GradientDescentOptimizer(0.5)
# 迭代的目标:最小化损失函数
train = optimizer.minimize(loss)
 
#参数定义声明 
isTrain = True
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = 'F:\\\\vivocode\\\\tftestmodel\\\\'
saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
 
############################################################
# 以下是用 tf 来解决上面的任务
# 1.初始化变量:tf 的必备步骤,主要声明了变量,就必须初始化才能用
# init = tf.global_variables_initializer() 
 
# 设置tensorflow对GPU的使用按需分配
#config  = tf.ConfigProto()
#config.gpu_options.allow_growth = True
 
# 2.启动图 (graph)
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables()) 
    #判断当前工作状态
    if isTrain: #isTrain:True表示训练;False:表示测试
        # 3.迭代,反复执行上面的最小化损失函数这一操作(train op),拟合平面
        for i in range(train_steps): #train_steps表示训练的次数,例子中使用1006666666
            sess.run(train, feed_dict={x: x_data})
            if (i + 1) % checkpoint_steps == 0: #表示训练多少次保存一下checkpoints,例子中使用50
                print ('step: {}  train_acc: {}  loss: {}'.format(i, sess.run(w), sess.run(b)))
                saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) #表示checkpoints文件的保存路径,例子中使用当前路径
    else: #如果isTrain=False,则进行测试
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path) #恢复变量
        else:
            pass
        print(sess.run(w),sess.run(b))

生成的checkpoint 如下所示:

Checkpoint

 

load 数据

如果你导入meta数据是不需要在定义图就可以做些运算

import tensorflow as tf
import numpy as np
checkpoint_dir = 'F:\\\\vivocode\\\\tftestmodel\\\\'
# 2.启动图 (graph)
# saver = tf.train.Saver()
with tf.Session() as sess:
    
    model_path = tf.train.latest_checkpoint(checkpoint_dir)  # 获取最新的模型,注意这里的是文件夹哦
    saver=tf.train.import_meta_graph(model_path+'.meta')
    saver.restore(sess,model_path)
    graph = tf.get_default_graph()  
    X = graph.get_tensor_by_name('Variable:0')

    print(sess.run(X))

[3.985119]
正文完
请博主喝杯咖啡吧!
post-qrcode
 
admin
版权声明:本站原创文章,由 admin 2021-01-06发表,共计3119字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)
验证码