您的位置:

tf.summary的完全解读

TensorFlow提供了很多机制来对模型进行可视化,其中一个强大的机制是tf.summary。tf.summary提供了一系列的函数,可以记录summary信息,以便于在TensorBoard中进行展示和调试。在这篇文章中,我们将从多个方面对tf.summary进行详细的阐述。

一、tf.summary内存

tf.summary记录的信息在内存中存储,即在模型训练过程中会不断的产生summary信息,这些信息会被累积在内存中,分散输出并不容易管理。可以通过下面两种方式来避免内存过多:1.使用tf.summary.FileWriter,在定义的时候设置max_queue和flush_secs参数,可以让数据定期的写入磁盘,释放内存;2.使用tf.summary.create_no_op操作,来创建一个空的操作,避免数据占据过多内存。

二、tf.summary.merge_all()

tf.summary.merge_all()是一个核心函数,它可以将所有采集的事件合并成一个tensor,这样可以同时运行多个采集操作。tf.summary.merge_all()合并的过程是针对所有可采集的事件,而tf.summary.merge函数是针对单独一个事件进行合并。例如:

summary_op = tf.summary.merge_all() # 合并所有可采集事件
sum_train_op = tf.summary.merge([loss_summary, acc_summary]) # 合并单独事件

三、tf.summary.scalar

tf.summary.scalar用于记录标量信息,统计各个node输出tensor的标量值,在浏览器显示如下:

这里是一个示例代码,记录了模型的loss和acc信息:

loss_summary = tf.summary.scalar('loss', loss)
acc_summary = tf.summary.scalar('accuracy', accuracy)

四、tf.summary.histogram

tf.summary.histogram用于记录张量的取值分布,是一种记录时序数据的好方法。在TensorBoard的Scalar面板下会展示大量的数据,使用SketchPad的方式,可以;在TensorBoard的Histogram面板下会展示取值分布,如下图所示:

下面是一个示例代码,记录全局参数的取值分布:

for i, variable in enumerate(tf.trainable_variables()):
   tf.summary.histogram("weights_{}".format(i), variable)

五、tf.summary.image

在训练CNN的时候,可以使用tf.summary.image来记录输入图片和卷积层输出的feature map,使模型的可视化更为直观。在TensorBoard的Images面板下展示,如下图所示:

下面是一个示例代码,记录卷积层的输出:

conv1_reshape = tf.reshape(conv1_output, [-1, 28, 28, 32])
tf.summary.image('Conv1', conv1_reshape, max_outputs=10)

六、tf.summary.FileWriter

tf.summary.FileWriter用于将所有采集到的事件写入到磁盘中。其中graph参数用于显示计算图,logdir参数表示日志存储位置。在启动TensorBoard的命令时,可以指定一个或多个日志目录,TensorBoard会将所有日志组合在一起,并显示可视化的结果。

train_writer = tf.summary.FileWriter(log_dir + '/train', sess.graph)
test_writer = tf.summary.FileWriter(log_dir + '/test')
merged = tf.summary.merge_all()
session.run(tf.global_variables_initializer())
for i in range(1000):
    batch = mnist.train.next_batch(100)
    summary, _= session.run([merged,  optimizer], feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
    train_writer.add_summary(summary, i)
    if i % 10 == 0:
        summary, acc = session.run([merged, accuracy], feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})
        test_writer.add_summary(summary, i)
        print('accuracy at step %s: %s' % (i, acc))
train_writer.close()
test_writer.close()

七、tf.summary.create

在定义变量的时候,可以通过tf.summary.create来对关键节点进行诊断,帮助我们快速发现问题。下面是一个示例代码,记录梯度和权重信息:

# compute gradients
var_grads = tf.gradients(loss, variables)
for grad, var in zip(var_grads, variables):
    tf.summary.histogram(var.name + '/gradient', grad)
    tf.summary.histogram(var.name + '/weight', var)

结语

以上就是对于tf.summary的详细解读,通过使用tf.summary我们可以更加清晰的了解模型的训练情况,定位问题,提高模型的可视化能力和调试效率。