• 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏吧

从tfrecords导入数据时,批处理标签顺序错误

python 来源:Yuxiao Xu 4次浏览

从tfrecords文件导入数据时出现问题。在tfrecords每个样品由feautures矢量与lenght 100和长度13。我使用下面的代码来导入来自tfrecords数据的一个热标签矢量的,指的是正式指南https://www.tensorflow.org/programmers_guide/datasets从tfrecords导入数据时,批处理标签顺序错误

def read_data(examples): 
    features = {"features": tf.FixedLenFeature([seq_len], tf.int64), 
       "label": tf.FixedLenFeature([category], tf.int64)} 
    parsed_features = tf.parse_single_example(examples, features) 
    return parsed_features['features'], parsed_features['label'] 

# get next batch of data and label 
def next_batch(filename, batch_size): 
    data = tf.data.TFRecordDataset(filename) 
    data = data.map(read_data) 
    data = data.batch(batch_size) 
    iterator = data.make_one_shot_iterator() 
    next_data, next_label = iterator.get_next() 
    return next_data, next_label 

with tf.Session() as sess: 
    filetrain = 'train.tfrecords' 
    next_data, next_label = next_batch(filetrain, num_example_train) 
    sess.run(tf.global_variables_initializer()) 

    data = sess.run(next_data) 
    label = sess.run(next_label) 

问题批次后标签的顺序会出错。如果我删除了代码’data = data.batch’,一切正常。

我认为一个可能的原因是功能和标签是独立分批的。所以我试图解析批处理后的例子,但得到一个错误“输入序列化必须是标量”。请帮助我,如果你知道如何处理这个问题,非常感谢!

===========解决方案如下:

我确定这是重复的,但我找不到其他问题,所以我会在这里回答。

您的问题是拨打sess.run()两次的数据和标签。 无论何时您致电sess.run,您的图表评估为(即,新的批次被提取并贯穿图表,直到全部作为第一个参数传递给run的列表中张量的值已知)。

这样做,您的datalabel是指两个不同的批次(因此他们看起来错了)。

你需要让他们在相同的呼叫与:

data, label = sess.run([next_data, next_label]) 

版权声明:本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系管理员进行删除。
喜欢 (0)