TensorFlow Estimator train 一探究竟

4,283次阅读
没有评论

共计 11036 个字符,预计需要花费 28 分钟才能阅读完成。

起初是在代码Review的时候有人提问道Estimator到底在哪管理着Session会话,Emmm,之前代码么有仔细的看过,一时还真的不知道。然后在网上搜还是花了点时间,大部分都是说不用去管Session,意指已经是High Level Api了就不用管这些Session问题了,本来Estimator的设计也是考虑了这点,说到底还是

def train(self,
            input_fn,
            hooks=None,
            steps=None,
            max_steps=None,
            saving_listeners=None):
    """Trains a model given training data `input_fn`.

    Args:
      input_fn: A function that provides input data for training as minibatches.
        See [Premade Estimators](
        https://tensorflow.org/guide/premade_estimators#create_input_functions)
        for more information. The function should construct and return one of
        the following:  * A
        `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
        `(features, labels)` with same constraints as below. * A tuple
        `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary
        of string feature name to `Tensor` and `labels` is a `Tensor` or a
        dictionary of string label name to `Tensor`. Both `features` and
        `labels` are consumed by `model_fn`. They should satisfy the expectation
        of `model_fn` from inputs.
      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
        callbacks inside the training loop.
      steps: Number of steps for which to train the model. If `None`, train
        forever or train until `input_fn` generates the `tf.errors.OutOfRange`
        error or `StopIteration` exception. `steps` works incrementally. If you
        call two times `train(steps=10)` then training occurs in total 20 steps.
        If `OutOfRange` or `StopIteration` occurs in the middle, training stops
        before 20 steps. If you don't want to have incremental behavior please
        set `max_steps` instead. If set, `max_steps` must be `None`.
      max_steps: Number of total steps for which to train model. If `None`,
        train forever or train until `input_fn` generates the
        `tf.errors.OutOfRange` error or `StopIteration` exception. If set,
        `steps` must be `None`. If `OutOfRange` or `StopIteration` occurs in the
        middle, training stops before `max_steps` steps. Two calls to
        `train(steps=100)` means 200 training iterations. On the other hand, two
        calls to `train(max_steps=100)` means that the second call will not do
        any iteration since first call did all 100 steps.
      saving_listeners: list of `CheckpointSaverListener` objects. Used for
        callbacks that run immediately before or after checkpoint savings.

    Returns:
      `self`, for chaining.

    Raises:
      ValueError: If both `steps` and `max_steps` are not `None`.
      ValueError: If either `steps` or `max_steps <= 0`.
    """
    _estimator_api_gauge.get_cell('train').set(True)
    if self.config.task_type in (run_config.TaskType.EVALUATOR,
                                 run_config.TaskType.PS):
      raise ValueError(
          'Train has been called wrong configuration. Please use '
          'tf.estimator.train_and_evaluate which calls proper API according '
          'to given configuration. Current configuration: {}.'.format(
              self.config))

    with context.graph_mode():
      if (steps is not None) and (max_steps is not None):
        raise ValueError('Can not provide both steps and max_steps.')
      if steps is not None and steps <= 0: raise ValueError('Must specify steps > 0, given: {}'.format(steps))
      if max_steps is not None and max_steps <= 0: raise ValueError( 'Must specify max_steps > 0, given: {}'.format(max_steps))

      if max_steps is not None:
        # 从 checkpoint里恢复GlobalStep
        start_step = _load_global_step_from_checkpoint_dir(self._model_dir)
        if max_steps <= start_step:
          logging.info('Skipping training since max_steps has already saved.')
          return self
      # 检查 hooks类型,确认是不是集成来自sessionrunhook
      hooks = _check_hooks_type(hooks)
      hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))
      #saving_listeners里面存放的是保存checkpoint hook
      saving_listeners = _check_listeners_type(saving_listeners)
      # 这个_train_model是重点,之后发现Session也是在这里面会去找到
      loss = self._train_model(input_fn, hooks, saving_listeners)
      logging.info('Loss for final step: %s.', loss)
      return self

参数:

input_fn:用于给训练过程提供minibatches的数据的函数,使用详情可以参考;Premade Estimators。该函数的返回值必须是以下几种之一:
(1) A tf.data.Dataset object: Dataset 的输出必须是(features, labels) 元组,它的格式要求和下面相同。
(2) A tuple(features, labels): 其中, features是一个 tf.Tensor 或者是以string为key,以Tensor为value的字典。 labels 同理。 features 和labels 都是供 model_fn消费的, 它们必须满足model_fn 的输入要求。

