发新帖

Keras的epoch概念问题

[复制链接]
4696 6

快来加入 TensorFlowers 大家庭!

您需要 登录 才可以下载或查看,没有帐号?加入社区

x
一般来讲一个epoch是把所有训练图片训练一遍。
而Keras的fit的参数中,每个epoch的训练图片数量其实是 ( steps_per_epoch 乘以 Batch大小 ),steps_per_epoch又要求是Integer类型,所以它这边的一个epoch其实并不是准准的把所有图片过一遍。比如13张图片,你的batch size=8的时候,它跑完第一个epoch其实才跑了8张图片。

当然这不是什么大问题,就是有点变扭。

model.fit(
    train_tfdata.make_one_shot_iterator(),
    steps_per_epoch=int(train_no / _BATCH_SIZE),
    epochs=_EPOCHS)
虽然可以重写,用tf.data在epoch完的时候,丢出来的tf.errors.OutOfRangeError来判断一个epoch的终止,可是代码量会上去,不容易读。
各位有没有简单的解决途径?
livernana已获得悬赏 10 金币+10 金币

最佳答案

可以通过一小段代码来检验最后一个不完整的batch是否会被tf.data.Iterator循环到: 试验一下就会发现最后一个不完整的batch是可以被循环到的。那么为什么在keras的model.fit中就不可以呢?原因出在steps_per_epoch= ...
本楼点评(0) 收起

精彩评论6

舟3332  TF芽芽  发表于 2018-10-9 00:24:09 来自手机  | 显示全部楼层
repeat。方法 无限循环 如何?
本楼点评(1) 收起
  • 树涛别扭主要起因是Keras的epoch概念和我们一般指的epoch概念不一样。
    2018-10-10 10:18 回复
wangzhe258369  TF荚荚  发表于 2018-10-9 11:09:21 | 显示全部楼层
对,当num_samples = 13,但是batch_szie = 8的时候,一个epoch的大小其实就是8,而且后面的5个在第二轮epoch里也不会用到。
本楼点评(1) 收起
livernana  TF荚荚  发表于 2018-10-9 16:49:22 | 显示全部楼层
本帖最后由 livernana 于 2018-10-9 16:50 编辑

可以通过一小段代码来检验最后一个不完整的batch是否会被tf.data.Iterator循环到:
  1. x_tensor_train, y_tensor_train = dataset_train.make_one_shot_iterator().get_next()
  2. with tf.Session() as sess:
  3.     while True:
  4.         try:
  5.             y_batch = sess.run(y_tensor_train)
  6.             print(y_batch.shape)
  7.         except tf.errors.OutOfRangeError:
  8.             break
复制代码

试验一下就会发现最后一个不完整的batch是可以被循环到的。那么为什么在keras的model.fit中就不可以呢?原因出在steps_per_epoch=int(train_no / _BATCH_SIZE)上。

通过检查tf.keras.models.Model的fit方法中的代码,可以逐渐定位到下面这段代码:
  1.   for step_index in range(steps_per_epoch):
  2.     batch_logs = {}
  3.     batch_logs['batch'] = step_index
  4.     batch_logs['size'] = 1
  5.     callbacks.on_batch_begin(step_index, batch_logs)
  6.     try:
  7.       outs = f(ins)
  8.     except errors.OutOfRangeError:
  9.       logging.warning('Your dataset iterator ran out of data; '
  10.                       'interrupting training. Make sure that your dataset '
  11.                       'can generate at least `steps_per_epoch * epochs` '
  12.                       'batches (in this case, %d batches).' %
  13.                       steps_per_epoch * epochs)
  14.       break

  15.     if not isinstance(outs, list):
  16.       outs = [outs]
  17.     for l, o in zip(out_labels, outs):
  18.       batch_logs[l] = o

  19.     callbacks.on_batch_end(step_index, batch_logs)
  20.     if callback_model.stop_training:
  21.       break
复制代码

如果steps_per_epoch=int(train_no / _BATCH_SIZE) + 1的话,最后一个batch就会被上面的代码循环到了。注意到过大的steps_per_epoch可能会导致steps_per_epoch * epochs大于你dataset中可循环的batch的个数,所以建议最大也不要超过int(train_no / _BATCH_SIZE) + 1。
希望能对你有所帮助。
本楼点评(1) 收起
  • 树涛 1的话,后面会报错的,如errors.OutOfRangeError的warning所示,`steps_per_epoch * epochs`不能大于总的dataset样本数量。

    其实Keras的epoch和我们一般讲的epoch是有区别的,我主要是想讨论一下这个。

    我们一般的epoch的认识如下:
       dataset = tf.data.Dataset.range(13)
        dataset = dataset.batch(8)
        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()

        # Compute for 100 epochs.
        for _ in range(100):
            while True:
                try:
                    sess.run(next_element)
                except tf.errors.OutOfRangeError:
                    print('Epoch End.')
                    break
    2018-10-10 10:16 回复
您需要登录后才可以回帖 登录 | 加入社区

本版积分规则

主题

帖子

16

积分
快速回复 返回顶部 返回列表