how to copy & modify nets model on tensorflow slim
想要修改tensorflow-slim 中 nets中的某个model,例如明明为kk_v2.py
观察到train_image_classifier.py中调用模型的部分
1network_fn = nets_factory.get_network_fn(
2 FLAGS.model_name,
3 num_classes=(dataset.num_classes - FLAGS.labels_offset),
4 weight_decay=FLAGS.weight_decay,
5 is_training=True)
调用了nets_factory.get_network_fn,get_network如下:
def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
"""Returns a network_fn such as `logits, end_points = network_fn(images)`.
Args:
name: The name of the network.
num_classes: The number of classes to use for classification.
weight_decay: The l2 coefficient for the model weights.
is_training: `True` if the model is being used for training and `False`
otherwise.
1 Returns:
2 network_fn: A function that applies the model to a batch of images. It has
3 the following signature:
4 logits, end_points = network_fn(images)
5 Raises:
6 ValueError: If network `name` is not recognized.
7 """
8 if name not in networks_map:
9 raise ValueError('Name of network unknown %s' % name)
10 func = networks_map[name]
11 @functools.wraps(func)
12 def network_fn(images):
13 arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
14 with slim.arg_scope(arg_scope):
15 return func(images, num_classes, is_training=is_training)
16 if hasattr(func, 'default_image_size'):
17 network_fn.default_image_size = func.default_image_size
return network_fn
我们看到model name 是通过 networks_map映射到func的
因此需要添加对于我们新的model,kk_v2的映射
1networks_map = {'alexnet_v2': alexnet.alexnet_v2,
2 'cifarnet': cifarnet.cifarnet,
3 'overfeat': overfeat.overfeat,
4 'vgg_a': vgg.vgg_a,
5 'vgg_16': vgg.vgg_16,
6 'vgg_19': vgg.vgg_19,
7 'inception_v1': inception.inception_v1,
8 'inception_v2': inception.inception_v2,
9 'inception_v3': inception.inception_v3,
10 'inception_v4': inception.inception_v4,
11 'inception_resnet_v2': inception.inception_resnet_v2,
12 'kk_v2':inception.inception_resnet_v2,
13 'lenet': lenet.lenet,
14 'resnet_v1_50': resnet_v1.resnet_v1_50,
15 'resnet_v1_101': resnet_v1.resnet_v1_101,
16 'resnet_v1_152': resnet_v1.resnet_v1_152,
17 'resnet_v1_200': resnet_v1.resnet_v1_200,
18 'resnet_v2_50': resnet_v2.resnet_v2_50,
19 'resnet_v2_101': resnet_v2.resnet_v2_101,
20 'resnet_v2_152': resnet_v2.resnet_v2_152,
21 'resnet_v2_200': resnet_v2.resnet_v2_200,
22 'mobilenet_v1': mobilenet_v1.mobilenet_v1,
23 }
由于train_image_classifier.py中有如下参数,
tf.app.flags.DEFINE_string(
'preprocessing_name', None, 'The name of the preprocessing to use. If left '
'as `None`, then the model_name flag is used.'
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
因此还需要修改预处理的映射表
1 preprocessing_fn_map = {
2 'cifarnet': cifarnet_preprocessing,
3 'inception': inception_preprocessing,
4 'inception_v1': inception_preprocessing,
5 'inception_v2': inception_preprocessing,
6 'inception_v3': inception_preprocessing,
7 'inception_v4': inception_preprocessing,
8 'inception_resnet_v2': inception_preprocessing,
9 'kk_v2': inception_preprocessing,
10 'lenet': lenet_preprocessing,
11 'mobilenet_v1': inception_preprocessing,
12 'resnet_v1_50': vgg_preprocessing,
13 'resnet_v1_101': vgg_preprocessing,
14 'resnet_v1_152': vgg_preprocessing,
15 'resnet_v1_200': vgg_preprocessing,
16 'resnet_v2_50': vgg_preprocessing,
17 'resnet_v2_101': vgg_preprocessing,
18 'resnet_v2_152': vgg_preprocessing,
19 'resnet_v2_200': vgg_preprocessing,
20 'vgg': vgg_preprocessing,
21 'vgg_a': vgg_preprocessing,
22 'vgg_16': vgg_preprocessing,
23 'vgg_19': vgg_preprocessing,
24 }
还要修改arg_scopes_map,添加kk_v2的key
1arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
2 'cifarnet': cifarnet.cifarnet_arg_scope,
3 'overfeat': overfeat.overfeat_arg_scope,
4 'vgg_a': vgg.vgg_arg_scope,
5 'vgg_16': vgg.vgg_arg_scope,
6 'vgg_19': vgg.vgg_arg_scope,
7 'inception_v1': inception.inception_v3_arg_scope,
8 'inception_v2': inception.inception_v3_arg_scope,
9 'inception_v3': inception.inception_v3_arg_scope,
10 'inception_v4': inception.inception_v4_arg_scope,
11 'inception_resnet_v2':
12 inception.inception_resnet_v2_arg_scope,
13 'kk_v2':
14 inception.inception_resnet_v2_arg_scope,
15 'lenet': lenet.lenet_arg_scope,
16 'resnet_v1_50': resnet_v1.resnet_arg_scope,
17 'resnet_v1_101': resnet_v1.resnet_arg_scope,
18 'resnet_v1_152': resnet_v1.resnet_arg_scope,
19 'resnet_v1_200': resnet_v1.resnet_arg_scope,
20 'resnet_v2_50': resnet_v2.resnet_arg_scope,
21 'resnet_v2_101': resnet_v2.resnet_arg_scope,
22 'resnet_v2_152': resnet_v2.resnet_arg_scope,
23 'resnet_v2_200': resnet_v2.resnet_arg_scope,
24 'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope,
25 }