久久国产成人av_抖音国产毛片_a片网站免费观看_A片无码播放手机在线观看,色五月在线观看,亚洲精品m在线观看,女人自慰的免费网址,悠悠在线观看精品视频,一级日本片免费的,亚洲精品久,国产精品成人久久久久久久

分享

Tensorflow中使用tfrecord方式讀取數(shù)據(jù)

 Rainbow_Heaven 2017-06-11

前言

本博客默認讀者對神經(jīng)網(wǎng)絡與Tensorflow有一定了解,,對其中的一些術語不再做具體解釋,。并且本博客主要以圖片數(shù)據(jù)為例進行介紹,如有錯誤,,敬請斧正,。

使用Tensorflow訓練神經(jīng)網(wǎng)絡時,我們可以用多種方式來讀取自己的數(shù)據(jù),。如果數(shù)據(jù)集比較小,,而且內(nèi)存足夠大,可以選擇直接將所有數(shù)據(jù)讀進內(nèi)存,,然后每次取一個batch的數(shù)據(jù)出來,。如果數(shù)據(jù)較多,可以每次直接從硬盤中進行讀取,,不過這種方式的讀取效率就比較低了,。此篇博客就主要講一下Tensorflow官方推薦的一種較為高效的數(shù)據(jù)讀取方式——tfrecord。

從宏觀來講,,tfrecord其實是一種數(shù)據(jù)存儲形式,。使用tfrecord時,實際上是先讀取原生數(shù)據(jù),,然后轉換成tfrecord格式,,再存儲在硬盤上。而使用時,,再把數(shù)據(jù)從相應的tfrecord文件中解碼讀取出來。那么使用tfrecord和直接從硬盤讀取原生數(shù)據(jù)相比到底有什么優(yōu)勢呢,?其實,Tensorflow有和tfrecord配套的一些函數(shù),可以加快數(shù)據(jù)的處理,。實際讀取tfrecord數(shù)據(jù)時,,先以相應的tfrecord文件為參數(shù),創(chuàng)建一個輸入隊列,,這個隊列有一定的容量(視具體硬件限制,,用戶可以設置不同的值),在一部分數(shù)據(jù)出隊列時,,tfrecord中的其他數(shù)據(jù)就可以通過預取進入隊列,,并且這個過程和網(wǎng)絡的計算是獨立進行的。也就是說,,網(wǎng)絡每一個iteration的訓練不必等待數(shù)據(jù)隊列準備好再開始,,隊列中的數(shù)據(jù)始終是充足的,而往隊列中填充數(shù)據(jù)時,,也可以使用多線程加速,。

下面,本文將從以下4個方面對tfrecord進行介紹:

  1. tfrecord格式簡介
  2. 利用自己的數(shù)據(jù)生成tfrecord文件
  3. 從tfrecord文件讀取數(shù)據(jù)
  4. 實例測試

1. tfrecord格式簡介

這部分主要參考了另一篇博文,,Tensorflow 訓練自己的數(shù)據(jù)集(二)(TFRecord)

tfecord文件中的數(shù)據(jù)是通過tf.train.Example Protocol Buffer的格式存儲的,,下面是tf.train.Example的定義

message Example {
 Features features = 1;
};

message Features{
 map<string,Feature> featrue = 1;
};

message Feature{
    oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

從上述代碼可以看出,tf.train.Example 的數(shù)據(jù)結構很簡單,。tf.train.Example中包含了一個從屬性名稱到取值的字典,,其中屬性名稱為一個字符串,屬性的取值可以為字符串(BytesList ),,浮點數(shù)列表(FloatList )或整數(shù)列表(Int64List ),。例如我們可以將圖片轉換為字符串進行存儲,,圖像對應的類別標號作為整數(shù)存儲,,而用于回歸任務的ground-truth可以作為浮點數(shù)存儲。通過后面的代碼我們會對tfrecord的這種字典形式有更直觀的認識,。

2. 利用自己的數(shù)據(jù)生成tfrecord文件

先上一段代碼,,然后我再針對代碼進行相關介紹。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from scipy import misc
import scipy.io as sio


def _bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value=[value]))


