零、學(xué)習(xí)目標(biāo)
- tensorflow 數(shù)據(jù)讀取原理
- 深度學(xué)習(xí)數(shù)據(jù)增強(qiáng)原理
一,、CIFAR-10數(shù)據(jù)集簡(jiǎn)介
是用于普通物體識(shí)別的小型數(shù)據(jù)集,,一共包含 10個(gè)類別 的 RGB彩×××片(包含:(飛機(jī)、汽車,、鳥類,、貓、鹿,、狗,、蛙、馬,、船,、卡車)。圖片大小均為 3232像素**,,數(shù)據(jù)集中一共有 50000 張訓(xùn)練圖片和 1000*** 張測(cè)試圖片,。部分代碼來自于tensorflow官方,以下表格列出了所需的官方代碼,。
文件 |
用途 |
cifar10.py |
建立CIFAR-1O預(yù)測(cè)模型 |
cifar10_input.py |
在tensorflow中讀入CIFAR-10訓(xùn)練圖片 |
cifar10_input_test.py |
cifar10_input 的測(cè)試用例文件 |
cifar10_train.py |
使用單個(gè)GPU或CPU訓(xùn)練模型 |
cifar10_train_multi_gpu.py |
使用多個(gè)gpu訓(xùn)練模型 |
cifar10_eval.py |
在測(cè)試集上測(cè)試模型的性能 |
二,、下載CIFAR-10數(shù)據(jù)
在工程根目錄創(chuàng)建 cifar10_download.py ,輸入如下代碼創(chuàng)建下載數(shù)據(jù)的程序:
# 引入當(dāng)前目錄中已經(jīng)編寫好的cifar10模塊
import cifar10
# 引入tensorflow
import tensorflow as tf
# 定義全局變量存儲(chǔ)器,可用于命令行參數(shù)的處理
# tf.app.flags.FLAGS 是tensorflow 內(nèi)部的一個(gè)全局變量存儲(chǔ)器
FLAGS = tf.app.flags.FLAGS
# 在cifar10 模塊中預(yù)先定義了cifar-10的數(shù)據(jù)存儲(chǔ)路徑,,修改數(shù)據(jù)存儲(chǔ)路徑
FLAGS.data_dir = 'cifar10_data/'
# 如果數(shù)據(jù)不存在,,則下載
cifar10.maybe_download_and_extract()
執(zhí)行完這段代碼后,CIFAR-10數(shù)據(jù)集會(huì)下載到目錄 cifar10_data 目錄下,。默認(rèn)的存儲(chǔ)路徑書 tmp/cifar10_data,,定義在代碼文件cifar10.py中,,位置大約在53行附近。 修改完數(shù)據(jù)存儲(chǔ)路徑后,,通過 cifar10.maybe_download_and_extract() 來下載數(shù)據(jù),,下載期間如果數(shù)據(jù)存在于數(shù)據(jù)文件夾中則跳過下載數(shù)據(jù),反之下載數(shù)據(jù),。下載成功后會(huì)提示 Successfully downloaded cifar-10-binary.tar.gz 170052171 bytes. 下載完成后,,cifar10_data/cifar-10-batches-bin 中將出現(xiàn)8個(gè)文件,名稱和用途如下表:
文件名 |
用途 |
batches.meta.txt |
存儲(chǔ)每個(gè)類別的英文名 |
data_batch_1.bin、......,、data_batch_5.bin |
CIFAR-10的五個(gè)訓(xùn)練集,,每個(gè)訓(xùn)練集用二進(jìn)制格式存儲(chǔ)了10000張32*32的彩×××像和圖相對(duì)應(yīng)的標(biāo)簽,沒個(gè)樣本由3073個(gè)字節(jié)組成,,第一個(gè)字節(jié)未標(biāo)簽,,剩下的字節(jié)未圖像數(shù)據(jù) |
test_batch.bin |
存儲(chǔ)1000張用于測(cè)試的圖像和對(duì)應(yīng)的標(biāo)簽 |
readme.html |
數(shù)據(jù)集介紹文件 |
三、TensorFlow 讀取數(shù)據(jù)的機(jī)制
- 普通方式
將硬盤上的數(shù)據(jù)讀入內(nèi)存中,,然后提供給CPU或者GPU處理
- 內(nèi)存隊(duì)列方式
普通方式讀取數(shù)據(jù)會(huì)出現(xiàn)GPU或CPU在一段時(shí)間內(nèi)存在空閑,,導(dǎo)致運(yùn)算效率降低。利用內(nèi)存隊(duì)列,,將數(shù)據(jù)讀取和計(jì)算放在兩個(gè)線程中,,讀取線程只需向內(nèi)存隊(duì)列中讀入文件,而計(jì)算線程只用從內(nèi)存隊(duì)列中讀取計(jì)算需要的數(shù)據(jù),,這樣就解決了GPU或者CPU的空閑問題,。
- 文件名隊(duì)列+內(nèi)存隊(duì)列
TensorFlow采用 文件名隊(duì)列+內(nèi)存隊(duì)列,這種方式可以很好的管理epoch(注1)和避免計(jì)算單元的空閑問題,。舉個(gè)例子,,假設(shè)有三個(gè)數(shù)據(jù)文件要執(zhí)行一次epoch,那么就在文件名隊(duì)列中放入這三個(gè)數(shù)據(jù)文件各一次,,并且在最后放入的數(shù)據(jù)文件后面標(biāo)注隊(duì)列結(jié)束,。內(nèi)存隊(duì)列依次從文件名隊(duì)列的頂部讀取數(shù)據(jù)文件,讀到結(jié)束標(biāo)記后就會(huì)自動(dòng)拋出異常,,捕獲這個(gè)異常后程序就可以結(jié)束,。如果是執(zhí)行N次epoch,那么就把每個(gè)數(shù)據(jù)文件放入文件名隊(duì)列N次,。
注1: 對(duì)于數(shù)據(jù)集來說,,運(yùn)行一次epoch就是將數(shù)據(jù)集里的所有數(shù)據(jù)完整的計(jì)算一遍,以此類推運(yùn)行N次epoch就是將數(shù)據(jù)集里的所有數(shù)據(jù)完整的計(jì)算N遍
四,、創(chuàng)建文件名隊(duì)列和內(nèi)存隊(duì)列
- 創(chuàng)建文件名隊(duì)列
利用tensorflow的 tf.train.string_input_producer() (注2) 函數(shù),。給函數(shù)傳入一個(gè)文件名列表,系統(tǒng)將會(huì)轉(zhuǎn)換未文件名隊(duì)列,。tf.train.string_input_producer() 函數(shù)有兩個(gè)重要的參數(shù),,分別是 num_epochs 和 shuffle ,,num_epochs表示epochs數(shù),shuffle表示是否打亂文件名隊(duì)列內(nèi)文件的順序,,如果是True表示不按照文件名列表添加的順序進(jìn)入文件名隊(duì)列,,如果是Flase表示按照文件名列表添加的順序進(jìn)入文件名隊(duì)列,。
-
創(chuàng)建內(nèi)存隊(duì)列 在tensorflow中不手動(dòng)創(chuàng)建內(nèi)存隊(duì)列,,只需使用 reader 對(duì)象從文件名隊(duì)列中讀取數(shù)據(jù)就可以了。
注2: 使用tf.train.string_input_producer() 創(chuàng)建完文件名隊(duì)列后,,文件名并沒有被加入到隊(duì)列中,,如果此時(shí)開始計(jì)算,會(huì)導(dǎo)致整個(gè)系統(tǒng)處于阻塞狀態(tài),。 在創(chuàng)建完文件名隊(duì)列后,,應(yīng)調(diào)用 tf.train.start_queue_runners 方法才會(huì)啟動(dòng)文件名隊(duì)列的填充,整個(gè)程序才能正常運(yùn)行起來,。
- 代碼
import tensorflow as tf
# 新建session
with tf.Session() as sess:
# 要讀取的三張圖片
filename = ['img/1.jpg', 'img/2.jpg', 'img/3.jpg']
# 創(chuàng)建文件名隊(duì)列
filename_queue = tf.train.string_input_producer(filename, num_epochs=5, shuffle=False)
reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)
# 初始化變量(epoch)
tf.local_variables_initializer().run()
threads = tf.train.start_queue_runners(sess=sess)
i = 0
while True:
i += 1
# 獲取圖片保存數(shù)據(jù)
image_data = sess.run(value)
with open('read/test_%d.jpg' % i, 'wb') as f:
f.write(image_data)
五,、數(shù)據(jù)增強(qiáng)
對(duì)于圖像數(shù)據(jù)來說,數(shù)據(jù)增強(qiáng)方法就是利用平移,、縮放,、顏色等變換增大訓(xùn)練集樣本個(gè)數(shù),從而達(dá)到更好的效果(注3),,使用數(shù)據(jù)增強(qiáng)可以大大提高模型的泛化能力,,并且能夠預(yù)防過擬合。 常用的圖像數(shù)據(jù)增強(qiáng)方法如下表
方法 |
說明 |
平移 |
將圖像在一定尺度范圍內(nèi)平移 |
旋轉(zhuǎn) |
將圖像在一定角度范圍內(nèi)旋轉(zhuǎn) |
翻轉(zhuǎn) |
水平翻轉(zhuǎn)或者上下翻轉(zhuǎn)圖片 |
裁剪 |
在原圖上裁剪出一塊 |
縮放 |
將圖像在一定尺度內(nèi)放大或縮小 |
顏色變換 |
對(duì)圖像的RGB顏色空間進(jìn)行一些變換 |
噪聲擾動(dòng) |
給圖像加入一些人工生成的噪聲 |
注3: 使用數(shù)據(jù)增強(qiáng)的方法前提是,,這些數(shù)據(jù)增強(qiáng)方法不會(huì)改變圖像的原有標(biāo)簽,。比如數(shù)字6的圖片,經(jīng)過上下翻轉(zhuǎn)之后就變成了數(shù)字9的圖片,。
六,、CIFAR-10識(shí)別模型
建立模型的代碼在cifar10.py文件額inference函數(shù)中,代碼在這里不進(jìn)行詳解,,讀者可以去閱讀代碼中的注釋,。 這里我們通過以下命令訓(xùn)練模型:
python cifar10_train.py --train_dir cifar10_train/ --data_dir cifar10_data/
這段命令中 --data_dir cifar10_data/ 表示數(shù)據(jù)保存的位置, --train_dir cifar10_train/ 表示保存模型參數(shù)和訓(xùn)練時(shí)日志信息的位置
七,、查看訓(xùn)練進(jìn)度
在訓(xùn)練的時(shí)候我們往往需要知道損失的變化和每層的訓(xùn)練情況,,這個(gè)時(shí)候我們就會(huì)用到tensorflow提供的 TensorBoard。打開一個(gè)新的命令行,,輸入如下命令:
tensorboard --logdir cifar10_train/
其中 --logdir cifar10_train/ 表示模型訓(xùn)練日志保存的位置,,運(yùn)行該命令后將會(huì)在命令行看到類似如下的內(nèi)容
在瀏覽器上輸入顯示的地址,即可訪問TensorBoard,。簡(jiǎn)單解釋一下常用的幾個(gè)標(biāo)簽:
標(biāo)簽 |
說明 |
total_loss_1 |
loss 的變化曲線,,變化曲線會(huì)根據(jù)時(shí)間實(shí)時(shí)變化 |
learning_rate |
學(xué)習(xí)率變化曲線 |
global_step |
美妙訓(xùn)練步數(shù)的情況,,如果訓(xùn)練速度變化較大,或者越來越慢,,就說明程序有可能存在錯(cuò)誤 |
八,、檢測(cè)模型的準(zhǔn)確性
在命令行窗口輸入如下命令:
python cifar10_eval.py --data_dir cifar10_data/ --eval_dir cifar10_eval/ --checkpoint_dir cifar10_train/
--data_dir cifar10_data/ 表 示 CIFAR-10 數(shù)據(jù)集的存儲(chǔ)位置 。 --heckpoint_dir cifar1O_train/ 則表示程序模型保存在 cifar10_train/文件夾下,。 --eval_dir cifar10_eval/ 指定了一個(gè)保存測(cè)試信息的文件夾 輸入以下命令,,在TensorBoard上查看準(zhǔn)確率歲訓(xùn)練步數(shù)的變化情況:
tensorboard --logdir cifar10_eval/ --port 6007
在瀏覽器中輸入:http://127.0.0.1:6007,展開 Precision @ 1 選項(xiàng)卡,,就可以看到準(zhǔn)確率隨訓(xùn)練步數(shù)變化的情況,。
九、代碼下載
Git地址:https:///bugback/ai_learning.git 百度網(wǎng)盤:https://pan.baidu.com/s/17HdfI2R9gsOMKi4pgundSA
|