tensorflow slim 源码分析

py的源码看起来还是很愉快的。。。(虽然熟练成程度完全不如cpp。。。。

datasets里是数据集相关

deployment是部署相关

nets里给了很多网络结构

preprocessing给了几种预处理的方式

这些都和slim没有太大关系,就不多废话了。

分析的部分见代码注释…

由于刚刚入门machine learning 一周…还有很多内容还没有从理论层面接触…所以源码的理解也十分有限…希望能以后有机会补充一波

 1    # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 2    #
 3    # Licensed under the Apache License, Version 2.0 (the "License");
 4    # you may not use this file except in compliance with the License.
 5    # You may obtain a copy of the License at
 6    #
 7    #     http://www.apache.org/licenses/LICENSE-2.0
 8    #
 9    # Unless required by applicable law or agreed to in writing, software
10    # distributed under the License is distributed on an "AS IS" BASIS,
11    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12    # See the License for the specific language governing permissions and
13    # limitations under the License.
14    # ==============================================================================
15    r"""Downloads and converts a particular dataset.
16    
17    Usage:
18    ```shell
19    
20```lua
21$ python download_and_convert_data.py \
22    --dataset_name=mnist \
23    --dataset_dir=/tmp/mnist
1$ python download_and_convert_data.py \
2    --dataset_name=cifar10 \
3    --dataset_dir=/tmp/cifar10
$ python download_and_convert_data.py \
    --dataset_name=flowers \
    --dataset_dir=/tmp/flowers
```python
"""
from __future__ import absolute_import    #from __future__是为了解决python版本升级导致的兼容问题,没必要纠结
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from datasets import download_and_convert_cifar10
from datasets import download_and_convert_flowers
from datasets import download_and_convert_mnist

FLAGS = tf.app.flags.FLAGS    #FLAGS 用来传递或者设置tensforflow的参数

tf.app.flags.DEFINE_string(   #设置的格式为:('参数名称',参数值,'参数的解释')
    'dataset_name',
    None,
    'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".')

tf.app.flags.DEFINE_string(
    'dataset_dir',
    None,
    'The directory where the output TFRecords and temporary files are saved.')


def main(_):
  if not FLAGS.dataset_name:
    raise ValueError('You must supply the dataset name with --dataset_name')
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  if FLAGS.dataset_name == 'cifar10':                #提供的三个数据集,[cifar10],[flowers],[mnist]
    download_and_convert_cifar10.run(FLAGS.dataset_dir)
  elif FLAGS.dataset_name == 'flowers':
    download_and_convert_flowers.run(FLAGS.dataset_dir)
  elif FLAGS.dataset_name == 'mnist':
    download_and_convert_mnist.run(FLAGS.dataset_dir)
  else:
    raise ValueError(
        'dataset_name [%s] was not recognized.' % FLAGS.dataset_name)  #数据经名字不属于上述三个

if __name__ == '__main__':  #这种写法可以保证在该文件被import的时候不会执行main函数
  tf.app.run()




# coding=utf-8
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generic evaluation script that evaluates a model using a given dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import tensorflow as tf

from datasets import dataset_factory
from nets import nets_factory
from preprocessing import preprocessing_factory

slim = tf.contrib.slim

tf.app.flags.DEFINE_integer(
    'batch_size', 100, 'The number of samples in each batch.')

tf.app.flags.DEFINE_integer(
    'max_num_batches', None,
    'Max number of batches to evaluate by default use all.')

tf.app.flags.DEFINE_string(
    'master', '', 'The address of the TensorFlow master to use.')

tf.app.flags.DEFINE_string(
    'checkpoint_path', '/tmp/tfmodel/',
    'The directory where the model was written to or an absolute path to a '
    'checkpoint file.')

tf.app.flags.DEFINE_string(
    'eval_dir', '/tmp/tfmodel/', 'Directory where the results are saved to.')

tf.app.flags.DEFINE_integer(
    'num_preprocessing_threads', 4,
    'The number of threads used to create the batches.')

tf.app.flags.DEFINE_string(
    'dataset_name', 'imagenet', 'The name of the dataset to load.')

tf.app.flags.DEFINE_string(
    'dataset_split_name', 'test', 'The name of the train/test split.')

tf.app.flags.DEFINE_string(
    'dataset_dir', None, 'The directory where the dataset files are stored.')

tf.app.flags.DEFINE_integer(
    'labels_offset', 0,
    'An offset for the labels in the dataset. This flag is primarily used to '
    'evaluate the VGG and ResNet architectures which do not use a background '
    'class for the ImageNet dataset.')

tf.app.flags.DEFINE_string(
    'model_name', 'inception_v3', 'The name of the architecture to evaluate.')

tf.app.flags.DEFINE_string(
    'preprocessing_name', None, 'The name of the preprocessing to use. If left '
    'as `None`, then the model_name flag is used.')

tf.app.flags.DEFINE_float(
    'moving_average_decay', None,
    'The decay to use for the moving average.'
    'If left as None, then moving averages are not used.')

tf.app.flags.DEFINE_integer(
    'eval_image_size', None, 'Eval image size')

FLAGS = tf.app.flags.FLAGS


def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO) #设置log信息的级别,有DEBUG, INFO, WARN, ERROR, or FATAL
  with tf.Graph().as_default():  #overrides the current default graph for the lifetime of the context
                                    #注意不是线程安全的..
    tf_global_step = slim.get_or_create_global_step()
                        #slim.get_or_create_global_step可以参考tf.train.get_or_create_global_step
                        #作用同样是得到global step tensor,参数为graph,参数为空时认为参数为default graph
    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ####################
    # Select the model #
    ####################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=False)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
                                    #读取数据
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=False,
        common_queue_capacity=2 * FLAGS.batch_size,
        common_queue_min=FLAGS.batch_size)
    [image, label] = provider.get(['image', 'label'])
    label -= FLAGS.labels_offset

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=False)
    #python or的用法,flag1 or flag2 or... or flagn,如果最后逻辑值为真,
    # 返回的是(葱左至右)第一个使其为真的值(而不返回布尔值),
    #如果都为假,则返回最后一个假值
    eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

    image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

    images, labels = tf.train.batch(
        [image, label],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)

    ####################
    # Define the model #
    ####################
    logits, _ = network_fn(images)  #python语法,序列解包

    #移动平均,参考 https://en.wikipedia.org/wiki/Moving_average
    if FLAGS.moving_average_decay:
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, tf_global_step)
      variables_to_restore = variable_averages.variables_to_restore(
          slim.get_model_variables())
      variables_to_restore[tf_global_step.op.name] = tf_global_step
    else:
      variables_to_restore = slim.get_variables_to_restore()

    predictions = tf.argmax(logits, 1)
    #tf.argmax(input, axis=None, name=None, dimension=None)
    #Returns the index with the largest value across axes of a tensor.
    #就是返回logits的第一维(行?)最大值的位置索引


    labels = tf.squeeze(labels) #将labels中维度是1的那一维去掉

    # Define the metrics:
    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        'Recall_5': slim.metrics.streaming_recall_at_k(
            logits, labels, 5),
    })
    #metrics.aggregate_metric_map在metrics的list很长的时候的一种简便的表达方式
    #metrics直接翻译为【度量】,不是tensorflow的概念.用来监控计算的性能指标


    # Print the summaries to screen.
    for name, value in names_to_values.items():
      summary_name = 'eval/%s' % name
      op = tf.summary.scalar(summary_name, value, collections=[])
      #tf.summary.scalar :Outputs a Summary protocol buffer containing a single scalar value
      op = tf.Print(op, [value], summary_name)
      tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

    # TODO(sguada) use num_epochs=1
    if FLAGS.max_num_batches:
      num_batches = FLAGS.max_num_batches
    else:
      # This ensures that we make a single pass over all of the data.
      num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):#返回是否为一个目录
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.logging.info('Evaluating %s' % checkpoint_path)  #记录log信息

    #Evaluates the model at the given checkpoint path.
    #Evaluates the model at the given checkpoint path.
    slim.evaluation.evaluate_once(
        master=FLAGS.master,
        checkpoint_path=checkpoint_path,
        logdir=FLAGS.eval_dir,
        num_evals=num_batches,
        eval_op=list(names_to_updates.values()),
        variables_to_restore=variables_to_restore)


