共计 11036 个字符,预计需要花费 28 分钟才能阅读完成。
起初是在代码Review的时候有人提问道Estimator到底在哪管理着Session会话,Emmm,之前代码么有仔细的看过,一时还真的不知道。然后在网上搜还是花了点时间,大部分都是说不用去管Session,意指已经是High Level Api了就不用管这些Session问题了,本来Estimator的设计也是考虑了这点,说到底还是
def train(self,
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None):
"""Trains a model given training data `input_fn`.
Args:
input_fn: A function that provides input data for training as minibatches.
See [Premade Estimators](
https://tensorflow.org/guide/premade_estimators#create_input_functions)
for more information. The function should construct and return one of
the following: * A
`tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
`(features, labels)` with same constraints as below. * A tuple
`(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary
of string feature name to `Tensor` and `labels` is a `Tensor` or a
dictionary of string label name to `Tensor`. Both `features` and
`labels` are consumed by `model_fn`. They should satisfy the expectation
of `model_fn` from inputs.
hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
callbacks inside the training loop.
steps: Number of steps for which to train the model. If `None`, train
forever or train until `input_fn` generates the `tf.errors.OutOfRange`
error or `StopIteration` exception. `steps` works incrementally. If you
call two times `train(steps=10)` then training occurs in total 20 steps.
If `OutOfRange` or `StopIteration` occurs in the middle, training stops
before 20 steps. If you don't want to have incremental behavior please
set `max_steps` instead. If set, `max_steps` must be `None`.
max_steps: Number of total steps for which to train model. If `None`,
train forever or train until `input_fn` generates the
`tf.errors.OutOfRange` error or `StopIteration` exception. If set,
`steps` must be `None`. If `OutOfRange` or `StopIteration` occurs in the
middle, training stops before `max_steps` steps. Two calls to
`train(steps=100)` means 200 training iterations. On the other hand, two
calls to `train(max_steps=100)` means that the second call will not do
any iteration since first call did all 100 steps.
saving_listeners: list of `CheckpointSaverListener` objects. Used for
callbacks that run immediately before or after checkpoint savings.
Returns:
`self`, for chaining.
Raises:
ValueError: If both `steps` and `max_steps` are not `None`.
ValueError: If either `steps` or `max_steps <= 0`.
"""
_estimator_api_gauge.get_cell('train').set(True)
if self.config.task_type in (run_config.TaskType.EVALUATOR,
run_config.TaskType.PS):
raise ValueError(
'Train has been called wrong configuration. Please use '
'tf.estimator.train_and_evaluate which calls proper API according '
'to given configuration. Current configuration: {}.'.format(
self.config))
with context.graph_mode():
if (steps is not None) and (max_steps is not None):
raise ValueError('Can not provide both steps and max_steps.')
if steps is not None and steps <= 0: raise ValueError('Must specify steps > 0, given: {}'.format(steps))
if max_steps is not None and max_steps <= 0: raise ValueError( 'Must specify max_steps > 0, given: {}'.format(max_steps))
if max_steps is not None:
# 从 checkpoint里恢复GlobalStep
start_step = _load_global_step_from_checkpoint_dir(self._model_dir)
if max_steps <= start_step:
logging.info('Skipping training since max_steps has already saved.')
return self
# 检查 hooks类型,确认是不是集成来自sessionrunhook
hooks = _check_hooks_type(hooks)
hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))
#saving_listeners里面存放的是保存checkpoint hook
saving_listeners = _check_listeners_type(saving_listeners)
# 这个_train_model是重点,之后发现Session也是在这里面会去找到
loss = self._train_model(input_fn, hooks, saving_listeners)
logging.info('Loss for final step: %s.', loss)
return self
参数:
input_fn:用于给训练过程提供minibatches的数据的函数,使用详情可以参考;Premade Estimators。该函数的返回值必须是以下几种之一:
(1) A tf.data.Dataset object: Dataset 的输出必须是(features, labels) 元组,它的格式要求和下面相同。
(2) A tuple(features, labels): 其中, features是一个 tf.Tensor 或者是以string为key,以Tensor为value的字典。 labels 同理。 features 和labels 都是供 model_fn消费的, 它们必须满足model_fn 的输入要求。
hooks: 一个包含若干 tf.train.SessionRunHook 子类实例的list,用于在训练过程中的回调。说点个人理解的事情,这个hooks是一个为estimator服务的类,它有begin、after_create_session、before_run、after_run、end方法,分别用于在创建Session之前、创建Session之后、Session运行之前、Session运行之后以及Session即将关闭之前执行一些需要的操作。[参考代码](## 附录)
steps: 模型训练的步数。如果是None, 模型将会一直训练下去,或者input_fn 遇到 tf.errors.OutOfRange的error或者StopIteration 的exception。steps 可以增量训练。例如,你先后调用了两次train(steps=10) ,那么总的训练步数是20步。如果在中间过程中发生了OutOfRange 或 StopIteration ,训练过程将在20步之前终止。如果你不想使用增量式的训练方式,请设置max_steps 参数. 如果设置了steps参数, max_steps必须设为 None。
max_steps: 模型训练的总步数,如果设为 None,模型一直训练直到 input_fn 发生tf.errors.OutOfRange error 或者StopIteration exception。如果设置了该参数,steps 必须设为 None。训练过程中如果遇到了 OutOfRange 或者 StopIteration ,训练过程将会在 max_steps 之前终止。 调用两次train(steps=100) 意味着总的训练步数为200,而两次调用train(max_steps=100) 只会训练100次,因为第一次的调用已经达到了最大训练次数。
saving_listeners: CheckpointSaverListener 对象list. 用于checkpoint savings执行前后的立即回调过程。
def _train_model(self, input_fn, hooks, saving_listeners):
if self._train_distribution:
return self._train_model_distributed(input_fn, hooks, saving_listeners)
else:
return self._train_model_default(input_fn, hooks, saving_listeners)
def _train_model_default(self, input_fn, hooks, saving_listeners):
"""Initiate training with `input_fn`, without `DistributionStrategies`.
Args:
input_fn: A function that provides input data for training as minibatches.
hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
callbacks inside the training loop.
saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used
for callbacks that run immediately before or after checkpoint savings.
Returns:
Loss from training
"""
worker_hooks = []
with tf.Graph().as_default() as g, g.device(self._device_fn):
tf.compat.v1.random.set_random_seed(self._config.tf_random_seed)
global_step_tensor = self._create_and_assert_global_step(g)
# Skip creating a read variable if _create_and_assert_global_step
# returns None (e.g. tf.contrib.estimator.SavedModelEstimator).
if global_step_tensor is not None:
training_util._get_or_create_global_step_read(g) # pylint: disable=protected-access
# 本质上是从input_fn里去获取数据
features, labels, input_hooks = (
self._get_features_and_labels_from_input_fn(input_fn, ModeKeys.TRAIN))
worker_hooks.extend(input_hooks)
# 调用自定义的Model_fn,这里面会有定义好各个模式下的配置和运行逻辑
estimator_spec = self._call_model_fn(features, labels, ModeKeys.TRAIN,
self.config)
global_step_tensor = tf.compat.v1.train.get_global_step(g)
# 返回 Estimator Spec ,至此这个函数里面会包含Session的定义
return self._train_with_estimator_spec(estimator_spec, worker_hooks,
hooks, global_step_tensor,
saving_listeners)
继续向下钻就会发现Session会话的控制:
def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
global_step_tensor, saving_listeners):
"""Train a model with the given Estimator Spec."""
if (self._warm_start_settings and
not tf.train.latest_checkpoint(self._model_dir)):
tf.compat.v1.logging.info('Warm-starting with WarmStartSettings: %s' %
(self._warm_start_settings,))
tf.compat.v1.train.warm_start(*self._warm_start_settings)
# Check if the user created a loss summary, and add one if they didn't.
# We assume here that the summary is called 'loss'. If it is not, we will
# make another one with the name 'loss' to ensure it shows up in the right
# graph in TensorBoard.
if not any([
x.op.name == 'loss' for x in ops.get_collection(ops.GraphKeys.SUMMARIES)
]):
summary.scalar('loss', estimator_spec.loss)
ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
worker_hooks.extend(hooks)
worker_hooks.append(tf.compat.v1.train.NanTensorHook(estimator_spec.loss))
if self._config.log_step_count_steps is not None:
worker_hooks.append(
tf.compat.v1.train.LoggingTensorHook(
{
'loss': estimator_spec.loss,
'step': global_step_tensor
},
every_n_iter=self._config.log_step_count_steps))
worker_hooks.extend(estimator_spec.training_hooks)
if not (estimator_spec.scaffold.saver or
tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SAVERS)):
tf.compat.v1.add_to_collection(
tf.compat.v1.GraphKeys.SAVERS,
tf.compat.v1.train.Saver(
sharded=True,
max_to_keep=self._config.keep_checkpoint_max,
keep_checkpoint_every_n_hours=(
self._config.keep_checkpoint_every_n_hours),
defer_build=True,
save_relative_paths=True))
if (self._config.cluster_spec and type(
self._train_distribution).__name__ in ('CollectiveAllReduceStrategy',
'CollectiveAllReduceStrategyV1',
'MultiWorkerMirroredStrategy')):
return self._train_with_estimator_spec_distributed(
estimator_spec, worker_hooks, saving_listeners)
chief_hooks = []
all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
saver_hooks = [
h for h in all_hooks
if isinstance(h, tf.compat.v1.train.CheckpointSaverHook)
]
if (self._config.save_checkpoints_secs or
self._config.save_checkpoints_steps):
if not saver_hooks:
chief_hooks = [
tf.compat.v1.train.CheckpointSaverHook(
self._model_dir,
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
scaffold=estimator_spec.scaffold,
save_graph_def=self._config.checkpoint_save_graph_def)
]
saver_hooks = [chief_hooks[0]]
if saving_listeners:
if not saver_hooks:
raise ValueError(
'There should be a CheckpointSaverHook to use saving_listeners. '
'Please set one of the RunConfig.save_checkpoints_steps or '
'RunConfig.save_checkpoints_secs.')
else:
# It is expected to have one CheckpointSaverHook. If multiple, we pick
# up the first one to add listener.
for listener in saving_listeners:
# pylint: disable=protected-access
if listener not in saver_hooks[0]._listeners:
saver_hooks[0]._listeners.append(listener)
# pylint: disable=protected-access
# Add summary hooks to worker 0 if we are running with a master, to ensure
# that summaries are written at correct intervals even with long-running
# evaluations.
save_summary_steps = self._config.save_summary_steps
log_step_count_steps = self._config.log_step_count_steps
# Check existence of appropriate cluster spec fields, as well as master and
# worker nodes. As master also performs evaluation, summary writing must
# occur on a different node. The presence of a worker is also checked to
# prevent reassigning hooks for single-replica jobs with just a master node.
if (self._config.cluster_spec and self._config.cluster_spec.jobs and
(run_config.TaskType.WORKER in self._config.cluster_spec.jobs) and
(run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
# Update config values to prevent the default hooks from being created on
# the master or other workers.
save_summary_steps = 0
log_step_count_steps = None
if (self._config.task_type == run_config.TaskType.WORKER and
self._config.task_id == 0):
if (self._config.save_summary_steps and
self._config.save_summary_steps > 0):
worker_hooks.append(
tf.compat.v1.train.SummarySaverHook(
save_steps=self._config.save_summary_steps,
output_dir=self._config.model_dir,
scaffold=estimator_spec.scaffold))
if (self._config.log_step_count_steps and
self._config.log_step_count_steps > 0):
worker_hooks.append(
tf.compat.v1.train.StepCounterHook(
every_n_steps=self._config.log_step_count_steps,
output_dir=self._config.model_dir))
# 哈哈,还是通过MonitorTrainsession来控制
with training.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
checkpoint_dir=self._model_dir,
scaffold=estimator_spec.scaffold,
hooks=worker_hooks,
chief_only_hooks=(tuple(chief_hooks) +
tuple(estimator_spec.training_chief_hooks)),
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=save_summary_steps,
config=self._session_config,
max_wait_secs=self._config.session_creation_timeout_secs,
log_step_count_steps=log_step_count_steps,
save_graph_def=self._config.checkpoint_save_graph_def) as mon_sess:
loss = None
any_step_done = False
while not mon_sess.should_stop():
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
any_step_done = True
if not any_step_done:
tf.compat.v1.logging.warn('Training with estimator made no steps. '
'Perhaps input is empty or misspecified.')
return loss
到最后是找到了EstimatorSession会话控制的位置。其实还有一些问题待解决,Estimator如何切换 Mode的,这个还在看,觉得应该在Train_and_evaluate里面能找大答案。