在Tensorflow中使用 MonitoredTrainingSession(粗糙的记录)

这是一个非常粗糙的记录,如果你偶然进入这个页面,那么,请注意!如果没看懂,请立刻放弃,并寻找更好的例子

import tensorflow as tf

a = tf.Variable(1)
b = tf.Variable(2)
c = tf.add(a, b)

saver = tf.train.Saver()
saver_hook = tf.train.CheckpointSaverHook('./testpot/', 
                                          save_steps=2, 
                                          saver=saver)

global_step = tf.Variable(0, name='global_step', trainable=False)
summary_op = tf.summary.scalar('c', c)
summary_hook = tf.train.SummarySaverHook(save_steps=2, 
                                         summary_op=summary_op)
with tf.train.MonitoredTrainingSession(hooks=[saver_hook, summary_hook]) as sess:
    while not sess.should_stop():
        print(sess.run(c))

说明

tf.train.Supervisor已经被放弃,github上建议使用MonitoredTrainingSession替代Supervisor

在使用MonitoredTrainingSession前,需要有
1. saver
2. saver_hook
3. global_step
4. summary_op
5. summary_hook

这些变量,才能保证程序不报错。

这个程序用 Tensorboard 打开看到的图
简单的图例

随附参考:tf.train.MonitoredSession

3 Comments

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据