发新帖

使用自己的数据,如何理解conv2d参数里的batch?

[复制链接]
562 6

快来加入 TensorFlowers 大家庭!

您需要 登录 才可以下载或查看,没有帐号?加入社区

x
新手入坑TensorFlow,复现了几个demo之后,开始自己写网络。
之前对照教程写过使用MNIST数据集的CNN,可在我自己写的时候,问题出现了。
我是想用自己的图片来做训练,为此写了个读取一个文件夹里所有图片且保存为其中元素类型为ndarray的list返回。
问题出现在conv2d里第一个参数 input上, 这个参数除了图片分辨率和通道数外还有一个batch。
我已知batch是批训练的参数,代表着每次训练使用多少个数据。

之前按照教程用MNIST,数据都是封装好的不需要操心,但我使用自己的数据时问题来了。
这个batch该如何设置?我之前的构想是在ndarray的list里一个一个元素提出来做feed_dict,结果就是在input参数上报错。
我的list里每一个元素都是ndarray, shape= 64x64x3,这个我确认过没有问题。
feed_dict是可以传ndarray的,这个我用全连接网络做数值预测也是确认过。
目前的问题就是,我自己导入的数据没有像MNIST的那样做过封装,Batch也不知道该如何理解。

请各位大神指点一下,这个batch该如何理解,或者克服?
如果我想用自己的数据做batch训练怎么处理?
或者有更好的读取自己数据的方法,请告知。


我知道答案 回答被采纳将会获得10 金币 + 5 金币 已有6人回答
本楼点评(0) 收起

精彩评论6

ZMikkelsen  TF荚荚  发表于 2018-6-1 20:13:21 | 显示全部楼层
本帖最后由 ZMikkelsen 于 2018-6-1 20:15 编辑

# -*- coding:utf-8 -*-
# @Author: Z Mikkelsen
# @Date  : 2018/6/1


import numpy as np
import os
from PIL import Image

def get_all_image_in_file(file_path='./imaget'):
    '''
:param file_path:  file path which you want to open
:return: an ndarray of all the images in the file , shape = [batch, a, b, c]


    '''
image_matrice = []
    image_list = os.listdir(file_path)
    for image in image_list:
        img = Image.open(file_path+'/'+image)
        matrix = np.array(img)
        image_matrice.append(matrix)
    return np.vstack(tuple(image_matrice))

# test
# i = get_all_image_in_file()
# print(i.shape)
# print(i.reshape(-1, 200, 200, 3).shape)

自己搞定了,具体操作如代码所示。
使用np.vstack一个tuple将所有的图片都垂直堆叠起来。
我这里是4张200x200的图片。
I的shape是800x200x3。
然后reshape成四个维度, 对应Batch的设置为-1。
最后结果为 4x200x200x3

本楼点评(0) 收起
ZMikkelsen  TF荚荚  发表于 2018-6-1 21:10:35 | 显示全部楼层

运行好像还是出了问题  莫名其妙。
我现在已经把数据转成包含batch的维度了, 但我用placeholder传的时候还是会报错,提示我的datatype是int。
我用了很多转类型的方法都还是会报错。



本楼点评(2) 收起
  • yunhai_luo图片加载不出来。楼主的打包貌似没问题,能说一下placeholder的定义和具体错误信息吗?
    2018-6-1 23:45 回复
  • ZMikkelsen是我搞错了,现在解决了
    2018-6-2 11:12 回复
Googler  TF荚荚  发表于 2018-6-2 10:32:31 | 显示全部楼层
为啥不用官方的batch API
  1. tf.train.shuffle_batch
复制代码
本楼点评(0) 收起
ZMikkelsen  TF荚荚  发表于 2018-6-2 11:09:05 | 显示全部楼层
Googler 发表于 2018-6-2 10:32
为啥不用官方的batch API

不太了解tf提供的关于加载数据的API,看着都挺麻烦的
用python不就是图个简单...


本楼点评(2) 收起
  • Googler那建议用keras
    2018-6-2 11:09 回复
  • ZMikkelsen回复 Googler : tensorflow虽然麻烦,但是同时也提供了更自由的搭建。
    要说完全简单,pytorch真的很好用
    2018-6-2 16:48 回复
重庆不热  TF荚荚  发表于 2018-7-3 16:26:52 | 显示全部楼层
batch可以理解成为样本的个数。conv2d接收的是四维的输入,或许你应该把你的输入reshape为[batch, height, width,channel]
本楼点评(0) 收起
ViolinSolo  TF豆豆  发表于 2018-7-3 20:32:35 | 显示全部楼层
batch其实就是一批下的样本的数量,一个epoch由多个batch组成
本楼点评(0) 收起
您需要登录后才可以回帖 登录 | 加入社区

本版积分规则

快速回复 返回顶部 返回列表