发新帖

模型训练准确率100%...是数据太少了么

[复制链接]
1423 23

快来加入 TensorFlowers 大家庭!

您需要 登录 才可以下载或查看,没有帐号?加入社区

x
本帖最后由 M丶Sulayman 于 2018-5-7 17:36 编辑
  1. #!/usr/bin/env python
  2. # -*- coding: UTF-8 -*-
  3. from __future__ import print_function

  4. import numpy as np
  5. from sklearn.model_selection import train_test_split
  6. import tensorflow as tf
  7. from tensorflow.contrib.tensor_forest.python import tensor_forest
  8. from tensorflow.python.ops import resources


  9. def load_data(file):
  10.     features = []
  11.     lables = []
  12.     file = open(file, 'r')
  13.     lines = file.readlines()
  14.     for line in lines:
  15.         items = line.strip().split(',')

  16.         list_to_string = ','.join(items)
  17.         for ch in ['Iris-setosa']:
  18.             if ch in list_to_string:
  19.                 list_to_string = list_to_string.replace(ch, '0')
  20.         for ch in ['Iris-versicolor']:
  21.             if ch in list_to_string:
  22.                 list_to_string = list_to_string.replace(ch, '1')
  23.         for ch in ['Iris-virginica']:
  24.             if ch in list_to_string:
  25.                 list_to_string = list_to_string.replace(ch, '2')

  26.         items = list_to_string.strip().split(',')

  27.         features.append([float(items[i]) for i in range(len(items) - 1)])
  28.         lables.append(float(items[-1]))

  29.     return np.array(features), np.array(lables)

  30. if __name__ == '__main__':
  31.     data, lables = load_data('iris.csv')
  32.     train_data, test_data, train_lables, test_lables = train_test_split(data, lables, test_size=0.3, random_state=33)

  33. # Parameters
  34.     num_steps = 1000
  35.     num_classes = 3
  36.     num_features = 4
  37.     num_trees = 10
  38.     max_nodes = 100

  39. # Input and Target data
  40.     X = tf.placeholder(tf.float32, shape=[None, num_features])
  41. # For random forest, labels must be integers (the class id)
  42.     Y = tf.placeholder(tf.int32, shape=[None])

  43. # Random Forest Parameters
  44.     hparams = tensor_forest.ForestHParams(num_classes=num_classes,
  45.                                           num_features=num_features,
  46.                                           num_trees=num_trees,
  47.                                           max_nodes=max_nodes).fill()

  48. # Build the Random Forest
  49.     forest_graph = tensor_forest.RandomForestGraphs(hparams)
  50. # Get training graph and loss
  51.     train_op = forest_graph.training_graph(X, Y)
  52.     loss_op = forest_graph.training_loss(X, Y)

  53. # Measure the accuracy
  54.     infer_op = forest_graph.inference_graph(X)
  55.     correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
  56.     accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  57. # Initialize the variables (i.e. assign their default value) and forest resources
  58.     init_vars = tf.group(tf.global_variables_initializer(),
  59.         resources.initialize_resources(resources.shared_resources()))

  60. # Start TensorFlow session
  61.     sess = tf.Session()

  62. # Run the initializer
  63.     sess.run(init_vars)

  64. # Training
  65.     for i in range(1, num_steps + 1):

  66.         _, l = sess.run([train_op, loss_op], feed_dict={X: train_data, Y: train_lables})
  67.         if i % 50 == 0 or i == 1:
  68.             acc = sess.run(accuracy_op, feed_dict={X: train_data, Y: train_lables})
  69.             print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))

  70. # Test Model

  71.     print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_data, Y: test_lables}))
复制代码
QQ图片20180507104935.png iris.rar (878 Bytes, 下载次数: 91)
本楼点评(0) 收起

精彩评论23

M丶Sulayman  TF豆豆  发表于 2018-5-7 10:55:36 | 显示全部楼层
  1. train_data, test_data, train_lables, test_lables = train_test_split(data, lables, test_size=0.3, random_state=33)
复制代码

是数据集分割的原因么?
本楼点评(0) 收起
M丶Sulayman  TF豆豆  发表于 2018-5-7 11:00:57 | 显示全部楼层
是最大节点数的原因     max_nodes    设置过大,分的太细,造成了过拟合
本楼点评(0) 收起
slobber  TF荚荚  发表于 2018-5-7 11:43:19 | 显示全部楼层
iris 不用考虑数据量吧
本楼点评(0) 收起
M丶Sulayman  TF豆豆  发表于 2018-5-7 11:44:18 | 显示全部楼层
slobber 发表于 2018-5-7 11:43
iris 不用考虑数据量吧

是的,然后我就自我否定了,原来是分支节点太多造成了过拟合。
本楼点评(0) 收起
Oreo.  TF豆豆  发表于 2018-5-7 17:20:06 | 显示全部楼层
你好,请问你用的TensorFlow是多少版本的?
我用TensorFlow-gpu 1.4.0版本的,提示有错误
本楼点评(1) 收起
  • Oreo.我在stackoverflow查我的错误,好像都有涉及到关于TensorFlow版本的问题。

    tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered 'FertileStatsResourceHandleOp' in binary running on PC-20160730LTFZ. Make sure the Op and Kernel are registered in the binary running in this process.
    2018-5-7 17:24 回复
