tensorflow 合并模型

在这里存个备份,还有些问题没有解决。

raise ValueError("GraphDef cannot be larger than 2GB.")

记录一些思路好了。现在是没有生成.meta文件,爆掉应该是因为所有的变量都加载到了默认图里。

也就是说我处理完checkpoint 0 之后开始处理checkpoint1,但是checkpoint0的那些变量还是存在的...所以越来越多?

目前有两个想法,第一个想法是是受TensorFlow极简教程:创建、保存和恢复机器学习模型  中启发,用多个saver,每个saver指定要搞的图(但是这样好像要每个checkpoint都是不同的saver才有意义?)

第二个想法是,每次save完变量之后,将图恢复成默认状态(可以把图中所有变量清空。。

想法二大失败:

会遇到if self.stack[-1] is not default: │ IndexError: list index out of range   的问题。。

根据 reset_default_graph awkwardly breaks graph nesting        

中提到了。。。reset_default_graph本身就不舍被设计成放在graph中清空变量用的。。。然后tf的代码也写得很不友好。。。没有 指明这个错误的原因。。。

For historical context, `tf.reset_default_graph()` was never designed to be used with `with g.as_default():` context managers. I think the proper fix here is to make `tf.reset_default_graph()` fail with an informative error message when used inside a `with g.as_default():` context. I think this could be done by checking that `ops._default_graph_stack` is empty before resetting.
1import sys, getopt
2import argparse
3import tensorflow as tf
4import os
5import shutil
6import numpy as np
 1# fix out ,not log_
 2def fix_var_name(id,var_name): 
 3    prefix = var_name[0:3]
 4    if id<10:
 5      suffix = var_name[4:]
 6    if id>=10 and id<100:
 7      suffix = var_name[5:]
 8    if id>=100 and id<1000:
 9      suffix = var_name[6:]
10    if id>=1000 and id<10000:
11      suffix = var_name[7:]
12    ret = prefix + str(id+1) + suffix
13    print('id=%d var_name=%s prefix=%s suffix=%s ret=%s' %(id,var_name,prefix,suffix,ret))
14    return ret
15# only concat full_link_layer 
16def merge_full_link_layer(checkpoint_list,dry_run=False):
17    with tf.Session() as sess:
18      log_num = len(checkpoint_list) # a int range [0,1000)
19      print("log_num:%d"%log_num)
20      for var_name,_ in tf.contrib.framework.list_variables('log_0'):
21        if not var_name.startswith('out'):
22          var_tmp = tf.contrib.framework.load_variable('log_0',var_name)
23          var = tf.Variable(var_tmp,name=var_name)
24          continue
25        print("var_name:%s"%var_name)
26        for id in range(0,log_num): 
27          # need to change the string  out0->out1,out2,out3 ... out15
28          if id!=0:
29            var_name = fix_var_name(id-1,var_name)
30          checkpoint_dir = 'log_'+str(id)
31          print('checkpoint_dir:%s'%checkpoint_dir)
       # for id,checkpoint_dir in enumerate(checkpoint_list):
       #  var_name = fix_var_name(id+1,var_name)
          var_tmp = tf.contrib.framework.load_variable(checkpoint_dir,var_name)
          #print("type(var_tmp):%s"%type(var_tmp))
       #   print(var_tmp)
          if 'weights' in var_name:
            if 'Momentum' in var_name:
              if id == 0:
                mom_weights = var_tmp
                #print("mom_weights:%s"%type(mom_weights))
              else:
                mom_weights = np.concatenate((mom_weights,var_tmp),axis=1)
            else:
              if id == 0:
                weights = var_tmp
              else:
                weights = np.concatenate((weights,var_tmp),axis=1)
          else:
            if 'Momentum' in var_name:
              if id == 0:
                mom_biases = var_tmp
              else:
                mom_biases = np.concatenate((mom_biases,var_tmp),axis=0)
 1            else:
 2              if id == 0:
 3                biases = var_tmp
 4              else: 
 5                biases = np.concatenate((biases,var_tmp),axis=0)
 6        if not dry_run:
 7            flag1 = 'weights' in var_name
 8            flag2 = 'Momentum' in var_name
 9            if flag1 and flag2:
10              mom_weights = tf.Variable(mom_weights, name='out/weights/Momentum' )
11            if flag1 and not flag2:
12              weights = tf.Variable(weights,name='out/weights')
13            if not flag1 and flag2:
14              mom_biases = tf.Variable(mom_biases,name='out/biases/Momentum')
15            if not flag1 and not flag2:
16              biases = tf.Variable(biases,name='out/biases')
17      if not dry_run:
18        print("writer running")
19        #writer = tf.summary.FileWriter('./graphs', sess.graph)
20        saver = tf.train.Saver()
21        #sess.run(tf.global_variables_initializer())
22        saver.save(sess,'./final_16_out',write_meta_graph=False)
23    #writer.close()
24def merge_ckpt(checkpoint_dir,  dry_run=False):
25  merge_full_link_layer(checkpoint_dir,False)
26def get_dir():
27  checkpoint_list=[]
28  dir_list = os.listdir('./')
29  for line in dir_list:
30    if line.startswith('log') and os.path.isdir(line):
31      checkpoint_list.append(line)
32  return checkpoint_list 
33def main():
34  os.environ['CUDA_VISIBLE_DEVICES']="" 
35  checkpoint_dir = get_dir()
36  #checkpoint_dir = ['log_0','log_1','log_2']
37  print (checkpoint_dir)
38  merge_ckpt(checkpoint_dir, dry_run=False)
39if __name__ == '__main__':
40  main()

嘛。。先不管了。。。据数据那边说已经够用了。下面是最终版本,没有合并动量,因为对验证没有作用。

1import sys, getopt
2import argparse
3import tensorflow as tf
4import os
5import shutil
6import numpy as np
 1# fix out ,not log_
 2def fix_var_name(id,var_name): 
 3    prefix = var_name[0:3]
 4    if id<10:
 5      suffix = var_name[4:]
 6    if id>=10 and id<100:
 7      suffix = var_name[5:]
 8    if id>=100 and id<1000:
 9      suffix = var_name[6:]
10    if id>=1000 and id<10000:
11      suffix = var_name[7:]
12    ret = prefix + str(id+1) + suffix
13    print('id=%d var_name=%s prefix=%s suffix=%s ret=%s' %(id,var_name,prefix,suffix,ret))
14    return ret
15# only concat full_link_layer
16def merge_full_link_layer(checkpoint_list):
17    with tf.Session() as sess:
18      log_num = len(checkpoint_list) # a int range [0,1000)
19      print("log_num:%d"%log_num)
20      for var_name,_ in tf.contrib.framework.list_variables('log_0'):
21        if not var_name.startswith('out'):
22          var_tmp = tf.contrib.framework.load_variable('log_0',var_name)
23          var = tf.Variable(var_tmp,name=var_name)
24          continue
25        if 'Momentum' in var_name:
26          continue
27        print("var_name:%s"%var_name)
28        for id in range(0,log_num): 
29          # need to change the string  out0->out1,out2,out3 ... out15
30          if id!=0:
31            var_name = fix_var_name(id-1,var_name)
32          checkpoint_dir = 'log_'+str(id)
33          print('checkpoint_dir:%s'%checkpoint_dir)
34          var_tmp = tf.contrib.framework.load_variable(checkpoint_dir,var_name)
35          if 'weights' in var_name:
36            if 'Momentum' in var_name:
37              if id == 0:
38                mom_weights = var_tmp
39              else:
40                mom_weights = np.concatenate((mom_weights,var_tmp),axis=1)
41            else:
42              if id == 0:
43                weights = var_tmp
44              else:
45                weights = np.concatenate((weights,var_tmp),axis=1)
46          else:
47            if 'Momentum' in var_name:
48              if id == 0:
49                mom_biases = var_tmp
50              else:
51                mom_biases = np.concatenate((mom_biases,var_tmp),axis=0)
52            else:
53              if id == 0:
54                biases = var_tmp
55              else: 
56                biases = np.concatenate((biases,var_tmp),axis=0)
57        flag1 = 'weights' in var_name
58        flag2 = 'Momentum' in var_name
59        if flag1 and not flag2:
60          weights = tf.Variable(weights,name='out/weights')
61        if not flag1 and not flag2:
62          biases = tf.Variable(biases,name='out/biases')
63      print("writer running")
64        #writer = tf.summary.FileWriter('./graphs', sess.graph)
65      saver = tf.train.Saver()
66      sess.run(tf.global_variables_initializer())
67      saver.save(sess,'./final_result',write_meta_graph=False)
68    #writer.close()
69def get_dir():
70  checkpoint_list=[]
71  dir_list = os.listdir('./')
72  for line in dir_list:
73    if line.startswith('log') and os.path.isdir(line):
74      checkpoint_list.append(line)
75  return checkpoint_list 
76def main():
77  os.environ['CUDA_VISIBLE_DEVICES']="" 
78  checkpoint_dir = get_dir()
79  # get_dir return the all the log_dir in './'  the log_dir format is 'log_%d',such as log_0,log_1
  #checkpoint_dir=['log_0','log_1','log_2']
  #checkpoint_dir = ['log_0','log_1','log_2','log_3','log_4','log_5','log_6','log_7','log_8','log_9','log_10','log_11']
  print (checkpoint_dir)
  merge_full_link_layer(checkpoint_dir)
if __name__ == '__main__':
  main()

20170822update:去掉了卷基层的动量,添加了一些超参

1import sys, getopt
2import argparse
3import tensorflow as tf
4import os
5import shutil
6import numpy as np
tf.flags.DEFINE_string('fc_prefix', 'out',
                               """the prefix of full_link_layer output name """)
 1FLAGS = tf.flags.FLAGS
 2# fix out ,not log_
 3def fix_var_name(id,var_name):
 4    len = len(FLAGS.fc_prefix)
 5    prefix = var_name[0:len]
 6    if id<10:
 7      suffix = var_name[len+1:]
 8    if id>=10 and id<100:
 9      suffix = var_name[len+2:]
10    if id>=100 and id<1000:
11      suffix = var_name[len+3:]
12    if id>=1000 and id<10000:
13      suffix = var_name[len+4:]
14    ret = prefix + str(id+1) + suffix
15    print('id=%d var_name=%s prefix=%s suffix=%s ret=%s' %(id,var_name,prefix,suffix,ret))
16    return ret
17# only concat full_link_layer
18def merge_full_link_layer(checkpoint_list):
19    with tf.Session() as sess:
20      log_num = len(checkpoint_list) # a int range [0,1000)
21      print("log_num:%d"%log_num)
22      for var_name,_ in tf.contrib.framework.list_variables('log_0'):
23        if 'Momentum' in var_name:
24          continue
25        if not var_name.startswith(FLAGS.fc_prefix):
26          var_tmp = tf.contrib.framework.load_variable('log_0',var_name)
27          var = tf.Variable(var_tmp,name=var_name)
28          continue
29        print("var_name:%s"%var_name)
30        for id in range(0,log_num): 
31          # need to change the string  out0->out1,out2,out3 ... out15
32          if id!=0:
33            var_name = fix_var_name(id-1,var_name)
34          checkpoint_dir = 'log_'+str(id)
35          print('checkpoint_dir:%s'%checkpoint_dir)
36          var_tmp = tf.contrib.framework.load_variable(checkpoint_dir,var_name)
37          if 'weights' in var_name:
38            if 'Momentum' in var_name:
39              if id == 0:
40                mom_weights = var_tmp
41              else:
42                mom_weights = np.concatenate((mom_weights,var_tmp),axis=1)
43            else:
44              if id == 0:
45                weights = var_tmp
46              else:
47                weights = np.concatenate((weights,var_tmp),axis=1)
48          else:
49            if 'Momentum' in var_name:
50              if id == 0:
51                mom_biases = var_tmp
52              else:
53                mom_biases = np.concatenate((mom_biases,var_tmp),axis=0)
54            else:
55              if id == 0:
56                biases = var_tmp
57              else: 
58                biases = np.concatenate((biases,var_tmp),axis=0)
59        flag1 = 'weights' in var_name
60        flag2 = 'Momentum' in var_name
61        if flag1 and not flag2:
62          weights = tf.Variable(weights,name='%s/weights'%FLAGS.fc_prefix)
63        if not flag1 and not flag2:
64          biases = tf.Variable(biases,name='%s/biases'%FLAGS.fc_prefix)
65      print("writer running")
66        #writer = tf.summary.FileWriter('./graphs', sess.graph)
67      saver = tf.train.Saver()
68      sess.run(tf.global_variables_initializer())
69      saver.save(sess,'./final_result',write_meta_graph=False)
70    #writer.close()
71def get_dir():
72  checkpoint_list=[]
73  dir_list = os.listdir('./')
74  for line in dir_list:
75    if line.startswith('log') and os.path.isdir(line):
76      checkpoint_list.append(line)
77  return checkpoint_list 
78def main():
79  os.environ['CUDA_VISIBLE_DEVICES']="" 
80  checkpoint_dir = get_dir()
81  # get_dir return the all the log_dir in './'  the log_dir format is 'log_%d',such as log_0,log_1
  #checkpoint_dir=['log_0','log_1','log_2']
  #checkpoint_dir = ['log_0','log_1','log_2','log_3','log_4','log_5','log_6','log_7','log_8','log_9','log_10','log_11']
  print (checkpoint_dir)
  merge_full_link_layer(checkpoint_dir)
if __name__ == '__main__':
  main()