tensorflow 合并模型

在这里存个备份,还有些问题没有解决。

raise ValueError(“GraphDef cannot be larger than 2GB.”)

记录一些思路好了。现在是没有生成.meta文件,爆掉应该是因为所有的变量都加载到了默认图里。

也就是说我处理完checkpoint 0 之后开始处理checkpoint1,但是checkpoint0的那些变量还是存在的…所以越来越多?

目前有两个想法,第一个想法是是受TensorFlow极简教程:创建、保存和恢复机器学习模型  中启发,用多个saver,每个saver指定要搞的图(但是这样好像要每个checkpoint都是不同的saver才有意义?)

第二个想法是,每次save完变量之后,将图恢复成默认状态(可以把图中所有变量清空。。

想法二大失败:

会遇到if self.stack[-1] is not default: │ IndexError: list index out of range   的问题。。

根据 reset_default_graph awkwardly breaks graph nesting        

中提到了。。。reset_default_graph本身就不舍被设计成放在graph中清空变量用的。。。然后tf的代码也写得很不友好。。。没有 指明这个错误的原因。。。

For historical context, `tf.reset_default_graph()` was never designed to be used with `with g.as_default():` context managers. I think the proper fix here is to make `tf.reset_default_graph()` fail with an informative error message when used inside a `with g.as_default():` context. I think this could be done by checking that `ops._default_graph_stack` is empty before resetting.
  1import sys, getopt
  2import argparse
  3import tensorflow as tf
  4import os
  5import shutil
  6import numpy as np
  7
  8# fix out ,not log_
  9def fix_var_name(id,var_name): 
 10    prefix = var_name[0:3]
 11    if id<10:
 12      suffix = var_name[4:]
 13    if id>=10 and id<100:
 14      suffix = var_name[5:]
 15    if id>=100 and id<1000:
 16      suffix = var_name[6:]
 17    if id>=1000 and id<10000:
 18      suffix = var_name[7:]
 19    ret = prefix + str(id+1) + suffix
 20    print('id=%d var_name=%s prefix=%s suffix=%s ret=%s' %(id,var_name,prefix,suffix,ret))
 21    return ret
 22# only concat full_link_layer 
 23def merge_full_link_layer(checkpoint_list,dry_run=False):
 24    with tf.Session() as sess:
 25      log_num = len(checkpoint_list) # a int range [0,1000)
 26      print("log_num:%d"%log_num)
 27      for var_name,_ in tf.contrib.framework.list_variables('log_0'):
 28        if not var_name.startswith('out'):
 29          var_tmp = tf.contrib.framework.load_variable('log_0',var_name)
 30          var = tf.Variable(var_tmp,name=var_name)
 31          continue
 32        print("var_name:%s"%var_name)
 33        for id in range(0,log_num): 
 34          # need to change the string  out0->out1,out2,out3 ... out15
 35          if id!=0:
 36            var_name = fix_var_name(id-1,var_name)
 37          checkpoint_dir = 'log_'+str(id)
 38          print('checkpoint_dir:%s'%checkpoint_dir)
 39
 40       # for id,checkpoint_dir in enumerate(checkpoint_list):
 41       #  var_name = fix_var_name(id+1,var_name)
 42          var_tmp = tf.contrib.framework.load_variable(checkpoint_dir,var_name)
 43          #print("type(var_tmp):%s"%type(var_tmp))
 44       #   print(var_tmp)
 45          if 'weights' in var_name:
 46            if 'Momentum' in var_name:
 47              if id == 0:
 48                mom_weights = var_tmp
 49                #print("mom_weights:%s"%type(mom_weights))
 50              else:
 51                mom_weights = np.concatenate((mom_weights,var_tmp),axis=1)
 52            else:
 53              if id == 0:
 54                weights = var_tmp
 55              else:
 56                weights = np.concatenate((weights,var_tmp),axis=1)
 57          else:
 58            if 'Momentum' in var_name:
 59              if id == 0:
 60                mom_biases = var_tmp
 61              else:
 62                mom_biases = np.concatenate((mom_biases,var_tmp),axis=0)
 63
 64            else:
 65              if id == 0:
 66                biases = var_tmp
 67              else: 
 68                biases = np.concatenate((biases,var_tmp),axis=0)
 69        if not dry_run:
 70            flag1 = 'weights' in var_name
 71            flag2 = 'Momentum' in var_name
 72            if flag1 and flag2:
 73              mom_weights = tf.Variable(mom_weights, name='out/weights/Momentum' )
 74            if flag1 and not flag2:
 75              weights = tf.Variable(weights,name='out/weights')
 76            if not flag1 and flag2:
 77              mom_biases = tf.Variable(mom_biases,name='out/biases/Momentum')
 78            if not flag1 and not flag2:
 79              biases = tf.Variable(biases,name='out/biases')
 80      if not dry_run:
 81        print("writer running")
 82        #writer = tf.summary.FileWriter('./graphs', sess.graph)
 83        saver = tf.train.Saver()
 84        #sess.run(tf.global_variables_initializer())
 85        saver.save(sess,'./final_16_out',write_meta_graph=False)
 86    #writer.close()
 87def merge_ckpt(checkpoint_dir,  dry_run=False):
 88  merge_full_link_layer(checkpoint_dir,False)
 89def get_dir():
 90  checkpoint_list=[]
 91  dir_list = os.listdir('./')
 92  for line in dir_list:
 93    if line.startswith('log') and os.path.isdir(line):
 94      checkpoint_list.append(line)
 95  return checkpoint_list 
 96def main():
 97  os.environ['CUDA_VISIBLE_DEVICES']="" 
 98  checkpoint_dir = get_dir()
 99  #checkpoint_dir = ['log_0','log_1','log_2']
100  print (checkpoint_dir)
101  merge_ckpt(checkpoint_dir, dry_run=False)
102if __name__ == '__main__':
103  main()

嘛。。先不管了。。。据数据那边说已经够用了。下面是最终版本,没有合并动量,因为对验证没有作用。

 1import sys, getopt
 2import argparse
 3import tensorflow as tf
 4import os
 5import shutil
 6import numpy as np
 7
 8# fix out ,not log_
 9def fix_var_name(id,var_name): 
10    prefix = var_name[0:3]
11    if id<10:
12      suffix = var_name[4:]
13    if id>=10 and id<100:
14      suffix = var_name[5:]
15    if id>=100 and id<1000:
16      suffix = var_name[6:]
17    if id>=1000 and id<10000:
18      suffix = var_name[7:]
19    ret = prefix + str(id+1) + suffix
20    print('id=%d var_name=%s prefix=%s suffix=%s ret=%s' %(id,var_name,prefix,suffix,ret))
21    return ret
22# only concat full_link_layer
23def merge_full_link_layer(checkpoint_list):
24    with tf.Session() as sess:
25      log_num = len(checkpoint_list) # a int range [0,1000)
26      print("log_num:%d"%log_num)
27      for var_name,_ in tf.contrib.framework.list_variables('log_0'):
28        if not var_name.startswith('out'):
29          var_tmp = tf.contrib.framework.load_variable('log_0',var_name)
30          var = tf.Variable(var_tmp,name=var_name)
31          continue
32        if 'Momentum' in var_name:
33          continue
34        print("var_name:%s"%var_name)
35        for id in range(0,log_num): 
36          # need to change the string  out0->out1,out2,out3 ... out15
37          if id!=0:
38            var_name = fix_var_name(id-1,var_name)
39          checkpoint_dir = 'log_'+str(id)
40          print('checkpoint_dir:%s'%checkpoint_dir)
41          var_tmp = tf.contrib.framework.load_variable(checkpoint_dir,var_name)
42          if 'weights' in var_name:
43            if 'Momentum' in var_name:
44              if id == 0:
45                mom_weights = var_tmp
46              else:
47                mom_weights = np.concatenate((mom_weights,var_tmp),axis=1)
48            else:
49              if id == 0:
50                weights = var_tmp
51              else:
52                weights = np.concatenate((weights,var_tmp),axis=1)
53          else:
54            if 'Momentum' in var_name:
55              if id == 0:
56                mom_biases = var_tmp
57              else:
58                mom_biases = np.concatenate((mom_biases,var_tmp),axis=0)
59            else:
60              if id == 0:
61                biases = var_tmp
62              else: 
63                biases = np.concatenate((biases,var_tmp),axis=0)
64        flag1 = 'weights' in var_name
65        flag2 = 'Momentum' in var_name
66        if flag1 and not flag2:
67          weights = tf.Variable(weights,name='out/weights')
68        if not flag1 and not flag2:
69          biases = tf.Variable(biases,name='out/biases')
70      print("writer running")
71        #writer = tf.summary.FileWriter('./graphs', sess.graph)
72      saver = tf.train.Saver()
73      sess.run(tf.global_variables_initializer())
74      saver.save(sess,'./final_result',write_meta_graph=False)
75    #writer.close()
76def get_dir():
77  checkpoint_list=[]
78  dir_list = os.listdir('./')
79  for line in dir_list:
80    if line.startswith('log') and os.path.isdir(line):
81      checkpoint_list.append(line)
82  return checkpoint_list 
83def main():
84  os.environ['CUDA_VISIBLE_DEVICES']="" 
85  checkpoint_dir = get_dir()
86  # get_dir return the all the log_dir in './'  the log_dir format is 'log_%d',such as log_0,log_1
87
88
89  #checkpoint_dir=['log_0','log_1','log_2']
90  #checkpoint_dir = ['log_0','log_1','log_2','log_3','log_4','log_5','log_6','log_7','log_8','log_9','log_10','log_11']
91  print (checkpoint_dir)
92  merge_full_link_layer(checkpoint_dir)
93if __name__ == '__main__':
94  main()

20170822update:去掉了卷基层的动量,添加了一些超参

  1import sys, getopt
  2import argparse
  3import tensorflow as tf
  4import os
  5import shutil
  6import numpy as np
  7
  8
  9
 10
 11tf.flags.DEFINE_string('fc_prefix', 'out',
 12                               """the prefix of full_link_layer output name """)
 13
 14
 15
 16
 17FLAGS = tf.flags.FLAGS
 18# fix out ,not log_
 19def fix_var_name(id,var_name):
 20    len = len(FLAGS.fc_prefix)
 21    prefix = var_name[0:len]
 22    if id<10:
 23      suffix = var_name[len+1:]
 24    if id>=10 and id<100:
 25      suffix = var_name[len+2:]
 26    if id>=100 and id<1000:
 27      suffix = var_name[len+3:]
 28    if id>=1000 and id<10000:
 29      suffix = var_name[len+4:]
 30    ret = prefix + str(id+1) + suffix
 31    print('id=%d var_name=%s prefix=%s suffix=%s ret=%s' %(id,var_name,prefix,suffix,ret))
 32    return ret
 33# only concat full_link_layer
 34def merge_full_link_layer(checkpoint_list):
 35    with tf.Session() as sess:
 36      log_num = len(checkpoint_list) # a int range [0,1000)
 37      print("log_num:%d"%log_num)
 38      for var_name,_ in tf.contrib.framework.list_variables('log_0'):
 39        if 'Momentum' in var_name:
 40          continue
 41        if not var_name.startswith(FLAGS.fc_prefix):
 42          var_tmp = tf.contrib.framework.load_variable('log_0',var_name)
 43          var = tf.Variable(var_tmp,name=var_name)
 44          continue
 45        print("var_name:%s"%var_name)
 46        for id in range(0,log_num): 
 47          # need to change the string  out0->out1,out2,out3 ... out15
 48          if id!=0:
 49            var_name = fix_var_name(id-1,var_name)
 50          checkpoint_dir = 'log_'+str(id)
 51          print('checkpoint_dir:%s'%checkpoint_dir)
 52          var_tmp = tf.contrib.framework.load_variable(checkpoint_dir,var_name)
 53          if 'weights' in var_name:
 54            if 'Momentum' in var_name:
 55              if id == 0:
 56                mom_weights = var_tmp
 57              else:
 58                mom_weights = np.concatenate((mom_weights,var_tmp),axis=1)
 59            else:
 60              if id == 0:
 61                weights = var_tmp
 62              else:
 63                weights = np.concatenate((weights,var_tmp),axis=1)
 64          else:
 65            if 'Momentum' in var_name:
 66              if id == 0:
 67                mom_biases = var_tmp
 68              else:
 69                mom_biases = np.concatenate((mom_biases,var_tmp),axis=0)
 70            else:
 71              if id == 0:
 72                biases = var_tmp
 73              else: 
 74                biases = np.concatenate((biases,var_tmp),axis=0)
 75        flag1 = 'weights' in var_name
 76        flag2 = 'Momentum' in var_name
 77        if flag1 and not flag2:
 78          weights = tf.Variable(weights,name='%s/weights'%FLAGS.fc_prefix)
 79        if not flag1 and not flag2:
 80          biases = tf.Variable(biases,name='%s/biases'%FLAGS.fc_prefix)
 81      print("writer running")
 82        #writer = tf.summary.FileWriter('./graphs', sess.graph)
 83      saver = tf.train.Saver()
 84      sess.run(tf.global_variables_initializer())
 85      saver.save(sess,'./final_result',write_meta_graph=False)
 86    #writer.close()
 87def get_dir():
 88  checkpoint_list=[]
 89  dir_list = os.listdir('./')
 90  for line in dir_list:
 91    if line.startswith('log') and os.path.isdir(line):
 92      checkpoint_list.append(line)
 93  return checkpoint_list 
 94def main():
 95  os.environ['CUDA_VISIBLE_DEVICES']="" 
 96  checkpoint_dir = get_dir()
 97  # get_dir return the all the log_dir in './'  the log_dir format is 'log_%d',such as log_0,log_1
 98
 99
100  #checkpoint_dir=['log_0','log_1','log_2']
101  #checkpoint_dir = ['log_0','log_1','log_2','log_3','log_4','log_5','log_6','log_7','log_8','log_9','log_10','log_11']
102  print (checkpoint_dir)
103  merge_full_link_layer(checkpoint_dir)
104if __name__ == '__main__':
105  main()