M丶Sulayman  TF豆豆  发表于 2018-5-7 17:29:51 | 显示全部楼层
Oreo. 发表于 2018-5-7 17:20
你好,请问你用的TensorFlow是多少版本的?
我用TensorFlow-gpu 1.4.0版本的,提示有错误 ...

QQ图片20180507172926.png 额,我的可以跑
本楼点评(0) 收起
Oreo.  TF豆豆  发表于 2018-5-7 19:25:49 | 显示全部楼层

我用的是TensorFlow-gpu-1.4.0跑有报错然后我uninstall后install了最新的TensorFlow 1.8.0还是会报一样的错
然后我装了1.3.0 居然成功了!!!
本楼点评(1) 收起
  • victor6510用1.3.0的話請問原本的  tf.app.run() 是要自己从 1.8 带过来吗? 貌似这个是1.4以后才支持的。
    2018-7-4 18:18 回复
M丶Sulayman  TF豆豆  发表于 2018-5-7 19:27:07 | 显示全部楼层
Oreo. 发表于 2018-5-7 19:25
我用的是TensorFlow-gpu-1.4.0跑有报错然后我uninstall后install了最新的TensorFlow 1.8.0还是会报一样的 ...

......看来是函数的问题
本楼点评(0) 收起
Oreo.  TF豆豆  发表于 2018-5-7 19:28:02 | 显示全部楼层
M丶Sulayman 发表于 2018-5-7 19:27
......看来是函数的问题

可能是用的函数比较老吧,新版本已经不能用了
本楼点评(0) 收起
Oreo.  TF豆豆  发表于 2018-5-7 20:35:08 | 显示全部楼层
M丶Sulayman 发表于 2018-5-7 19:27
......看来是函数的问题

已经跑了一下数据了,随机森林好容易就能跑到90以上,真是幸福
而且代码也很简洁,比我之前跑神经网络舒服多了呀
本楼点评(0) 收起
Oreo.  TF豆豆  发表于 2018-5-7 20:35:40 | 显示全部楼层
M丶Sulayman 发表于 2018-5-7 19:27
......看来是函数的问题

谢谢你在论坛里面提问了那么多,我也学习好多哈哈哈
本楼点评(1) 收起
M丶Sulayman  TF豆豆  发表于 2018-5-7 20:55:37 | 显示全部楼层
Oreo. 发表于 2018-5-7 20:35
谢谢你在论坛里面提问了那么多,我也学习好多哈哈哈

共同进步~
本楼点评(0) 收起
Oreo.  TF豆豆  发表于 2018-5-9 16:28:01 | 显示全部楼层
M丶Sulayman 发表于 2018-5-7 19:27
......看来是函数的问题

我去网上查了一下,在stackoverflow上的问题也没有明确的回答,有一位是选择放弃tensorflow用sklearn实现随机森林了直接。。。
在github上也有人提问,但是tensorflow官方账号也没有给出明确的答复,函数用的也不老,至少现在调用tensor_forest包还是一样路径,反正我是用1.3.0不会报错,1.4.0和1.8.0都会报错,而且在网上好像报错的基本都是从1.4.0版本的,估计是bug一直没有解决吧。。。
本楼点评(0) 收起
tking  TF荚荚  发表于 2018-5-9 20:48:59 | 显示全部楼层
学习了google官方的视频,其中在遇到100%正确率的时候极大可能是过拟合,需要检查自己的代码
本楼点评(0) 收起
M丶Sulayman  TF豆豆  发表于 2018-5-10 09:55:20 | 显示全部楼层
tking 发表于 2018-5-9 20:48
学习了google官方的视频,其中在遇到100%正确率的时候极大可能是过拟合,需要检查自己的代码 ...

谢谢~是过拟合了
本楼点评(0) 收起
M丶Sulayman  TF豆豆  发表于 2018-5-10 09:56:13 | 显示全部楼层
Oreo. 发表于 2018-5-9 16:28
我去网上查了一下,在stackoverflow上的问题也没有明确的回答,有一位是选择放弃tensorflow用sklearn实现 ...

嗯嗯,毕竟新出的东西,完善还需要一定的时间~
本楼点评(0) 收起
重庆不热  TF荚荚  发表于 2018-7-3 16:43:23
学习一下
本楼点评(0) 收起

fantasycheng  TF荚荚  发表于 2018-7-3 16:47:37 | 显示全部楼层
iris数据集本来就是给你拿来练手的,100%的准确度没什么好奇怪的
本楼点评(0) 收起
neverchange  TF豆豆  发表于 2018-7-4 12:21:43 | 显示全部楼层
不是的,不应该存在100%的结果,反向预测除非就一个数据。
本楼点评(0) 收起
Lemon  TF荚荚  发表于 2018-7-4 13:27:34 | 显示全部楼层
过拟合了吧。训练集准确率有了,试试你的测试集。
本楼点评(0) 收起
您需要登录后才可以回帖 登录 | 加入社区

本版积分规则

快速回复 返回顶部 返回列表