发新帖

Dataset&TFrecord编写CNN代码,loss不下降,预测概率相当于随机...

[复制链接]
112 0

快来加入 TensorFlowers 大家庭!

您需要 登录 才可以下载或查看,没有帐号?加入社区

x
数据有4类,都是图片,搭建的网络结构在matlab中表现很好。但在tensorflow中,尝试了很多参数,梯度算法,结果都是loss很高,或者一会大一会小。预测概率永远是0.25(总共有4类),下面是代码,求解答,是我哪里写错了吗
  1. import os
  2. import tensorflow as tf
  3. import numpy as np
  4. import tensorflow.contrib.slim as slim

  5. def lenet_slim(x,is_tra):
  6.     with slim.arg_scope([slim.conv2d,slim.fully_connected],                        
  7.                         biases_initializer=tf.constant_initializer(0.1),
  8.                         weights_regularizer=slim.l2_regularizer(0.004),
  9.                         activation_fn=tf.nn.relu,
  10.                          weights_initializer=tf.truncated_normal_initializer(stddev=0.1)                       
  11.                        ):
  12.         x=slim.conv2d(x,30,[5,5],scope='conv1',reuse=tf.AUTO_REUSE)
  13.         x=slim.batch_norm(x,is_training=is_tra,scope='BN1',reuse=tf.AUTO_REUSE)
  14.         x=slim.max_pool2d(x,2,stride=2,scope='pool1')
  15. #         print(x)
  16.         x=slim.conv2d(x,50,[5,5],scope='conv2',reuse=tf.AUTO_REUSE)
  17.         x=slim.batch_norm(x,is_training=is_tra,scope='BN2',reuse=tf.AUTO_REUSE)
  18.         x=slim.max_pool2d(x,2,stride=2,scope='pool2')     
  19. #         x=slim.conv2d(x,64,[3,3],scope='conv3',reuse=tf.AUTO_REUSE)
  20. #         x=slim.batch_norm(x,is_training=is_tra,scope='BN3',reuse=tf.AUTO_REUSE)
  21. #         x=slim.max_pool2d(x,2,stride=2,scope='pool3')  
  22.         x=slim.flatten(x,scope="flatten")
  23. #         print(x)
  24.        # x=slim.fully_connected(x,512,scope='fc1')
  25.         x=slim.fully_connected(x,400,scope='fc2',reuse=tf.AUTO_REUSE)
  26. #         x=slim.dropout(x,0.5,is_training=is_tra,scope="drop1")
  27. #         x=slim.fully_connected(x,128,scope='fc3',reuse=tf.AUTO_REUSE)
  28. #         print(x)
  29.         logits=slim.fully_connected(x,4,scope='finally',reuse=tf.AUTO_REUSE)  
  30.     return logits
  31. tf.reset_default_graph()
  32. def parser(record):
  33.     features=tf.parse_single_example(record,
  34.                                     features={ 'image_raw':tf.FixedLenFeature([],tf.string),
  35.         'label':tf.FixedLenFeature([],tf.int64)})
  36.     decode_image=tf.decode_raw(features['image_raw'],tf.uint8)
  37. #     decode_image.set_shape([Image_Size,Image_Size,Num_Channels])
  38.     decode_image=tf.reshape(decode_image,[256,256,3])
  39.     decode_image=tf.cast(decode_image,tf.float32)
  40.     label=features['label']
  41.     label=tf.cast(label,tf.float32)
  42.     return decode_image,label
  43. # train_file_path='D:\\Data\\Data_Picture_useML\\tensorflow_format\\UC_4Class_train.tfrecords'
  44. train_file_path=tf.placeholder(tf.string)
  45. dataset=tf.data.TFRecordDataset(train_file_path)
  46. dataset=dataset.map(parser)
  47. dataset=dataset.shuffle(100).batch(32)
  48. dataset=dataset.repeat(3)
  49. iterator=dataset.make_initializable_iterator()
  50. f1,f2=iterator.get_next()
  51. logits=lenet_slim(f1,True)
  52. global_step=tf.Variable(0,trainable=False)
  53. # prediction=tf.argmax(logits,1)
  54. # prediction=tf.cast(prediction,tf.int32)
  55. f2=tf.cast(f2,tf.int32)
  56. # print([prediction,f2])
  57. # cross_entropy=tf.nn.softmax_cross_entropy_with_logits_v2(logits=prediction,labels=f2)
  58. cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,labels=f2)
  59. # cross_entropy=tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=f2)
  60. cross_mean=tf.reduce_mean(cross_entropy)
  61. Learn_Rate=0.001
  62. Learn_Decay=0.98
  63. learn_rate=tf.train.exponential_decay(Learn_Rate,global_step,4,Learn_Decay,staircase=True)
  64. train_step=tf.train.MomentumOptimizer(learn_rate,momentum=0.7).minimize(cross_mean)
  65. # train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_mean,global_step=global_step)

  66. #测试数据集
  67. test_file_path=tf.placeholder(tf.string)
  68. test_dataset=tf.data.TFRecordDataset(test_file_path)
  69. test_dataset=test_dataset.map(parser)
  70. test_dataset=test_dataset.shuffle(100).batch(32)

  71. test_iter=test_dataset.make_initializable_iterator()
  72. test_image_batch,test_label_batch=test_iter.get_next()

  73. # #测试成功率
  74. test_logit=lenet_slim(test_image_batch,False)
  75. prediction=tf.argmax(test_logit,axis=-1,output_type=tf.int32)

  76. with tf.Session() as sess:
  77. #     for i in range(10):
  78. #         f3,f4=sess.run([f1,f2])
  79. #         print(f3)
  80.     #一定要加tf.local_variables_initializer(),不然
  81.     sess.run((tf.global_variables_initializer(),tf.local_variables_initializer()))
  82.     sess.run(iterator.initializer,feed_dict={train_file_path:'D:\\Data\\Data_Picture_useML\\tensorflow_format\\UC_4Class_train.tfrecords'})
  83.     while True:
  84.         try:
  85.             
  86. #             f3,f4=sess.run([f1,f2])
  87.             f5,_=sess.run([cross_mean,train_step])
  88. #             f7=sess.run(cross_entropy)
  89. #             print(f7)
  90.             print(f5)
  91. #             print(f4.shape)
  92.         except tf.errors.OutOfRangeError:
  93.             break
  94.     print('train finish')
  95.     sess.run(test_iter.initializer,feed_dict={test_file_path:'D:\\Data\\Data_Picture_useML\\tensorflow_format\\UC_4Class_test.tfrecords'})
  96.     test_results=[]
  97.     test_labels=[]
  98.     while True:
  99.       
  100.         try:
  101. #             t3,t4=sess.run([test_logit,prediction])
  102. #             print([t3[0:2],t4[0:2]])
  103.             log,pred,label=sess.run([test_logit,prediction,test_label_batch])
  104.             print([log[0:2],pred[0:2]])
  105.             test_results.extend(pred)
  106.             test_labels.extend(label)
  107. #             t1,t2=sess.run([test_image_batch,test_label_batch])
  108. #             print(t2.shape)
  109. #             print("/n")
  110.         except tf.errors.OutOfRangeError:
  111.             break
  112.     # 计算准确率
  113.     print('11',test_results,test_labels)
  114.     correct = [float(y == y_) for (y, y_) in zip (test_results, test_labels)]
  115.     print("准确率:",correct)
  116.     accuracy = sum(correct) / len(correct)
  117.     print("Test accuracy is:", accuracy)
复制代码
我知道答案 回答被采纳将会获得10 金币 + 20 金币 已有0人回答
本楼点评(0) 收起
您需要登录后才可以回帖 登录 | 加入社区

本版积分规则

快速回复 返回顶部 返回列表