发新帖

持久化模型后再调用的一个问题

[复制链接]
813 1

快来加入 TensorFlowers 大家庭!

您需要 登录 才可以下载或查看,没有帐号?加入社区

x
保存模型代码:
  1. import tensorflow as tf
  2. from tensorflow.python.framework import graph_util

  3. v1=tf.constant([10000.0],name='v1')
  4. #v1 = tf.placeholder(tf.float32,shape=[1],name='v1')
  5. v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
  6. result = v1 + v2

  7. init_op = tf.global_variables_initializer()
  8. with tf.Session() as sess:
  9.     sess.run(init_op,{v1:[100]})
  10.     print (sess.run(result,{v1:[1000]}))
  11.     writer = tf.summary.FileWriter('./graphs/model_graph', sess.graph)
  12.     writer.close()   
  13.     graph_def = tf.get_default_graph().as_graph_def()
  14.     output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def,['add'])
  15.     with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:
  16.         f.write(output_graph_def.SerializeToString())
复制代码

加载模型并使用的代码:
  1. import tensorflow as tf
  2. import numpy as np
  3. from numpy.random import RandomState
  4. from tensorflow.python.platform import gfile
  5. with tf.Session() as sess:
  6.     model_filename = "Saved_model/combined_model.pb"  
  7.     #model_filename = "inception_dec_2015/tensorflow_inception_graph.pb"  
  8.    
  9.     with gfile.FastGFile(model_filename, 'rb') as f:
  10.         graph_def = tf.GraphDef()
  11.         graph_def.ParseFromString(f.read())
  12.     """
  13.     f = open("xiaojie2.txt", "w")
  14.     print ("xiaojie2\n",file = f)
  15.     print (graph_def,file=f) #输出加载参数及变量的信息
  16.     f.close()
  17.     writer = tf.summary.FileWriter('./graphs/model_graph2', graph_def)
  18.     writer.close()
  19.     """
  20.     #result2 = tf.import_graph_def(graph_def, return_elements=["v1:0"])
  21.     v1= tf.import_graph_def(graph_def, return_elements=["v1:0"])
  22.     print (sess.run(v1))
  23.     v2= tf.import_graph_def(graph_def, return_elements=["v2:0"])
  24.     print (sess.run(v2))
  25.     result = tf.import_graph_def(graph_def, return_elements=["add:0"])
  26.     print (sess.run(result))
  27.     x=np.array([2000.0])
  28.     print (sess.run(result,feed_dict={v1: x}))
  29.    
  30.    
复制代码

最后,总是报错:
  1. [array([10000.], dtype=float32)]
  2. [array([2.], dtype=float32)]
  3. [array([10002.], dtype=float32)]
  4. ---------------------------------------------------------------------------
  5. TypeError                                 Traceback (most recent call last)
  6. <ipython-input-28-51ef401848b6> in <module>()
  7.      26     print (sess.run(result))
  8.      27     x=np.array([2000.0])
  9. ---> 28     print (sess.run(result,feed_dict={v1: x}))
  10.      29
  11.      30

  12. TypeError: unhashable type: 'list'
复制代码

即使,将保存模型时的:
  1. v1=tf.constant([10000.0],name='v1')
  2. #v1 = tf.placeholder(tf.float32,shape=[1],name='v1')
复制代码

改为tf.placeholde也没用。该如何接近?
我知道答案 回答被采纳将会获得10 金币 + 5 金币 已有1人回答
本楼点评(0) 收起

精彩评论1

申克  TF荚荚  发表于 2018-6-23 14:52:19 | 显示全部楼层
问题自己解决了。
  1. import tensorflow as tf
  2. import numpy as np
  3. from numpy.random import RandomState
  4. from tensorflow.python.platform import gfile
  5. with tf.Graph().as_default():
  6.     output_graph_def = tf.GraphDef()
  7.     output_graph_path='Saved_model/combined_model.pb'
  8.     with open(output_graph_path, "rb") as f:
  9.         output_graph_def.ParseFromString(f.read())
  10.         _ = tf.import_graph_def(output_graph_def, name="")
  11.     with tf.Session() as sess:
  12.         """④输出所有可训练的变量名称,也就是神经网络的参数"""
  13.         trainable_variables=tf.trainable_variables()
  14.         variable_list_name = [c.name for c in tf.trainable_variables()]
  15.         variable_list = sess.run(variable_list_name)
  16.         for k,v in zip(variable_list_name,variable_list):
  17.             print("variable name:",k)
  18.             print("shape:",v.shape)
  19.             #print(v)
  20.         """④输出所有可训练的变量名称,也就是神经网络的参数"""
  21.         input_x = sess.graph.get_tensor_by_name("v1:0")
  22.         print (input_x)
复制代码

因为我以前写的只是获取一个值。
而现在修正的v1则是一个tensor。我们可以修正tensor的值。
所以,tensorflow 实战google深度学习框架中有重大bug。
不懂的联系我手机 18627711314 杰
本楼点评(0) 收起
您需要登录后才可以回帖 登录 | 加入社区

本版积分规则

主题

帖子

7

积分
快速回复 返回顶部 返回列表