最近用tensorflow寫了個OCR的程序,在實現(xiàn)的過程中,發(fā)現(xiàn)自己還是跳了不少坑,,在這里做一個記錄,,便于以后回憶。主要的內(nèi)容有l(wèi)stm+ctc具體的輸入輸出,,以及TF中的CTC和百度開源的warpCTC在具體使用中的區(qū)別。
正文輸入輸出因為我最后要最小化的目標(biāo)函數(shù)就是ctc_loss ,所以下面就從如何構(gòu)造輸入輸出說起,。
tf.nn.ctc_loss先從TF自帶的tf.nn.ctc_loss說起,官方給的定義如下,,因此我們需要做的就是將圖片的label(需要OCR出的結(jié)果),,圖片,,以及圖片的長度 轉(zhuǎn)換為label,input,,和sequence_length,。
1 2 3 4 5 6 7 8
| ctc_loss( labels, inputs, sequence_length, preprocess_collapse_repeated=False, ctc_merge_repeated=True, time_major=True )
|
input: 輸入(訓(xùn)練)數(shù)據(jù),是一個三維float型的數(shù)據(jù)結(jié)構(gòu)[max_time_step , batch_size , num_classes] ,,當(dāng)修改time_major = False時,,[batch_size,max_time_step,num_classes] 。 總體的數(shù)據(jù)流: image_batch ->[batch_size,max_time_step,num_features] ->lstm ->[batch_size,max_time_step,cell.output_size] ->reshape ->[batch_size*max_time_step,num_hidden] ->affine projection A*W+b ->[batch_size*max_time_step,num_classes] ->reshape ->[batch_size,max_time_step,num_classes] ->transpose ->[max_time_step,batch_size,num_classes] 下面詳細(xì)解釋一下,, 假如一張圖片有如下shape:[60,160,3],,我們?nèi)绻x取灰度圖則shape=[60,160],此時,,我們將其一列作為feature,,那么共有60個features,160個time_step,,這時假設(shè)一個batch為64,,那么我們此時獲得到了一個[batch_size,max_time_step,num_features] = [64,160,60] 的訓(xùn)練數(shù)據(jù)。
然后將該訓(xùn)練數(shù)據(jù)送入構(gòu)建的lstm網(wǎng)絡(luò)中,,(需要注意的是dynamic_rnn的輸入數(shù)據(jù)在一個batch內(nèi)的長度是固定的,,但是不同batch之間可以不同,我們需要給他一個sequence_length (長度為batch_size的向量)來記錄本次batch數(shù)據(jù)的長度,對于OCR這個問題,,sequence_length就是長度為64,,而值為160的一維向量) 得到形如[batch_size,max_time_step,cell.output_size] 的輸出,其中cell.output_size == num_hidden,。
下面我們需要做一個線性變換將其送入ctc_loos中進(jìn)行計算,,lstm中不同time_step之間共享權(quán)值,所以我們只需定義W 的結(jié)構(gòu)為[num_hidden,num_classes] ,,b 的結(jié)構(gòu)為[num_classes],。而tf.matmul操作中,兩個矩陣相乘階數(shù)應(yīng)當(dāng)匹配,,所以我們將上一步的輸出reshape成[batch_size*max_time_step,num_hidden] (num_hidden為自己定義的lstm的unit個數(shù))記為A ,,然后將其做一個線性變換,于是A*w+b 得到形如[batch_size*max_time_step,num_classes] 然后在reshape回來得到[batch_size,max_time_step,num_classes] 最后由于ctc_loss的要求,,我們再做一次轉(zhuǎn)置,,得到[max_time_step,batch_size,num_classes] 形狀的數(shù)據(jù)作為input
labels: 標(biāo)簽序列 由于OCR的結(jié)果是不定長的,所以label實際上是一個稀疏矩陣SparseTensor,, 其中:
indices :二維int64的矩陣,,代表非0的坐標(biāo)點
values :二維tensor,代表indice位置的數(shù)據(jù)值
dense_shape :一維,,代表稀疏矩陣的大小 比如有兩幅圖,,分別是123 ,和4567 那么 indecs = [[0,0],[0,1],[0,2],[1,0],[1,1],[1,2],[1,3]] values = [1,2,3,4,5,6,7] dense_shape = [2,4] 代表dense tensor:
seq_len: 在input一節(jié)中已經(jīng)講過,,一維數(shù)據(jù),[time_step,…,time_step]長度為batch_size,值為time_step
warpCTC對于warpCTC需要注意幾點
- 輸入格式:有四個輸入,,與標(biāo)準(zhǔn)CTC的三個輸入不同
- class_label問題,,標(biāo)準(zhǔn)CTC的情況下,如果自己使用的數(shù)據(jù)有
N 中類別,,分別是0~N-1,那么標(biāo)準(zhǔn)CTC會把默認(rèn)的blank 類別作為第N類,而warpCTC中0 被用作了默認(rèn)的blank 標(biāo)簽
proprocess_collapse_repeated 必須設(shè)為False,, ctc_merge_repeated 為True,,這是默認(rèn)值,無需過多注意,,另一方面來說過來說,,就是這個值不能被更改,warpCTC還不支持,,詳見這個鏈接
costs = warpctc_tensorflow.ctc(activations, flat_labels, label_lengths, input_lengths)
to be continue…
|