快来加入 TensorFlowers 大家庭!
您需要 登录 才可以下载或查看,没有帐号?加入社区
x
数据有4类,都是图片,搭建的网络结构在matlab中表现很好。但在tensorflow中,尝试了很多参数,梯度算法,结果都是loss很高,或者一会大一会小。预测概率永远是0.25(总共有4类),下面是代码,求解答,是我哪里写错了吗
- import os
- import tensorflow as tf
- import numpy as np
- import tensorflow.contrib.slim as slim
- def lenet_slim(x,is_tra):
- with slim.arg_scope([slim.conv2d,slim.fully_connected],
- biases_initializer=tf.constant_initializer(0.1),
- weights_regularizer=slim.l2_regularizer(0.004),
- activation_fn=tf.nn.relu,
- weights_initializer=tf.truncated_normal_initializer(stddev=0.1)
- ):
- x=slim.conv2d(x,30,[5,5],scope='conv1',reuse=tf.AUTO_REUSE)
- x=slim.batch_norm(x,is_training=is_tra,scope='BN1',reuse=tf.AUTO_REUSE)
- x=slim.max_pool2d(x,2,stride=2,scope='pool1')
- # print(x)
- x=slim.conv2d(x,50,[5,5],scope='conv2',reuse=tf.AUTO_REUSE)
- x=slim.batch_norm(x,is_training=is_tra,scope='BN2',reuse=tf.AUTO_REUSE)
- x=slim.max_pool2d(x,2,stride=2,scope='pool2')
- # x=slim.conv2d(x,64,[3,3],scope='conv3',reuse=tf.AUTO_REUSE)
- # x=slim.batch_norm(x,is_training=is_tra,scope='BN3',reuse=tf.AUTO_REUSE)
- # x=slim.max_pool2d(x,2,stride=2,scope='pool3')
- x=slim.flatten(x,scope="flatten")
- # print(x)
- # x=slim.fully_connected(x,512,scope='fc1')
- x=slim.fully_connected(x,400,scope='fc2',reuse=tf.AUTO_REUSE)
- # x=slim.dropout(x,0.5,is_training=is_tra,scope="drop1")
- # x=slim.fully_connected(x,128,scope='fc3',reuse=tf.AUTO_REUSE)
- # print(x)
- logits=slim.fully_connected(x,4,scope='finally',reuse=tf.AUTO_REUSE)
- return logits
- tf.reset_default_graph()
- def parser(record):
- features=tf.parse_single_example(record,
- features={ 'image_raw':tf.FixedLenFeature([],tf.string),
- 'label':tf.FixedLenFeature([],tf.int64)})
- decode_image=tf.decode_raw(features['image_raw'],tf.uint8)
- # decode_image.set_shape([Image_Size,Image_Size,Num_Channels])
- decode_image=tf.reshape(decode_image,[256,256,3])
- decode_image=tf.cast(decode_image,tf.float32)
- label=features['label']
- label=tf.cast(label,tf.float32)
- return decode_image,label
- # train_file_path='D:\\Data\\Data_Picture_useML\\tensorflow_format\\UC_4Class_train.tfrecords'
- train_file_path=tf.placeholder(tf.string)
- dataset=tf.data.TFRecordDataset(train_file_path)
- dataset=dataset.map(parser)
- dataset=dataset.shuffle(100).batch(32)
- dataset=dataset.repeat(3)
- iterator=dataset.make_initializable_iterator()
- f1,f2=iterator.get_next()
- logits=lenet_slim(f1,True)
- global_step=tf.Variable(0,trainable=False)
- # prediction=tf.argmax(logits,1)
- # prediction=tf.cast(prediction,tf.int32)
- f2=tf.cast(f2,tf.int32)
- # print([prediction,f2])
- # cross_entropy=tf.nn.softmax_cross_entropy_with_logits_v2(logits=prediction,labels=f2)
- cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,labels=f2)
- # cross_entropy=tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=f2)
- cross_mean=tf.reduce_mean(cross_entropy)
- Learn_Rate=0.001
- Learn_Decay=0.98
- learn_rate=tf.train.exponential_decay(Learn_Rate,global_step,4,Learn_Decay,staircase=True)
- train_step=tf.train.MomentumOptimizer(learn_rate,momentum=0.7).minimize(cross_mean)
- # train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_mean,global_step=global_step)
- #测试数据集
- test_file_path=tf.placeholder(tf.string)
- test_dataset=tf.data.TFRecordDataset(test_file_path)
- test_dataset=test_dataset.map(parser)
- test_dataset=test_dataset.shuffle(100).batch(32)
- test_iter=test_dataset.make_initializable_iterator()
- test_image_batch,test_label_batch=test_iter.get_next()
- # #测试成功率
- test_logit=lenet_slim(test_image_batch,False)
- prediction=tf.argmax(test_logit,axis=-1,output_type=tf.int32)
- with tf.Session() as sess:
- # for i in range(10):
- # f3,f4=sess.run([f1,f2])
- # print(f3)
- #一定要加tf.local_variables_initializer(),不然
- sess.run((tf.global_variables_initializer(),tf.local_variables_initializer()))
- sess.run(iterator.initializer,feed_dict={train_file_path:'D:\\Data\\Data_Picture_useML\\tensorflow_format\\UC_4Class_train.tfrecords'})
- while True:
- try:
-
- # f3,f4=sess.run([f1,f2])
- f5,_=sess.run([cross_mean,train_step])
- # f7=sess.run(cross_entropy)
- # print(f7)
- print(f5)
- # print(f4.shape)
- except tf.errors.OutOfRangeError:
- break
- print('train finish')
- sess.run(test_iter.initializer,feed_dict={test_file_path:'D:\\Data\\Data_Picture_useML\\tensorflow_format\\UC_4Class_test.tfrecords'})
- test_results=[]
- test_labels=[]
- while True:
-
- try:
- # t3,t4=sess.run([test_logit,prediction])
- # print([t3[0:2],t4[0:2]])
- log,pred,label=sess.run([test_logit,prediction,test_label_batch])
- print([log[0:2],pred[0:2]])
- test_results.extend(pred)
- test_labels.extend(label)
- # t1,t2=sess.run([test_image_batch,test_label_batch])
- # print(t2.shape)
- # print("/n")
- except tf.errors.OutOfRangeError:
- break
- # 计算准确率
- print('11',test_results,test_labels)
- correct = [float(y == y_) for (y, y_) in zip (test_results, test_labels)]
- print("准确率:",correct)
- accuracy = sum(correct) / len(correct)
- print("Test accuracy is:", accuracy)
复制代码
我知道答案
回答被采纳将会获得 10 金币 + 20 金币 已有0人回答
|