if __name__ == '__main__':
  tf.app.run()




# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Saves out a GraphDef containing the architecture of the model.

To use it, run something like this, with a model name defined by slim:

bazel build tensorflow_models/slim:export_inference_graph
bazel-bin/tensorflow_models/slim/export_inference_graph \
--model_name=inception_v3 --output_file=/tmp/inception_v3_inf_graph.pb

If you then want to use the resulting model with your own or pretrained
checkpoints as part of a mobile model, you can run freeze_graph to get a graph
def with the variables inlined as constants using:

bazel build tensorflow/python/tools:freeze_graph
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/tmp/inception_v3_inf_graph.pb \
--input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \
--input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \
--output_node_names=InceptionV3/Predictions/Reshape_1

The output node names will vary depending on the model, but you can inspect and
estimate them using the summarize_graph tool:

bazel build tensorflow/tools/graph_transforms:summarize_graph
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
--in_graph=/tmp/inception_v3_inf_graph.pb

To run the resulting graph in C++, you can look at the label_image sample code:

bazel build tensorflow/examples/label_image:label_image
bazel-bin/tensorflow/examples/label_image/label_image \
--image=${HOME}/Pictures/flowers.jpg \
--input_layer=input \
--output_layer=InceptionV3/Predictions/Reshape_1 \
--graph=/tmp/frozen_inception_v3.pb \
--labels=/tmp/imagenet_slim_labels.txt \
--input_mean=0 \
--input_std=255

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow.python.platform import gfile
from datasets import dataset_factory
from nets import nets_factory


