tensorflow slim 源码分析

py的源码看起来还是很愉快的。。。(虽然熟练成程度完全不如cpp。。。。

datasets里是数据集相关

deployment是部署相关

nets里给了很多网络结构

preprocessing给了几种预处理的方式

这些都和slim没有太大关系,就不多废话了。

分析的部分见代码注释...

由于刚刚入门machine learning 一周...还有很多内容还没有从理论层面接触...所以源码的理解也十分有限...希望能以后有机会补充一波

  1    # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2    #
  3    # Licensed under the Apache License, Version 2.0 (the "License");
  4    # you may not use this file except in compliance with the License.
  5    # You may obtain a copy of the License at
  6    #
  7    #     http://www.apache.org/licenses/LICENSE-2.0
  8    #
  9    # Unless required by applicable law or agreed to in writing, software
 10    # distributed under the License is distributed on an "AS IS" BASIS,
 11    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12    # See the License for the specific language governing permissions and
 13    # limitations under the License.
 14    # ==============================================================================
 15    r"""Downloads and converts a particular dataset.
 16    
 17    Usage:
 18    ```shell
 19    
 20    $ python download_and_convert_data.py \
 21        --dataset_name=mnist \
 22        --dataset_dir=/tmp/mnist
 23    
 24    $ python download_and_convert_data.py \
 25        --dataset_name=cifar10 \
 26        --dataset_dir=/tmp/cifar10
 27    
 28    $ python download_and_convert_data.py \
 29        --dataset_name=flowers \
 30        --dataset_dir=/tmp/flowers
 31    ```
 32    """
 33    from __future__ import absolute_import    #from __future__是为了解决python版本升级导致的兼容问题,没必要纠结
 34    from __future__ import division
 35    from __future__ import print_function
 36    
 37    import tensorflow as tf
 38    
 39    from datasets import download_and_convert_cifar10
 40    from datasets import download_and_convert_flowers
 41    from datasets import download_and_convert_mnist
 42    
 43    FLAGS = tf.app.flags.FLAGS    #FLAGS 用来传递或者设置tensforflow的参数
 44    
 45    tf.app.flags.DEFINE_string(   #设置的格式为:('参数名称',参数值,'参数的解释')
 46        'dataset_name',
 47        None,
 48        'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".')
 49    
 50    tf.app.flags.DEFINE_string(
 51        'dataset_dir',
 52        None,
 53        'The directory where the output TFRecords and temporary files are saved.')
 54    
 55    
 56    def main(_):
 57      if not FLAGS.dataset_name:
 58        raise ValueError('You must supply the dataset name with --dataset_name')
 59      if not FLAGS.dataset_dir:
 60        raise ValueError('You must supply the dataset directory with --dataset_dir')
 61    
 62      if FLAGS.dataset_name == 'cifar10':                #提供的三个数据集,[cifar10],[flowers],[mnist]
 63        download_and_convert_cifar10.run(FLAGS.dataset_dir)
 64      elif FLAGS.dataset_name == 'flowers':
 65        download_and_convert_flowers.run(FLAGS.dataset_dir)
 66      elif FLAGS.dataset_name == 'mnist':
 67        download_and_convert_mnist.run(FLAGS.dataset_dir)
 68      else:
 69        raise ValueError(
 70            'dataset_name [%s] was not recognized.' % FLAGS.dataset_name)  #数据经名字不属于上述三个
 71    
 72    if __name__ == '__main__':  #这种写法可以保证在该文件被import的时候不会执行main函数
 73      tf.app.run()
 74
 75
 76
 77    
 78    # coding=utf-8
 79    # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 80    #
 81    # Licensed under the Apache License, Version 2.0 (the "License");
 82    # you may not use this file except in compliance with the License.
 83    # You may obtain a copy of the License at
 84    #
 85    # http://www.apache.org/licenses/LICENSE-2.0
 86    #
 87    # Unless required by applicable law or agreed to in writing, software
 88    # distributed under the License is distributed on an "AS IS" BASIS,
 89    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 90    # See the License for the specific language governing permissions and
 91    # limitations under the License.
 92    # ==============================================================================
 93    """Generic evaluation script that evaluates a model using a given dataset."""
 94    
 95    from __future__ import absolute_import
 96    from __future__ import division
 97    from __future__ import print_function
 98    
 99    import math
100    import tensorflow as tf
101    
102    from datasets import dataset_factory
103    from nets import nets_factory
104    from preprocessing import preprocessing_factory
105    
106    slim = tf.contrib.slim
107    
108    tf.app.flags.DEFINE_integer(
109        'batch_size', 100, 'The number of samples in each batch.')
110    
111    tf.app.flags.DEFINE_integer(
112        'max_num_batches', None,
113        'Max number of batches to evaluate by default use all.')
114    
115    tf.app.flags.DEFINE_string(
116        'master', '', 'The address of the TensorFlow master to use.')
117    
118    tf.app.flags.DEFINE_string(
119        'checkpoint_path', '/tmp/tfmodel/',
120        'The directory where the model was written to or an absolute path to a '
121        'checkpoint file.')
122    
123    tf.app.flags.DEFINE_string(
124        'eval_dir', '/tmp/tfmodel/', 'Directory where the results are saved to.')
125    
126    tf.app.flags.DEFINE_integer(
127        'num_preprocessing_threads', 4,
128        'The number of threads used to create the batches.')
129    
130    tf.app.flags.DEFINE_string(
131        'dataset_name', 'imagenet', 'The name of the dataset to load.')
132    
133    tf.app.flags.DEFINE_string(
134        'dataset_split_name', 'test', 'The name of the train/test split.')
135    
136    tf.app.flags.DEFINE_string(
137        'dataset_dir', None, 'The directory where the dataset files are stored.')
138    
139    tf.app.flags.DEFINE_integer(
140        'labels_offset', 0,
141        'An offset for the labels in the dataset. This flag is primarily used to '
142        'evaluate the VGG and ResNet architectures which do not use a background '
143        'class for the ImageNet dataset.')
144    
145    tf.app.flags.DEFINE_string(
146        'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
147    
148    tf.app.flags.DEFINE_string(
149        'preprocessing_name', None, 'The name of the preprocessing to use. If left '
150        'as `None`, then the model_name flag is used.')
151    
152    tf.app.flags.DEFINE_float(
153        'moving_average_decay', None,
154        'The decay to use for the moving average.'
155        'If left as None, then moving averages are not used.')
156    
157    tf.app.flags.DEFINE_integer(
158        'eval_image_size', None, 'Eval image size')
159    
160    FLAGS = tf.app.flags.FLAGS
161    
162    
163    def main(_):
164      if not FLAGS.dataset_dir:
165        raise ValueError('You must supply the dataset directory with --dataset_dir')
166    
167      tf.logging.set_verbosity(tf.logging.INFO) #设置log信息的级别,有DEBUG, INFO, WARN, ERROR, or FATAL
168      with tf.Graph().as_default():  #overrides the current default graph for the lifetime of the context
169                                        #注意不是线程安全的..
170        tf_global_step = slim.get_or_create_global_step()
171                            #slim.get_or_create_global_step可以参考tf.train.get_or_create_global_step
172                            #作用同样是得到global step tensor,参数为graph,参数为空时认为参数为default graph
173        ######################
174        # Select the dataset #
175        ######################
176        dataset = dataset_factory.get_dataset(
177            FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
178    
179        ####################
180        # Select the model #
181        ####################
182        network_fn = nets_factory.get_network_fn(
183            FLAGS.model_name,
184            num_classes=(dataset.num_classes - FLAGS.labels_offset),
185            is_training=False)
186    
187        ##############################################################
188        # Create a dataset provider that loads data from the dataset #
189        ##############################################################
190                                        #读取数据
191        provider = slim.dataset_data_provider.DatasetDataProvider(
192            dataset,
193            shuffle=False,
194            common_queue_capacity=2 * FLAGS.batch_size,
195            common_queue_min=FLAGS.batch_size)
196        [image, label] = provider.get(['image', 'label'])
197        label -= FLAGS.labels_offset
198    
199        #####################################
200        # Select the preprocessing function #
201        #####################################
202        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
203        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
204            preprocessing_name,
205            is_training=False)
206        #python or的用法,flag1 or flag2 or... or flagn,如果最后逻辑值为真,
207        # 返回的是(葱左至右)第一个使其为真的值(而不返回布尔值),
208        #如果都为假,则返回最后一个假值
209        eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size
210    
211        image = image_preprocessing_fn(image, eval_image_size, eval_image_size)
212    
213        images, labels = tf.train.batch(
214            [image, label],
215            batch_size=FLAGS.batch_size,
216            num_threads=FLAGS.num_preprocessing_threads,
217            capacity=5 * FLAGS.batch_size)
218    
219        ####################
220        # Define the model #
221        ####################
222        logits, _ = network_fn(images)  #python语法,序列解包
223    
224        #移动平均,参考 https://en.wikipedia.org/wiki/Moving_average
225        if FLAGS.moving_average_decay:
226          variable_averages = tf.train.ExponentialMovingAverage(
227              FLAGS.moving_average_decay, tf_global_step)
228          variables_to_restore = variable_averages.variables_to_restore(
229              slim.get_model_variables())
230          variables_to_restore[tf_global_step.op.name] = tf_global_step
231        else:
232          variables_to_restore = slim.get_variables_to_restore()
233    
234        predictions = tf.argmax(logits, 1)
235        #tf.argmax(input, axis=None, name=None, dimension=None)
236        #Returns the index with the largest value across axes of a tensor.
237        #就是返回logits的第一维(行?)最大值的位置索引
238    
239    
240        labels = tf.squeeze(labels) #将labels中维度是1的那一维去掉
241    
242        # Define the metrics:
243        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
244            'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
245            'Recall_5': slim.metrics.streaming_recall_at_k(
246                logits, labels, 5),
247        })
248        #metrics.aggregate_metric_map在metrics的list很长的时候的一种简便的表达方式
249        #metrics直接翻译为【度量】,不是tensorflow的概念.用来监控计算的性能指标
250    
251    
252        # Print the summaries to screen.
253        for name, value in names_to_values.items():
254          summary_name = 'eval/%s' % name
255          op = tf.summary.scalar(summary_name, value, collections=[])
256          #tf.summary.scalar :Outputs a Summary protocol buffer containing a single scalar value
257          op = tf.Print(op, [value], summary_name)
258          tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)
259    
260        # TODO(sguada) use num_epochs=1
261        if FLAGS.max_num_batches:
262          num_batches = FLAGS.max_num_batches
263        else:
264          # This ensures that we make a single pass over all of the data.
265          num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))
266    
267        if tf.gfile.IsDirectory(FLAGS.checkpoint_path):#返回是否为一个目录
268          checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
269        else:
270          checkpoint_path = FLAGS.checkpoint_path
271    
272        tf.logging.info('Evaluating %s' % checkpoint_path)  #记录log信息
273    
274        #Evaluates the model at the given checkpoint path.
275        #Evaluates the model at the given checkpoint path.
276        slim.evaluation.evaluate_once(
277            master=FLAGS.master,
278            checkpoint_path=checkpoint_path,
279            logdir=FLAGS.eval_dir,
280            num_evals=num_batches,
281            eval_op=list(names_to_updates.values()),
282            variables_to_restore=variables_to_restore)
283    
284    
285    if __name__ == '__main__':
286      tf.app.run()
287
288
289
290    
291    # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
292    #
293    # Licensed under the Apache License, Version 2.0 (the "License");
294    # you may not use this file except in compliance with the License.
295    # You may obtain a copy of the License at
296    #
297    # http://www.apache.org/licenses/LICENSE-2.0
298    #
299    # Unless required by applicable law or agreed to in writing, software
300    # distributed under the License is distributed on an "AS IS" BASIS,
301    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
302    # See the License for the specific language governing permissions and
303    # limitations under the License.
304    # ==============================================================================
305    r"""Saves out a GraphDef containing the architecture of the model.
306    
307    To use it, run something like this, with a model name defined by slim:
308    
309    bazel build tensorflow_models/slim:export_inference_graph
310    bazel-bin/tensorflow_models/slim/export_inference_graph \
311    --model_name=inception_v3 --output_file=/tmp/inception_v3_inf_graph.pb
312    
313    If you then want to use the resulting model with your own or pretrained
314    checkpoints as part of a mobile model, you can run freeze_graph to get a graph
315    def with the variables inlined as constants using:
316    
317    bazel build tensorflow/python/tools:freeze_graph
318    bazel-bin/tensorflow/python/tools/freeze_graph \
319    --input_graph=/tmp/inception_v3_inf_graph.pb \
320    --input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \
321    --input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \
322    --output_node_names=InceptionV3/Predictions/Reshape_1
323    
324    The output node names will vary depending on the model, but you can inspect and
325    estimate them using the summarize_graph tool:
326    
327    bazel build tensorflow/tools/graph_transforms:summarize_graph
328    bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
329    --in_graph=/tmp/inception_v3_inf_graph.pb
330    
331    To run the resulting graph in C++, you can look at the label_image sample code:
332    
333    bazel build tensorflow/examples/label_image:label_image
334    bazel-bin/tensorflow/examples/label_image/label_image \
335    --image=${HOME}/Pictures/flowers.jpg \
336    --input_layer=input \
337    --output_layer=InceptionV3/Predictions/Reshape_1 \
338    --graph=/tmp/frozen_inception_v3.pb \
339    --labels=/tmp/imagenet_slim_labels.txt \
340    --input_mean=0 \
341    --input_std=255
342    
343    """
344    
345    from __future__ import absolute_import
346    from __future__ import division
347    from __future__ import print_function
348    
349    import tensorflow as tf
350    
351    from tensorflow.python.platform import gfile
352    from datasets import dataset_factory
353    from nets import nets_factory
354    
355    
356    slim = tf.contrib.slim
357    
358    tf.app.flags.DEFINE_string(
359        'model_name', 'inception_v3', 'The name of the architecture to save.')
360    
361    tf.app.flags.DEFINE_boolean(
362        'is_training', False,
363        'Whether to save out a training-focused version of the model.')
364    
365    tf.app.flags.DEFINE_integer(
366        'default_image_size', 224,
367        'The image size to use if the model does not define it.')
368    
369    tf.app.flags.DEFINE_string('dataset_name', 'imagenet',
370                               'The name of the dataset to use with the model.')
371    
372    tf.app.flags.DEFINE_integer(
373        'labels_offset', 0,
374        'An offset for the labels in the dataset. This flag is primarily used to '
375        'evaluate the VGG and ResNet architectures which do not use a background '
376        'class for the ImageNet dataset.')
377    
378    tf.app.flags.DEFINE_string(
379        'output_file', '', 'Where to save the resulting file to.')
380    
381    tf.app.flags.DEFINE_string(
382        'dataset_dir', '', 'Directory to save intermediate dataset files to')
383    
384    FLAGS = tf.app.flags.FLAGS
385    
386    
387    def main(_):
388      if not FLAGS.output_file:
389        raise ValueError('You must supply the path to save to with --output_file')
390      tf.logging.set_verbosity(tf.logging.INFO)
391      with tf.Graph().as_default() as graph:
392        dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
393                                              FLAGS.dataset_dir)
394        network_fn = nets_factory.get_network_fn(
395            FLAGS.model_name,
396            num_classes=(dataset.num_classes - FLAGS.labels_offset),
397            is_training=FLAGS.is_training)
398    
399        #hasattr 是python语法, hasattr(object, name) -> bool,用来判断object中是否有name属性
400        if hasattr(network_fn, 'default_image_size'):
401          image_size = network_fn.default_image_size
402        else:
403          image_size = FLAGS.default_image_size
404        placeholder = tf.placeholder(name='input', dtype=tf.float32,
405                                     shape=[1, image_size, image_size, 3])
406        network_fn(placeholder)
407        graph_def = graph.as_graph_def()
408        #graph.as_graph_def():Returns a serialized(序列化) GraphDef representation of this graph.
409        #The serialized GraphDef can be imported into another Graph (using tf.import_graph_def) or used with the C++ Session API.
410        #该方法线程安全
411    
412    
413        # gfile。GFile 是一个无线程锁的I/O 封装
414        with gfile.GFile(FLAGS.output_file, 'wb') as f:
415          f.write(graph_def.SerializeToString())
416    
417    
418    if __name__ == '__main__':
419      tf.app.run()
420
  1
  2
  3    
  4    # coding=utf-8
  5    # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  6    #
  7    # Licensed under the Apache License, Version 2.0 (the "License");
  8    # you may not use this file except in compliance with the License.
  9    # You may obtain a copy of the License at
 10    #
 11    # http://www.apache.org/licenses/LICENSE-2.0
 12    #
 13    # Unless required by applicable law or agreed to in writing, software
 14    # distributed under the License is distributed on an "AS IS" BASIS,
 15    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 16    # See the License for the specific language governing permissions and
 17    # limitations under the License.
 18    # ==============================================================================
 19    """Generic training script that trains a model using a given dataset."""
 20    
 21    from __future__ import absolute_import
 22    from __future__ import division
 23    from __future__ import print_function
 24    
 25    import tensorflow as tf
 26    
 27    from datasets import dataset_factory
 28    from deployment import model_deploy
 29    from nets import nets_factory
 30    from preprocessing import preprocessing_factory
 31    
 32    slim = tf.contrib.slim
 33    
 34    tf.app.flags.DEFINE_string(
 35        'master', '', 'The address of the TensorFlow master to use.')
 36    
 37    tf.app.flags.DEFINE_string(
 38        'train_dir', '/tmp/tfmodel/',
 39        'Directory where checkpoints and event logs are written to.')
 40    
 41    tf.app.flags.DEFINE_integer('num_clones', 1,
 42                                'Number of model clones to deploy.')
 43    
 44    tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
 45                                'Use CPUs to deploy clones.')
 46    
 47    tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.')
 48    
 49    tf.app.flags.DEFINE_integer(
 50        'num_ps_tasks', 0,
 51        'The number of parameter servers. If the value is 0, then the parameters '
 52        'are handled locally by the worker.')
 53    
 54    tf.app.flags.DEFINE_integer(
 55        'num_readers', 4,
 56        'The number of parallel readers that read data from the dataset.')
 57    
 58    tf.app.flags.DEFINE_integer(
 59        'num_preprocessing_threads', 4,
 60        'The number of threads used to create the batches.')
 61    
 62    tf.app.flags.DEFINE_integer(
 63        'log_every_n_steps', 10,
 64        'The frequency with which logs are print.')
 65    
 66    tf.app.flags.DEFINE_integer(
 67        'save_summaries_secs', 600,
 68        'The frequency with which summaries are saved, in seconds.')
 69    
 70    tf.app.flags.DEFINE_integer(
 71        'save_interval_secs', 600,
 72        'The frequency with which the model is saved, in seconds.')
 73    
 74    tf.app.flags.DEFINE_integer(
 75        'task', 0, 'Task id of the replica running the training.')
 76    
 77    ######################
 78    # Optimization Flags #
 79    ######################
 80    
 81    tf.app.flags.DEFINE_float(
 82        'weight_decay', 0.00004, 'The weight decay on the model weights.')
 83    
 84    tf.app.flags.DEFINE_string(
 85        'optimizer', 'rmsprop',
 86        'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
 87        '"ftrl", "momentum", "sgd" or "rmsprop".')
 88    
 89    tf.app.flags.DEFINE_float(
 90        'adadelta_rho', 0.95,
 91        'The decay rate for adadelta.')
 92    
 93    tf.app.flags.DEFINE_float(
 94        'adagrad_initial_accumulator_value', 0.1,
 95        'Starting value for the AdaGrad accumulators.')
 96    
 97    tf.app.flags.DEFINE_float(
 98        'adam_beta1', 0.9,
 99        'The exponential decay rate for the 1st moment estimates.')
