tensorflow 合并模型
在这里存个备份,还有些问题没有解决。
raise ValueError("GraphDef cannot be larger than 2GB.")
记录一些思路好了。现在是没有生成.meta文件,爆掉应该是因为所有的变量都加载到了默认图里。
也就是说我处理完checkpoint 0 之后开始处理checkpoint1,但是checkpoint0的那些变量还是存在的...所以越来越多?
目前有两个想法,第一个想法是是受TensorFlow极简教程:创建、保存和恢复机器学习模型 中启发,用多个saver,每个saver指定要搞的图(但是这样好像要每个checkpoint都是不同的saver才有意义?)
第二个想法是,每次save完变量之后,将图恢复成默认状态(可以把图中所有变量清空。。
想法二大失败:
会遇到if self.stack[-1] is not default: │ IndexError: list index out of range 的问题。。
根据 reset_default_graph awkwardly breaks graph nesting
中提到了。。。reset_default_graph本身就不舍被设计成放在graph中清空变量用的。。。然后tf的代码也写得很不友好。。。没有 指明这个错误的原因。。。
For historical context, `tf.reset_default_graph()` was never designed to be used with `with g.as_default():` context managers. I think the proper fix here is to make `tf.reset_default_graph()` fail with an informative error message when used inside a `with g.as_default():` context. I think this could be done by checking that `ops._default_graph_stack` is empty before resetting.
1import sys, getopt
2import argparse
3import tensorflow as tf
4import os
5import shutil
6import numpy as np
1# fix out ,not log_
2def fix_var_name(id,var_name):
3 prefix = var_name[0:3]
4 if id<10:
5 suffix = var_name[4:]
6 if id>=10 and id<100:
7 suffix = var_name[5:]
8 if id>=100 and id<1000:
9 suffix = var_name[6:]
10 if id>=1000 and id<10000:
11 suffix = var_name[7:]
12 ret = prefix + str(id+1) + suffix
13 print('id=%d var_name=%s prefix=%s suffix=%s ret=%s' %(id,var_name,prefix,suffix,ret))
14 return ret
15# only concat full_link_layer
16def merge_full_link_layer(checkpoint_list,dry_run=False):
17 with tf.Session() as sess:
18 log_num = len(checkpoint_list) # a int range [0,1000)
19 print("log_num:%d"%log_num)
20 for var_name,_ in tf.contrib.framework.list_variables('log_0'):
21 if not var_name.startswith('out'):
22 var_tmp = tf.contrib.framework.load_variable('log_0',var_name)
23 var = tf.Variable(var_tmp,name=var_name)
24 continue
25 print("var_name:%s"%var_name)
26 for id in range(0,log_num):
27 # need to change the string out0->out1,out2,out3 ... out15
28 if id!=0:
29 var_name = fix_var_name(id-1,var_name)
30 checkpoint_dir = 'log_'+str(id)
31 print('checkpoint_dir:%s'%checkpoint_dir)
# for id,checkpoint_dir in enumerate(checkpoint_list):
# var_name = fix_var_name(id+1,var_name)
var_tmp = tf.contrib.framework.load_variable(checkpoint_dir,var_name)
#print("type(var_tmp):%s"%type(var_tmp))
# print(var_tmp)
if 'weights' in var_name:
if 'Momentum' in var_name:
if id == 0:
mom_weights = var_tmp
#print("mom_weights:%s"%type(mom_weights))
else:
mom_weights = np.concatenate((mom_weights,var_tmp),axis=1)
else:
if id == 0:
weights = var_tmp
else:
weights = np.concatenate((weights,var_tmp),axis=1)
else:
if 'Momentum' in var_name:
if id == 0:
mom_biases = var_tmp
else:
mom_biases = np.concatenate((mom_biases,var_tmp),axis=0)
1 else:
2 if id == 0:
3 biases = var_tmp
4 else:
5 biases = np.concatenate((biases,var_tmp),axis=0)
6 if not dry_run:
7 flag1 = 'weights' in var_name
8 flag2 = 'Momentum' in var_name
9 if flag1 and flag2:
10 mom_weights = tf.Variable(mom_weights, name='out/weights/Momentum' )
11 if flag1 and not flag2:
12 weights = tf.Variable(weights,name='out/weights')
13 if not flag1 and flag2:
14 mom_biases = tf.Variable(mom_biases,name='out/biases/Momentum')
15 if not flag1 and not flag2:
16 biases = tf.Variable(biases,name='out/biases')
17 if not dry_run:
18 print("writer running")
19 #writer = tf.summary.FileWriter('./graphs', sess.graph)
20 saver = tf.train.Saver()
21 #sess.run(tf.global_variables_initializer())
22 saver.save(sess,'./final_16_out',write_meta_graph=False)
23 #writer.close()
24def merge_ckpt(checkpoint_dir, dry_run=False):
25 merge_full_link_layer(checkpoint_dir,False)
26def get_dir():
27 checkpoint_list=[]
28 dir_list = os.listdir('./')
29 for line in dir_list:
30 if line.startswith('log') and os.path.isdir(line):
31 checkpoint_list.append(line)
32 return checkpoint_list
33def main():
34 os.environ['CUDA_VISIBLE_DEVICES']=""
35 checkpoint_dir = get_dir()
36 #checkpoint_dir = ['log_0','log_1','log_2']
37 print (checkpoint_dir)
38 merge_ckpt(checkpoint_dir, dry_run=False)
39if __name__ == '__main__':
40 main()
嘛。。先不管了。。。据数据那边说已经够用了。下面是最终版本,没有合并动量,因为对验证没有作用。
1import sys, getopt
2import argparse
3import tensorflow as tf
4import os
5import shutil
6import numpy as np
1# fix out ,not log_
2def fix_var_name(id,var_name):
3 prefix = var_name[0:3]
4 if id<10:
5 suffix = var_name[4:]
6 if id>=10 and id<100:
7 suffix = var_name[5:]
8 if id>=100 and id<1000:
9 suffix = var_name[6:]
10 if id>=1000 and id<10000:
11 suffix = var_name[7:]
12 ret = prefix + str(id+1) + suffix
13 print('id=%d var_name=%s prefix=%s suffix=%s ret=%s' %(id,var_name,prefix,suffix,ret))
14 return ret
15# only concat full_link_layer
16def merge_full_link_layer(checkpoint_list):
17 with tf.Session() as sess:
18 log_num = len(checkpoint_list) # a int range [0,1000)
19 print("log_num:%d"%log_num)
20 for var_name,_ in tf.contrib.framework.list_variables('log_0'):
21 if not var_name.startswith('out'):
22 var_tmp = tf.contrib.framework.load_variable('log_0',var_name)
23 var = tf.Variable(var_tmp,name=var_name)
24 continue
25 if 'Momentum' in var_name:
26 continue
27 print("var_name:%s"%var_name)
28 for id in range(0,log_num):
29 # need to change the string out0->out1,out2,out3 ... out15
30 if id!=0:
31 var_name = fix_var_name(id-1,var_name)
32 checkpoint_dir = 'log_'+str(id)
33 print('checkpoint_dir:%s'%checkpoint_dir)
34 var_tmp = tf.contrib.framework.load_variable(checkpoint_dir,var_name)
35 if 'weights' in var_name:
36 if 'Momentum' in var_name:
37 if id == 0:
38 mom_weights = var_tmp
39 else:
40 mom_weights = np.concatenate((mom_weights,var_tmp),axis=1)
41 else:
42 if id == 0:
43 weights = var_tmp
44 else:
45 weights = np.concatenate((weights,var_tmp),axis=1)
46 else:
47 if 'Momentum' in var_name:
48 if id == 0:
49 mom_biases = var_tmp
50 else:
51 mom_biases = np.concatenate((mom_biases,var_tmp),axis=0)
52 else:
53 if id == 0:
54 biases = var_tmp
55 else:
56 biases = np.concatenate((biases,var_tmp),axis=0)
57 flag1 = 'weights' in var_name
58 flag2 = 'Momentum' in var_name
59 if flag1 and not flag2:
60 weights = tf.Variable(weights,name='out/weights')
61 if not flag1 and not flag2:
62 biases = tf.Variable(biases,name='out/biases')
63 print("writer running")
64 #writer = tf.summary.FileWriter('./graphs', sess.graph)
65 saver = tf.train.Saver()
66 sess.run(tf.global_variables_initializer())
67 saver.save(sess,'./final_result',write_meta_graph=False)
68 #writer.close()
69def get_dir():
70 checkpoint_list=[]
71 dir_list = os.listdir('./')
72 for line in dir_list:
73 if line.startswith('log') and os.path.isdir(line):
74 checkpoint_list.append(line)
75 return checkpoint_list
76def main():
77 os.environ['CUDA_VISIBLE_DEVICES']=""
78 checkpoint_dir = get_dir()
79 # get_dir return the all the log_dir in './' the log_dir format is 'log_%d',such as log_0,log_1
#checkpoint_dir=['log_0','log_1','log_2']
#checkpoint_dir = ['log_0','log_1','log_2','log_3','log_4','log_5','log_6','log_7','log_8','log_9','log_10','log_11']
print (checkpoint_dir)
merge_full_link_layer(checkpoint_dir)
if __name__ == '__main__':
main()
20170822update:去掉了卷基层的动量,添加了一些超参
1import sys, getopt
2import argparse
3import tensorflow as tf
4import os
5import shutil
6import numpy as np
tf.flags.DEFINE_string('fc_prefix', 'out',
"""the prefix of full_link_layer output name """)
1FLAGS = tf.flags.FLAGS
2# fix out ,not log_
3def fix_var_name(id,var_name):
4 len = len(FLAGS.fc_prefix)
5 prefix = var_name[0:len]
6 if id<10:
7 suffix = var_name[len+1:]
8 if id>=10 and id<100:
9 suffix = var_name[len+2:]
10 if id>=100 and id<1000:
11 suffix = var_name[len+3:]
12 if id>=1000 and id<10000:
13 suffix = var_name[len+4:]
14 ret = prefix + str(id+1) + suffix
15 print('id=%d var_name=%s prefix=%s suffix=%s ret=%s' %(id,var_name,prefix,suffix,ret))
16 return ret
17# only concat full_link_layer
18def merge_full_link_layer(checkpoint_list):
19 with tf.Session() as sess:
20 log_num = len(checkpoint_list) # a int range [0,1000)
21 print("log_num:%d"%log_num)
22 for var_name,_ in tf.contrib.framework.list_variables('log_0'):
23 if 'Momentum' in var_name:
24 continue
25 if not var_name.startswith(FLAGS.fc_prefix):
26 var_tmp = tf.contrib.framework.load_variable('log_0',var_name)
27 var = tf.Variable(var_tmp,name=var_name)
28 continue
29 print("var_name:%s"%var_name)
30 for id in range(0,log_num):
31 # need to change the string out0->out1,out2,out3 ... out15
32 if id!=0:
33 var_name = fix_var_name(id-1,var_name)
34 checkpoint_dir = 'log_'+str(id)
35 print('checkpoint_dir:%s'%checkpoint_dir)
36 var_tmp = tf.contrib.framework.load_variable(checkpoint_dir,var_name)
37 if 'weights' in var_name:
38 if 'Momentum' in var_name:
39 if id == 0:
40 mom_weights = var_tmp
41 else:
42 mom_weights = np.concatenate((mom_weights,var_tmp),axis=1)
43 else:
44 if id == 0:
45 weights = var_tmp
46 else:
47 weights = np.concatenate((weights,var_tmp),axis=1)
48 else:
49 if 'Momentum' in var_name:
50 if id == 0:
51 mom_biases = var_tmp
52 else:
53 mom_biases = np.concatenate((mom_biases,var_tmp),axis=0)
54 else:
55 if id == 0:
56 biases = var_tmp
57 else:
58 biases = np.concatenate((biases,var_tmp),axis=0)
59 flag1 = 'weights' in var_name
60 flag2 = 'Momentum' in var_name
61 if flag1 and not flag2:
62 weights = tf.Variable(weights,name='%s/weights'%FLAGS.fc_prefix)
63 if not flag1 and not flag2:
64 biases = tf.Variable(biases,name='%s/biases'%FLAGS.fc_prefix)
65 print("writer running")
66 #writer = tf.summary.FileWriter('./graphs', sess.graph)
67 saver = tf.train.Saver()
68 sess.run(tf.global_variables_initializer())
69 saver.save(sess,'./final_result',write_meta_graph=False)
70 #writer.close()
71def get_dir():
72 checkpoint_list=[]
73 dir_list = os.listdir('./')
74 for line in dir_list:
75 if line.startswith('log') and os.path.isdir(line):
76 checkpoint_list.append(line)
77 return checkpoint_list
78def main():
79 os.environ['CUDA_VISIBLE_DEVICES']=""
80 checkpoint_dir = get_dir()
81 # get_dir return the all the log_dir in './' the log_dir format is 'log_%d',such as log_0,log_1
#checkpoint_dir=['log_0','log_1','log_2']
#checkpoint_dir = ['log_0','log_1','log_2','log_3','log_4','log_5','log_6','log_7','log_8','log_9','log_10','log_11']
print (checkpoint_dir)
merge_full_link_layer(checkpoint_dir)
if __name__ == '__main__':
main()