共计 2717 个字符,预计需要花费 7 分钟才能阅读完成。
前面的描述中讲完了序列模型和函数式模型的理论,对于keras而言后面所有模型代码的实现都是基于这两种方式来实现,所以这也是有讨论,接下来就是要从每一点去学习,自己计划的方式是从input到output依次来学习,所以最先开始讲的就是输入数据相关,总的顺序是这样,每一步都可以发散很多点来了解。先从tfrecord学习开始。
tfrecord也是官方推荐的一种数据存储方式,也是基于pb协议存储的方式。为什么推荐使用tfrecord?现在大部分场景下数据量都是很大的,所以你打算使用基于内存型的方式那么肯定不行的,一下子将数据塞入内存效率比较低,那你肯定想说我一点一点的写入到内存基于batch方式来训练模型,但是也要兼顾效率的问题。
使用tfrecord的方式可以借助多线程io并行读取数据用来训练,tf.data对此也有很好的处理,使用pipline加速数据的加载处理。
TFRecord
整体上建立TFRecord文件的流程主要如下;
- 在TFRecord数据文件中,任何数据都是以bytes列表或float列表或int64列表的形式存储(注意:是列表形式),因此,将每条数据转化为列表格式。
- 创建的每条数据列表都必须由一个Feature类包装,并且,每个feature都存储在一个key-value键值对中,其中key对应每个feature的名称。这些key将在后面从TFRecord提取数据时使用。
- 当所需的字典创建完之后,会传递给Features类。
- 最后,将features对象作为输入传递给example类,然后这个example类对象会被追加到TFRecord中。
- 对于所有数据,重复上述过程。
TFRecord建立
下面给出一个实例来说明tfrecord
import tensorflow as tf
data_arr = [
{
'clothes_category': 10, # 整型
'clothes_prices':100.6, #浮点型
'clothes_name':'jack jones'.encode(), # 字符串型,python3下转化为byte
'clothes_topic':[110,120,78] # 列表型
},
{clothes_category': 11, # 整型
'clothes_prices':101.6, #浮点型
'clothes_name':'cat'.encode(), # 字符串型,python3下转化为byte
'clothes_topic':[89,130,87,522] # 列表型
}
]
上面两条数据列了四个字段,第一个是衣服的分类数据,就是常见的category特征,第二个是衣服的价格,是个浮点型数据,
第三个是字符串类型,一般情况下可以经过哈希处理,第四个是衣服的topic 主题向量,这是自己瞎编的。注意一点第四个特征是变长的。
完整的代码程序如下所示
# -*- coding: utf-8 -*-
# @Time : 2019-08-20 23:13
# @Author : zhusimaji
# @File : gen_tfrecord.py
# @Software: PyCharm
import tensorflow as tf
data_arr = [
{
'clothes_category': 10, # 整型
'clothes_prices':100.6, #浮点型
'clothes_name':'jack jones'.encode(), # 字符串型,python3下转化为byte
'clothes_topic':[110,120,78] # 列表型
},
{'clothes_category': 11, # 整型
'clothes_prices':101.6, #浮点型
'clothes_name':'cat'.encode(), # 字符串型,python3下转化为byte
'clothes_topic':[89,130,87,522] # 列表型
}
]
def get_example_object(data_record):
# 将数据转化为int64 float 或bytes类型的列表
# 注意都是list形式
int_list1 = tf.train.Int64List(value = [data_record['clothes_category']])
float_list1 = tf.train.FloatList(value = [data_record['clothes_prices']])
str_list1 = tf.train.BytesList(value = [data_record['clothes_name']])
float_list2 = tf.train.FloatList(value = data_record['clothes_topic'])
feature_key_value_pair = {
'clothes_category':tf.train.Feature(int64_list = int_list1),
'clothes_prices': tf.train.Feature(float_list=float_list1),
'clothes_name': tf.train.Feature(bytes_list=str_list1),
'clothes_topic': tf.train.Feature(float_list=float_list2),
}
# 创建一个features
features = tf.train.Features(feature=feature_key_value_pair)
# 创建一个example
example = tf.train.Example(features=features)
return example
with tf.python_io.TFRecordWriter('./resources/clothes.tfrecord') as tfwriter:
#遍历所有数据
for data_record in data_arr:
example = get_example_object(data_record)
# 写入tfrecord数据文件
tfwriter.write(example.SerializeToString())
此时你发现tfrecord数据已经生成了。上面的模式其实跟python正常读写文件的方式是一样的,只是写数据的内容不一样。
上面是借助python io方法来实现tfrecord建立,你也可以使用tf.data来创建tfrecord,tf.data还是很强大的,后面关于的数据流这块应该都会使用到它。
下一节再去学习读 tfrecord,可以的话在把tf.data生成tfrecord方法也写出来。