tensorflow slim 源码分析
py的源码看起来还是很愉快的。。。(虽然熟练成程度完全不如cpp。。。。
datasets里是数据集相关
deployment是部署相关
nets里给了很多网络结构
preprocessing给了几种预处理的方式
这些都和slim没有太大关系,就不多废话了。
分析的部分见代码注释...
由于刚刚入门machine learning 一周...还有很多内容还没有从理论层面接触...所以源码的理解也十分有限...希望能以后有机会补充一波
1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 r"""Downloads and converts a particular dataset.
16
17 Usage:
18 ```shell
19
20 $ python download_and_convert_data.py \
21 --dataset_name=mnist \
22 --dataset_dir=/tmp/mnist
23
24 $ python download_and_convert_data.py \
25 --dataset_name=cifar10 \
26 --dataset_dir=/tmp/cifar10
27
28 $ python download_and_convert_data.py \
29 --dataset_name=flowers \
30 --dataset_dir=/tmp/flowers
31 ```
32 """
33 from __future__ import absolute_import #from __future__是为了解决python版本升级导致的兼容问题,没必要纠结
34 from __future__ import division
35 from __future__ import print_function
36
37 import tensorflow as tf
38
39 from datasets import download_and_convert_cifar10
40 from datasets import download_and_convert_flowers
41 from datasets import download_and_convert_mnist
42
43 FLAGS = tf.app.flags.FLAGS #FLAGS 用来传递或者设置tensforflow的参数
44
45 tf.app.flags.DEFINE_string( #设置的格式为:('参数名称',参数值,'参数的解释')
46 'dataset_name',
47 None,
48 'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".')
49
50 tf.app.flags.DEFINE_string(
51 'dataset_dir',
52 None,
53 'The directory where the output TFRecords and temporary files are saved.')
54
55
56 def main(_):
57 if not FLAGS.dataset_name:
58 raise ValueError('You must supply the dataset name with --dataset_name')
59 if not FLAGS.dataset_dir:
60 raise ValueError('You must supply the dataset directory with --dataset_dir')
61
62 if FLAGS.dataset_name == 'cifar10': #提供的三个数据集,[cifar10],[flowers],[mnist]
63 download_and_convert_cifar10.run(FLAGS.dataset_dir)
64 elif FLAGS.dataset_name == 'flowers':
65 download_and_convert_flowers.run(FLAGS.dataset_dir)
66 elif FLAGS.dataset_name == 'mnist':
67 download_and_convert_mnist.run(FLAGS.dataset_dir)
68 else:
69 raise ValueError(
70 'dataset_name [%s] was not recognized.' % FLAGS.dataset_name) #数据经名字不属于上述三个
71
72 if __name__ == '__main__': #这种写法可以保证在该文件被import的时候不会执行main函数
73 tf.app.run()
74
75
76
77
78 # coding=utf-8
79 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
80 #
81 # Licensed under the Apache License, Version 2.0 (the "License");
82 # you may not use this file except in compliance with the License.
83 # You may obtain a copy of the License at
84 #
85 # http://www.apache.org/licenses/LICENSE-2.0
86 #
87 # Unless required by applicable law or agreed to in writing, software
88 # distributed under the License is distributed on an "AS IS" BASIS,
89 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
90 # See the License for the specific language governing permissions and
91 # limitations under the License.
92 # ==============================================================================
93 """Generic evaluation script that evaluates a model using a given dataset."""
94
95 from __future__ import absolute_import
96 from __future__ import division
97 from __future__ import print_function
98
99 import math
100 import tensorflow as tf
101
102 from datasets import dataset_factory
103 from nets import nets_factory
104 from preprocessing import preprocessing_factory
105
106 slim = tf.contrib.slim
107
108 tf.app.flags.DEFINE_integer(
109 'batch_size', 100, 'The number of samples in each batch.')
110
111 tf.app.flags.DEFINE_integer(
112 'max_num_batches', None,
113 'Max number of batches to evaluate by default use all.')
114
115 tf.app.flags.DEFINE_string(
116 'master', '', 'The address of the TensorFlow master to use.')
117
118 tf.app.flags.DEFINE_string(
119 'checkpoint_path', '/tmp/tfmodel/',
120 'The directory where the model was written to or an absolute path to a '
121 'checkpoint file.')
122
123 tf.app.flags.DEFINE_string(
124 'eval_dir', '/tmp/tfmodel/', 'Directory where the results are saved to.')
125
126 tf.app.flags.DEFINE_integer(
127 'num_preprocessing_threads', 4,
128 'The number of threads used to create the batches.')
129
130 tf.app.flags.DEFINE_string(
131 'dataset_name', 'imagenet', 'The name of the dataset to load.')
132
133 tf.app.flags.DEFINE_string(
134 'dataset_split_name', 'test', 'The name of the train/test split.')
135
136 tf.app.flags.DEFINE_string(
137 'dataset_dir', None, 'The directory where the dataset files are stored.')
138
139 tf.app.flags.DEFINE_integer(
140 'labels_offset', 0,
141 'An offset for the labels in the dataset. This flag is primarily used to '
142 'evaluate the VGG and ResNet architectures which do not use a background '
143 'class for the ImageNet dataset.')
144
145 tf.app.flags.DEFINE_string(
146 'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
147
148 tf.app.flags.DEFINE_string(
149 'preprocessing_name', None, 'The name of the preprocessing to use. If left '
150 'as `None`, then the model_name flag is used.')
151
152 tf.app.flags.DEFINE_float(
153 'moving_average_decay', None,
154 'The decay to use for the moving average.'
155 'If left as None, then moving averages are not used.')
156
157 tf.app.flags.DEFINE_integer(
158 'eval_image_size', None, 'Eval image size')
159
160 FLAGS = tf.app.flags.FLAGS
161
162
163 def main(_):
164 if not FLAGS.dataset_dir:
165 raise ValueError('You must supply the dataset directory with --dataset_dir')
166
167 tf.logging.set_verbosity(tf.logging.INFO) #设置log信息的级别,有DEBUG, INFO, WARN, ERROR, or FATAL
168 with tf.Graph().as_default(): #overrides the current default graph for the lifetime of the context
169 #注意不是线程安全的..
170 tf_global_step = slim.get_or_create_global_step()
171 #slim.get_or_create_global_step可以参考tf.train.get_or_create_global_step
172 #作用同样是得到global step tensor,参数为graph,参数为空时认为参数为default graph
173 ######################
174 # Select the dataset #
175 ######################
176 dataset = dataset_factory.get_dataset(
177 FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
178
179 ####################
180 # Select the model #
181 ####################
182 network_fn = nets_factory.get_network_fn(
183 FLAGS.model_name,
184 num_classes=(dataset.num_classes - FLAGS.labels_offset),
185 is_training=False)
186
187 ##############################################################
188 # Create a dataset provider that loads data from the dataset #
189 ##############################################################
190 #读取数据
191 provider = slim.dataset_data_provider.DatasetDataProvider(
192 dataset,
193 shuffle=False,
194 common_queue_capacity=2 * FLAGS.batch_size,
195 common_queue_min=FLAGS.batch_size)
196 [image, label] = provider.get(['image', 'label'])
197 label -= FLAGS.labels_offset
198
199 #####################################
200 # Select the preprocessing function #
201 #####################################
202 preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
203 image_preprocessing_fn = preprocessing_factory.get_preprocessing(
204 preprocessing_name,
205 is_training=False)
206 #python or的用法,flag1 or flag2 or... or flagn,如果最后逻辑值为真,
207 # 返回的是(葱左至右)第一个使其为真的值(而不返回布尔值),
208 #如果都为假,则返回最后一个假值
209 eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size
210
211 image = image_preprocessing_fn(image, eval_image_size, eval_image_size)
212
213 images, labels = tf.train.batch(
214 [image, label],
215 batch_size=FLAGS.batch_size,
216 num_threads=FLAGS.num_preprocessing_threads,
217 capacity=5 * FLAGS.batch_size)
218
219 ####################
220 # Define the model #
221 ####################
222 logits, _ = network_fn(images) #python语法,序列解包
223
224 #移动平均,参考 https://en.wikipedia.org/wiki/Moving_average
225 if FLAGS.moving_average_decay:
226 variable_averages = tf.train.ExponentialMovingAverage(
227 FLAGS.moving_average_decay, tf_global_step)
228 variables_to_restore = variable_averages.variables_to_restore(
229 slim.get_model_variables())
230 variables_to_restore[tf_global_step.op.name] = tf_global_step
231 else:
232 variables_to_restore = slim.get_variables_to_restore()
233
234 predictions = tf.argmax(logits, 1)
235 #tf.argmax(input, axis=None, name=None, dimension=None)
236 #Returns the index with the largest value across axes of a tensor.
237 #就是返回logits的第一维(行?)最大值的位置索引
238
239
240 labels = tf.squeeze(labels) #将labels中维度是1的那一维去掉
241
242 # Define the metrics:
243 names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
244 'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
245 'Recall_5': slim.metrics.streaming_recall_at_k(
246 logits, labels, 5),
247 })
248 #metrics.aggregate_metric_map在metrics的list很长的时候的一种简便的表达方式
249 #metrics直接翻译为【度量】,不是tensorflow的概念.用来监控计算的性能指标
250
251
252 # Print the summaries to screen.
253 for name, value in names_to_values.items():
254 summary_name = 'eval/%s' % name
255 op = tf.summary.scalar(summary_name, value, collections=[])
256 #tf.summary.scalar :Outputs a Summary protocol buffer containing a single scalar value
257 op = tf.Print(op, [value], summary_name)
258 tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)
259
260 # TODO(sguada) use num_epochs=1
261 if FLAGS.max_num_batches:
262 num_batches = FLAGS.max_num_batches
263 else:
264 # This ensures that we make a single pass over all of the data.
265 num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))
266
267 if tf.gfile.IsDirectory(FLAGS.checkpoint_path):#返回是否为一个目录
268 checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
269 else:
270 checkpoint_path = FLAGS.checkpoint_path
271
272 tf.logging.info('Evaluating %s' % checkpoint_path) #记录log信息
273
274 #Evaluates the model at the given checkpoint path.
275 #Evaluates the model at the given checkpoint path.
276 slim.evaluation.evaluate_once(
277 master=FLAGS.master,
278 checkpoint_path=checkpoint_path,
279 logdir=FLAGS.eval_dir,
280 num_evals=num_batches,
281 eval_op=list(names_to_updates.values()),
282 variables_to_restore=variables_to_restore)
283
284
285 if __name__ == '__main__':
286 tf.app.run()
287
288
289
290
291 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
292 #
293 # Licensed under the Apache License, Version 2.0 (the "License");
294 # you may not use this file except in compliance with the License.
295 # You may obtain a copy of the License at
296 #
297 # http://www.apache.org/licenses/LICENSE-2.0
298 #
299 # Unless required by applicable law or agreed to in writing, software
300 # distributed under the License is distributed on an "AS IS" BASIS,
301 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
302 # See the License for the specific language governing permissions and
303 # limitations under the License.
304 # ==============================================================================
305 r"""Saves out a GraphDef containing the architecture of the model.
306
307 To use it, run something like this, with a model name defined by slim:
308
309 bazel build tensorflow_models/slim:export_inference_graph
310 bazel-bin/tensorflow_models/slim/export_inference_graph \
311 --model_name=inception_v3 --output_file=/tmp/inception_v3_inf_graph.pb
312
313 If you then want to use the resulting model with your own or pretrained
314 checkpoints as part of a mobile model, you can run freeze_graph to get a graph
315 def with the variables inlined as constants using:
316
317 bazel build tensorflow/python/tools:freeze_graph
318 bazel-bin/tensorflow/python/tools/freeze_graph \
319 --input_graph=/tmp/inception_v3_inf_graph.pb \
320 --input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \
321 --input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \
322 --output_node_names=InceptionV3/Predictions/Reshape_1
323
324 The output node names will vary depending on the model, but you can inspect and
325 estimate them using the summarize_graph tool:
326
327 bazel build tensorflow/tools/graph_transforms:summarize_graph
328 bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
329 --in_graph=/tmp/inception_v3_inf_graph.pb
330
331 To run the resulting graph in C++, you can look at the label_image sample code:
332
333 bazel build tensorflow/examples/label_image:label_image
334 bazel-bin/tensorflow/examples/label_image/label_image \
335 --image=${HOME}/Pictures/flowers.jpg \
336 --input_layer=input \
337 --output_layer=InceptionV3/Predictions/Reshape_1 \
338 --graph=/tmp/frozen_inception_v3.pb \
339 --labels=/tmp/imagenet_slim_labels.txt \
340 --input_mean=0 \
341 --input_std=255
342
343 """
344
345 from __future__ import absolute_import
346 from __future__ import division
347 from __future__ import print_function
348
349 import tensorflow as tf
350
351 from tensorflow.python.platform import gfile
352 from datasets import dataset_factory
353 from nets import nets_factory
354
355
356 slim = tf.contrib.slim
357
358 tf.app.flags.DEFINE_string(
359 'model_name', 'inception_v3', 'The name of the architecture to save.')
360
361 tf.app.flags.DEFINE_boolean(
362 'is_training', False,
363 'Whether to save out a training-focused version of the model.')
364
365 tf.app.flags.DEFINE_integer(
366 'default_image_size', 224,
367 'The image size to use if the model does not define it.')
368
369 tf.app.flags.DEFINE_string('dataset_name', 'imagenet',
370 'The name of the dataset to use with the model.')
371
372 tf.app.flags.DEFINE_integer(
373 'labels_offset', 0,
374 'An offset for the labels in the dataset. This flag is primarily used to '
375 'evaluate the VGG and ResNet architectures which do not use a background '
376 'class for the ImageNet dataset.')
377
378 tf.app.flags.DEFINE_string(
379 'output_file', '', 'Where to save the resulting file to.')
380
381 tf.app.flags.DEFINE_string(
382 'dataset_dir', '', 'Directory to save intermediate dataset files to')
383
384 FLAGS = tf.app.flags.FLAGS
385
386
387 def main(_):
388 if not FLAGS.output_file:
389 raise ValueError('You must supply the path to save to with --output_file')
390 tf.logging.set_verbosity(tf.logging.INFO)
391 with tf.Graph().as_default() as graph:
392 dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
393 FLAGS.dataset_dir)
394 network_fn = nets_factory.get_network_fn(
395 FLAGS.model_name,
396 num_classes=(dataset.num_classes - FLAGS.labels_offset),
397 is_training=FLAGS.is_training)
398
399 #hasattr 是python语法, hasattr(object, name) -> bool,用来判断object中是否有name属性
400 if hasattr(network_fn, 'default_image_size'):
401 image_size = network_fn.default_image_size
402 else:
403 image_size = FLAGS.default_image_size
404 placeholder = tf.placeholder(name='input', dtype=tf.float32,
405 shape=[1, image_size, image_size, 3])
406 network_fn(placeholder)
407 graph_def = graph.as_graph_def()
408 #graph.as_graph_def():Returns a serialized(序列化) GraphDef representation of this graph.
409 #The serialized GraphDef can be imported into another Graph (using tf.import_graph_def) or used with the C++ Session API.
410 #该方法线程安全
411
412
413 # gfile。GFile 是一个无线程锁的I/O 封装
414 with gfile.GFile(FLAGS.output_file, 'wb') as f:
415 f.write(graph_def.SerializeToString())
416
417
418 if __name__ == '__main__':
419 tf.app.run()
420
1
2
3
4 # coding=utf-8
5 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
6 #
7 # Licensed under the Apache License, Version 2.0 (the "License");
8 # you may not use this file except in compliance with the License.
9 # You may obtain a copy of the License at
10 #
11 # http://www.apache.org/licenses/LICENSE-2.0
12 #
13 # Unless required by applicable law or agreed to in writing, software
14 # distributed under the License is distributed on an "AS IS" BASIS,
15 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 # See the License for the specific language governing permissions and
17 # limitations under the License.
18 # ==============================================================================
19 """Generic training script that trains a model using a given dataset."""
20
21 from __future__ import absolute_import
22 from __future__ import division
23 from __future__ import print_function
24
25 import tensorflow as tf
26
27 from datasets import dataset_factory
28 from deployment import model_deploy
29 from nets import nets_factory
30 from preprocessing import preprocessing_factory
31
32 slim = tf.contrib.slim
33
34 tf.app.flags.DEFINE_string(
35 'master', '', 'The address of the TensorFlow master to use.')
36
37 tf.app.flags.DEFINE_string(
38 'train_dir', '/tmp/tfmodel/',
39 'Directory where checkpoints and event logs are written to.')
40
41 tf.app.flags.DEFINE_integer('num_clones', 1,
42 'Number of model clones to deploy.')
43
44 tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
45 'Use CPUs to deploy clones.')
46
47 tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.')
48
49 tf.app.flags.DEFINE_integer(
50 'num_ps_tasks', 0,
51 'The number of parameter servers. If the value is 0, then the parameters '
52 'are handled locally by the worker.')
53
54 tf.app.flags.DEFINE_integer(
55 'num_readers', 4,
56 'The number of parallel readers that read data from the dataset.')
57
58 tf.app.flags.DEFINE_integer(
59 'num_preprocessing_threads', 4,
60 'The number of threads used to create the batches.')
61
62 tf.app.flags.DEFINE_integer(
63 'log_every_n_steps', 10,
64 'The frequency with which logs are print.')
65
66 tf.app.flags.DEFINE_integer(
67 'save_summaries_secs', 600,
68 'The frequency with which summaries are saved, in seconds.')
69
70 tf.app.flags.DEFINE_integer(
71 'save_interval_secs', 600,
72 'The frequency with which the model is saved, in seconds.')
73
74 tf.app.flags.DEFINE_integer(
75 'task', 0, 'Task id of the replica running the training.')
76
77 ######################
78 # Optimization Flags #
79 ######################
80
81 tf.app.flags.DEFINE_float(
82 'weight_decay', 0.00004, 'The weight decay on the model weights.')
83
84 tf.app.flags.DEFINE_string(
85 'optimizer', 'rmsprop',
86 'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
87 '"ftrl", "momentum", "sgd" or "rmsprop".')
88
89 tf.app.flags.DEFINE_float(
90 'adadelta_rho', 0.95,
91 'The decay rate for adadelta.')
92
93 tf.app.flags.DEFINE_float(
94 'adagrad_initial_accumulator_value', 0.1,
95 'Starting value for the AdaGrad accumulators.')
96
97 tf.app.flags.DEFINE_float(
98 'adam_beta1', 0.9,
99 'The exponential decay rate for the 1st moment estimates.')
100
101 tf.app.flags.DEFINE_float(
102 'adam_beta2', 0.999,
103 'The exponential decay rate for the 2nd moment estimates.')
104
105 tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')
106
107 tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
108 'The learning rate power.')
109
110 tf.app.flags.DEFINE_float(
111 'ftrl_initial_accumulator_value', 0.1,
112 'Starting value for the FTRL accumulators.')
113
114 tf.app.flags.DEFINE_float(
115 'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')
116
117 tf.app.flags.DEFINE_float(
118 'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')
119
120 tf.app.flags.DEFINE_float(
121 'momentum', 0.9,
122 'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
123
124 tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
125
126 #######################
127 # Learning Rate Flags #
128 #######################
129
130 tf.app.flags.DEFINE_string(
131 'learning_rate_decay_type',
132 'exponential',
133 'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
134 ' or "polynomial"')
135
136 tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
137
138 tf.app.flags.DEFINE_float(
139 'end_learning_rate', 0.0001,
140 'The minimal end learning rate used by a polynomial decay learning rate.')
141
142 tf.app.flags.DEFINE_float(
143 'label_smoothing', 0.0, 'The amount of label smoothing.')
144
145 tf.app.flags.DEFINE_float(
146 'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')
147
148 tf.app.flags.DEFINE_float(
149 'num_epochs_per_decay', 2.0,
150 'Number of epochs after which learning rate decays.')
151
152 tf.app.flags.DEFINE_bool(
153 'sync_replicas', False,
154 'Whether or not to synchronize the replicas during training.')
155
156 tf.app.flags.DEFINE_integer(
157 'replicas_to_aggregate', 1,
158 'The Number of gradients to collect before updating params.')
159
160 tf.app.flags.DEFINE_float(
161 'moving_average_decay', None,
162 'The decay to use for the moving average.'
163 'If left as None, then moving averages are not used.')
164
165 #######################
166 # Dataset Flags #
167 #######################
168
169 tf.app.flags.DEFINE_string(
170 'dataset_name', 'imagenet', 'The name of the dataset to load.')
171
172 tf.app.flags.DEFINE_string(
173 'dataset_split_name', 'train', 'The name of the train/test split.')
174
175 tf.app.flags.DEFINE_string(
176 'dataset_dir', None, 'The directory where the dataset files are stored.')
177
178 tf.app.flags.DEFINE_integer(
179 'labels_offset', 0,
180 'An offset for the labels in the dataset. This flag is primarily used to '
181 'evaluate the VGG and ResNet architectures which do not use a background '
182 'class for the ImageNet dataset.')
183
184 tf.app.flags.DEFINE_string(
185 'model_name', 'inception_v3', 'The name of the architecture to train.')
186
187 tf.app.flags.DEFINE_string(
188 'preprocessing_name', None, 'The name of the preprocessing to use. If left '
189 'as `None`, then the model_name flag is used.')
190
191 tf.app.flags.DEFINE_integer(
192 'batch_size', 32, 'The number of samples in each batch.')
193
194 tf.app.flags.DEFINE_integer(
195 'train_image_size', None, 'Train image size')
196
197 tf.app.flags.DEFINE_integer('max_number_of_steps', None,
198 'The maximum number of training steps.')
199
200 #####################
201 # Fine-Tuning Flags #
202 #####################
203
204 tf.app.flags.DEFINE_string(
205 'checkpoint_path', None,
206 'The path to a checkpoint from which to fine-tune.')
207
208 tf.app.flags.DEFINE_string(
209 'checkpoint_exclude_scopes', None,
210 'Comma-separated list of scopes of variables to exclude when restoring '
211 'from a checkpoint.')
212
213 tf.app.flags.DEFINE_string(
214 'trainable_scopes', None,
215 'Comma-separated list of scopes to filter the set of variables to train.'
216 'By default, None would train all the variables.')
217
218 tf.app.flags.DEFINE_boolean(
219 'ignore_missing_vars', False,
220 'When restoring a checkpoint would ignore missing variables.')
221
222 FLAGS = tf.app.flags.FLAGS
223
224
225 def _configure_learning_rate(num_samples_per_epoch, global_step):
226 """Configures the learning rate.
227
228 Args:
229 num_samples_per_epoch: The number of samples in each epoch of training.
230 global_step: The global_step tensor.
231
232 Returns:
233 A `Tensor` representing the learning rate.
234
235 Raises:
236 ValueError: if
237 """
238 decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
239 FLAGS.num_epochs_per_decay)
240 if FLAGS.sync_replicas:
241 decay_steps /= FLAGS.replicas_to_aggregate
242 #dacay,衰退
243
244
245 #下面是几种学习速率的变化形式,可以是指数型衰退,可以是固定不变,也可以使多项式型衰退。
246 if FLAGS.learning_rate_decay_type == 'exponential':
247 return tf.train.exponential_decay(FLAGS.learning_rate,
248 global_step,
249 decay_steps,
250 FLAGS.learning_rate_decay_factor,
251 staircase=True,
252 name='exponential_decay_learning_rate')
253 elif FLAGS.learning_rate_decay_type == 'fixed':
254 return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
255 elif FLAGS.learning_rate_decay_type == 'polynomial':
256 return tf.train.polynomial_decay(FLAGS.learning_rate,
257 global_step,
258 decay_steps,
259 FLAGS.end_learning_rate,
260 power=1.0,
261 cycle=False,
262 name='polynomial_decay_learning_rate')
263 else:
264 raise ValueError('learning_rate_decay_type [%s] was not recognized',
265 FLAGS.learning_rate_decay_type)
266
267
268
269 #选择优化方法(优化器,大概是求偏导的具体数值计算方法?),tf内置了许多优化方法,类比梯度下降,细节黑箱即可。
270 def _configure_optimizer(learning_rate):
271 """Configures the optimizer used for training.
272
273 Args:
274 learning_rate: A scalar or `Tensor` learning rate.
275
276 Returns:
277 An instance of an optimizer.
278
279 Raises:
280 ValueError: if FLAGS.optimizer is not recognized.
281 """
282 if FLAGS.optimizer == 'adadelta':
283 optimizer = tf.train.AdadeltaOptimizer(
284 learning_rate,
285 rho=FLAGS.adadelta_rho,
286 epsilon=FLAGS.opt_epsilon)
287 elif FLAGS.optimizer == 'adagrad':
288 optimizer = tf.train.AdagradOptimizer(
289 learning_rate,
290 initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
291 elif FLAGS.optimizer == 'adam':
292 optimizer = tf.train.AdamOptimizer(
293 learning_rate,
294 beta1=FLAGS.adam_beta1,
295 beta2=FLAGS.adam_beta2,
296 epsilon=FLAGS.opt_epsilon)
297 elif FLAGS.optimizer == 'ftrl':
298 optimizer = tf.train.FtrlOptimizer(
299 learning_rate,
300 learning_rate_power=FLAGS.ftrl_learning_rate_power,
301 initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
302 l1_regularization_strength=FLAGS.ftrl_l1,
303 l2_regularization_strength=FLAGS.ftrl_l2)
304 elif FLAGS.optimizer == 'momentum':
305 optimizer = tf.train.MomentumOptimizer(
306 learning_rate,
307 momentum=FLAGS.momentum,
308 name='Momentum')
309 elif FLAGS.optimizer == 'rmsprop':
310 optimizer = tf.train.RMSPropOptimizer(
311 learning_rate,
312 decay=FLAGS.rmsprop_decay,
313 momentum=FLAGS.momentum,
314 epsilon=FLAGS.opt_epsilon)
315 elif FLAGS.optimizer == 'sgd':
316 optimizer = tf.train.GradientDescentOptimizer(learning_rate)
317 else:
318 raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
319 return optimizer
320
321
322 def _get_init_fn():
323 """Returns a function run by the chief worker to warm-start the training.
324
325 Note that the init_fn is only run when initializing the model during the very
326 first global step.
327
328 Returns:
329 An init function run by the supervisor.
330 """
331 if FLAGS.checkpoint_path is None:
332 return None
333
334 # Warn the user if a checkpoint exists in the train_dir. Then we'll be
335 # ignoring the checkpoint anyway.
336 if tf.train.latest_checkpoint(FLAGS.train_dir):
337 tf.logging.info(
338 'Ignoring --checkpoint_path because a checkpoint already exists in %s'
339 % FLAGS.train_dir)
340 return None
341
342 exclusions = [] #python的list数据类型
343 if FLAGS.checkpoint_exclude_scopes:
344 exclusions = [scope.strip()
345 for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
346
347 # TODO(sguada) variables.filter_variables()
348 variables_to_restore = []
349 for var in slim.get_model_variables():
350 excluded = False
351 for exclusion in exclusions:
352 if var.op.name.startswith(exclusion):
353 excluded = True
354 break
355 if not excluded:
356 variables_to_restore.append(var)
357
358 if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
359 checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
360 else:
361 checkpoint_path = FLAGS.checkpoint_path
362
363 tf.logging.info('Fine-tuning from %s' % checkpoint_path)
364
365 return slim.assign_from_checkpoint_fn(
366 checkpoint_path,
367 variables_to_restore,
368 ignore_missing_vars=FLAGS.ignore_missing_vars)
369
370
371 def _get_variables_to_train():
372 """Returns a list of variables to train.
373
374 Returns:
375 A list of variables to train by the optimizer.
376 """
377 if FLAGS.trainable_scopes is None:
378 return tf.trainable_variables()
379 else:
380 scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]
381
382 variables_to_train = []
383 for scope in scopes:
384 variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
385 variables_to_train.extend(variables)
386 return variables_to_train
387
388
389 def main(_):
390 if not FLAGS.dataset_dir:
391 raise ValueError('You must supply the dataset directory with --dataset_dir')
392
393 tf.logging.set_verbosity(tf.logging.INFO)
394 with tf.Graph().as_default():
395 #######################
396 # Config model_deploy #
397 #######################
398 deploy_config = model_deploy.DeploymentConfig(
399 num_clones=FLAGS.num_clones,
400 clone_on_cpu=FLAGS.clone_on_cpu,
401 replica_id=FLAGS.task,
402 num_replicas=FLAGS.worker_replicas,
403 num_ps_tasks=FLAGS.num_ps_tasks)
404
405 # Create global_step
406 with tf.device(deploy_config.variables_device()):
407 global_step = slim.create_global_step()
408
409 ######################
410 # Select the dataset #
411 ######################
412 dataset = dataset_factory.get_dataset(
413 FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
414
415 ######################
416 # Select the network #
417 ######################
418 network_fn = nets_factory.get_network_fn(
419 FLAGS.model_name,
420 num_classes=(dataset.num_classes - FLAGS.labels_offset),
421 weight_decay=FLAGS.weight_decay,
422 is_training=True)
423
424 #####################################
425 # Select the preprocessing function #
426 #####################################
427 preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
428 image_preprocessing_fn = preprocessing_factory.get_preprocessing(
429 preprocessing_name,
430 is_training=True)
431
432 ##############################################################
433 # Create a dataset provider that loads data from the dataset #
434 ##############################################################
435 with tf.device(deploy_config.inputs_device()):
436 provider = slim.dataset_data_provider.DatasetDataProvider( #还是定义一些数据的读取方式
437 dataset,
438 num_readers=FLAGS.num_readers,
439 common_queue_capacity=20 * FLAGS.batch_size,
440 common_queue_min=10 * FLAGS.batch_size)
441 [image, label] = provider.get(['image', 'label'])
442 label -= FLAGS.labels_offset
443
444 train_image_size = FLAGS.train_image_size or network_fn.default_image_size
445
446 image = image_preprocessing_fn(image, train_image_size, train_image_size)
447
448 images, labels = tf.train.batch(
449 [image, label],
450 batch_size=FLAGS.batch_size,
451 num_threads=FLAGS.num_preprocessing_threads,
452 capacity=5 * FLAGS.batch_size)
453 labels = slim.one_hot_encoding( #one-hot是一种向量编码方式,n维向量只有为相应值的位置为1,其余都为0
454 labels, dataset.num_classes - FLAGS.labels_offset)
455 batch_queue = slim.prefetch_queue.prefetch_queue(
456 [images, labels], capacity=2 * deploy_config.num_clones)
457
458 ####################
459 # Define the model #
460 ####################
461 #通过复制多个网络来实现并行
462 def clone_fn(batch_queue):
463 """Allows data parallelism by creating multiple clones of network_fn."""
464 with tf.device(deploy_config.inputs_device()):
465 images, labels = batch_queue.dequeue() #dequeue,双端队列。。
466 logits, end_points = network_fn(images)
467
468 #############################
469 # Specify the loss function #
470 #############################
471 if 'AuxLogits' in end_points:
472 tf.losses.softmax_cross_entropy( #softmax函数对应使用的cost function(loss function)
473 #是corss_entropy,也就是交叉熵
474 logits=end_points['AuxLogits'], onehot_labels=labels,
475 label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
476 tf.losses.softmax_cross_entropy(
477 logits=logits, onehot_labels=labels,
478 label_smoothing=FLAGS.label_smoothing, weights=1.0)
479 #Label Smoothing Regularization,一种防止overfit的优化方法
480 return end_points
481
482 # Gather initial summaries.
483 summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
484
485 clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
486 first_clone_scope = deploy_config.clone_scope(0)
487 # Gather update_ops from the first clone. These contain, for example,
488 # the updates for the batch_norm variables created by network_fn.
489 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
490
491 # Add summaries for end_points.
492 end_points = clones[0].outputs
493 for end_point in end_points:
494 x = end_points[end_point]
495 summaries.add(tf.summary.histogram('activations/' + end_point, x))
496 summaries.add(tf.summary.scalar('sparsity/' + end_point,
497 tf.nn.zero_fraction(x)))
498
499 # Add summaries for losses.
500 for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
501 summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
502
503 # Add summaries for variables.
504 for variable in slim.get_model_variables():
505 summaries.add(tf.summary.histogram(variable.op.name, variable))
506 #突然画图
507
508 #################################
509 # Configure the moving averages # #参考moving averages的wiki
510 #################################
511 if FLAGS.moving_average_decay:
512 moving_average_variables = slim.get_model_variables()
513 variable_averages = tf.train.ExponentialMovingAverage(
514 FLAGS.moving_average_decay, global_step)
515 else:
516 moving_average_variables, variable_averages = None, None
517
518 #########################################
519 # Configure the optimization procedure. #
520 #########################################
521 with tf.device(deploy_config.optimizer_device()):
522 learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
523 optimizer = _configure_optimizer(learning_rate)
524 summaries.add(tf.summary.scalar('learning_rate', learning_rate))
525
526
527 #在分布式系统上训练的同步...
528 if FLAGS.sync_replicas:
529 # If sync_replicas is enabled, the averaging will be done in the chief
530 # queue runner.
531 optimizer = tf.train.SyncReplicasOptimizer(
532 opt=optimizer,
533 replicas_to_aggregate=FLAGS.replicas_to_aggregate,
534 variable_averages=variable_averages,
535 variables_to_average=moving_average_variables,
536 replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
537 total_num_replicas=FLAGS.worker_replicas)
538 elif FLAGS.moving_average_decay:
539 # Update ops executed locally by trainer.
540 update_ops.append(variable_averages.apply(moving_average_variables))
541
542 # Variables to train.
543 variables_to_train = _get_variables_to_train()
544
545 # and returns a train_tensor and summary_op
546 total_loss, clones_gradients = model_deploy.optimize_clones(
547 clones,
548 optimizer,
549 var_list=variables_to_train)
550 # Add total_loss to summary.
551 summaries.add(tf.summary.scalar('total_loss', total_loss))
552
553 # Create gradient updates.
554 grad_updates = optimizer.apply_gradients(clones_gradients,
555 global_step=global_step)
556 update_ops.append(grad_updates)
557
558 update_op = tf.group(*update_ops)
559 with tf.control_dependencies([update_op]):
560 train_tensor = tf.identity(total_loss, name='train_op')
561
562 # Add the summaries from the first clone. These contain the summaries
563 # created by model_fn and either optimize_clones() or _gather_clone_loss().
564 summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
565 first_clone_scope))
566
567 # Merge all summaries together.
568 summary_op = tf.summary.merge(list(summaries), name='summary_op')
569
570
571 ###########################
572 # Kicks off the training. #
573 ###########################
574 slim.learning.train(
575 train_tensor,
576 logdir=FLAGS.train_dir,
577 master=FLAGS.master,
578 is_chief=(FLAGS.task == 0),
579 init_fn=_get_init_fn(),
580 summary_op=summary_op,
581 number_of_steps=FLAGS.max_number_of_steps,
582 log_every_n_steps=FLAGS.log_every_n_steps,
583 save_summaries_secs=FLAGS.save_summaries_secs,
584 save_interval_secs=FLAGS.save_interval_secs,
585 sync_optimizer=optimizer if FLAGS.sync_replicas else None)
586
587
588 if __name__ == '__main__':
589 tf.app.run()
590
591