发新帖

【已解决】继续RF模型的学习......问题一个接一个,求救

[复制链接]
877 6

快来加入 TensorFlowers 大家庭!

您需要 登录 才可以下载或查看,没有帐号?加入社区

x
本帖最后由 M丶Sulayman 于 2018-5-10 09:54 编辑

QQ图片20180507093403.png
我脑子有问题......三个标签,我填了个2    下面代码的44行    num_classes = 2 应该是 num_classes = 3


虽然跑出结果了,但是又出现了新问题(下图)

QQ图片20180507100748.png
QQ图片20180507100759.png 自己解决了......代码最后一行Y传入的值应该是test_lables,然后就好了
QQ图片20180507101635.png
下面附上代码,数据

  1. #!/usr/bin/env python
  2. # -*- coding: UTF-8 -*-
  3. from __future__ import print_function

  4. import numpy as np
  5. from sklearn.model_selection import train_test_split
  6. import tensorflow as tf
  7. from tensorflow.contrib.tensor_forest.python import tensor_forest
  8. from tensorflow.python.ops import resources


  9. def load_data(file):
  10.     features = []
  11.     lables = []
  12.     file = open(file, 'r')
  13.     lines = file.readlines()
  14.     for line in lines:
  15.         items = line.strip().split(',')

  16.         list_to_string = ','.join(items)
  17.         for ch in ['Iris-setosa']:
  18.             if ch in list_to_string:
  19.                 list_to_string = list_to_string.replace(ch, '0')
  20.         for ch in ['Iris-versicolor']:
  21.             if ch in list_to_string:
  22.                 list_to_string = list_to_string.replace(ch, '1')
  23.         for ch in ['Iris-virginica']:
  24.             if ch in list_to_string:
  25.                 list_to_string = list_to_string.replace(ch, '2')

  26.         items = list_to_string.strip().split(',')

  27.         features.append([float(items[i]) for i in range(len(items) - 1)])
  28.         lables.append(float(items[-1]))

  29.     return np.array(features), np.array(lables)

  30. if __name__ == '__main__':
  31.     data, lables = load_data('iris.csv')
  32.     train_data, test_data, train_lables, test_lables = train_test_split(data, lables, test_size=0.3, random_state=33)

  33. # Parameters
  34.     num_steps = 100
  35.     num_classes = 2
  36.     num_features = 4
  37.     num_trees = 10
  38.     max_nodes = 10

  39. # Input and Target data
  40.     X = tf.placeholder(tf.float32, shape=[None, num_features])
  41. # For random forest, labels must be integers (the class id)
  42.     Y = tf.placeholder(tf.int32, shape=[None])

  43. # Random Forest Parameters
  44.     hparams = tensor_forest.ForestHParams(num_classes=num_classes,
  45.                                           num_features=num_features,
  46.                                           num_trees=num_trees,
  47.                                           max_nodes=max_nodes).fill()

  48. # Build the Random Forest
  49.     forest_graph = tensor_forest.RandomForestGraphs(hparams)
  50. # Get training graph and loss
  51.     train_op = forest_graph.training_graph(X, Y)
  52.     loss_op = forest_graph.training_loss(X, Y)

  53. # Measure the accuracy
  54.     infer_op = forest_graph.inference_graph(X)
  55.     correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
  56.     accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  57. # Initialize the variables (i.e. assign their default value) and forest resources
  58.     init_vars = tf.group(tf.global_variables_initializer(),
  59.         resources.initialize_resources(resources.shared_resources()))

  60. # Start TensorFlow session
  61.     sess = tf.Session()

  62. # Run the initializer
  63.     sess.run(init_vars)

  64. # Training
  65.     for i in range(1, num_steps + 1):

  66.         _, l = sess.run([train_op, loss_op], feed_dict={X: train_data, Y: train_lables})
  67.         if i % 50 == 0 or i == 1:
  68.             acc = sess.run(accuracy_op, feed_dict={X: train_data, Y: train_lables})
  69.             print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))

  70. # Test Model

  71.     print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_data, Y: train_lables}))

复制代码
借鉴的是这篇    https://github.com/aymericdamien ... ls/random_forest.py




我知道答案 回答被采纳将会获得10 金币 + 50 金币 已有6人回答

iris.rar

878 Bytes, 下载次数: 126

本楼点评(0) 收起

精彩评论6

M丶Sulayman  TF豆豆  发表于 2018-5-7 09:38:51 | 显示全部楼层
2018-05-07 09:32:56.962720: F C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\contrib\tensor_forest\kernels\count_extremely_random_stats_op.cc:400]
Check failed: column < num_classes_ (3 vs. 3)
这个错误啥意思......
本楼点评(0) 收起
M丶Sulayman  TF豆豆  发表于 2018-5-7 09:41:19 | 显示全部楼层
数据预处理,划分数据集都没问题,然后我就模仿那篇代码把数据加载进去跑,怎么把我的Python给跑死了......
本楼点评(0) 收起
M丶Sulayman  TF豆豆  发表于 2018-5-7 10:09:50 | 显示全部楼层
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [45] vs. [105]
这又是个什么鬼......
本楼点评(0) 收起
M丶Sulayman  TF豆豆  发表于 2018-5-7 10:20:44 | 显示全部楼层
总结一点,粘贴复制是下下策,还是要自己理解原理,理解代码,理解数据流
本楼点评(1) 收起
舟3332  TF芽芽  发表于 2018-5-9 23:02:10 | 显示全部楼层
已经解决了的话就把文章标记成已解决吧~
本楼点评(0) 收起
M丶Sulayman  TF豆豆  发表于 2018-5-10 09:54:13 | 显示全部楼层
舟3332 发表于 2018-5-9 23:02
已经解决了的话就把文章标记成已解决吧~

好的~
本楼点评(0) 收起
您需要登录后才可以回帖 登录 | 加入社区

本版积分规则

快速回复 返回顶部 返回列表