root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecords_filename = root_path + 'tfrecords/train.tfrecords'
writer = tf.python_io.TFRecordWriter(tfrecords_filename)


height = 300
width = 300
meanfile = sio.loadmat(root_path + 'mats/mean300.mat')
meanvalue = meanfile['mean']

txtfile = root_path + 'txt/train.txt'
fr = open(txtfile)

for i in fr.readlines():
    item = i.split()
    img = np.float64(misc.imread(root_path + '/images/train_images/' + item[0]))
    img = img - meanvalue
    maskmat = sio.loadmat(root_path + '/mats/train_mats/' + item[1])
    mask = np.float64(maskmat['seg_mask'])
    label = int(item[2])
    img_raw = img.tostring()
    mask_raw = mask.tostring()
    example = tf.train.Example(features=tf.train.Features(feature={
        'height': _int64_feature(height),
        'width': _int64_feature(width),
        'name': _bytes_feature(item[0]),
        'image_raw': _bytes_feature(img_raw),
        'mask_raw': _bytes_feature(mask_raw),
        'label': _int64_feature(label)}))

    writer.write(example.SerializeToString())

writer.close()
fr.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

代碼中前兩個函數(shù)(_bytes_feature和_int64_feature)是將我們的原生數(shù)據(jù)進行轉換用的,,尤其是圖片要轉換成字符串再進行存儲。這兩個函數(shù)的定義來自官方的示例。 
接下來,,我定義了數(shù)據(jù)的(路徑-label文件)txtfile,,它大概長這個樣子:

txtfile

這里稍微啰嗦下,介紹一下我的實驗內(nèi)容,。我做的是一個multi-task的實驗,,一支task做分割,一支task做分類,。所以txtfile中每一行是一個樣本,,每個樣本又包含3項,第一項為圖片名稱,,第二項為相應的ground-truth segmentation mask的名稱,,第三項是圖片的標簽。(txtfile中內(nèi)容形式無所謂,,只要能讀到想讀的數(shù)據(jù)就可以)

接著回到主題繼續(xù)講代碼,,之后我又定義了即將生成的tfrecord的文件路徑和名稱,即tfrecord_filename,,還有一個writer,,這個writer是進行寫操作用的。

接下來是圖片的高度,、寬度以及我事先在整個數(shù)據(jù)集上計算好的圖像均值文件,。高度、寬度其實完全沒必要引入,,這里只是為了說明tfrecord的生成而寫的,。而均值文件是為了對圖像進行事先的去均值化操作而引入的,在大多數(shù)機器學習任務中,,圖像去均值化對提高算法的性能還是很有幫助的,。

最后就是根據(jù)txtfile中的每一行進行相關數(shù)據(jù)的讀取、轉換以及tfrecord的生成了,。首先是根據(jù)圖片路徑讀取圖片內(nèi)容,,然后圖像減去之前讀入的均值,接著根據(jù)segmentation mask的路徑讀取mask(如果只是圖像分類任務,,那么就不會有這些額外的mask),,txtfile中的label讀出來是string格式,這里要轉換成int,。然后圖像和mask數(shù)據(jù)也要用相應的tosring函數(shù)轉換成string,。

真正的核心是下面這一小段代碼:

example = tf.train.Example(features=tf.train.Features(feature={
        'height': _int64_feature(height),
        'width': _int64_feature(width),
        'name': _bytes_feature(item[0]),
        'image_raw': _bytes_feature(img_raw),
        'mask_raw': _bytes_feature(mask_raw),
        'label': _int64_feature(label)}))

writer.write(example.SerializeToString())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

這里很好地體現(xiàn)了tfrecord的字典特性,tfrecord中每一個樣本都是一個小字典,,這個字典可以包含任意多個鍵值對,。比如我這里就存儲了圖片的高度,、寬度、圖片名稱,、圖片內(nèi)容,、mask內(nèi)容以及圖片的label。對于我的任務來說,,其實height,、width、name都不是必需的,,這里僅僅是為了展示,。鍵值對的鍵全都是字符串,鍵起什么名字都可以,,只要能方便以后使用就可以,。

定義好一個example后就可以用之前的writer來把它真正寫入tfrecord文件了,這其實就跟把一行內(nèi)容寫入一個txt文件一樣,。代碼的最后就是writer和txt文件對象的關閉了,。