slim = tf.contrib.slim

tf.app.flags.DEFINE_string(
    'model_name', 'inception_v3', 'The name of the architecture to save.')

tf.app.flags.DEFINE_boolean(
    'is_training', False,
    'Whether to save out a training-focused version of the model.')

tf.app.flags.DEFINE_integer(
    'default_image_size', 224,
    'The image size to use if the model does not define it.')

tf.app.flags.DEFINE_string('dataset_name', 'imagenet',
                           'The name of the dataset to use with the model.')

tf.app.flags.DEFINE_integer(
    'labels_offset', 0,
    'An offset for the labels in the dataset. This flag is primarily used to '
    'evaluate the VGG and ResNet architectures which do not use a background '
    'class for the ImageNet dataset.')

tf.app.flags.DEFINE_string(
    'output_file', '', 'Where to save the resulting file to.')

tf.app.flags.DEFINE_string(
    'dataset_dir', '', 'Directory to save intermediate dataset files to')

FLAGS = tf.app.flags.FLAGS


def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)

    #hasattr 是python语法, hasattr(object, name) -> bool,用来判断object中是否有name属性
    if hasattr(network_fn, 'default_image_size'):
      image_size = network_fn.default_image_size
    else:
      image_size = FLAGS.default_image_size
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[1, image_size, image_size, 3])
    network_fn(placeholder)
    graph_def = graph.as_graph_def()
    #graph.as_graph_def():Returns a serialized(序列化) GraphDef representation of this graph.
    #The serialized GraphDef can be imported into another Graph (using tf.import_graph_def) or used with the C++ Session API.
    #该方法线程安全


    # gfile。GFile 是一个无线程锁的I/O 封装
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
      f.write(graph_def.SerializeToString())


