tensorflow 合并模型
在这里存个备份,还有些问题没有解决。
raise ValueError(“GraphDef cannot be larger than 2GB.”)
记录一些思路好了。现在是没有生成.meta文件,爆掉应该是因为所有的变量都加载到了默认图里。
也就是说我处理完checkpoint 0 之后开始处理checkpoint1,但是checkpoint0的那些变量还是存在的…所以越来越多?
目前有两个想法,第一个想法是是受TensorFlow极简教程:创建、保存和恢复机器学习模型 中启发,用多个saver,每个saver指定要搞的图(但是这样好像要每个checkpoint都是不同的saver才有意义?)
第二个想法是,每次save完变量之后,将图恢复成默认状态(可以把图中所有变量清空。。
想法二大失败:
会遇到if self.stack[-1] is not default: │ IndexError: list index out of range 的问题。。
根据 reset_default_graph awkwardly breaks graph nesting
中提到了。。。reset_default_graph本身就不舍被设计成放在graph中清空变量用的。。。然后tf的代码也写得很不友好。。。没有 指明这个错误的原因。。。
For historical context, `tf.reset_default_graph()` was never designed to be used with `with g.as_default():` context managers. I think the proper fix here is to make `tf.reset_default_graph()` fail with an informative error message when used inside a `with g.as_default():` context. I think this could be done by checking that `ops._default_graph_stack` is empty before resetting.
1import sys, getopt
2import argparse
3import tensorflow as tf
4import os
5import shutil
6import numpy as np
7
8# fix out ,not log_
9def fix_var_name(id,var_name):
10 prefix = var_name[0:3]
11 if id<10:
12 suffix = var_name[4:]
13 if id>=10 and id<100:
14 suffix = var_name[5:]
15 if id>=100 and id<1000:
16 suffix = var_name[6:]
17 if id>=1000 and id<10000:
18 suffix = var_name[7:]
19 ret = prefix + str(id+1) + suffix
20 print('id=%d var_name=%s prefix=%s suffix=%s ret=%s' %(id,var_name,prefix,suffix,ret))
21 return ret
22# only concat full_link_layer
23def merge_full_link_layer(checkpoint_list,dry_run=False):
24 with tf.Session() as sess:
25 log_num = len(checkpoint_list) # a int range [0,1000)
26 print("log_num:%d"%log_num)
27 for var_name,_ in tf.contrib.framework.list_variables('log_0'):
28 if not var_name.startswith('out'):
29 var_tmp = tf.contrib.framework.load_variable('log_0',var_name)
30 var = tf.Variable(var_tmp,name=var_name)
31 continue
32 print("var_name:%s"%var_name)
33 for id in range(0,log_num):
34 # need to change the string out0->out1,out2,out3 ... out15
35 if id!=0:
36 var_name = fix_var_name(id-1,var_name)
37 checkpoint_dir = 'log_'+str(id)
38 print('checkpoint_dir:%s'%checkpoint_dir)
39
40 # for id,checkpoint_dir in enumerate(checkpoint_list):
41 # var_name = fix_var_name(id+1,var_name)
42 var_tmp = tf.contrib.framework.load_variable(checkpoint_dir,var_name)
43 #print("type(var_tmp):%s"%type(var_tmp))
44 # print(var_tmp)
45 if 'weights' in var_name:
46 if 'Momentum' in var_name:
47 if id == 0:
48 mom_weights = var_tmp
49 #print("mom_weights:%s"%type(mom_weights))
50 else:
51 mom_weights = np.concatenate((mom_weights,var_tmp),axis=1)
52 else:
53 if id == 0:
54 weights = var_tmp
55 else:
56 weights = np.concatenate((weights,var_tmp),axis=1)
57 else:
58 if 'Momentum' in var_name:
59 if id == 0:
60 mom_biases = var_tmp
61 else:
62 mom_biases = np.concatenate((mom_biases,var_tmp),axis=0)
63
64 else:
65 if id == 0:
66 biases = var_tmp
67 else:
68 biases = np.concatenate((biases,var_tmp),axis=0)
69 if not dry_run:
70 flag1 = 'weights' in var_name
71 flag2 = 'Momentum' in var_name
72 if flag1 and flag2:
73 mom_weights = tf.Variable(mom_weights, name='out/weights/Momentum' )
74 if flag1 and not flag2:
75 weights = tf.Variable(weights,name='out/weights')
76 if not flag1 and flag2:
77 mom_biases = tf.Variable(mom_biases,name='out/biases/Momentum')
78 if not flag1 and not flag2:
79 biases = tf.Variable(biases,name='out/biases')
80 if not dry_run:
81 print("writer running")
82 #writer = tf.summary.FileWriter('./graphs', sess.graph)
83 saver = tf.train.Saver()
84 #sess.run(tf.global_variables_initializer())
85 saver.save(sess,'./final_16_out',write_meta_graph=False)
86 #writer.close()
87def merge_ckpt(checkpoint_dir, dry_run=False):
88 merge_full_link_layer(checkpoint_dir,False)
89def get_dir():
90 checkpoint_list=[]
91 dir_list = os.listdir('./')
92 for line in dir_list:
93 if line.startswith('log') and os.path.isdir(line):
94 checkpoint_list.append(line)
95 return checkpoint_list
96def main():
97 os.environ['CUDA_VISIBLE_DEVICES']=""
98 checkpoint_dir = get_dir()
99 #checkpoint_dir = ['log_0','log_1','log_2']
100 print (checkpoint_dir)
101 merge_ckpt(checkpoint_dir, dry_run=False)
102if __name__ == '__main__':
103 main()
嘛。。先不管了。。。据数据那边说已经够用了。下面是最终版本,没有合并动量,因为对验证没有作用。
1import sys, getopt
2import argparse
3import tensorflow as tf
4import os
5import shutil
6import numpy as np
7
8# fix out ,not log_
9def fix_var_name(id,var_name):
10 prefix = var_name[0:3]
11 if id<10:
12 suffix = var_name[4:]
13 if id>=10 and id<100:
14 suffix = var_name[5:]
15 if id>=100 and id<1000:
16 suffix = var_name[6:]
17 if id>=1000 and id<10000:
18 suffix = var_name[7:]
19 ret = prefix + str(id+1) + suffix
20 print('id=%d var_name=%s prefix=%s suffix=%s ret=%s' %(id,var_name,prefix,suffix,ret))
21 return ret
22# only concat full_link_layer
23def merge_full_link_layer(checkpoint_list):
24 with tf.Session() as sess:
25 log_num = len(checkpoint_list) # a int range [0,1000)
26 print("log_num:%d"%log_num)
27 for var_name,_ in tf.contrib.framework.list_variables('log_0'):
28 if not var_name.startswith('out'):
29 var_tmp = tf.contrib.framework.load_variable('log_0',var_name)
30 var = tf.Variable(var_tmp,name=var_name)
31 continue
32 if 'Momentum' in var_name:
33 continue
34 print("var_name:%s"%var_name)
35 for id in range(0,log_num):
36 # need to change the string out0->out1,out2,out3 ... out15
37 if id!=0:
38 var_name = fix_var_name(id-1,var_name)
39 checkpoint_dir = 'log_'+str(id)
40 print('checkpoint_dir:%s'%checkpoint_dir)
41 var_tmp = tf.contrib.framework.load_variable(checkpoint_dir,var_name)
42 if 'weights' in var_name:
43 if 'Momentum' in var_name:
44 if id == 0:
45 mom_weights = var_tmp
46 else:
47 mom_weights = np.concatenate((mom_weights,var_tmp),axis=1)
48 else:
49 if id == 0:
50 weights = var_tmp
51 else:
52 weights = np.concatenate((weights,var_tmp),axis=1)
53 else:
54 if 'Momentum' in var_name:
55 if id == 0:
56 mom_biases = var_tmp
57 else:
58 mom_biases = np.concatenate((mom_biases,var_tmp),axis=0)
59 else:
60 if id == 0:
61 biases = var_tmp
62 else:
63 biases = np.concatenate((biases,var_tmp),axis=0)
64 flag1 = 'weights' in var_name
65 flag2 = 'Momentum' in var_name
66 if flag1 and not flag2:
67 weights = tf.Variable(weights,name='out/weights')
68 if not flag1 and not flag2:
69 biases = tf.Variable(biases,name='out/biases')
70 print("writer running")
71 #writer = tf.summary.FileWriter('./graphs', sess.graph)
72 saver = tf.train.Saver()
73 sess.run(tf.global_variables_initializer())
74 saver.save(sess,'./final_result',write_meta_graph=False)
75 #writer.close()
76def get_dir():
77 checkpoint_list=[]
78 dir_list = os.listdir('./')
79 for line in dir_list:
80 if line.startswith('log') and os.path.isdir(line):
81 checkpoint_list.append(line)
82 return checkpoint_list
83def main():
84 os.environ['CUDA_VISIBLE_DEVICES']=""
85 checkpoint_dir = get_dir()
86 # get_dir return the all the log_dir in './' the log_dir format is 'log_%d',such as log_0,log_1
87
88
89 #checkpoint_dir=['log_0','log_1','log_2']
90 #checkpoint_dir = ['log_0','log_1','log_2','log_3','log_4','log_5','log_6','log_7','log_8','log_9','log_10','log_11']
91 print (checkpoint_dir)
92 merge_full_link_layer(checkpoint_dir)
93if __name__ == '__main__':
94 main()
20170822update:去掉了卷基层的动量,添加了一些超参
1import sys, getopt
2import argparse
3import tensorflow as tf
4import os
5import shutil
6import numpy as np
7
8
9
10
11tf.flags.DEFINE_string('fc_prefix', 'out',
12 """the prefix of full_link_layer output name """)
13
14
15
16
17FLAGS = tf.flags.FLAGS
18# fix out ,not log_
19def fix_var_name(id,var_name):
20 len = len(FLAGS.fc_prefix)
21 prefix = var_name[0:len]
22 if id<10:
23 suffix = var_name[len+1:]
24 if id>=10 and id<100:
25 suffix = var_name[len+2:]
26 if id>=100 and id<1000:
27 suffix = var_name[len+3:]
28 if id>=1000 and id<10000:
29 suffix = var_name[len+4:]
30 ret = prefix + str(id+1) + suffix
31 print('id=%d var_name=%s prefix=%s suffix=%s ret=%s' %(id,var_name,prefix,suffix,ret))
32 return ret
33# only concat full_link_layer
34def merge_full_link_layer(checkpoint_list):
35 with tf.Session() as sess:
36 log_num = len(checkpoint_list) # a int range [0,1000)
37 print("log_num:%d"%log_num)
38 for var_name,_ in tf.contrib.framework.list_variables('log_0'):
39 if 'Momentum' in var_name:
40 continue
41 if not var_name.startswith(FLAGS.fc_prefix):
42 var_tmp = tf.contrib.framework.load_variable('log_0',var_name)
43 var = tf.Variable(var_tmp,name=var_name)
44 continue
45 print("var_name:%s"%var_name)
46 for id in range(0,log_num):
47 # need to change the string out0->out1,out2,out3 ... out15
48 if id!=0:
49 var_name = fix_var_name(id-1,var_name)
50 checkpoint_dir = 'log_'+str(id)
51 print('checkpoint_dir:%s'%checkpoint_dir)
52 var_tmp = tf.contrib.framework.load_variable(checkpoint_dir,var_name)
53 if 'weights' in var_name:
54 if 'Momentum' in var_name:
55 if id == 0:
56 mom_weights = var_tmp
57 else:
58 mom_weights = np.concatenate((mom_weights,var_tmp),axis=1)
59 else:
60 if id == 0:
61 weights = var_tmp
62 else:
63 weights = np.concatenate((weights,var_tmp),axis=1)
64 else:
65 if 'Momentum' in var_name:
66 if id == 0:
67 mom_biases = var_tmp
68 else:
69 mom_biases = np.concatenate((mom_biases,var_tmp),axis=0)
70 else:
71 if id == 0:
72 biases = var_tmp
73 else:
74 biases = np.concatenate((biases,var_tmp),axis=0)
75 flag1 = 'weights' in var_name
76 flag2 = 'Momentum' in var_name
77 if flag1 and not flag2:
78 weights = tf.Variable(weights,name='%s/weights'%FLAGS.fc_prefix)
79 if not flag1 and not flag2:
80 biases = tf.Variable(biases,name='%s/biases'%FLAGS.fc_prefix)
81 print("writer running")
82 #writer = tf.summary.FileWriter('./graphs', sess.graph)
83 saver = tf.train.Saver()
84 sess.run(tf.global_variables_initializer())
85 saver.save(sess,'./final_result',write_meta_graph=False)
86 #writer.close()
87def get_dir():
88 checkpoint_list=[]
89 dir_list = os.listdir('./')
90 for line in dir_list:
91 if line.startswith('log') and os.path.isdir(line):
92 checkpoint_list.append(line)
93 return checkpoint_list
94def main():
95 os.environ['CUDA_VISIBLE_DEVICES']=""
96 checkpoint_dir = get_dir()
97 # get_dir return the all the log_dir in './' the log_dir format is 'log_%d',such as log_0,log_1
98
99
100 #checkpoint_dir=['log_0','log_1','log_2']
101 #checkpoint_dir = ['log_0','log_1','log_2','log_3','log_4','log_5','log_6','log_7','log_8','log_9','log_10','log_11']
102 print (checkpoint_dir)
103 merge_full_link_layer(checkpoint_dir)
104if __name__ == '__main__':
105 main()