• 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏吧

如何保存tensorflow模型(省略标签张量),没有定义

python 来源:Ayush Pandey 5次浏览

变量定义

我tensorflow模型如下:如何保存tensorflow模型(省略标签张量),没有定义

<code class="prettyprint-override">X = tf.placeholder(tf.float32, [None,training_set.shape[1]],name = 'X') 
Y = tf.placeholder(tf.float32,[None,training_labels.shape[1]], name = 'Y') 
A1 = tf.contrib.layers.fully_connected(X, num_outputs = 50, activation_fn = tf.nn.relu) 
A1 = tf.nn.dropout(A1, 0.8) 
A2 = tf.contrib.layers.fully_connected(A1, num_outputs = 2, activation_fn = None) 
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = A2, labels = Y))  
global_step = tf.Variable(0, trainable=False) 
start_learning_rate = 0.001 
learning_rate = tf.train.exponential_decay(start_learning_rate, global_step, 200, 0.1, True) 
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) 
</code>

现在我要救这个模型遗漏张YY是标签张量对于培训,X是实际的输入)。同时在提及freeze_graph.py时提及输出节点时,我应该提及"A2"还是以其他名称保存?

===========解决方案如下:

虽然您尚未手动定义变量,但上面的代码片段实际上包含15个可保存的变量。你可以使用这个内部tensorflow功能看到他们:

<code class="prettyprint-override">from tensorflow.<a href="http://www.fixbbs.com/p/tag/python" title="查看更多关于python的文章" target="_blank">python</a>.ops.variables import _all_saveable_objects 
for obj in _all_saveable_objects(): 
    print(obj) 
</code>

对于上面的代码,它产生以下列表:

<code class="prettyprint-override"><tf.Variable 'fully_connected/weights:0' shape=(100, 50) dtype=float32_ref> 
<tf.Variable 'fully_connected/biases:0' shape=(50,) dtype=float32_ref> 
<tf.Variable 'fully_connected_1/weights:0' shape=(50, 2) dtype=float32_ref> 
<tf.Variable 'fully_connected_1/biases:0' shape=(2,) dtype=float32_ref> 
<tf.Variable 'Variable:0' shape=() dtype=int32_ref> 
<tf.Variable 'beta1_power:0' shape=() dtype=float32_ref> 
<tf.Variable 'beta2_power:0' shape=() dtype=float32_ref> 
<tf.Variable 'fully_connected/weights/Adam:0' shape=(100, 50) dtype=float32_ref> 
<tf.Variable 'fully_connected/weights/Adam_1:0' shape=(100, 50) dtype=float32_ref> 
<tf.Variable 'fully_connected/biases/Adam:0' shape=(50,) dtype=float32_ref> 
<tf.Variable 'fully_connected/biases/Adam_1:0' shape=(50,) dtype=float32_ref> 
<tf.Variable 'fully_connected_1/weights/Adam:0' shape=(50, 2) dtype=float32_ref> 
<tf.Variable 'fully_connected_1/weights/Adam_1:0' shape=(50, 2) dtype=float32_ref> 
<tf.Variable 'fully_connected_1/biases/Adam:0' shape=(2,) dtype=float32_ref> 
<tf.Variable 'fully_connected_1/biases/Adam_1:0' shape=(2,) dtype=float32_ref> 
</code>

有来自fully_connected层变量和几个从亚当来优化器(请参阅this question)。请注意,此列表中没有XY占位符,因此不需要排除它们。当然,这些张量存在于元图中,但它们没有任何价值,因此无法保存。

_all_saveable_objects()列表是tensorflow保存程序默认保存的内容,如果未明确提供变量的话。因此,回答你的主要的问题很简单:

<code class="prettyprint-override">saver = tf.train.Saver() # all saveable objects! 
with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
    saver.save(sess, "...") 
</code>

有没有办法提供了tf.contrib.layers.fully_connected函数的名称(如一个结果,它保存fully_connected_1/...),但我们鼓励你切换到tf.layers.dense,这有一个name的论点。无论如何,看看为什么这是一个好主意,看看this和this discussion。


版权声明:本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系管理员进行删除。
喜欢 (0)