# http://blog.csdn.net/u011961856/article/details/77064631 # coding:utf-8 # tensorflow模型保存文件分析 import tensorflow as tf import os from tensorflow.python import pywrap_tensorflow # # 保存model # v1 = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1") # v2 = tf.Variable(tf.zeros([200]), name="v2") # v3 = tf.Variable(tf.zeros([100]), name="v3") # saver = tf.train.Saver() # with tf.Session() as sess: # init_op = tf.global_variables_initializer() # sess.run(init_op) # # saver.save(sess,"model.ckpt",global_step=1) # saver.save(sess, "./model.ckpt") # 恢复model # with tf.Session() as sess: # saver.restore(sess, "./model.ckpt-10.index") # http://blog.csdn.net/u010698086/article/details/77916532 # 显示打印模型的信息 model_dir = ".\checkpoints\ped2_l_2_alpha_1_lp_1.0_adv_0.05_gdl_1.0_flow_2.0_channel_1_flow_2_his_6_numobject_5" checkpoint_path = os.path.join(model_dir, "model.ckpt-10") reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() var_to_shape_map = sorted(var_to_shape_map) for key in var_to_shape_map: print("tensor_name: ", key) print(reader.get_tensor(key)) # Remove this is you want to print only variable names