Caffe的數(shù)據(jù)接口主要有原始圖像(ImageData), HDF5, LMDB/LevelDB,。由于Caffe自帶的圖像轉(zhuǎn)LMDB接口只支持但label,對于多l(xiāng)abel的任務(wù),,往往需要使用HDF5,。
然而,Caffe對于HDF5的數(shù)據(jù),,需要預(yù)先讀取整個h5文件,,這對于小數(shù)據(jù)的情況不成問題,,而且一次性讀到內(nèi)存里還節(jié)省訓(xùn)練中的IO開銷??墒菍τ跀?shù)據(jù)量大的情況,,內(nèi)存可能放不下整個h5文件,則需要劃分成幾個小的h5文件,??蛇@樣的實現(xiàn)一方面不優(yōu)雅,另一方面訓(xùn)練中需要不停地輪流讀取h5文件,。一種可能的解決方案是把圖像數(shù)據(jù)放到lmdb,,label數(shù)據(jù)放到h5文件,prototxt里面label和data分別來自兩個data layer,??墒莻€人覺得這樣的實現(xiàn)也不好看,畢竟代碼里面要做HDF5和LMDB的存儲,。
最近從網(wǎng)上看到一種更直接的方法,,大致是結(jié)合Python的LMDB庫和Caffe的Python 接口 caffe.io.array_to_datum,把圖像數(shù)據(jù)和label,分別存儲到兩個lmdb文件,。而對于存儲好的lmdb,,又怎樣寫prototxt里面的datalayer來讀取呢?目前caffe的datalayer, 指明了LMDB作為backend的話,,默認(rèn)第一個top就是存儲lmdb時datum的data,,第二個top就是datum的label,在下面的代碼里沒有指定datum的label,,因此,,對于data和label的lmdb,分別寫一個datalayer, 每個datalayer的第一個top就是對應(yīng)lmdb里的內(nèi)容了,。而top的blob的名字是可以自己定義的。
代碼如下:
- def write_lmdb(image_name_list,label_array,lmdb_img_name,lmdb_label_name,resize_image = False):
- for lmdb_name in [lmdb_img_name, lmdb_label_name]:
- db_path = os.path.abspath(lmdb_name)
- if os.path.exists(db_path):
- shutil.rmtree(db_path)
- counter_img = 0
- counter_label = 0
- batchsz = 100
- fail_cnt = 0
- print("Processing {:d} images and labels...".format(len(image_name_list)))
- for i in xrange(int(np.ceil(len(image_name_list)/float(batchsz)))):
- image_name_batch = image_name_list[batchsz*i:batchsz*(i+1)]
- label_batch = label_array[batchsz*i:batchsz*(i+1),:]
- print label_batch[np.newaxis,np.newaxis,0].dtype
- raw_input('r')
- imgs, labels = [], []
- for idx,image_name in enumerate(image_name_batch):
- img = skimage.io.imread(image_name)
- if resize_image==True:
- img = skimage.transform.resize(img,(96,96))
- imgs.append(img)
- db_imgs = lmdb.open(lmdb_img_name, map_size=1e12)
- with db_imgs.begin(write=True) as txn_img:
- for img in imgs:
- datum = caffe.io.array_to_datum(np.expand_dims(img, axis=0))
- txn_img.put("{:0>10d}".format(counter_img), datum.SerializeToString())
- counter_img += 1
- print("Processed {:d} images".format(counter_img))
- db_labels = lmdb.open(lmdb_label_name, map_size=1e12)
- with db_labels.begin(write=True) as txn_label:
- for idx in range(label_batch.shape[0]):
- datum = caffe.io.array_to_datum(label_batch[np.newaxis,np.newaxis,idx])
- txn_label.put("{:0>10d}".format(counter_label), datum.SerializeToString())
- counter_label += 1
- print("Processed {:d} labels".format(counter_label))
- print fail_cnt,'images fail reading'
- db_imgs.close()
- db_labels.close()
|