tensorflow slim 源码分析
py的源码看起来还是很愉快的。。。(虽然熟练成程度完全不如cpp。。。。
datasets里是数据集相关
deployment是部署相关
nets里给了很多网络结构
preprocessing给了几种预处理的方式
这些都和slim没有太大关系,就不多废话了。
分析的部分见代码注释…
由于刚刚入门machine learning 一周…还有很多内容还没有从理论层面接触…所以源码的理解也十分有限…希望能以后有机会补充一波
1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 r"""Downloads and converts a particular dataset.
16
17 Usage:
18 ```shell
19
20```lua
21$ python download_and_convert_data.py \
22 --dataset_name=mnist \
23 --dataset_dir=/tmp/mnist
1$ python download_and_convert_data.py \
2 --dataset_name=cifar10 \
3 --dataset_dir=/tmp/cifar10
$ python download_and_convert_data.py \
--dataset_name=flowers \
--dataset_dir=/tmp/flowers
```python
"""
from __future__ import absolute_import #from __future__是为了解决python版本升级导致的兼容问题,没必要纠结
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from datasets import download_and_convert_cifar10
from datasets import download_and_convert_flowers
from datasets import download_and_convert_mnist
FLAGS = tf.app.flags.FLAGS #FLAGS 用来传递或者设置tensforflow的参数
tf.app.flags.DEFINE_string( #设置的格式为:('参数名称',参数值,'参数的解释')
'dataset_name',
None,
'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".')
tf.app.flags.DEFINE_string(
'dataset_dir',
None,
'The directory where the output TFRecords and temporary files are saved.')
def main(_):
if not FLAGS.dataset_name:
raise ValueError('You must supply the dataset name with --dataset_name')
if not FLAGS.dataset_dir:
raise ValueError('You must supply the dataset directory with --dataset_dir')
if FLAGS.dataset_name == 'cifar10': #提供的三个数据集,[cifar10],[flowers],[mnist]
download_and_convert_cifar10.run(FLAGS.dataset_dir)
elif FLAGS.dataset_name == 'flowers':
download_and_convert_flowers.run(FLAGS.dataset_dir)
elif FLAGS.dataset_name == 'mnist':
download_and_convert_mnist.run(FLAGS.dataset_dir)
else:
raise ValueError(
'dataset_name [%s] was not recognized.' % FLAGS.dataset_name) #数据经名字不属于上述三个
if __name__ == '__main__': #这种写法可以保证在该文件被import的时候不会执行main函数
tf.app.run()
# coding=utf-8
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generic evaluation script that evaluates a model using a given dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
from datasets import dataset_factory
from nets import nets_factory
from preprocessing import preprocessing_factory
slim = tf.contrib.slim
tf.app.flags.DEFINE_integer(
'batch_size', 100, 'The number of samples in each batch.')
tf.app.flags.DEFINE_integer(
'max_num_batches', None,
'Max number of batches to evaluate by default use all.')
tf.app.flags.DEFINE_string(
'master', '', 'The address of the TensorFlow master to use.')
tf.app.flags.DEFINE_string(
'checkpoint_path', '/tmp/tfmodel/',
'The directory where the model was written to or an absolute path to a '
'checkpoint file.')
tf.app.flags.DEFINE_string(
'eval_dir', '/tmp/tfmodel/', 'Directory where the results are saved to.')
tf.app.flags.DEFINE_integer(
'num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
tf.app.flags.DEFINE_string(
'dataset_name', 'imagenet', 'The name of the dataset to load.')
tf.app.flags.DEFINE_string(
'dataset_split_name', 'test', 'The name of the train/test split.')
tf.app.flags.DEFINE_string(
'dataset_dir', None, 'The directory where the dataset files are stored.')
tf.app.flags.DEFINE_integer(
'labels_offset', 0,
'An offset for the labels in the dataset. This flag is primarily used to '
'evaluate the VGG and ResNet architectures which do not use a background '
'class for the ImageNet dataset.')
tf.app.flags.DEFINE_string(
'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
tf.app.flags.DEFINE_string(
'preprocessing_name', None, 'The name of the preprocessing to use. If left '
'as `None`, then the model_name flag is used.')
tf.app.flags.DEFINE_float(
'moving_average_decay', None,
'The decay to use for the moving average.'
'If left as None, then moving averages are not used.')
tf.app.flags.DEFINE_integer(
'eval_image_size', None, 'Eval image size')
FLAGS = tf.app.flags.FLAGS
def main(_):
if not FLAGS.dataset_dir:
raise ValueError('You must supply the dataset directory with --dataset_dir')
tf.logging.set_verbosity(tf.logging.INFO) #设置log信息的级别,有DEBUG, INFO, WARN, ERROR, or FATAL
with tf.Graph().as_default(): #overrides the current default graph for the lifetime of the context
#注意不是线程安全的..
tf_global_step = slim.get_or_create_global_step()
#slim.get_or_create_global_step可以参考tf.train.get_or_create_global_step
#作用同样是得到global step tensor,参数为graph,参数为空时认为参数为default graph
######################
# Select the dataset #
######################
dataset = dataset_factory.get_dataset(
FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
####################
# Select the model #
####################
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=False)
##############################################################
# Create a dataset provider that loads data from the dataset #
##############################################################
#读取数据
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=False,
common_queue_capacity=2 * FLAGS.batch_size,
common_queue_min=FLAGS.batch_size)
[image, label] = provider.get(['image', 'label'])
label -= FLAGS.labels_offset
#####################################
# Select the preprocessing function #
#####################################
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
preprocessing_name,
is_training=False)
#python or的用法,flag1 or flag2 or... or flagn,如果最后逻辑值为真,
# 返回的是(葱左至右)第一个使其为真的值(而不返回布尔值),
#如果都为假,则返回最后一个假值
eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size
image = image_preprocessing_fn(image, eval_image_size, eval_image_size)
images, labels = tf.train.batch(
[image, label],
batch_size=FLAGS.batch_size,
num_threads=FLAGS.num_preprocessing_threads,
capacity=5 * FLAGS.batch_size)
####################
# Define the model #
####################
logits, _ = network_fn(images) #python语法,序列解包
#移动平均,参考 https://en.wikipedia.org/wiki/Moving_average
if FLAGS.moving_average_decay:
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, tf_global_step)
variables_to_restore = variable_averages.variables_to_restore(
slim.get_model_variables())
variables_to_restore[tf_global_step.op.name] = tf_global_step
else:
variables_to_restore = slim.get_variables_to_restore()
predictions = tf.argmax(logits, 1)
#tf.argmax(input, axis=None, name=None, dimension=None)
#Returns the index with the largest value across axes of a tensor.
#就是返回logits的第一维(行?)最大值的位置索引
labels = tf.squeeze(labels) #将labels中维度是1的那一维去掉
# Define the metrics:
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
'Recall_5': slim.metrics.streaming_recall_at_k(
logits, labels, 5),
})
#metrics.aggregate_metric_map在metrics的list很长的时候的一种简便的表达方式
#metrics直接翻译为【度量】,不是tensorflow的概念.用来监控计算的性能指标
# Print the summaries to screen.
for name, value in names_to_values.items():
summary_name = 'eval/%s' % name
op = tf.summary.scalar(summary_name, value, collections=[])
#tf.summary.scalar :Outputs a Summary protocol buffer containing a single scalar value
op = tf.Print(op, [value], summary_name)
tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)
# TODO(sguada) use num_epochs=1
if FLAGS.max_num_batches:
num_batches = FLAGS.max_num_batches
else:
# This ensures that we make a single pass over all of the data.
num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):#返回是否为一个目录
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
else:
checkpoint_path = FLAGS.checkpoint_path
tf.logging.info('Evaluating %s' % checkpoint_path) #记录log信息
#Evaluates the model at the given checkpoint path.
#Evaluates the model at the given checkpoint path.
slim.evaluation.evaluate_once(
master=FLAGS.master,
checkpoint_path=checkpoint_path,
logdir=FLAGS.eval_dir,
num_evals=num_batches,
eval_op=list(names_to_updates.values()),
variables_to_restore=variables_to_restore)
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Saves out a GraphDef containing the architecture of the model.
To use it, run something like this, with a model name defined by slim:
bazel build tensorflow_models/slim:export_inference_graph
bazel-bin/tensorflow_models/slim/export_inference_graph \
--model_name=inception_v3 --output_file=/tmp/inception_v3_inf_graph.pb
If you then want to use the resulting model with your own or pretrained
checkpoints as part of a mobile model, you can run freeze_graph to get a graph
def with the variables inlined as constants using:
bazel build tensorflow/python/tools:freeze_graph
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/tmp/inception_v3_inf_graph.pb \
--input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \
--input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \
--output_node_names=InceptionV3/Predictions/Reshape_1
The output node names will vary depending on the model, but you can inspect and
estimate them using the summarize_graph tool:
bazel build tensorflow/tools/graph_transforms:summarize_graph
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
--in_graph=/tmp/inception_v3_inf_graph.pb
To run the resulting graph in C++, you can look at the label_image sample code:
bazel build tensorflow/examples/label_image:label_image
bazel-bin/tensorflow/examples/label_image/label_image \
--image=${HOME}/Pictures/flowers.jpg \
--input_layer=input \
--output_layer=InceptionV3/Predictions/Reshape_1 \
--graph=/tmp/frozen_inception_v3.pb \
--labels=/tmp/imagenet_slim_labels.txt \
--input_mean=0 \
--input_std=255
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.platform import gfile
from datasets import dataset_factory
from nets import nets_factory
slim = tf.contrib.slim
tf.app.flags.DEFINE_string(
'model_name', 'inception_v3', 'The name of the architecture to save.')
tf.app.flags.DEFINE_boolean(
'is_training', False,
'Whether to save out a training-focused version of the model.')
tf.app.flags.DEFINE_integer(
'default_image_size', 224,
'The image size to use if the model does not define it.')
tf.app.flags.DEFINE_string('dataset_name', 'imagenet',
'The name of the dataset to use with the model.')
tf.app.flags.DEFINE_integer(
'labels_offset', 0,
'An offset for the labels in the dataset. This flag is primarily used to '
'evaluate the VGG and ResNet architectures which do not use a background '
'class for the ImageNet dataset.')
tf.app.flags.DEFINE_string(
'output_file', '', 'Where to save the resulting file to.')
tf.app.flags.DEFINE_string(
'dataset_dir', '', 'Directory to save intermediate dataset files to')
FLAGS = tf.app.flags.FLAGS
def main(_):
if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
FLAGS.dataset_dir)
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=FLAGS.is_training)
#hasattr 是python语法, hasattr(object, name) -> bool,用来判断object中是否有name属性
if hasattr(network_fn, 'default_image_size'):
image_size = network_fn.default_image_size
else:
image_size = FLAGS.default_image_size
placeholder = tf.placeholder(name='input', dtype=tf.float32,
shape=[1, image_size, image_size, 3])
network_fn(placeholder)
graph_def = graph.as_graph_def()
#graph.as_graph_def():Returns a serialized(序列化) GraphDef representation of this graph.
#The serialized GraphDef can be imported into another Graph (using tf.import_graph_def) or used with the C++ Session API.
#该方法线程安全
# gfile。GFile 是一个无线程锁的I/O 封装
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())
if __name__ == '__main__':
tf.app.run()
1
2
3```python
4
5
6
7 # coding=utf-8
8 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
9 #
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
13 #
14 # http://www.apache.org/licenses/LICENSE-2.0
15 #
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
21 # ==============================================================================
22 """Generic training script that trains a model using a given dataset."""
23
24 from __future__ import absolute_import
25 from __future__ import division
26 from __future__ import print_function
27
28 import tensorflow as tf
29
30 from datasets import dataset_factory
31 from deployment import model_deploy
32 from nets import nets_factory
33 from preprocessing import preprocessing_factory
34
35 slim = tf.contrib.slim
36
37 tf.app.flags.DEFINE_string(
38 'master', '', 'The address of the TensorFlow master to use.')
39
40 tf.app.flags.DEFINE_string(
41 'train_dir', '/tmp/tfmodel/',
42 'Directory where checkpoints and event logs are written to.')
43
44 tf.app.flags.DEFINE_integer('num_clones', 1,
45 'Number of model clones to deploy.')
46
47 tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
48 'Use CPUs to deploy clones.')
49
50 tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.')
51
52 tf.app.flags.DEFINE_integer(
53 'num_ps_tasks', 0,
54 'The number of parameter servers. If the value is 0, then the parameters '
55 'are handled locally by the worker.')
56
57 tf.app.flags.DEFINE_integer(
58 'num_readers', 4,
59 'The number of parallel readers that read data from the dataset.')
60
61 tf.app.flags.DEFINE_integer(
62 'num_preprocessing_threads', 4,
63 'The number of threads used to create the batches.')
64
65 tf.app.flags.DEFINE_integer(
66 'log_every_n_steps', 10,
67 'The frequency with which logs are print.')
68
69 tf.app.flags.DEFINE_integer(
70 'save_summaries_secs', 600,
71 'The frequency with which summaries are saved, in seconds.')
72
73 tf.app.flags.DEFINE_integer(
74 'save_interval_secs', 600,
75 'The frequency with which the model is saved, in seconds.')
76
77 tf.app.flags.DEFINE_integer(
78 'task', 0, 'Task id of the replica running the training.')
79
80 ######################
81 # Optimization Flags #
82 ######################
83
84 tf.app.flags.DEFINE_float(
85 'weight_decay', 0.00004, 'The weight decay on the model weights.')
86
87 tf.app.flags.DEFINE_string(
88 'optimizer', 'rmsprop',
89 'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
90 '"ftrl", "momentum", "sgd" or "rmsprop".')
91
92 tf.app.flags.DEFINE_float(
93 'adadelta_rho', 0.95,
94 'The decay rate for adadelta.')
95
96 tf.app.flags.DEFINE_float(
97 'adagrad_initial_accumulator_value', 0.1,
98 'Starting value for the AdaGrad accumulators.')
99
100 tf.app.flags.DEFINE_float(
101 'adam_beta1', 0.9,
102 'The exponential decay rate for the 1st moment estimates.')
103
104 tf.app.flags.DEFINE_float(
105 'adam_beta2', 0.999,
106 'The exponential decay rate for the 2nd moment estimates.')
107
108 tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')
109
110 tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
111 'The learning rate power.')
112
113 tf.app.flags.DEFINE_float(
114 'ftrl_initial_accumulator_value', 0.1,
115 'Starting value for the FTRL accumulators.')
116
117 tf.app.flags.DEFINE_float(
118 'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')
119
120 tf.app.flags.DEFINE_float(
121 'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')
122
123 tf.app.flags.DEFINE_float(
124 'momentum', 0.9,
125 'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
126
127 tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
128
129 #######################
130 # Learning Rate Flags #
131 #######################
132
133 tf.app.flags.DEFINE_string(
134 'learning_rate_decay_type',
135 'exponential',
136 'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
137 ' or "polynomial"')
138
139 tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
140
141 tf.app.flags.DEFINE_float(
142 'end_learning_rate', 0.0001,
143 'The minimal end learning rate used by a polynomial decay learning rate.')
144
145 tf.app.flags.DEFINE_float(
146 'label_smoothing', 0.0, 'The amount of label smoothing.')
147
148 tf.app.flags.DEFINE_float(
149 'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')
150
151 tf.app.flags.DEFINE_float(
152 'num_epochs_per_decay', 2.0,
153 'Number of epochs after which learning rate decays.')
154
155 tf.app.flags.DEFINE_bool(
156 'sync_replicas', False,
157 'Whether or not to synchronize the replicas during training.')
158
159 tf.app.flags.DEFINE_integer(
160 'replicas_to_aggregate', 1,
161 'The Number of gradients to collect before updating params.')
162
163 tf.app.flags.DEFINE_float(
164 'moving_average_decay', None,
165 'The decay to use for the moving average.'
166 'If left as None, then moving averages are not used.')
167
168 #######################
169 # Dataset Flags #
170 #######################
171
172 tf.app.flags.DEFINE_string(
173 'dataset_name', 'imagenet', 'The name of the dataset to load.')
174
175 tf.app.flags.DEFINE_string(
176 'dataset_split_name', 'train', 'The name of the train/test split.')
177
178 tf.app.flags.DEFINE_string(
179 'dataset_dir', None, 'The directory where the dataset files are stored.')
180
181 tf.app.flags.DEFINE_integer(
182 'labels_offset', 0,
183 'An offset for the labels in the dataset. This flag is primarily used to '
184 'evaluate the VGG and ResNet architectures which do not use a background '
185 'class for the ImageNet dataset.')
186
187 tf.app.flags.DEFINE_string(
188 'model_name', 'inception_v3', 'The name of the architecture to train.')
189
190 tf.app.flags.DEFINE_string(
191 'preprocessing_name', None, 'The name of the preprocessing to use. If left '
192 'as `None`, then the model_name flag is used.')
193
194 tf.app.flags.DEFINE_integer(
195 'batch_size', 32, 'The number of samples in each batch.')
196
197 tf.app.flags.DEFINE_integer(
198 'train_image_size', None, 'Train image size')
199
200 tf.app.flags.DEFINE_integer('max_number_of_steps', None,
201 'The maximum number of training steps.')
202
203 #####################
204 # Fine-Tuning Flags #
205 #####################
206
207 tf.app.flags.DEFINE_string(
208 'checkpoint_path', None,
209 'The path to a checkpoint from which to fine-tune.')
210
211 tf.app.flags.DEFINE_string(
212 'checkpoint_exclude_scopes', None,
213 'Comma-separated list of scopes of variables to exclude when restoring '
214 'from a checkpoint.')
215
216 tf.app.flags.DEFINE_string(
217 'trainable_scopes', None,
218 'Comma-separated list of scopes to filter the set of variables to train.'
219 'By default, None would train all the variables.')
220
221 tf.app.flags.DEFINE_boolean(
222 'ignore_missing_vars', False,
223 'When restoring a checkpoint would ignore missing variables.')
224
225 FLAGS = tf.app.flags.FLAGS
226
227
228 def _configure_learning_rate(num_samples_per_epoch, global_step):
229 """Configures the learning rate.
230
231 Args:
232 num_samples_per_epoch: The number of samples in each epoch of training.
233 global_step: The global_step tensor.
234
235 Returns:
236 A `Tensor` representing the learning rate.
237
238 Raises:
239 ValueError: if
240 """
241 decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
242 FLAGS.num_epochs_per_decay)
243 if FLAGS.sync_replicas:
244 decay_steps /= FLAGS.replicas_to_aggregate
245 #dacay,衰退
246
247
248 #下面是几种学习速率的变化形式,可以是指数型衰退,可以是固定不变,也可以使多项式型衰退。
249 if FLAGS.learning_rate_decay_type == 'exponential':
250 return tf.train.exponential_decay(FLAGS.learning_rate,
251 global_step,
252 decay_steps,
253 FLAGS.learning_rate_decay_factor,
254 staircase=True,
255 name='exponential_decay_learning_rate')
256 elif FLAGS.learning_rate_decay_type == 'fixed':
257 return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
258 elif FLAGS.learning_rate_decay_type == 'polynomial':
259 return tf.train.polynomial_decay(FLAGS.learning_rate,
260 global_step,
261 decay_steps,
262 FLAGS.end_learning_rate,
263 power=1.0,
264 cycle=False,
265 name='polynomial_decay_learning_rate')
266 else:
267 raise ValueError('learning_rate_decay_type [%s] was not recognized',
268 FLAGS.learning_rate_decay_type)
269
270
271
272 #选择优化方法(优化器,大概是求偏导的具体数值计算方法?),tf内置了许多优化方法,类比梯度下降,细节黑箱即可。
273 def _configure_optimizer(learning_rate):
274 """Configures the optimizer used for training.
275
276 Args:
277 learning_rate: A scalar or `Tensor` learning rate.
278
279 Returns:
280 An instance of an optimizer.
281
282 Raises:
283 ValueError: if FLAGS.optimizer is not recognized.
284 """
285 if FLAGS.optimizer == 'adadelta':
286 optimizer = tf.train.AdadeltaOptimizer(
287 learning_rate,
288 rho=FLAGS.adadelta_rho,
289 epsilon=FLAGS.opt_epsilon)
290 elif FLAGS.optimizer == 'adagrad':
291 optimizer = tf.train.AdagradOptimizer(
292 learning_rate,
293 initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
294 elif FLAGS.optimizer == 'adam':
295 optimizer = tf.train.AdamOptimizer(
296 learning_rate,
297 beta1=FLAGS.adam_beta1,
298 beta2=FLAGS.adam_beta2,
299 epsilon=FLAGS.opt_epsilon)
300 elif FLAGS.optimizer == 'ftrl':
301 optimizer = tf.train.FtrlOptimizer(
302 learning_rate,
303 learning_rate_power=FLAGS.ftrl_learning_rate_power,
304 initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
305 l1_regularization_strength=FLAGS.ftrl_l1,
306 l2_regularization_strength=FLAGS.ftrl_l2)
307 elif FLAGS.optimizer == 'momentum':
308 optimizer = tf.train.MomentumOptimizer(
309 learning_rate,
310 momentum=FLAGS.momentum,
311 name='Momentum')
312 elif FLAGS.optimizer == 'rmsprop':
313 optimizer = tf.train.RMSPropOptimizer(
314 learning_rate,
315 decay=FLAGS.rmsprop_decay,
316 momentum=FLAGS.momentum,
317 epsilon=FLAGS.opt_epsilon)
318 elif FLAGS.optimizer == 'sgd':
319 optimizer = tf.train.GradientDescentOptimizer(learning_rate)
320 else:
321 raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
322 return optimizer
323
324
325 def _get_init_fn():
326 """Returns a function run by the chief worker to warm-start the training.
327
328 Note that the init_fn is only run when initializing the model during the very
329 first global step.
330
331 Returns:
332 An init function run by the supervisor.
333 """
334 if FLAGS.checkpoint_path is None:
335 return None
336
337 # Warn the user if a checkpoint exists in the train_dir. Then we'll be
338 # ignoring the checkpoint anyway.
339 if tf.train.latest_checkpoint(FLAGS.train_dir):
340 tf.logging.info(
341 'Ignoring --checkpoint_path because a checkpoint already exists in %s'
342 % FLAGS.train_dir)
343 return None
344
345 exclusions = [] #python的list数据类型
346 if FLAGS.checkpoint_exclude_scopes:
347 exclusions = [scope.strip()
348 for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
349
350 # TODO(sguada) variables.filter_variables()
351 variables_to_restore = []
352 for var in slim.get_model_variables():
353 excluded = False
354 for exclusion in exclusions:
355 if var.op.name.startswith(exclusion):
356 excluded = True
357 break
358 if not excluded:
359 variables_to_restore.append(var)
360
361 if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
362 checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
363 else:
364 checkpoint_path = FLAGS.checkpoint_path
365
366 tf.logging.info('Fine-tuning from %s' % checkpoint_path)
367
368 return slim.assign_from_checkpoint_fn(
369 checkpoint_path,
370 variables_to_restore,
371 ignore_missing_vars=FLAGS.ignore_missing_vars)
372
373
374 def _get_variables_to_train():
375 """Returns a list of variables to train.
376
377 Returns:
378 A list of variables to train by the optimizer.
379 """
380 if FLAGS.trainable_scopes is None:
381 return tf.trainable_variables()
382 else:
383 scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]
384
385 variables_to_train = []
386 for scope in scopes:
387 variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
388 variables_to_train.extend(variables)
389 return variables_to_train
390
391
392 def main(_):
393 if not FLAGS.dataset_dir:
394 raise ValueError('You must supply the dataset directory with --dataset_dir')
395
396 tf.logging.set_verbosity(tf.logging.INFO)
397 with tf.Graph().as_default():
398 #######################
399 # Config model_deploy #
400 #######################
401 deploy_config = model_deploy.DeploymentConfig(
402 num_clones=FLAGS.num_clones,
403 clone_on_cpu=FLAGS.clone_on_cpu,
404 replica_id=FLAGS.task,
405 num_replicas=FLAGS.worker_replicas,
406 num_ps_tasks=FLAGS.num_ps_tasks)
407
408 # Create global_step
409 with tf.device(deploy_config.variables_device()):
410 global_step = slim.create_global_step()
411
412 ######################
413 # Select the dataset #
414 ######################
415 dataset = dataset_factory.get_dataset(
416 FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
417
418 ######################
419 # Select the network #
420 ######################
421 network_fn = nets_factory.get_network_fn(
422 FLAGS.model_name,
423 num_classes=(dataset.num_classes - FLAGS.labels_offset),
424 weight_decay=FLAGS.weight_decay,
425 is_training=True)
426
427 #####################################
428 # Select the preprocessing function #
429 #####################################
430 preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
431 image_preprocessing_fn = preprocessing_factory.get_preprocessing(
432 preprocessing_name,
433 is_training=True)
434
435 ##############################################################
436 # Create a dataset provider that loads data from the dataset #
437 ##############################################################
438 with tf.device(deploy_config.inputs_device()):
439 provider = slim.dataset_data_provider.DatasetDataProvider( #还是定义一些数据的读取方式
440 dataset,
441 num_readers=FLAGS.num_readers,
442 common_queue_capacity=20 * FLAGS.batch_size,
443 common_queue_min=10 * FLAGS.batch_size)
444 [image, label] = provider.get(['image', 'label'])
445 label -= FLAGS.labels_offset
446
447 train_image_size = FLAGS.train_image_size or network_fn.default_image_size
448
449 image = image_preprocessing_fn(image, train_image_size, train_image_size)
450
451 images, labels = tf.train.batch(
452 [image, label],
453 batch_size=FLAGS.batch_size,
454 num_threads=FLAGS.num_preprocessing_threads,
455 capacity=5 * FLAGS.batch_size)
456 labels = slim.one_hot_encoding( #one-hot是一种向量编码方式,n维向量只有为相应值的位置为1,其余都为0
457 labels, dataset.num_classes - FLAGS.labels_offset)
458 batch_queue = slim.prefetch_queue.prefetch_queue(
459 [images, labels], capacity=2 * deploy_config.num_clones)
460
461 ####################
462 # Define the model #
463 ####################
464 #通过复制多个网络来实现并行
465 def clone_fn(batch_queue):
466 """Allows data parallelism by creating multiple clones of network_fn."""
467 with tf.device(deploy_config.inputs_device()):
468 images, labels = batch_queue.dequeue() #dequeue,双端队列。。
469 logits, end_points = network_fn(images)
470
471 #############################
472 # Specify the loss function #
473 #############################
474 if 'AuxLogits' in end_points:
475 tf.losses.softmax_cross_entropy( #softmax函数对应使用的cost function(loss function)
476 #是corss_entropy,也就是交叉熵
477 logits=end_points['AuxLogits'], onehot_labels=labels,
478 label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
479 tf.losses.softmax_cross_entropy(
480 logits=logits, onehot_labels=labels,
481 label_smoothing=FLAGS.label_smoothing, weights=1.0)
482 #Label Smoothing Regularization,一种防止overfit的优化方法
483 return end_points
484
485 # Gather initial summaries.
486 summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
487
488 clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
489 first_clone_scope = deploy_config.clone_scope(0)
490 # Gather update_ops from the first clone. These contain, for example,
491 # the updates for the batch_norm variables created by network_fn.
492 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
493
494 # Add summaries for end_points.
495 end_points = clones[0].outputs
496 for end_point in end_points:
497 x = end_points[end_point]
498 summaries.add(tf.summary.histogram('activations/' + end_point, x))
499 summaries.add(tf.summary.scalar('sparsity/' + end_point,
500 tf.nn.zero_fraction(x)))
501
502 # Add summaries for losses.
503 for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
504 summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
505
506 # Add summaries for variables.
507 for variable in slim.get_model_variables():
508 summaries.add(tf.summary.histogram(variable.op.name, variable))
509 #突然画图
510
511 #################################
512 # Configure the moving averages # #参考moving averages的wiki
513 #################################
514 if FLAGS.moving_average_decay:
515 moving_average_variables = slim.get_model_variables()
516 variable_averages = tf.train.ExponentialMovingAverage(
517 FLAGS.moving_average_decay, global_step)
518 else:
519 moving_average_variables, variable_averages = None, None
520
521 #########################################
522 # Configure the optimization procedure. #
523 #########################################
524 with tf.device(deploy_config.optimizer_device()):
525 learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
526 optimizer = _configure_optimizer(learning_rate)
527 summaries.add(tf.summary.scalar('learning_rate', learning_rate))
528
529
530 #在分布式系统上训练的同步...
531 if FLAGS.sync_replicas:
532 # If sync_replicas is enabled, the averaging will be done in the chief
533 # queue runner.
534 optimizer = tf.train.SyncReplicasOptimizer(
535 opt=optimizer,
536 replicas_to_aggregate=FLAGS.replicas_to_aggregate,
537 variable_averages=variable_averages,
538 variables_to_average=moving_average_variables,
539 replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
540 total_num_replicas=FLAGS.worker_replicas)
541 elif FLAGS.moving_average_decay:
542 # Update ops executed locally by trainer.
543 update_ops.append(variable_averages.apply(moving_average_variables))
544
545 # Variables to train.
546 variables_to_train = _get_variables_to_train()
547
548 # and returns a train_tensor and summary_op
549 total_loss, clones_gradients = model_deploy.optimize_clones(
550 clones,
551 optimizer,
552 var_list=variables_to_train)
553 # Add total_loss to summary.
554 summaries.add(tf.summary.scalar('total_loss', total_loss))
555
556 # Create gradient updates.
557 grad_updates = optimizer.apply_gradients(clones_gradients,
558 global_step=global_step)
559 update_ops.append(grad_updates)
560
561 update_op = tf.group(*update_ops)
562 with tf.control_dependencies([update_op]):
563 train_tensor = tf.identity(total_loss, name='train_op')
564
565 # Add the summaries from the first clone. These contain the summaries
566 # created by model_fn and either optimize_clones() or _gather_clone_loss().
567 summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
568 first_clone_scope))
569
570 # Merge all summaries together.
571 summary_op = tf.summary.merge(list(summaries), name='summary_op')
572
573
574 ###########################
575 # Kicks off the training. #
576 ###########################
577 slim.learning.train(
578 train_tensor,
579 logdir=FLAGS.train_dir,
580 master=FLAGS.master,
581 is_chief=(FLAGS.task == 0),
582 init_fn=_get_init_fn(),
583 summary_op=summary_op,
584 number_of_steps=FLAGS.max_number_of_steps,
585 log_every_n_steps=FLAGS.log_every_n_steps,
586 save_summaries_secs=FLAGS.save_summaries_secs,
587 save_interval_secs=FLAGS.save_interval_secs,
588 sync_optimizer=optimizer if FLAGS.sync_replicas else None)
589
590
591 if __name__ == '__main__':
592 tf.app.run()
