how to copy & modify nets model on tensorflow slim

想要修改tensorflow-slim 中 nets中的某个model,例如明明为kk_v2.py

观察到train_image_classifier.py中调用模型的部分

1network_fn = nets_factory.get_network_fn(
2        FLAGS.model_name,
3        num_classes=(dataset.num_classes - FLAGS.labels_offset),
4        weight_decay=FLAGS.weight_decay,
5        is_training=True)

调用了nets_factory.get_network_fn,get_network如下:

def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
  """Returns a network_fn such as `logits, end_points = network_fn(images)`.

  Args:
    name: The name of the network.
    num_classes: The number of classes to use for classification.
    weight_decay: The l2 coefficient for the model weights.
    is_training: `True` if the model is being used for training and `False`
      otherwise.
 1  Returns:
 2    network_fn: A function that applies the model to a batch of images. It has
 3      the following signature:
 4        logits, end_points = network_fn(images)
 5  Raises:
 6    ValueError: If network `name` is not recognized.
 7  """
 8  if name not in networks_map:
 9    raise ValueError('Name of network unknown %s' % name)
10  func = networks_map[name]
11  @functools.wraps(func)
12  def network_fn(images):
13    arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
14    with slim.arg_scope(arg_scope):
15      return func(images, num_classes, is_training=is_training)
16  if hasattr(func, 'default_image_size'):
17    network_fn.default_image_size = func.default_image_size
  return network_fn

我们看到model name 是通过 networks_map映射到func的

因此需要添加对于我们新的model,kk_v2的映射

 1networks_map = {'alexnet_v2': alexnet.alexnet_v2,
 2                'cifarnet': cifarnet.cifarnet,
 3                'overfeat': overfeat.overfeat,
 4                'vgg_a': vgg.vgg_a,
 5                'vgg_16': vgg.vgg_16,
 6                'vgg_19': vgg.vgg_19,
 7                'inception_v1': inception.inception_v1,
 8                'inception_v2': inception.inception_v2,
 9                'inception_v3': inception.inception_v3,
10                'inception_v4': inception.inception_v4,
11                'inception_resnet_v2': inception.inception_resnet_v2,
12                'kk_v2':inception.inception_resnet_v2,
13                'lenet': lenet.lenet,
14                'resnet_v1_50': resnet_v1.resnet_v1_50,
15                'resnet_v1_101': resnet_v1.resnet_v1_101,
16                'resnet_v1_152': resnet_v1.resnet_v1_152,
17                'resnet_v1_200': resnet_v1.resnet_v1_200,
18                'resnet_v2_50': resnet_v2.resnet_v2_50,
19                'resnet_v2_101': resnet_v2.resnet_v2_101,
20                'resnet_v2_152': resnet_v2.resnet_v2_152,
21                'resnet_v2_200': resnet_v2.resnet_v2_200,
22                'mobilenet_v1': mobilenet_v1.mobilenet_v1,
23               }

由于train_image_classifier.py中有如下参数,

tf.app.flags.DEFINE_string(
    'preprocessing_name', None, 'The name of the preprocessing to use. If left '
    'as `None`, then the model_name flag is used.'




preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name

因此还需要修改预处理的映射表

 1 preprocessing_fn_map = {
 2      'cifarnet': cifarnet_preprocessing,
 3      'inception': inception_preprocessing,
 4      'inception_v1': inception_preprocessing,
 5      'inception_v2': inception_preprocessing,
 6      'inception_v3': inception_preprocessing,
 7      'inception_v4': inception_preprocessing,
 8      'inception_resnet_v2': inception_preprocessing,
 9      'kk_v2': inception_preprocessing,
10      'lenet': lenet_preprocessing,
11      'mobilenet_v1': inception_preprocessing,
12      'resnet_v1_50': vgg_preprocessing,
13      'resnet_v1_101': vgg_preprocessing,
14      'resnet_v1_152': vgg_preprocessing,
15      'resnet_v1_200': vgg_preprocessing,
16      'resnet_v2_50': vgg_preprocessing,
17      'resnet_v2_101': vgg_preprocessing,
18      'resnet_v2_152': vgg_preprocessing,
19      'resnet_v2_200': vgg_preprocessing,
20      'vgg': vgg_preprocessing,
21      'vgg_a': vgg_preprocessing,
22      'vgg_16': vgg_preprocessing,
23      'vgg_19': vgg_preprocessing,
24  }

还要修改arg_scopes_map,添加kk_v2的key

 1arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
 2                  'cifarnet': cifarnet.cifarnet_arg_scope,
 3                  'overfeat': overfeat.overfeat_arg_scope,
 4                  'vgg_a': vgg.vgg_arg_scope,
 5                  'vgg_16': vgg.vgg_arg_scope,
 6                  'vgg_19': vgg.vgg_arg_scope,
 7                  'inception_v1': inception.inception_v3_arg_scope,
 8                  'inception_v2': inception.inception_v3_arg_scope,
 9                  'inception_v3': inception.inception_v3_arg_scope,
10                  'inception_v4': inception.inception_v4_arg_scope,
11                  'inception_resnet_v2':
12                  inception.inception_resnet_v2_arg_scope,
13                  'kk_v2':
14                  inception.inception_resnet_v2_arg_scope,
15                  'lenet': lenet.lenet_arg_scope,
16                  'resnet_v1_50': resnet_v1.resnet_arg_scope,
17                  'resnet_v1_101': resnet_v1.resnet_arg_scope,
18                  'resnet_v1_152': resnet_v1.resnet_arg_scope,
19                  'resnet_v1_200': resnet_v1.resnet_arg_scope,
20                  'resnet_v2_50': resnet_v2.resnet_arg_scope,
21                  'resnet_v2_101': resnet_v2.resnet_arg_scope,
22                  'resnet_v2_152': resnet_v2.resnet_arg_scope,
23                  'resnet_v2_200': resnet_v2.resnet_arg_scope,
24                  'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope,
25                 }