最后在指定文件夾下,就得到了指定名字的tfrecord文件,,如下所示:

tfrecord文件

需要注意的是,,生成的tfrecord文件比原生數(shù)據(jù)的大小還要大,這是正?,F(xiàn)象,。這種現(xiàn)象可能是因為圖片一般都存儲為jpg等壓縮格式,而tfrecord文件存儲的是解壓后的數(shù)據(jù),。

3. 從tfrecord文件讀取數(shù)據(jù)

還是代碼先行,。

from scipy import misc
import tensorflow as tf
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt

root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecord_filename = root_path + 'tfrecords/test.tfrecords'

def read_and_decode(filename_queue, random_crop=False, random_clip=False, shuffle_batch=True):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
      serialized_example,
      features={
          'height': tf.FixedLenFeature([], tf.int64),
          'width': tf.FixedLenFeature([], tf.int64),
          'name': tf.FixedLenFeature([], tf.string),                           
          'image_raw': tf.FixedLenFeature([], tf.string),
          'mask_raw': tf.FixedLenFeature([], tf.string),                               
          'label': tf.FixedLenFeature([], tf.int64)
      })

    image = tf.decode_raw(features['image_raw'], tf.float64)
    image = tf.reshape(image, [300,300,3])

    mask = tf.decode_raw(features['mask_raw'], tf.float64)
    mask = tf.reshape(mask, [300,300])

    name = features['name']

    label = features['label']
    width = features['width']
    height = features['height']

#    if random_crop:
#        image = tf.random_crop(image, [227, 227, 3])
#    else:
#        image = tf.image.resize_image_with_crop_or_pad(image, 227, 227)

#    if random_clip:
#        image = tf.image.random_flip_left_right(image)


    if shuffle_batch:
        images, masks, names, labels, widths, heights = tf.train.shuffle_batch([image, mask, name, label, width, height],
                                                batch_size=4,
                                                capacity=8000,
                                                num_threads=4,
                                                min_after_dequeue=2000)
    else:
        images, masks, names, labels, widths, heights = tf.train.batch([image, mask, name, label, width, height],
                                        batch_size=4,
                                        capacity=8000,
                                        num_threads=4)
    return images, masks, names, labels, widths, heights
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56

讀取tfrecord文件中的數(shù)據(jù)主要是應用read_and_decode()這個函數(shù),可以看到其中有個參數(shù)是filename_queue,,其實我們并不是直接從tfrecord文件進行讀取,,而是要先利用tfrecord文件創(chuàng)建一個輸入隊列,如本文開頭所述那樣,。關于這點,,到后面真正的測試代碼我再介紹。

在read_and_decode()中,,一上來我們先定義一個reader對象,,然后使用reader得到serialized_example,,這是一個序列化的對象,,接著使用tf.parse_single_example()函數(shù)對此對象進行初步解析。從代碼中可以看到,,解析時,,我們要用到之前定義的那些鍵,。對于圖像、mask這種轉換成字符串的數(shù)據(jù),,要進一步使用tf.decode_raw()函數(shù)進行解析,,這里要特別注意函數(shù)里的第二個參數(shù),也就是解析后的類型,。之前圖片在轉成字符串之前是什么類型的數(shù)據(jù),,那么這里的參數(shù)就要填成對應的類型,否則會報錯,。對于name,、label、width,、height這樣的數(shù)據(jù)就不用再解析了,,我們得到的features對象就是個字典,利用鍵就可以拿到對應的值,,如代碼所示,。

我注釋掉的部分是用來做數(shù)據(jù)增強的,比如隨機的裁剪與翻轉,,除了這兩種,,其他形式的數(shù)據(jù)增強也可以寫在這里,讀者可以根據(jù)自己的需要,,決定是否使用各種數(shù)據(jù)增強方式,。