if __name__ == '__main__':
  tf.app.run()
  1
  2
  3```python
  4
  5
  6    
  7    # coding=utf-8
  8    # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  9    #
 10    # Licensed under the Apache License, Version 2.0 (the "License");
 11    # you may not use this file except in compliance with the License.
 12    # You may obtain a copy of the License at
 13    #
 14    # http://www.apache.org/licenses/LICENSE-2.0
 15    #
 16    # Unless required by applicable law or agreed to in writing, software
 17    # distributed under the License is distributed on an "AS IS" BASIS,
 18    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 19    # See the License for the specific language governing permissions and
 20    # limitations under the License.
 21    # ==============================================================================
 22    """Generic training script that trains a model using a given dataset."""
 23    
 24    from __future__ import absolute_import
 25    from __future__ import division
 26    from __future__ import print_function
 27    
 28    import tensorflow as tf
 29    
 30    from datasets import dataset_factory
 31    from deployment import model_deploy
 32    from nets import nets_factory
 33    from preprocessing import preprocessing_factory
 34    
 35    slim = tf.contrib.slim
 36    
 37    tf.app.flags.DEFINE_string(
 38        'master', '', 'The address of the TensorFlow master to use.')
 39    
 40    tf.app.flags.DEFINE_string(
 41        'train_dir', '/tmp/tfmodel/',
 42        'Directory where checkpoints and event logs are written to.')
 43    
 44    tf.app.flags.DEFINE_integer('num_clones', 1,
 45                                'Number of model clones to deploy.')
 46    
 47    tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
 48                                'Use CPUs to deploy clones.')
 49    
 50    tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.')
 51    
 52    tf.app.flags.DEFINE_integer(
 53        'num_ps_tasks', 0,
 54        'The number of parameter servers. If the value is 0, then the parameters '
 55        'are handled locally by the worker.')
 56    
 57    tf.app.flags.DEFINE_integer(
 58        'num_readers', 4,
 59        'The number of parallel readers that read data from the dataset.')
 60    
 61    tf.app.flags.DEFINE_integer(
 62        'num_preprocessing_threads', 4,
 63        'The number of threads used to create the batches.')
 64    
 65    tf.app.flags.DEFINE_integer(
 66        'log_every_n_steps', 10,
 67        'The frequency with which logs are print.')
 68    
 69    tf.app.flags.DEFINE_integer(
 70        'save_summaries_secs', 600,
 71        'The frequency with which summaries are saved, in seconds.')
 72    
 73    tf.app.flags.DEFINE_integer(
 74        'save_interval_secs', 600,
 75        'The frequency with which the model is saved, in seconds.')
 76    
 77    tf.app.flags.DEFINE_integer(
 78        'task', 0, 'Task id of the replica running the training.')
 79    
 80    ######################
 81    # Optimization Flags #
 82    ######################
 83    
 84    tf.app.flags.DEFINE_float(
 85        'weight_decay', 0.00004, 'The weight decay on the model weights.')
 86    
 87    tf.app.flags.DEFINE_string(
 88        'optimizer', 'rmsprop',
 89        'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
 90        '"ftrl", "momentum", "sgd" or "rmsprop".')
 91    
 92    tf.app.flags.DEFINE_float(
 93        'adadelta_rho', 0.95,
 94        'The decay rate for adadelta.')
 95    
 96    tf.app.flags.DEFINE_float(
 97        'adagrad_initial_accumulator_value', 0.1,
 98        'Starting value for the AdaGrad accumulators.')
 99    
100    tf.app.flags.DEFINE_float(
101        'adam_beta1', 0.9,
102        'The exponential decay rate for the 1st moment estimates.')
103    
104    tf.app.flags.DEFINE_float(
105        'adam_beta2', 0.999,
106        'The exponential decay rate for the 2nd moment estimates.')
107    
108    tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')
109    
110    tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
111                              'The learning rate power.')
112    
113    tf.app.flags.DEFINE_float(
114        'ftrl_initial_accumulator_value', 0.1,
115        'Starting value for the FTRL accumulators.')
116    
117    tf.app.flags.DEFINE_float(
118        'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')
119    
120    tf.app.flags.DEFINE_float(
121        'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')
122    
123    tf.app.flags.DEFINE_float(
124        'momentum', 0.9,
125        'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
126    
127    tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
128    
129    #######################
130    # Learning Rate Flags #
131    #######################
132    
133    tf.app.flags.DEFINE_string(
134        'learning_rate_decay_type',
135        'exponential',
136        'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
137        ' or "polynomial"')
138    
139    tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
140    
141    tf.app.flags.DEFINE_float(
142        'end_learning_rate', 0.0001,
143        'The minimal end learning rate used by a polynomial decay learning rate.')
144    
145    tf.app.flags.DEFINE_float(
146        'label_smoothing', 0.0, 'The amount of label smoothing.')
147    
148    tf.app.flags.DEFINE_float(
149        'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')
150    
151    tf.app.flags.DEFINE_float(
152        'num_epochs_per_decay', 2.0,
153        'Number of epochs after which learning rate decays.')
154    
155    tf.app.flags.DEFINE_bool(
156        'sync_replicas', False,
157        'Whether or not to synchronize the replicas during training.')
158    
159    tf.app.flags.DEFINE_integer(
160        'replicas_to_aggregate', 1,
161        'The Number of gradients to collect before updating params.')
162    
163    tf.app.flags.DEFINE_float(
164        'moving_average_decay', None,
165        'The decay to use for the moving average.'
166        'If left as None, then moving averages are not used.')
167    
168    #######################
169    # Dataset Flags #
170    #######################
171    
172    tf.app.flags.DEFINE_string(
173        'dataset_name', 'imagenet', 'The name of the dataset to load.')
174    
175    tf.app.flags.DEFINE_string(
176        'dataset_split_name', 'train', 'The name of the train/test split.')
177    
178    tf.app.flags.DEFINE_string(
179        'dataset_dir', None, 'The directory where the dataset files are stored.')
180    
181    tf.app.flags.DEFINE_integer(
182        'labels_offset', 0,
183        'An offset for the labels in the dataset. This flag is primarily used to '
184        'evaluate the VGG and ResNet architectures which do not use a background '
185        'class for the ImageNet dataset.')
186    
187    tf.app.flags.DEFINE_string(
188        'model_name', 'inception_v3', 'The name of the architecture to train.')
189    
190    tf.app.flags.DEFINE_string(
191        'preprocessing_name', None, 'The name of the preprocessing to use. If left '
192        'as `None`, then the model_name flag is used.')
193    
194    tf.app.flags.DEFINE_integer(
195        'batch_size', 32, 'The number of samples in each batch.')
196    
197    tf.app.flags.DEFINE_integer(
198        'train_image_size', None, 'Train image size')
199    
200    tf.app.flags.DEFINE_integer('max_number_of_steps', None,
201                                'The maximum number of training steps.')
202    
203    #####################
204    # Fine-Tuning Flags #
205    #####################
206    
207    tf.app.flags.DEFINE_string(
208        'checkpoint_path', None,
209        'The path to a checkpoint from which to fine-tune.')
210    
211    tf.app.flags.DEFINE_string(
212        'checkpoint_exclude_scopes', None,
213        'Comma-separated list of scopes of variables to exclude when restoring '
214        'from a checkpoint.')
215    
216    tf.app.flags.DEFINE_string(
217        'trainable_scopes', None,
218        'Comma-separated list of scopes to filter the set of variables to train.'
219        'By default, None would train all the variables.')
220    
221    tf.app.flags.DEFINE_boolean(
222        'ignore_missing_vars', False,
223        'When restoring a checkpoint would ignore missing variables.')
224    
225    FLAGS = tf.app.flags.FLAGS
226    
227    
228    def _configure_learning_rate(num_samples_per_epoch, global_step):
229      """Configures the learning rate.
230    
231      Args:
232        num_samples_per_epoch: The number of samples in each epoch of training.
233        global_step: The global_step tensor.
234    
235      Returns:
236        A `Tensor` representing the learning rate.
237    
238      Raises:
239        ValueError: if
240      """
241      decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
242                        FLAGS.num_epochs_per_decay)
243      if FLAGS.sync_replicas:
244        decay_steps /= FLAGS.replicas_to_aggregate
245          #dacay,衰退
246    
247    
248        #下面是几种学习速率的变化形式,可以是指数型衰退,可以是固定不变,也可以使多项式型衰退。
249      if FLAGS.learning_rate_decay_type == 'exponential':
250        return tf.train.exponential_decay(FLAGS.learning_rate,
251                                          global_step,
252                                          decay_steps,
253                                          FLAGS.learning_rate_decay_factor,
254                                          staircase=True,
255                                          name='exponential_decay_learning_rate')
256      elif FLAGS.learning_rate_decay_type == 'fixed':
257        return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
258      elif FLAGS.learning_rate_decay_type == 'polynomial':
259        return tf.train.polynomial_decay(FLAGS.learning_rate,
260                                         global_step,
261                                         decay_steps,
262                                         FLAGS.end_learning_rate,
263                                         power=1.0,
264                                         cycle=False,
265                                         name='polynomial_decay_learning_rate')
266      else:
267        raise ValueError('learning_rate_decay_type [%s] was not recognized',
268                         FLAGS.learning_rate_decay_type)
269    
270    
271    
272    #选择优化方法(优化器,大概是求偏导的具体数值计算方法?),tf内置了许多优化方法,类比梯度下降,细节黑箱即可。
273    def _configure_optimizer(learning_rate):
274      """Configures the optimizer used for training.
275    
276      Args:
277        learning_rate: A scalar or `Tensor` learning rate.
278    
279      Returns:
280        An instance of an optimizer.
281    
282      Raises:
283        ValueError: if FLAGS.optimizer is not recognized.
284      """
285      if FLAGS.optimizer == 'adadelta':
286        optimizer = tf.train.AdadeltaOptimizer(
287            learning_rate,
288            rho=FLAGS.adadelta_rho,
289            epsilon=FLAGS.opt_epsilon)
290      elif FLAGS.optimizer == 'adagrad':
291        optimizer = tf.train.AdagradOptimizer(
292            learning_rate,
293            initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
294      elif FLAGS.optimizer == 'adam':
295        optimizer = tf.train.AdamOptimizer(
296            learning_rate,
297            beta1=FLAGS.adam_beta1,
298            beta2=FLAGS.adam_beta2,
299            epsilon=FLAGS.opt_epsilon)
300      elif FLAGS.optimizer == 'ftrl':
301        optimizer = tf.train.FtrlOptimizer(
302            learning_rate,
303            learning_rate_power=FLAGS.ftrl_learning_rate_power,
304            initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
305            l1_regularization_strength=FLAGS.ftrl_l1,
306            l2_regularization_strength=FLAGS.ftrl_l2)
307      elif FLAGS.optimizer == 'momentum':
308        optimizer = tf.train.MomentumOptimizer(
309            learning_rate,
310            momentum=FLAGS.momentum,
311            name='Momentum')
312      elif FLAGS.optimizer == 'rmsprop':
313        optimizer = tf.train.RMSPropOptimizer(
314            learning_rate,
315            decay=FLAGS.rmsprop_decay,
316            momentum=FLAGS.momentum,
317            epsilon=FLAGS.opt_epsilon)
318      elif FLAGS.optimizer == 'sgd':
319        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
320      else:
321        raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
322      return optimizer
323    
324    
325    def _get_init_fn():
326      """Returns a function run by the chief worker to warm-start the training.
327    
328      Note that the init_fn is only run when initializing the model during the very
329      first global step.
330    
331      Returns:
332        An init function run by the supervisor.
333      """
334      if FLAGS.checkpoint_path is None:
335        return None
336    
337      # Warn the user if a checkpoint exists in the train_dir. Then we'll be
338      # ignoring the checkpoint anyway.
339      if tf.train.latest_checkpoint(FLAGS.train_dir):
340        tf.logging.info(
341            'Ignoring --checkpoint_path because a checkpoint already exists in %s'
342            % FLAGS.train_dir)
343        return None
344    
345      exclusions = []  #python的list数据类型
346      if FLAGS.checkpoint_exclude_scopes:
347        exclusions = [scope.strip()
348                      for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
349    
350      # TODO(sguada) variables.filter_variables()
351      variables_to_restore = []
352      for var in slim.get_model_variables():
353        excluded = False
354        for exclusion in exclusions:
355          if var.op.name.startswith(exclusion):
356            excluded = True
357            break
358        if not excluded:
359          variables_to_restore.append(var)
360    
361      if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
362        checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
363      else:
364        checkpoint_path = FLAGS.checkpoint_path
365    
366      tf.logging.info('Fine-tuning from %s' % checkpoint_path)
367    
368      return slim.assign_from_checkpoint_fn(
369          checkpoint_path,
370          variables_to_restore,
371          ignore_missing_vars=FLAGS.ignore_missing_vars)
372    
373    
374    def _get_variables_to_train():
375      """Returns a list of variables to train.
376    
377      Returns:
378        A list of variables to train by the optimizer.
379      """
380      if FLAGS.trainable_scopes is None:
381        return tf.trainable_variables()
382      else:
383        scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]
384    
385      variables_to_train = []
386      for scope in scopes:
387        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
388        variables_to_train.extend(variables)
389      return variables_to_train
390    
391    
392    def main(_):
393      if not FLAGS.dataset_dir:
394        raise ValueError('You must supply the dataset directory with --dataset_dir')
395    
396      tf.logging.set_verbosity(tf.logging.INFO)
397      with tf.Graph().as_default():
398        #######################
399        # Config model_deploy #
400        #######################
401        deploy_config = model_deploy.DeploymentConfig(
402            num_clones=FLAGS.num_clones,
403            clone_on_cpu=FLAGS.clone_on_cpu,
404            replica_id=FLAGS.task,
405            num_replicas=FLAGS.worker_replicas,
406            num_ps_tasks=FLAGS.num_ps_tasks)
407    
408        # Create global_step
409        with tf.device(deploy_config.variables_device()):
410          global_step = slim.create_global_step()
411    
412        ######################
413        # Select the dataset #
414        ######################
415        dataset = dataset_factory.get_dataset(
416            FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
417    
418        ######################
419        # Select the network #
420        ######################
421        network_fn = nets_factory.get_network_fn(
422            FLAGS.model_name,
423            num_classes=(dataset.num_classes - FLAGS.labels_offset),
424            weight_decay=FLAGS.weight_decay,
425            is_training=True)
426    
427        #####################################
428        # Select the preprocessing function #
429        #####################################
430        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
431        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
432            preprocessing_name,
433            is_training=True)
434    
435        ##############################################################
436        # Create a dataset provider that loads data from the dataset #
437        ##############################################################
438        with tf.device(deploy_config.inputs_device()):
439          provider = slim.dataset_data_provider.DatasetDataProvider(  #还是定义一些数据的读取方式
440              dataset,
441              num_readers=FLAGS.num_readers,
442              common_queue_capacity=20 * FLAGS.batch_size,
443              common_queue_min=10 * FLAGS.batch_size)
444          [image, label] = provider.get(['image', 'label'])
445          label -= FLAGS.labels_offset
446    
447          train_image_size = FLAGS.train_image_size or network_fn.default_image_size
448    
449          image = image_preprocessing_fn(image, train_image_size, train_image_size)
450    
451          images, labels = tf.train.batch(
452              [image, label],
453              batch_size=FLAGS.batch_size,
454              num_threads=FLAGS.num_preprocessing_threads,
455              capacity=5 * FLAGS.batch_size)
456          labels = slim.one_hot_encoding(    #one-hot是一种向量编码方式,n维向量只有为相应值的位置为1,其余都为0
457              labels, dataset.num_classes - FLAGS.labels_offset)
458          batch_queue = slim.prefetch_queue.prefetch_queue(
459              [images, labels], capacity=2 * deploy_config.num_clones)
460    
461        ####################
462        # Define the model #
463        ####################
464        #通过复制多个网络来实现并行
465        def clone_fn(batch_queue):
466          """Allows data parallelism by creating multiple clones of network_fn."""
467          with tf.device(deploy_config.inputs_device()):
468            images, labels = batch_queue.dequeue()  #dequeue,双端队列。。
469          logits, end_points = network_fn(images)
470    
471          #############################
472          # Specify the loss function #
473          #############################
474          if 'AuxLogits' in end_points:
475            tf.losses.softmax_cross_entropy(   #softmax函数对应使用的cost function(loss function)
476                                                #是corss_entropy,也就是交叉熵
477                logits=end_points['AuxLogits'], onehot_labels=labels,
478                label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
479          tf.losses.softmax_cross_entropy(
480              logits=logits, onehot_labels=labels,
481              label_smoothing=FLAGS.label_smoothing, weights=1.0)
482          #Label Smoothing Regularization,一种防止overfit的优化方法
483          return end_points
484    
485        # Gather initial summaries.
486        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
487    
488        clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
489        first_clone_scope = deploy_config.clone_scope(0)
490        # Gather update_ops from the first clone. These contain, for example,
491        # the updates for the batch_norm variables created by network_fn.
492        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
493    
494        # Add summaries for end_points.
495        end_points = clones[0].outputs
496        for end_point in end_points:
497          x = end_points[end_point]
498          summaries.add(tf.summary.histogram('activations/' + end_point, x))
499          summaries.add(tf.summary.scalar('sparsity/' + end_point,
500                                          tf.nn.zero_fraction(x)))
501    
502        # Add summaries for losses.
503        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
504          summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
505    
506        # Add summaries for variables.
507        for variable in slim.get_model_variables():
508          summaries.add(tf.summary.histogram(variable.op.name, variable))
509            #突然画图
510    
511        #################################
512        # Configure the moving averages #  #参考moving averages的wiki
513        #################################
514        if FLAGS.moving_average_decay:
515          moving_average_variables = slim.get_model_variables()
516          variable_averages = tf.train.ExponentialMovingAverage(
517              FLAGS.moving_average_decay, global_step)
518        else:
519          moving_average_variables, variable_averages = None, None
520    
521        #########################################
522        # Configure the optimization procedure. #
523        #########################################
524        with tf.device(deploy_config.optimizer_device()):
525          learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
526          optimizer = _configure_optimizer(learning_rate)
527          summaries.add(tf.summary.scalar('learning_rate', learning_rate))
528    
529    
530        #在分布式系统上训练的同步...
531        if FLAGS.sync_replicas:
532          # If sync_replicas is enabled, the averaging will be done in the chief
533          # queue runner.
534          optimizer = tf.train.SyncReplicasOptimizer(
535              opt=optimizer,
536              replicas_to_aggregate=FLAGS.replicas_to_aggregate,
537              variable_averages=variable_averages,
538              variables_to_average=moving_average_variables,
539              replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
540              total_num_replicas=FLAGS.worker_replicas)
541        elif FLAGS.moving_average_decay:
542          # Update ops executed locally by trainer.
543          update_ops.append(variable_averages.apply(moving_average_variables))
544    
545        # Variables to train.
546        variables_to_train = _get_variables_to_train()
547    
548        #  and returns a train_tensor and summary_op
549        total_loss, clones_gradients = model_deploy.optimize_clones(
550            clones,
551            optimizer,
552            var_list=variables_to_train)
553        # Add total_loss to summary.
554        summaries.add(tf.summary.scalar('total_loss', total_loss))
555    
556        # Create gradient updates.
557        grad_updates = optimizer.apply_gradients(clones_gradients,
558                                                 global_step=global_step)
559        update_ops.append(grad_updates)
560    
561        update_op = tf.group(*update_ops)
562        with tf.control_dependencies([update_op]):
563          train_tensor = tf.identity(total_loss, name='train_op')
564    
565        # Add the summaries from the first clone. These contain the summaries
566        # created by model_fn and either optimize_clones() or _gather_clone_loss().
567        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
568                                           first_clone_scope))
569    
570        # Merge all summaries together.
571        summary_op = tf.summary.merge(list(summaries), name='summary_op')
572    
573    
574        ###########################
575        # Kicks off the training. #
576        ###########################
577        slim.learning.train(
578            train_tensor,
579            logdir=FLAGS.train_dir,
580            master=FLAGS.master,
581            is_chief=(FLAGS.task == 0),
582            init_fn=_get_init_fn(),
583            summary_op=summary_op,
584            number_of_steps=FLAGS.max_number_of_steps,
585            log_every_n_steps=FLAGS.log_every_n_steps,
586            save_summaries_secs=FLAGS.save_summaries_secs,
587            save_interval_secs=FLAGS.save_interval_secs,
588            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
589    
590    
591    if __name__ == '__main__':
592      tf.app.run()