Keras学习-0x06-Tfrecord相关

3,051次阅读
没有评论

共计 2717 个字符,预计需要花费 7 分钟才能阅读完成。

前面的描述中讲完了序列模型和函数式模型的理论,对于keras而言后面所有模型代码的实现都是基于这两种方式来实现,所以这也是有讨论,接下来就是要从每一点去学习,自己计划的方式是从input到output依次来学习,所以最先开始讲的就是输入数据相关,总的顺序是这样,每一步都可以发散很多点来了解。先从tfrecord学习开始。

tfrecord也是官方推荐的一种数据存储方式,也是基于pb协议存储的方式。为什么推荐使用tfrecord?现在大部分场景下数据量都是很大的,所以你打算使用基于内存型的方式那么肯定不行的,一下子将数据塞入内存效率比较低,那你肯定想说我一点一点的写入到内存基于batch方式来训练模型,但是也要兼顾效率的问题。

使用tfrecord的方式可以借助多线程io并行读取数据用来训练,tf.data对此也有很好的处理,使用pipline加速数据的加载处理。

TFRecord

整体上建立TFRecord文件的流程主要如下;

  • 在TFRecord数据文件中,任何数据都是以bytes列表或float列表或int64列表的形式存储(注意:是列表形式),因此,将每条数据转化为列表格式。
  • 创建的每条数据列表都必须由一个Feature类包装,并且,每个feature都存储在一个key-value键值对中,其中key对应每个feature的名称。这些key将在后面从TFRecord提取数据时使用。
  • 当所需的字典创建完之后,会传递给Features类。
  • 最后,将features对象作为输入传递给example类,然后这个example类对象会被追加到TFRecord中。
  • 对于所有数据,重复上述过程。

TFRecord建立

下面给出一个实例来说明tfrecord

import  tensorflow as tf
data_arr = [
    {
        'clothes_category': 10, # 整型
        'clothes_prices':100.6, #浮点型
        'clothes_name':'jack jones'.encode(), # 字符串型,python3下转化为byte
        'clothes_topic':[110,120,78]  # 列表型
    },
    {clothes_category': 11, # 整型 
'clothes_prices':101.6, #浮点型 
'clothes_name':'cat'.encode(), # 字符串型,python3下转化为byte 
'clothes_topic':[89,130,87,522] # 列表型
    }
]

上面两条数据列了四个字段,第一个是衣服的分类数据,就是常见的category特征,第二个是衣服的价格,是个浮点型数据,

第三个是字符串类型,一般情况下可以经过哈希处理,第四个是衣服的topic 主题向量,这是自己瞎编的。注意一点第四个特征是变长的。

完整的代码程序如下所示

# -*- coding: utf-8 -*-
# @Time    : 2019-08-20 23:13
# @Author  : zhusimaji
# @File    : gen_tfrecord.py
# @Software: PyCharm


import  tensorflow as tf
data_arr = [
    {
        'clothes_category': 10, # 整型
        'clothes_prices':100.6, #浮点型
        'clothes_name':'jack jones'.encode(), # 字符串型,python3下转化为byte
        'clothes_topic':[110,120,78]  # 列表型
    },
    {'clothes_category': 11, # 整型
'clothes_prices':101.6, #浮点型
'clothes_name':'cat'.encode(), # 字符串型,python3下转化为byte
'clothes_topic':[89,130,87,522] # 列表型
    }
]

def get_example_object(data_record):
    # 将数据转化为int64 float 或bytes类型的列表
    # 注意都是list形式
    int_list1 = tf.train.Int64List(value = [data_record['clothes_category']])
    float_list1 = tf.train.FloatList(value = [data_record['clothes_prices']])
    str_list1 = tf.train.BytesList(value = [data_record['clothes_name']])
    float_list2 = tf.train.FloatList(value = data_record['clothes_topic'])
    feature_key_value_pair = {
        'clothes_category':tf.train.Feature(int64_list = int_list1),
        'clothes_prices': tf.train.Feature(float_list=float_list1),
        'clothes_name': tf.train.Feature(bytes_list=str_list1),
        'clothes_topic': tf.train.Feature(float_list=float_list2),
    }
    # 创建一个features
    features = tf.train.Features(feature=feature_key_value_pair)
    # 创建一个example
    example = tf.train.Example(features=features)
    return example
with tf.python_io.TFRecordWriter('./resources/clothes.tfrecord') as tfwriter:
    #遍历所有数据
    for data_record in data_arr:
        example = get_example_object(data_record)
        # 写入tfrecord数据文件
        tfwriter.write(example.SerializeToString())

此时你发现tfrecord数据已经生成了。上面的模式其实跟python正常读写文件的方式是一样的,只是写数据的内容不一样。

上面是借助python io方法来实现tfrecord建立,你也可以使用tf.data来创建tfrecord,tf.data还是很强大的,后面关于的数据流这块应该都会使用到它。

下一节再去学习读 tfrecord,可以的话在把tf.data生成tfrecord方法也写出来。

正文完
请博主喝杯咖啡吧!
post-qrcode
 
admin
版权声明:本站原创文章,由 admin 2019-08-20发表,共计2717字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)
验证码