函數(shù)最后就是使用解析出來的數(shù)據(jù)生成batch了。Tensorflow提供了兩種方式,,一種是shuffle_batch,,這種主要是用在訓練中,隨機選取樣本組成batch,。另外一種就是按照數(shù)據(jù)在tfrecord中的先后順序生成batch,。對于生成batch的函數(shù),建議讀者去官網(wǎng)查看API文檔進行細致了解,。這里稍微做一下介紹,,batch的大小,即batch_size就需要在生成batch的函數(shù)里指定,。另外,,capacity參數(shù)指定數(shù)據(jù)隊列一次性能放多少個樣本,此參數(shù)設置什么值需要視硬件環(huán)境而定,。num_threads參數(shù)指定可以開啟幾個線程來向數(shù)據(jù)隊列中填充數(shù)據(jù),,如果硬件性能不夠強,最好設小一點,,否則容易崩,。

4. 實例測試

實際使用時先指定好我們需要使用的tfrecord文件:

root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecord_filename = root_path + 'tfrecords/test.tfrecords'
  • 1
  • 2
  • 1
  • 2

然后用該tfrecord文件創(chuàng)建一個輸入隊列:

filename_queue = tf.train.string_input_producer([tfrecord_filename],
                                                    num_epochs=3)
  • 1
  • 2
  • 1
  • 2

這里有個參數(shù)是num_epochs,,指定好之后,Tensorflow自然知道如何讀取數(shù)據(jù),,保證在遍歷數(shù)據(jù)集的一個epoch中樣本不會重復,,也知道數(shù)據(jù)讀取何時應該停止。

下面我將完整的測試代碼貼出:

def test_run(tfrecord_filename):
    filename_queue = tf.train.string_input_producer([tfrecord_filename],
                                                    num_epochs=3)
    images, masks, names, labels, widths, heights = read_and_decode(filename_queue)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    meanfile = sio.loadmat(root_path + 'mats/mean300.mat')
    meanvalue = meanfile['mean']


    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        for i in range(1):
            imgs, msks, nms, labs, wids, heis = sess.run([images, masks, names, labels, widths, heights])
            print 'batch' + str(i) + ': '
            #print type(imgs[0])

            for j in range(4):
                print nms[j] + ': ' + str(labs[j]) + ' ' + str(wids[j]) + ' ' + str(heis[j])
                img = np.uint8(imgs[j] + meanvalue)
                msk = np.uint8(msks[j])
                plt.subplot(4,2,j*2+1)
                plt.imshow(img)
                plt.subplot(4,2,j*2+2)
                plt.imshow(msk, vmin=0, vmax=5)
            plt.show()

        coord.request_stop()
        coord.join(threads)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

函數(shù)中接下來就是利用之前定義的read_and_decode()來得到一個batch的數(shù)據(jù),,此后我又讀入了均值文件,,這是因為之前做了去均值處理,如果要正常顯示圖片需要再把均值加回來,。

再之后就是建立一個Tensorflow session,,然后初始化對象。這些是Tensorflow基本操作,,不再贅述,。下面的這兩句代碼非常重要,是讀取數(shù)據(jù)必不可少的,。

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
  • 1
  • 2
  • 1
  • 2

然后是運行sess.run()拿到實際數(shù)據(jù),,之前只是相當于定義好了,并沒有得到真實數(shù)值,。為了簡單起見,,我在之后的循環(huán)里只測試了一個batch的數(shù)據(jù),關于tfrecord的標準使用我也建議讀者去官網(wǎng)的數(shù)據(jù)讀取部分看看示例,。循環(huán)里對數(shù)據(jù)的各種信息進行了展示,,結果如下:

結果展示

從圖片的名字可以看出,數(shù)據(jù)的確是進行了shuffle的,,標簽,、寬度、高度,、圖片本身以及對應的mask圖像也全部展示出來了,。

測試函數(shù)的最后,要使用以下兩句代碼進行停止,,就如同文件需要close()一樣:

coord.request_stop()
coord.join(threads)
  • 1
  • 2
  • 1
  • 2

    本站是提供個人知識管理的網(wǎng)絡存儲空間,,所有內(nèi)容均由用戶發(fā)布,不代表本站觀點,。請注意甄別內(nèi)容中的聯(lián)系方式,、誘導購買等信息,謹防詐騙,。如發(fā)現(xiàn)有害或侵權內(nèi)容,,請點擊一鍵舉報。
    轉藏 分享 獻花(0

    0條評論

    發(fā)表

    請遵守用戶 評論公約

    類似文章 更多