博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow生成tfrecord格式的数据
阅读量:7246 次
发布时间:2019-06-29

本文共 2120 字,大约阅读时间需要 7 分钟。

tensorflow生成tfrecord格式的数据

tfrecord格式数据能高效的组织数据,提高训练时的IO性能

1,2步骤定义了函数,3步骤生成tfrecord格式的数据 1.TF-Feature 将数据(values)封装于tf.train.Feature

def int64_feature(values):  """Returns a TF-Feature of int64s.  Args:    values: A scalar or list of values.  Returns:    A TF-Feature.  """  if not isinstance(values, (tuple, list)):    values = [values]  return tf.train.Feature(int64_list=tf.train.Int64List(value=values))def bytes_feature(values):  """Returns a TF-Feature of bytes.  Args:    values: A string.  Returns:    A TF-Feature.  """  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))def float_feature(values):  """Returns a TF-Feature of floats.  Args:    values: A scalar of list of values.  Returns:    A TF-Feature.  """  if not isinstance(values, (tuple, list)):    values = [values]  return tf.train.Feature(float_list=tf.train.FloatList(value=values))复制代码

2.tf.train.Example

def create_tf_example(features_dict):    '''    :param features_dict: { img_query_bytes:bytes    :return: tf.train.Example    '''    feature_map={  'img_query':bytes_feature(features_dict['img_query_bytes'])                  }    return tf.train.Example(features=tf.train.Features(feature=feature_map))复制代码

3.写入tfrecord文件 使用tf.python_io.TFRecordWriter(out_path) 写入 tf.train.Example

out_path = './dataset/train.record'    with tf.python_io.TFRecordWriter(out_path) as writer:            features_dict=dict()            with tf.gfile.GFile(img_path,'rb') as fid:                features_dict['img_query_bytes']=fid.read()            example=create_tf_example(features_dict)            writer.write(example.SerializeToString())            if iter_num%1000==0:                print('done : {} % {}'.format(iter_num,iter_steps))复制代码

一个Example结构

dict 表示字典类型

tf.train.Example {    features: tf.train.Features{        feature: dict{            'feature_name':tf.train.Feature{                int64_list:tf.train.Int64List{value:list}                bytes_list:tf.train.BytesList{value:list}                float_list:tf.train.FloatList{value:list}            }        }    }}复制代码

转载于:https://juejin.im/post/5c380e55e51d455272017e41

你可能感兴趣的文章
javascript 词法作用域
查看>>
Wireshark捕获理解TCP三次握手四次断开
查看>>
Go面向对象编程
查看>>
小程序分包加载
查看>>
scrollIntoView与键盘遮挡
查看>>
Nodejs API - events 提纲式笔记
查看>>
以太坊ganache CLI命令行参数详解
查看>>
iOSURL中文解决方法
查看>>
快速掌握 MongoDB 数据库
查看>>
Kotlin和RecyclerView的一个demo
查看>>
【译】Bootstrap的网格体系
查看>>
在AS中自定义字体库报错:java.lang.RuntimeException: native typeface cannot be made
查看>>
在Linux系统里安装Virtual Box的详细步骤
查看>>
Python 日志库 logging 的理解和实践经验
查看>>
Vert.x 文件上传Client
查看>>
iOS之SVN
查看>>
MySQL日志故障的处理和分析
查看>>
tcp和udp使用总结
查看>>
koa,node,express通用方法连接mysql
查看>>
(转)关于敏捷团队领任务的几个误区
查看>>