发新帖

生成tfrecord文件的程序无法运行

[复制链接]
150 1
  1. import os
  2. import io
  3. import pandas as pd
  4. import tensorflow as tf

  5. from PIL import Image
  6. from object_detection.utils import dataset_util
  7. from collections import namedtuple, OrderedDict
复制代码
  1. flags = tf.app.flags
  2. flags.DEFINE_string('csv_input', '', 'F:/GradeTwo/train/csv')
  3. flags.DEFINE_string('output_path', '', 'F:/GradeTwo/train/TFRecord')
  4. FLAGS = flags.FLAGS
  5. image_path = r'F:/GradeTwo/train/images'
复制代码
  1. def class_text_to_int(row_label):
  2.     if row_label == 'ship':
  3.         return 1
  4. #    elif row_label == 'vehicle':
  5. #        return 2
  6.     else:
  7.         None

  8. def split(df, group):
  9.     data = namedtuple('data', ['filename', 'object'])
  10.     gb = df.groupby(group)
  11.     return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]

  12. def create_tf_example(group, path):
  13.     with tf.gfile.GFile(os.path.join(image_path, '{}'.format(group.filename)), 'rb') as fid:
  14.         encoded_jpg = fid.read()
  15.     encoded_jpg_io = io.BytesIO(encoded_jpg)
  16.     image = Image.open(encoded_jpg_io)
  17.     width, height = image.size

  18.     filename = group.filename.encode('utf8')
  19.     image_format = b'jpg'
  20.     xmins = []
  21.     xmaxs = []
  22.     ymins = []
  23.     ymaxs = []
  24.     classes_text = []
  25.     classes = []

  26.     for index, row in group.object.iterrows():
  27.         xmins.append(row['xmin'] / width)
  28.         xmaxs.append(row['xmax'] / width)
  29.         ymins.append(row['ymin'] / height)
  30.         ymaxs.append(row['ymax'] / height)
  31.         classes_text.append(row['class'].encode('utf8'))
  32.         classes.append(class_text_to_int(row['class']))

  33.     tf_example = tf.train.Example(features=tf.train.Features(feature={
  34.         'image/height': dataset_util.int64_feature(height),
  35.         'image/width': dataset_util.int64_feature(width),
  36.         'image/filename': dataset_util.bytes_feature(filename),
  37.         'image/source_id': dataset_util.bytes_feature(filename),
  38.         'image/encoded': dataset_util.bytes_feature(encoded_jpg),
  39.         'image/format': dataset_util.bytes_feature(image_format),
  40.         'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
  41.         'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
  42.         'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
  43.         'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
  44.         'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
  45.         'image/object/class/label': dataset_util.int64_list_feature(classes),
  46.     }))
  47.     return tf_example

  48. def main(_):
  49.     writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
  50.     path = os.path.join(os.getcwd(), 'images')
  51.     examples = pd.read_csv(FLAGS.csv_input)
  52.     grouped = split(examples, 'filename')
  53.     for group in grouped:
  54.         tf_example = create_tf_example(group, path)
  55.         writer.write(tf_example.SerializeToString())

  56.     writer.close()
  57.     output_path = os.path.join(os.getcwd(), FLAGS.output_path)
  58.     print('Successfully created the TFRecords: {}'.format(output_path))

  59. def main(_):
  60.     writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
  61.     path = os.path.join(os.getcwd(), 'images')
  62.     examples = pd.read_csv(FLAGS.csv_input)
  63.     grouped = split(examples, 'filename')
  64.     for group in grouped:
  65.         tf_example = create_tf_example(group, path)
  66.         writer.write(tf_example.SerializeToString())

  67.     writer.close()
  68.     output_path = os.path.join(os.getcwd(), FLAGS.output_path)
  69.     print('Successfully created the TFRecords: {}'.format(output_path))

  70. if __name__ == '__main__':
  71.     tf.app.run()
  72.    
复制代码
以上是代码,在jupyter中运行,最后一步
  1. if __name__ == '__main__':
  2. tf.app.run()
复制代码
如下图所视,始终没有沙漏的运行标志,且无结果生成,请问有人知道是怎么回事吗?

我知道答案 回答被采纳将会获得10 金币 + 5 金币 已有1人回答

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有帐号?加入社区

x
本楼点评(0) 收起

精彩评论1

blackhandguy  TF荚荚  发表于 2019-10-17 16:37:02 | 显示全部楼层
该程序是在cmd中运行的
本楼点评(0) 收起
您需要登录后才可以回帖 登录 | 加入社区

本版积分规则

快速回复 返回顶部 返回列表