共计 3283 个字符,预计需要花费 9 分钟才能阅读完成。
metric包里主要是用来做些衡量指标的,mean、accuracy等指标的计算方法都在这。这些计算的指标值顺便写入到 summary或者在logger hook 里打印都可以。在指标计算的地方有一处,就是返回值会有点让人迷惑,我们 sdk 还改了分布式验证,在此基础上加了 allreduce 操作,比原生的又多了一步。
以 mean
作为介绍的范例:
@tf_export(v1=['metrics.mean'])
def mean(values,
weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
"""Computes the (weighted) mean of the given values.
The `mean` function creates two local variables, `total` and `count`
that are used to compute the average of `values`. This average is ultimately
returned as `mean` which is an idempotent operation that simply divides
`total` by `count`.
For estimation of the metric over a stream of data, the function creates an
`update_op` operation that updates these variables and returns the `mean`.
`update_op` increments `total` with the reduced sum of the product of `values`
and `weights`, and it increments `count` with the reduced sum of `weights`.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Args:
values: A `Tensor` of arbitrary dimensions.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`values`, and must be broadcastable to `values` (i.e., all dimensions must
be either `1`, or the same as the corresponding `values` dimension).
metrics_collections: An optional list of collections that `mean`
should be added to.
updates_collections: An optional list of collections that `update_op`
should be added to.
name: An optional variable_scope name.
Returns:
mean: A `Tensor` representing the current mean, the value of `total` divided
by `count`.
update_op: An operation that increments the `total` and `count` variables
appropriately and whose value matches `mean_value`.
Raises:
ValueError: If `weights` is not `None` and its shape doesn't match `values`,
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
RuntimeError: If eager execution is enabled.
"""
if context.executing_eagerly():
raise RuntimeError('tf.metrics.mean is not supported when eager execution '
'is enabled.')
with variable_scope.variable_scope(name, 'mean', (values, weights)):
values = math_ops.cast(values, dtypes.float32)
total = metric_variable([], dtypes.float32, name='total')
count = metric_variable([], dtypes.float32, name='count')
if weights is None:
num_values = math_ops.cast(array_ops.size(values), dtypes.float32)
else:
values, _, weights = _remove_squeezable_dimensions(
predictions=values, labels=None, weights=weights)
weights = weights_broadcast_ops.broadcast_weights(
math_ops.cast(weights, dtypes.float32), values)
values = math_ops.multiply(values, weights)
num_values = math_ops.reduce_sum(weights)
update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
def compute_mean(_, t, c):
return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value')
mean_t = _aggregate_across_replicas(
metrics_collections, compute_mean, total, count)
update_op = math_ops.div_no_nan(
update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return mean_t, update_op
上面可以看到返回值是 mean_t 和 update_op
mean_t
mean_t = _aggregate_across_replicas(
metrics_collections, compute_mean, total, count)
这一步的计算就是使用 total/count
计算的是一个Batch 里面的平均值。
你训练的数据需要 N 个Bacth,这个值计算只是你其中一个Batch的结果。
update_op
update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
update_op
的计算依赖上面的算子,assign_add
返回的是一个OP,如果想要得到total的值需要run这个OP,而且计算的结果重新赋给total。到这里可以明白了其实 total 记录的是你训练到现在K 个 Batch的累计结果。
附言
前面也提到分布式验证,这里暂时不贴代码,描述一下。会借助 Allreduce 来获取对应的值,主要在验证的时候使用。
正文完
请博主喝杯咖啡吧!