从tensorflow对象检测API的PTS标签文件中导出TFRecords的步骤如下:
tf.python_io.TFRecordWriter
类创建一个TFRecords文件,用于存储导出的数据。parse_pascal_voc_xml
函数,解析PTS标签文件,获取每个图像的标签信息。tf.train.Example.FromString
方法将其转换为字符串。tf.python_io.TFRecordWriter
的write
方法。以下是一个示例代码,展示了如何从PTS标签文件中导出TFRecords:
import tensorflow as tf
import numpy as np
from PIL import Image
def create_tf_example(image_path, labels):
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_image = fid.read()
image = Image.open(image_path)
width, height = image.size
xmins = []
xmaxs = []
ymins = []
ymaxs = []
classes_text = []
classes = []
for label in labels:
xmins.append(label['xmin'] / width)
xmaxs.append(label['xmax'] / width)
ymins.append(label['ymin'] / height)
ymaxs.append(label['ymax'] / height)
classes_text.append(label['class'].encode('utf8'))
classes.append(label['class_id'])
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[encoded_image])),
'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=['jpeg'.encode('utf8')])),
'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=xmins)),
'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=xmaxs)),
'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=ymins)),
'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=ymaxs)),
'image/object/class/text': tf.train.Feature(bytes_list=tf.train.BytesList(value=classes_text)),
'image/object/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=classes)),
}))
return tf_example
def main():
output_path = 'output.tfrecord'
image_dir = 'images/'
label_file = 'labels.xml'
writer = tf.python_io.TFRecordWriter(output_path)
# 解析PTS标签文件,获取标签信息
labels = parse_pascal_voc_xml(label_file)
for label in labels:
image_path = image_dir + label['filename']
tf_example = create_tf_example(image_path, label['objects'])
writer.write(tf_example.SerializeToString())
writer.close()
print('TFRecords导出完成!')
if __name__ == '__main__':
main()
请注意,以上代码仅为示例,你需要根据自己的具体情况进行适当的修改和调整。
领取专属 10元无门槛券
手把手带您无忧上云