发新帖

代码出错,麻烦大神们帮我看看错误怎样修正

[复制链接]
494 5

快来加入 TensorFlowers 大家庭!

您需要 登录 才可以下载或查看,没有帐号?加入社区

x
  1. """ Random Forest.
  2. Implement Random Forest algorithm with TensorFlow, and apply it to classify
  3. handwritten digit images. This example is using the MNIST database of
  4. handwritten digits as training samples (http://yann.lecun.com/exdb/mnist/).
  5. Author: Aymeric Damien
  6. Project: https://github.com/aymericdamien/TensorFlow-Examples/
  7. """

  8. from __future__ import print_function

  9. import tensorflow as tf
  10. from tensorflow.contrib.tensor_forest.python import tensor_forest
  11. from tensorflow.python.ops import resources

  12. # Ignore all GPUs, tf random forest does not benefit from it.
  13. import os
  14. os.environ["CUDA_VISIBLE_DEVICES"] = ""

  15. # Import MNIST data
  16. from tensorflow.examples.tutorials.mnist import input_data
  17. mnist = input_data.read_data_sets("/tmp/data/", one_hot=False)

  18. # Parameters
  19. num_steps = 500 # Total steps to train
  20. batch_size = 1024 # The number of samples per batch
  21. num_classes = 10 # The 10 digits
  22. num_features = 784 # Each image is 28x28 pixels
  23. num_trees = 10
  24. max_nodes = 1000

  25. # Input and Target data
  26. X = tf.placeholder(tf.float32, shape=[None, num_features])
  27. # For random forest, labels must be integers (the class id)
  28. Y = tf.placeholder(tf.int32, shape=[None])

  29. # Random Forest Parameters
  30. hparams = tensor_forest.ForestHParams(num_classes=num_classes,
  31.                                       num_features=num_features,
  32.                                       num_trees=num_trees,
  33.                                       max_nodes=max_nodes).fill()

  34. # Build the Random Forest
  35. forest_graph = tensor_forest.RandomForestGraphs(hparams)
  36. # Get training graph and loss
  37. train_op = forest_graph.training_graph(X, Y)
  38. loss_op = forest_graph.training_loss(X, Y)

  39. # Measure the accuracy
  40. infer_op, _, _ = forest_graph.inference_graph(X)
  41. correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
  42. accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  43. # Initialize the variables (i.e. assign their default value) and forest resources
  44. init_vars = tf.group(tf.global_variables_initializer(),
  45.     resources.initialize_resources(resources.shared_resources()))

  46. # Start TensorFlow session
  47. sess = tf.Session()

  48. # Run the initializer
  49. sess.run(init_vars)

  50. # Training
  51. for i in range(1, num_steps + 1):
  52.     # Prepare Data
  53.     # Get the next batch of MNIST data (only images are needed, not labels)
  54.     batch_x, batch_y = mnist.train.next_batch(batch_size)
  55.     _, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})
  56.     if i % 50 == 0 or i == 1:
  57.         acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y})
  58.         print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))

  59. # Test Model
  60. test_x, test_y = mnist.test.images, mnist.test.labels
  61. print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))
复制代码


错误信息如下:

QQ图片20180430200046.png
舟3332已获得悬赏 10 金币+50 金币

最佳答案

https://github.com/tensorflow/tensorflow/blob/e7f158858479400f17a1b6351e9827e3aa83e7ff/tensorflow/contrib/tensor_forest/python/tensor_forest.py#L481 哎呀。这个我就不太理解了。你看我上边贴的那个链接 ...
本楼点评(0) 收起

精彩评论5

M丶Sulayman  TF豆豆  发表于 2018-4-30 20:14:01 | 显示全部楼层
  1. infer_op, _, _ = forest_graph.inference_graph(X)
复制代码

改为:
  1. infer_op = forest_graph.inference_graph(X)
复制代码


=,=好了,正常了......
本楼点评(0) 收起
M丶Sulayman  TF豆豆  发表于 2018-4-30 20:15:16 | 显示全部楼层
但是49行代码,为什么要写成那样啊?我直接改成一楼这种样子,是不是丢失了什么信息?
本楼点评(0) 收起
舟3332  TF芽芽  发表于 2018-5-1 10:28:06 | 显示全部楼层

https://github.com/tensorflow/tensorflow/blob/e7f158858479400f17a1b6351e9827e3aa83e7ff/tensorflow/contrib/tensor_forest/python/tensor_forest.py#L481

哎呀。这个我就不太理解了。你看我上边贴的那个链接。貌似这个函数的返回值是一个三元组,所以写成 v, _, _ = graph.inference_graph() 的意思是可以拿到三元组的第一个元素。
至于为什么你把代码改动后就能跑,我就不是特别理解了。你可以用 print type(v) 看下你拿到的变量是什么类型~

至于 python 的多返回值函数语法你可以看看这里: https://www.geeksforgeeks.org/g- ... n-values-in-python/
本楼点评(0) 收起
M丶Sulayman  TF豆豆  发表于 2018-5-1 10:47:05 | 显示全部楼层
舟3332 发表于 2018-5-1 10:28
https://github.com/tensorflow/tensorflow/blob/e7f158858479400f17a1b6351e9827e3aa83e7ff/tensorflow/ ...

好的,我发现我好多函数都没见过......我还是继续学习吧
本楼点评(0) 收起
neverchange  TF豆豆  发表于 2018-7-3 21:41:10 | 显示全部楼层
infer_op, _, _ = forest_graph.inference_graph(X)
这种写法我都没见过,而且没有意义
本楼点评(0) 收起
您需要登录后才可以回帖 登录 | 加入社区

本版积分规则

快速回复 返回顶部 返回列表