hooks: 一个包含若干 tf.train.SessionRunHook 子类实例的list,用于在训练过程中的回调。说点个人理解的事情,这个hooks是一个为estimator服务的类,它有begin、after_create_session、before_run、after_run、end方法,分别用于在创建Session之前、创建Session之后、Session运行之前、Session运行之后以及Session即将关闭之前执行一些需要的操作。[参考代码](## 附录)

steps: 模型训练的步数。如果是None, 模型将会一直训练下去,或者input_fn 遇到 tf.errors.OutOfRange的error或者StopIteration 的exception。steps 可以增量训练。例如,你先后调用了两次train(steps=10) ,那么总的训练步数是20步。如果在中间过程中发生了OutOfRange 或 StopIteration ,训练过程将在20步之前终止。如果你不想使用增量式的训练方式,请设置max_steps 参数. 如果设置了steps参数, max_steps必须设为 None。

max_steps: 模型训练的总步数,如果设为 None,模型一直训练直到 input_fn 发生tf.errors.OutOfRange error 或者StopIteration exception。如果设置了该参数,steps 必须设为 None。训练过程中如果遇到了 OutOfRange 或者 StopIteration ,训练过程将会在 max_steps 之前终止。 调用两次train(steps=100) 意味着总的训练步数为200,而两次调用train(max_steps=100) 只会训练100次,因为第一次的调用已经达到了最大训练次数。

saving_listeners: CheckpointSaverListener 对象list. 用于checkpoint savings执行前后的立即回调过程。

 

 

def _train_model(self, input_fn, hooks, saving_listeners):
    if self._train_distribution:
      return self._train_model_distributed(input_fn, hooks, saving_listeners)
    else:
      return self._train_model_default(input_fn, hooks, saving_listeners)

  def _train_model_default(self, input_fn, hooks, saving_listeners):
    """Initiate training with `input_fn`, without `DistributionStrategies`.

    Args:
      input_fn: A function that provides input data for training as minibatches.
      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
        callbacks inside the training loop.
      saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used
        for callbacks that run immediately before or after checkpoint savings.

    Returns:
      Loss from training
    """
    worker_hooks = []
    with tf.Graph().as_default() as g, g.device(self._device_fn):
      tf.compat.v1.random.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = self._create_and_assert_global_step(g)

      # Skip creating a read variable if _create_and_assert_global_step
      # returns None (e.g. tf.contrib.estimator.SavedModelEstimator).
      if global_step_tensor is not None:
        training_util._get_or_create_global_step_read(g)  # pylint: disable=protected-access
      # 本质上是从input_fn里去获取数据
      features, labels, input_hooks = (
          self._get_features_and_labels_from_input_fn(input_fn, ModeKeys.TRAIN))
      worker_hooks.extend(input_hooks)
      # 调用自定义的Model_fn,这里面会有定义好各个模式下的配置和运行逻辑
      estimator_spec = self._call_model_fn(features, labels, ModeKeys.TRAIN,
                                           self.config)
      global_step_tensor = tf.compat.v1.train.get_global_step(g)
      # 返回 Estimator Spec ,至此这个函数里面会包含Session的定义
      return self._train_with_estimator_spec(estimator_spec, worker_hooks,
                                             hooks, global_step_tensor,
                                             saving_listeners)

继续向下钻就会发现Session会话的控制:

def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
                                 global_step_tensor, saving_listeners):
    """Train a model with the given Estimator Spec."""
    if (self._warm_start_settings and
        not tf.train.latest_checkpoint(self._model_dir)):
      tf.compat.v1.logging.info('Warm-starting with WarmStartSettings: %s' %
                                (self._warm_start_settings,))
      tf.compat.v1.train.warm_start(*self._warm_start_settings)
    # Check if the user created a loss summary, and add one if they didn't.
    # We assume here that the summary is called 'loss'. If it is not, we will
    # make another one with the name 'loss' to ensure it shows up in the right
    # graph in TensorBoard.
    if not any([
        x.op.name == 'loss' for x in ops.get_collection(ops.GraphKeys.SUMMARIES)
    ]):
      summary.scalar('loss', estimator_spec.loss)
    ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
    worker_hooks.extend(hooks)
    worker_hooks.append(tf.compat.v1.train.NanTensorHook(estimator_spec.loss))
    if self._config.log_step_count_steps is not None:
      worker_hooks.append(
          tf.compat.v1.train.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=self._config.log_step_count_steps))
    worker_hooks.extend(estimator_spec.training_hooks)

    if not (estimator_spec.scaffold.saver or
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SAVERS)):
      tf.compat.v1.add_to_collection(
          tf.compat.v1.GraphKeys.SAVERS,
          tf.compat.v1.train.Saver(
              sharded=True,
              max_to_keep=self._config.keep_checkpoint_max,
              keep_checkpoint_every_n_hours=(
                  self._config.keep_checkpoint_every_n_hours),
              defer_build=True,
              save_relative_paths=True))

    if (self._config.cluster_spec and type(
        self._train_distribution).__name__ in ('CollectiveAllReduceStrategy',
                                               'CollectiveAllReduceStrategyV1',
                                               'MultiWorkerMirroredStrategy')):
      return self._train_with_estimator_spec_distributed(
          estimator_spec, worker_hooks, saving_listeners)

    chief_hooks = []
    all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
    saver_hooks = [
        h for h in all_hooks
        if isinstance(h, tf.compat.v1.train.CheckpointSaverHook)
    ]
    if (self._config.save_checkpoints_secs or
        self._config.save_checkpoints_steps):
      if not saver_hooks:
        chief_hooks = [
            tf.compat.v1.train.CheckpointSaverHook(
                self._model_dir,
                save_secs=self._config.save_checkpoints_secs,
                save_steps=self._config.save_checkpoints_steps,
                scaffold=estimator_spec.scaffold,
                save_graph_def=self._config.checkpoint_save_graph_def)
        ]
        saver_hooks = [chief_hooks[0]]
    if saving_listeners:
      if not saver_hooks:
        raise ValueError(
            'There should be a CheckpointSaverHook to use saving_listeners. '
            'Please set one of the RunConfig.save_checkpoints_steps or '
            'RunConfig.save_checkpoints_secs.')
      else:
        # It is expected to have one CheckpointSaverHook. If multiple, we pick
        # up the first one to add listener.
        for listener in saving_listeners:
          # pylint: disable=protected-access
          if listener not in saver_hooks[0]._listeners:
            saver_hooks[0]._listeners.append(listener)
          # pylint: disable=protected-access

    # Add summary hooks to worker 0 if we are running with a master, to ensure
    # that summaries are written at correct intervals even with long-running
    # evaluations.
    save_summary_steps = self._config.save_summary_steps
    log_step_count_steps = self._config.log_step_count_steps

    # Check existence of appropriate cluster spec fields, as well as master and
    # worker nodes. As master also performs evaluation, summary writing must
    # occur on a different node. The presence of a worker is also checked to
    # prevent reassigning hooks for single-replica jobs with just a master node.
    if (self._config.cluster_spec and self._config.cluster_spec.jobs and
        (run_config.TaskType.WORKER in self._config.cluster_spec.jobs) and
        (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
      # Update config values to prevent the default hooks from being created on
      # the master or other workers.
      save_summary_steps = 0
      log_step_count_steps = None

      if (self._config.task_type == run_config.TaskType.WORKER and
          self._config.task_id == 0):
        if (self._config.save_summary_steps and
            self._config.save_summary_steps > 0):
          worker_hooks.append(
              tf.compat.v1.train.SummarySaverHook(
                  save_steps=self._config.save_summary_steps,
                  output_dir=self._config.model_dir,
                  scaffold=estimator_spec.scaffold))

        if (self._config.log_step_count_steps and
            self._config.log_step_count_steps > 0):
          worker_hooks.append(
              tf.compat.v1.train.StepCounterHook(
                  every_n_steps=self._config.log_step_count_steps,
                  output_dir=self._config.model_dir))
    # 哈哈,还是通过MonitorTrainsession来控制
    with training.MonitoredTrainingSession(
        master=self._config.master,
        is_chief=self._config.is_chief,
        checkpoint_dir=self._model_dir,
        scaffold=estimator_spec.scaffold,
        hooks=worker_hooks,
        chief_only_hooks=(tuple(chief_hooks) +
                          tuple(estimator_spec.training_chief_hooks)),
        save_checkpoint_secs=0,  # Saving is handled by a hook.
        save_summaries_steps=save_summary_steps,
        config=self._session_config,
        max_wait_secs=self._config.session_creation_timeout_secs,
        log_step_count_steps=log_step_count_steps,
        save_graph_def=self._config.checkpoint_save_graph_def) as mon_sess:
      loss = None
      any_step_done = False
      while not mon_sess.should_stop():
        _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
        any_step_done = True
    if not any_step_done:
      tf.compat.v1.logging.warn('Training with estimator made no steps. '
                                'Perhaps input is empty or misspecified.')
    return loss

到最后是找到了EstimatorSession会话控制的位置。其实还有一些问题待解决,Estimator如何切换 Mode的,这个还在看,觉得应该在Train_and_evaluate里面能找大答案。

正文完
请博主喝杯咖啡吧!
post-qrcode
 
admin
版权声明:本站原创文章,由 admin 2020-12-26发表,共计11036字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)
验证码