import tensorflow as tf import cv2
filename = "train.tfrecords" filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) #返回文件名和文件 features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'image' : tf.FixedLenFeature([], tf.string), })
img = tf.decode_raw(features['image'], tf.uint8) img = tf.reshape(img, [300, 300,3]) #notice shape應(yīng)該保持一致
img = tf.cast(img, tf.float32) * (1. / 128) - 0.5 label = tf.cast(features['label'], tf.int32)
/////////////////////////////////////////// 函數(shù)封裝 /////////////////////////////////////////// import tensorflow as tf import cv2 def read_decode(filename): #filename = "train.tfrecords" filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) #返回文件名和文件 features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'image' : tf.FixedLenFeature([], tf.string), })
img = tf.decode_raw(features['image'], tf.uint8) img = tf.reshape(img, [500, 500,3])
img = tf.cast(img, tf.float32) * (1. / 128) - 0.5 label = tf.cast(features['label'], tf.int32) return img,label ////////////////////////////////////////////////// 使用案例 ////////////////////////////////////////////////// import tensorflow as tf import cv2 import aaa
filename = "train.tfrecords" img,label = aaa.read_decode(filename)
img_batch,label_batch = tf.train.shuffle_batch([img,label],batch_size=1, capacity=10, min_after_dequeue=1)
init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) threads = tf.train.start_queue_runners(sess=sess) for _ in range(10): val = sess.run(img_batch) label = sess.run(label_batch) val.resize((500,500,3)) cv2.imshow("cool",val) print(label) ////////////////////////////////////////////////////////////
1
|