100    
101    tf.app.flags.DEFINE_float(
102        'adam_beta2', 0.999,
103        'The exponential decay rate for the 2nd moment estimates.')
104    
105    tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')
106    
107    tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
108                              'The learning rate power.')
109    
110    tf.app.flags.DEFINE_float(
111        'ftrl_initial_accumulator_value', 0.1,
112        'Starting value for the FTRL accumulators.')
113    
114    tf.app.flags.DEFINE_float(
115        'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')
116    
117    tf.app.flags.DEFINE_float(
118        'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')
119    
120    tf.app.flags.DEFINE_float(
121        'momentum', 0.9,
122        'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
123    
124    tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
125    
126    #######################
127    # Learning Rate Flags #
128    #######################
129    
130    tf.app.flags.DEFINE_string(
131        'learning_rate_decay_type',
132        'exponential',
133        'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
134        ' or "polynomial"')
135    
136    tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
137    
138    tf.app.flags.DEFINE_float(
139        'end_learning_rate', 0.0001,
140        'The minimal end learning rate used by a polynomial decay learning rate.')
141    
142    tf.app.flags.DEFINE_float(
143        'label_smoothing', 0.0, 'The amount of label smoothing.')
144    
145    tf.app.flags.DEFINE_float(
146        'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')
147    
148    tf.app.flags.DEFINE_float(
149        'num_epochs_per_decay', 2.0,
150        'Number of epochs after which learning rate decays.')
151    
152    tf.app.flags.DEFINE_bool(
153        'sync_replicas', False,
154        'Whether or not to synchronize the replicas during training.')
155    
156    tf.app.flags.DEFINE_integer(
157        'replicas_to_aggregate', 1,
158        'The Number of gradients to collect before updating params.')
159    
160    tf.app.flags.DEFINE_float(
161        'moving_average_decay', None,
162        'The decay to use for the moving average.'
163        'If left as None, then moving averages are not used.')
164    
165    #######################
166    # Dataset Flags #
167    #######################
168    
169    tf.app.flags.DEFINE_string(
170        'dataset_name', 'imagenet', 'The name of the dataset to load.')
171    
172    tf.app.flags.DEFINE_string(
173        'dataset_split_name', 'train', 'The name of the train/test split.')
174    
175    tf.app.flags.DEFINE_string(
176        'dataset_dir', None, 'The directory where the dataset files are stored.')
177    
178    tf.app.flags.DEFINE_integer(
179        'labels_offset', 0,
180        'An offset for the labels in the dataset. This flag is primarily used to '
181        'evaluate the VGG and ResNet architectures which do not use a background '
182        'class for the ImageNet dataset.')
183    
184    tf.app.flags.DEFINE_string(
185        'model_name', 'inception_v3', 'The name of the architecture to train.')
186    
187    tf.app.flags.DEFINE_string(
188        'preprocessing_name', None, 'The name of the preprocessing to use. If left '
189        'as `None`, then the model_name flag is used.')
190    
191    tf.app.flags.DEFINE_integer(
192        'batch_size', 32, 'The number of samples in each batch.')
193    
194    tf.app.flags.DEFINE_integer(
195        'train_image_size', None, 'Train image size')
196    
197    tf.app.flags.DEFINE_integer('max_number_of_steps', None,
198                                'The maximum number of training steps.')
199    
200    #####################
201    # Fine-Tuning Flags #
202    #####################
203    
204    tf.app.flags.DEFINE_string(
205        'checkpoint_path', None,
206        'The path to a checkpoint from which to fine-tune.')
207    
208    tf.app.flags.DEFINE_string(
209        'checkpoint_exclude_scopes', None,
210        'Comma-separated list of scopes of variables to exclude when restoring '
211        'from a checkpoint.')
212    
213    tf.app.flags.DEFINE_string(
214        'trainable_scopes', None,
215        'Comma-separated list of scopes to filter the set of variables to train.'
216        'By default, None would train all the variables.')
217    
218    tf.app.flags.DEFINE_boolean(
219        'ignore_missing_vars', False,
220        'When restoring a checkpoint would ignore missing variables.')
221    
222    FLAGS = tf.app.flags.FLAGS
223    
224    
225    def _configure_learning_rate(num_samples_per_epoch, global_step):
226      """Configures the learning rate.
227    
228      Args:
229        num_samples_per_epoch: The number of samples in each epoch of training.
230        global_step: The global_step tensor.
231    
232      Returns:
233        A `Tensor` representing the learning rate.
234    
235      Raises:
236        ValueError: if
237      """
238      decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
239                        FLAGS.num_epochs_per_decay)
240      if FLAGS.sync_replicas:
241        decay_steps /= FLAGS.replicas_to_aggregate
242          #dacay,衰退
243    
244    
245        #下面是几种学习速率的变化形式,可以是指数型衰退,可以是固定不变,也可以使多项式型衰退。
246      if FLAGS.learning_rate_decay_type == 'exponential':
247        return tf.train.exponential_decay(FLAGS.learning_rate,
248                                          global_step,
249                                          decay_steps,
250                                          FLAGS.learning_rate_decay_factor,
251                                          staircase=True,
252                                          name='exponential_decay_learning_rate')
253      elif FLAGS.learning_rate_decay_type == 'fixed':
254        return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
255      elif FLAGS.learning_rate_decay_type == 'polynomial':
256        return tf.train.polynomial_decay(FLAGS.learning_rate,
257                                         global_step,
258                                         decay_steps,
259                                         FLAGS.end_learning_rate,
260                                         power=1.0,
261                                         cycle=False,
262                                         name='polynomial_decay_learning_rate')
263      else:
264        raise ValueError('learning_rate_decay_type [%s] was not recognized',
265                         FLAGS.learning_rate_decay_type)
266    
267    
268    
269    #选择优化方法(优化器,大概是求偏导的具体数值计算方法?),tf内置了许多优化方法,类比梯度下降,细节黑箱即可。
270    def _configure_optimizer(learning_rate):
271      """Configures the optimizer used for training.
272    
273      Args:
274        learning_rate: A scalar or `Tensor` learning rate.
275    
276      Returns:
277        An instance of an optimizer.
278    
279      Raises:
280        ValueError: if FLAGS.optimizer is not recognized.
281      """
282      if FLAGS.optimizer == 'adadelta':
283        optimizer = tf.train.AdadeltaOptimizer(
284            learning_rate,
285            rho=FLAGS.adadelta_rho,
286            epsilon=FLAGS.opt_epsilon)
287      elif FLAGS.optimizer == 'adagrad':
288        optimizer = tf.train.AdagradOptimizer(
289            learning_rate,
290            initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
291      elif FLAGS.optimizer == 'adam':
292        optimizer = tf.train.AdamOptimizer(
293            learning_rate,
294            beta1=FLAGS.adam_beta1,
295            beta2=FLAGS.adam_beta2,
296            epsilon=FLAGS.opt_epsilon)
297      elif FLAGS.optimizer == 'ftrl':
298        optimizer = tf.train.FtrlOptimizer(
299            learning_rate,
300            learning_rate_power=FLAGS.ftrl_learning_rate_power,
301            initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
302            l1_regularization_strength=FLAGS.ftrl_l1,
303            l2_regularization_strength=FLAGS.ftrl_l2)
304      elif FLAGS.optimizer == 'momentum':
305        optimizer = tf.train.MomentumOptimizer(
306            learning_rate,
307            momentum=FLAGS.momentum,
308            name='Momentum')
309      elif FLAGS.optimizer == 'rmsprop':
310        optimizer = tf.train.RMSPropOptimizer(
311            learning_rate,
312            decay=FLAGS.rmsprop_decay,
313            momentum=FLAGS.momentum,
314            epsilon=FLAGS.opt_epsilon)
315      elif FLAGS.optimizer == 'sgd':
316        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
317      else:
318        raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
319      return optimizer
320    
321    
322    def _get_init_fn():
323      """Returns a function run by the chief worker to warm-start the training.
324    
325      Note that the init_fn is only run when initializing the model during the very
326      first global step.
327    
328      Returns:
329        An init function run by the supervisor.
330      """
331      if FLAGS.checkpoint_path is None:
332        return None
333    
334      # Warn the user if a checkpoint exists in the train_dir. Then we'll be
335      # ignoring the checkpoint anyway.
336      if tf.train.latest_checkpoint(FLAGS.train_dir):
337        tf.logging.info(
338            'Ignoring --checkpoint_path because a checkpoint already exists in %s'
339            % FLAGS.train_dir)
340        return None
341    
342      exclusions = []  #python的list数据类型
343      if FLAGS.checkpoint_exclude_scopes:
344        exclusions = [scope.strip()
345                      for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
346    
347      # TODO(sguada) variables.filter_variables()
348      variables_to_restore = []
349      for var in slim.get_model_variables():
350        excluded = False
351        for exclusion in exclusions:
352          if var.op.name.startswith(exclusion):
353            excluded = True
354            break
355        if not excluded:
356          variables_to_restore.append(var)
357    
358      if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
359        checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
360      else:
361        checkpoint_path = FLAGS.checkpoint_path
362    
363      tf.logging.info('Fine-tuning from %s' % checkpoint_path)
364    
365      return slim.assign_from_checkpoint_fn(
366          checkpoint_path,
367          variables_to_restore,
368          ignore_missing_vars=FLAGS.ignore_missing_vars)
369    
370    
371    def _get_variables_to_train():
372      """Returns a list of variables to train.
373    
374      Returns:
375        A list of variables to train by the optimizer.
376      """
377      if FLAGS.trainable_scopes is None:
378        return tf.trainable_variables()
379      else:
380        scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]
381    
382      variables_to_train = []
383      for scope in scopes:
384        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
385        variables_to_train.extend(variables)
386      return variables_to_train
387    
388    
389    def main(_):
390      if not FLAGS.dataset_dir:
391        raise ValueError('You must supply the dataset directory with --dataset_dir')
392    
393      tf.logging.set_verbosity(tf.logging.INFO)
394      with tf.Graph().as_default():
395        #######################
396        # Config model_deploy #
397        #######################
398        deploy_config = model_deploy.DeploymentConfig(
399            num_clones=FLAGS.num_clones,
400            clone_on_cpu=FLAGS.clone_on_cpu,
401            replica_id=FLAGS.task,
402            num_replicas=FLAGS.worker_replicas,
403            num_ps_tasks=FLAGS.num_ps_tasks)
404    
405        # Create global_step
406        with tf.device(deploy_config.variables_device()):
407          global_step = slim.create_global_step()
408    
409        ######################
410        # Select the dataset #
411        ######################
412        dataset = dataset_factory.get_dataset(
413            FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
414    
415        ######################
416        # Select the network #
417        ######################
418        network_fn = nets_factory.get_network_fn(
419            FLAGS.model_name,
420            num_classes=(dataset.num_classes - FLAGS.labels_offset),
421            weight_decay=FLAGS.weight_decay,
422            is_training=True)
423    
424        #####################################
425        # Select the preprocessing function #
426        #####################################
427        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
428        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
429            preprocessing_name,
430            is_training=True)
431    
432        ##############################################################
433        # Create a dataset provider that loads data from the dataset #
434        ##############################################################
435        with tf.device(deploy_config.inputs_device()):
436          provider = slim.dataset_data_provider.DatasetDataProvider(  #还是定义一些数据的读取方式
437              dataset,
438              num_readers=FLAGS.num_readers,
439              common_queue_capacity=20 * FLAGS.batch_size,
440              common_queue_min=10 * FLAGS.batch_size)
441          [image, label] = provider.get(['image', 'label'])
442          label -= FLAGS.labels_offset
443    
444          train_image_size = FLAGS.train_image_size or network_fn.default_image_size
445    
446          image = image_preprocessing_fn(image, train_image_size, train_image_size)
447    
448          images, labels = tf.train.batch(
449              [image, label],
450              batch_size=FLAGS.batch_size,
451              num_threads=FLAGS.num_preprocessing_threads,
452              capacity=5 * FLAGS.batch_size)
453          labels = slim.one_hot_encoding(    #one-hot是一种向量编码方式,n维向量只有为相应值的位置为1,其余都为0
454              labels, dataset.num_classes - FLAGS.labels_offset)
455          batch_queue = slim.prefetch_queue.prefetch_queue(
456              [images, labels], capacity=2 * deploy_config.num_clones)
457    
458        ####################
459        # Define the model #
460        ####################
461        #通过复制多个网络来实现并行
462        def clone_fn(batch_queue):
463          """Allows data parallelism by creating multiple clones of network_fn."""
464          with tf.device(deploy_config.inputs_device()):
465            images, labels = batch_queue.dequeue()  #dequeue,双端队列。。
466          logits, end_points = network_fn(images)
467    
468          #############################
469          # Specify the loss function #
470          #############################
471          if 'AuxLogits' in end_points:
472            tf.losses.softmax_cross_entropy(   #softmax函数对应使用的cost function(loss function)
473                                                #是corss_entropy,也就是交叉熵
474                logits=end_points['AuxLogits'], onehot_labels=labels,
475                label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
476          tf.losses.softmax_cross_entropy(
477              logits=logits, onehot_labels=labels,
478              label_smoothing=FLAGS.label_smoothing, weights=1.0)
479          #Label Smoothing Regularization,一种防止overfit的优化方法
480          return end_points
481    
482        # Gather initial summaries.
483        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
484    
485        clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
486        first_clone_scope = deploy_config.clone_scope(0)
487        # Gather update_ops from the first clone. These contain, for example,
488        # the updates for the batch_norm variables created by network_fn.
489        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
490    
491        # Add summaries for end_points.
492        end_points = clones[0].outputs
493        for end_point in end_points:
494          x = end_points[end_point]
495          summaries.add(tf.summary.histogram('activations/' + end_point, x))
496          summaries.add(tf.summary.scalar('sparsity/' + end_point,
497                                          tf.nn.zero_fraction(x)))
498    
499        # Add summaries for losses.
500        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
501          summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
502    
503        # Add summaries for variables.
504        for variable in slim.get_model_variables():
505          summaries.add(tf.summary.histogram(variable.op.name, variable))
506            #突然画图
507    
508        #################################
509        # Configure the moving averages #  #参考moving averages的wiki
510        #################################
511        if FLAGS.moving_average_decay:
512          moving_average_variables = slim.get_model_variables()
513          variable_averages = tf.train.ExponentialMovingAverage(
514              FLAGS.moving_average_decay, global_step)
515        else:
516          moving_average_variables, variable_averages = None, None
517    
518        #########################################
519        # Configure the optimization procedure. #
520        #########################################
521        with tf.device(deploy_config.optimizer_device()):
522          learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
523          optimizer = _configure_optimizer(learning_rate)
524          summaries.add(tf.summary.scalar('learning_rate', learning_rate))
525    
526    
527        #在分布式系统上训练的同步...
528        if FLAGS.sync_replicas:
529          # If sync_replicas is enabled, the averaging will be done in the chief
530          # queue runner.
531          optimizer = tf.train.SyncReplicasOptimizer(
532              opt=optimizer,
533              replicas_to_aggregate=FLAGS.replicas_to_aggregate,
534              variable_averages=variable_averages,
535              variables_to_average=moving_average_variables,
536              replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
537              total_num_replicas=FLAGS.worker_replicas)
538        elif FLAGS.moving_average_decay:
539          # Update ops executed locally by trainer.
540          update_ops.append(variable_averages.apply(moving_average_variables))
541    
542        # Variables to train.
543        variables_to_train = _get_variables_to_train()
544    
545        #  and returns a train_tensor and summary_op
546        total_loss, clones_gradients = model_deploy.optimize_clones(
547            clones,
548            optimizer,
549            var_list=variables_to_train)
550        # Add total_loss to summary.
551        summaries.add(tf.summary.scalar('total_loss', total_loss))
552    
553        # Create gradient updates.
554        grad_updates = optimizer.apply_gradients(clones_gradients,
555                                                 global_step=global_step)
556        update_ops.append(grad_updates)
557    
558        update_op = tf.group(*update_ops)
559        with tf.control_dependencies([update_op]):
560          train_tensor = tf.identity(total_loss, name='train_op')
561    
562        # Add the summaries from the first clone. These contain the summaries
563        # created by model_fn and either optimize_clones() or _gather_clone_loss().
564        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
565                                           first_clone_scope))
566    
567        # Merge all summaries together.
568        summary_op = tf.summary.merge(list(summaries), name='summary_op')
569    
570    
571        ###########################
572        # Kicks off the training. #
573        ###########################
574        slim.learning.train(
575            train_tensor,
576            logdir=FLAGS.train_dir,
577            master=FLAGS.master,
578            is_chief=(FLAGS.task == 0),
579            init_fn=_get_init_fn(),
580            summary_op=summary_op,
581            number_of_steps=FLAGS.max_number_of_steps,
582            log_every_n_steps=FLAGS.log_every_n_steps,
583            save_summaries_secs=FLAGS.save_summaries_secs,
584            save_interval_secs=FLAGS.save_interval_secs,
585            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
586    
587    
588    if __name__ == '__main__':
589      tf.app.run()
590
591