久久国产成人av_抖音国产毛片_a片网站免费观看_A片无码播放手机在线观看,色五月在线观看,亚洲精品m在线观看,女人自慰的免费网址,悠悠在线观看精品视频,一级日本片免费的,亚洲精品久,国产精品成人久久久久久久

分享

tensorflow LSTM+CTC/warpCTC使用詳解 | Scott''''s Notes

 Rainbow_Heaven 2018-03-05

最近用tensorflow寫了個OCR的程序,在實現(xiàn)的過程中,發(fā)現(xiàn)自己還是跳了不少坑,,在這里做一個記錄,,便于以后回憶。主要的內(nèi)容有l(wèi)stm+ctc具體的輸入輸出,,以及TF中的CTC和百度開源的warpCTC在具體使用中的區(qū)別。

正文

輸入輸出

因為我最后要最小化的目標(biāo)函數(shù)就是ctc_loss,所以下面就從如何構(gòu)造輸入輸出說起,。

tf.nn.ctc_loss

先從TF自帶的tf.nn.ctc_loss說起,官方給的定義如下,,因此我們需要做的就是將圖片的label(需要OCR出的結(jié)果),,圖片,,以及圖片的長度轉(zhuǎn)換為label,input,,和sequence_length,。

1
2
3
4
5
6
7
8
ctc_loss(
labels,
inputs,
sequence_length,
preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
time_major=True
)

input: 輸入(訓(xùn)練)數(shù)據(jù),是一個三維float型的數(shù)據(jù)結(jié)構(gòu)[max_time_step , batch_size , num_classes],,當(dāng)修改time_major = False時,,[batch_size,max_time_step,num_classes]
總體的數(shù)據(jù)流:
image_batch
->[batch_size,max_time_step,num_features]->lstm
->[batch_size,max_time_step,cell.output_size]->reshape
->[batch_size*max_time_step,num_hidden]->affine projection A*W+b
->[batch_size*max_time_step,num_classes]->reshape
->[batch_size,max_time_step,num_classes]->transpose
->[max_time_step,batch_size,num_classes]
下面詳細(xì)解釋一下,,
假如一張圖片有如下shape:[60,160,3],,我們?nèi)绻x取灰度圖則shape=[60,160],此時,,我們將其一列作為feature,,那么共有60個features,160個time_step,,這時假設(shè)一個batch為64,,那么我們此時獲得到了一個[batch_size,max_time_step,num_features] = [64,160,60]的訓(xùn)練數(shù)據(jù)。

然后將該訓(xùn)練數(shù)據(jù)送入構(gòu)建的lstm網(wǎng)絡(luò)中,,(需要注意的是dynamic_rnn的輸入數(shù)據(jù)在一個batch內(nèi)的長度是固定的,,但是不同batch之間可以不同,我們需要給他一個sequence_length(長度為batch_size的向量)來記錄本次batch數(shù)據(jù)的長度,對于OCR這個問題,,sequence_length就是長度為64,,而值為160的一維向量)
得到形如[batch_size,max_time_step,cell.output_size]的輸出,其中cell.output_size == num_hidden,。

下面我們需要做一個線性變換將其送入ctc_loos中進(jìn)行計算,,lstm中不同time_step之間共享權(quán)值,所以我們只需定義W的結(jié)構(gòu)為[num_hidden,num_classes],,b的結(jié)構(gòu)為[num_classes],。而tf.matmul操作中,兩個矩陣相乘階數(shù)應(yīng)當(dāng)匹配,,所以我們將上一步的輸出reshape成[batch_size*max_time_step,num_hidden](num_hidden為自己定義的lstm的unit個數(shù))記為A,,然后將其做一個線性變換,于是A*w+b得到形如[batch_size*max_time_step,num_classes]然后在reshape回來得到[batch_size,max_time_step,num_classes]最后由于ctc_loss的要求,,我們再做一次轉(zhuǎn)置,,得到[max_time_step,batch_size,num_classes]形狀的數(shù)據(jù)作為input


labels: 標(biāo)簽序列
由于OCR的結(jié)果是不定長的,所以label實際上是一個稀疏矩陣SparseTensor,,
其中:

  • indices:二維int64的矩陣,,代表非0的坐標(biāo)點
  • values:二維tensor,代表indice位置的數(shù)據(jù)值
  • dense_shape:一維,,代表稀疏矩陣的大小
    比如有兩幅圖,,分別是123,和4567那么
    indecs = [[0,0],[0,1],[0,2],[1,0],[1,1],[1,2],[1,3]]
    values = [1,2,3,4,5,6,7]
    dense_shape = [2,4]
    代表dense tensor:
    1
    2
    [[1,2,3,0]
    [4,5,6,7]]

seq_len: 在input一節(jié)中已經(jīng)講過,,一維數(shù)據(jù),[time_step,…,time_step]長度為batch_size,值為time_step

warpCTC

對于warpCTC需要注意幾點

  • 輸入格式:有四個輸入,,與標(biāo)準(zhǔn)CTC的三個輸入不同
  • class_label問題,,標(biāo)準(zhǔn)CTC的情況下,如果自己使用的數(shù)據(jù)有N中類別,,分別是0~N-1,那么標(biāo)準(zhǔn)CTC會把默認(rèn)的blank類別作為第N類,而warpCTC中0被用作了默認(rèn)的blank標(biāo)簽
  • proprocess_collapse_repeated必須設(shè)為False,, ctc_merge_repeated為True,,這是默認(rèn)值,無需過多注意,,另一方面來說過來說,,就是這個值不能被更改,warpCTC還不支持,,詳見這個鏈接

costs = warpctc_tensorflow.ctc(activations, flat_labels, label_lengths, input_lengths)

to be continue…

    本站是提供個人知識管理的網(wǎng)絡(luò)存儲空間,,所有內(nèi)容均由用戶發(fā)布,不代表本站觀點,。請注意甄別內(nèi)容中的聯(lián)系方式,、誘導(dǎo)購買等信息,謹(jǐn)防詐騙,。如發(fā)現(xiàn)有害或侵權(quán)內(nèi)容,,請點擊一鍵舉報。
    轉(zhuǎn)藏 分享 獻(xiàn)花(0

    0條評論

    發(fā)表

    請遵守用戶 評論公約

    類似文章 更多