Skip to content

Latest commit

 

History

History
97 lines (68 loc) · 3.39 KB

File metadata and controls

97 lines (68 loc) · 3.39 KB

五、如何将原始图片数据转换为 TFRecords

大家好! 与前一个教程一样,本教程的重点是自动化数据输入流水线。

大多数情况下,我们的数据集太大而无法读取到内存,因此我们必须准备一个流水线,用于从硬盘批量读取数据。 我总是将我的原始数据(文本,图像,表格)处理为 TFRecords,因为它让我的生活变得更加容易。

教程的流程图

本教程将包含以下部分:

  • 创建一个函数,读取原始图像并将其转换为 TFRecords 的。
  • 创建一个函数,将 TFRecords 解析为 TF 张量。

所以废话不多说,让我们开始吧。

导入有用的库

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import glob

# 开启 Eager 模式。一旦开启不能撤销!只执行一次。
tfe.enable_eager_execution()

将原始数据转换为 TFRecords

对于此任务,我们将使用 FER2013 数据集中的一些图像,你可以在datasets/dummy_images文件夹中找到这些图像。 情感标签可以在图像的文件名中找到。 例如,图片id7_3.jpg情感标签为 3,其对应于状态'Happy'(快乐),如下面的字典中所示。

# 获取每个情感的下标的含义
emotion_cat = {0:'Angry', 1:'Disgust', 2:'Fear', 3:'Happy', 4:'Sad', 5:'Surprise', 6:'Neutral'}

def img2tfrecords(path_data='datasets/dummy_images/', image_format='jpeg'):
    ''' 用于将原始图像以及它们标签转换为 TFRecords 的函数
        辅助函数的原始的源代码:https://goo.gl/jEhp2B
        
        Args:
            path_data: the location of the raw images
            image_format: the format of the raw images (e.g. 'png', 'jpeg')
    '''
    
    def _int64_feature(value):
        '''辅助函数'''
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    def _bytes_feature(value):
        '''辅助函数'''
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    # 获取目录中每个图像的文件名
    filenames = glob.glob(path_data + '*' + image_format)
    
    # 创建 TFRecordWriter
    writer = tf.python_io.TFRecordWriter(path_data + 'dummy.tfrecords')
    
    # 遍历每个图像,并将其写到 TFrecords 文件中
    for filename in filenames:
        # 读取原始图像
        img = tf.read_file(filename).numpy()
        # 从文件名中解析它的标签
        label = int(filename.split('_')[-1].split('.')[0])
        # 创建样本(图像,标签)
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': _int64_feature(label),
            'image': _bytes_feature(img)}))
        # 向 TFRecords 写出序列化样本
        writer.write(example.SerializeToString())

# 将原始数据转换为 TFRecords
img2tfrecords()

将 TFRecords 解析为 TF 张量

def parser(record):
    '''解析 TFRecords 样本的函数'''
    
    # 定义你想要解析的特征
    features = {'image': tf.FixedLenFeature((), tf.string),
                'label': tf.FixedLenFeature((), tf.int64)}
    
    # 解析样本
    parsed = tf.parse_single_example(record, features)

    # 解码图像
    img = tf.image.decode_image(parsed['image'])
   
    return img, parsed['label']

如果你希望我在本教程中添加任何内容,请告诉我,我将很乐意进一步改善它。