久久国产成人av_抖音国产毛片_a片网站免费观看_A片无码播放手机在线观看,色五月在线观看,亚洲精品m在线观看,女人自慰的免费网址,悠悠在线观看精品视频,一级日本片免费的,亚洲精品久,国产精品成人久久久久久久

分享

實(shí)戰(zhàn) 遷移學(xué)習(xí) VGG19,、ResNet50、InceptionV3 實(shí)踐 貓狗大戰(zhàn) 問題

 昵稱56314485 2018-06-03

一,、實(shí)踐流程

1,、數(shù)據(jù)預(yù)處理

主要是對(duì)訓(xùn)練數(shù)據(jù)進(jìn)行隨機(jī)偏移、轉(zhuǎn)動(dòng)等變換圖像處理,,這樣可以盡可能讓訓(xùn)練數(shù)據(jù)多樣化

另外處理數(shù)據(jù)方式采用分批無序讀取的形式,,避免了數(shù)據(jù)按目錄排序訓(xùn)練

  1. #數(shù)據(jù)準(zhǔn)備  
  2. def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):  
  3.     if is_train:  
  4.         datagen = ImageDataGenerator(rescale=1./255,  
  5.             zoom_range=0.25, rotation_range=15.,  
  6.             channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,  
  7.             horizontal_flip=True, fill_mode='constant')  
  8.     else:  
  9.         datagen = ImageDataGenerator(rescale=1./255)  
  10.   
  11.     generator = datagen.flow_from_directory(  
  12.         dir_path, target_size=(img_row, img_col),  
  13.         batch_size=batch_size,  
  14.         shuffle=is_train)  
  15.   
  16.     return generator  
2、載入現(xiàn)有模型

這個(gè)部分是核心工作,,目的是使用ImageNet訓(xùn)練出的權(quán)重來做我們的特征提取器,,注意這里后面的分類層去掉

  1. base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,  
  2.                            input_shape=(img_rows, img_cols, color),  
  3.                            classes=nb_classes)  

然后是凍結(jié)這些層,因?yàn)槭怯?xùn)練好的

  1. for layer in base_model.layers:  
  2.     layer.trainable = False  
而分類部分,,需要我們根據(jù)現(xiàn)有需求來新定義的,,這里可以根據(jù)實(shí)際情況自己進(jìn)行調(diào)整,比如這樣
  1. x = base_model.output  
  2. # 添加自己的全鏈接分類層  
  3. x = GlobalAveragePooling2D()(x)  
  4. x = Dense(1024, activation='relu')(x)  
  5. predictions = Dense(nb_classes, activation='softmax')(x)  
或者

  1. x = base_model.output  
  2.  #添加自己的全鏈接分類層  
  3.  x = Flatten()(x)  
  4.  predictions = Dense(nb_classes, activation='softmax')(x)  
3,、訓(xùn)練模型

這里我們用fit_generator函數(shù),,它可以避免了一次性加載大量的數(shù)據(jù),并且生成器與模型將并行執(zhí)行以提高效率,。比如可以在CPU上進(jìn)行實(shí)時(shí)的數(shù)據(jù)提升,,同時(shí)在GPU上進(jìn)行模型訓(xùn)練

  1. history_ft = model.fit_generator(  
  2. train_generator,  
  3. steps_per_epoch=steps_per_epoch,  
  4. epochs=epochs,  
  5. validation_data=validation_generator,  
  6. validation_steps=validation_steps)  

二、貓狗大戰(zhàn)數(shù)據(jù)集


訓(xùn)練數(shù)據(jù)540M,,測(cè)試數(shù)據(jù)270M,,大家可以去官網(wǎng)下載

https://www./c/dogs-vs-cats-redux-kernels-edition/data

下載后把數(shù)據(jù)分成dog和cat兩個(gè)目錄來存放

三、訓(xùn)練

訓(xùn)練的時(shí)候會(huì)自動(dòng)去下權(quán)值,,比如vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5,,但是如果我們已經(jīng)下載好了的話,可以改源代碼,,讓他直接讀取我們的下載好的權(quán)值,比如在resnet50.py中


1,、VGG19

vgg19的深度有26層,,參數(shù)達(dá)到了549M,原模型最后有3個(gè)全連接層做分類器所以我還是加了一個(gè)1024的全連接層,,訓(xùn)練10輪的情況達(dá)到了89%


2,、ResNet50

ResNet50的深度達(dá)到了168層,但是參數(shù)只有99M,,分類模型我就簡單點(diǎn),,一層直接分類,,訓(xùn)練10輪的達(dá)到了96%的準(zhǔn)確率


3、inception_v3

InceptionV3的深度159層,,參數(shù)92M,,訓(xùn)練10輪的結(jié)果

這是一層直接分類的結(jié)果


這是加了一個(gè)512全連接的,大家可以隨意調(diào)整測(cè)試



四,、完整的代碼

  1. # -*- coding: utf-8 -*-  
  2. import os  
  3. from keras.utils import plot_model  
  4. from keras.applications.resnet50 import ResNet50  
  5. from keras.applications.vgg19 import VGG19  
  6. from keras.applications.inception_v3 import InceptionV3  
  7. from keras.layers import Dense,Flatten,GlobalAveragePooling2D  
  8. from keras.models import Model,load_model  
  9. from keras.optimizers import SGD  
  10. from keras.preprocessing.image import ImageDataGenerator  
  11. import matplotlib.pyplot as plt  
  12.   
  13. class PowerTransferMode:  
  14.     #數(shù)據(jù)準(zhǔn)備  
  15.     def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):  
  16.         if is_train:  
  17.             datagen = ImageDataGenerator(rescale=1./255,  
  18.                 zoom_range=0.25, rotation_range=15.,  
  19.                 channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,  
  20.                 horizontal_flip=True, fill_mode='constant')  
  21.         else:  
  22.             datagen = ImageDataGenerator(rescale=1./255)  
  23.   
  24.         generator = datagen.flow_from_directory(  
  25.             dir_path, target_size=(img_row, img_col),  
  26.             batch_size=batch_size,  
  27.             #class_mode='binary',  
  28.             shuffle=is_train)  
  29.   
  30.         return generator  
  31.   
  32.     #ResNet模型  
  33.     def ResNet50_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):  
  34.         color = 3 if RGB else 1  
  35.         base_model = ResNet50(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),  
  36.                               classes=nb_classes)  
  37.   
  38.         #凍結(jié)base_model所有層,,這樣就可以正確獲得bottleneck特征  
  39.         for layer in base_model.layers:  
  40.             layer.trainable = False  
  41.   
  42.         x = base_model.output  
  43.         #添加自己的全鏈接分類層  
  44.         x = Flatten()(x)  
  45.         #x = GlobalAveragePooling2D()(x)  
  46.         #x = Dense(1024, activation='relu')(x)  
  47.         predictions = Dense(nb_classes, activation='softmax')(x)  
  48.   
  49.         #訓(xùn)練模型  
  50.         model = Model(inputs=base_model.input, outputs=predictions)  
  51.         sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)  
  52.         model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])  
  53.   
  54.         #繪制模型  
  55.         if is_plot_model:  
  56.             plot_model(model, to_file='resnet50_model.png',show_shapes=True)  
  57.   
  58.         return model  
  59.   
  60.   
  61.     #VGG模型  
  62.     def VGG19_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):  
  63.         color = 3 if RGB else 1  
  64.         base_model = VGG19(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),  
  65.                               classes=nb_classes)  
  66.   
  67.         #凍結(jié)base_model所有層,這樣就可以正確獲得bottleneck特征  
  68.         for layer in base_model.layers:  
  69.             layer.trainable = False  
  70.   
  71.         x = base_model.output  
  72.         #添加自己的全鏈接分類層  
  73.         x = GlobalAveragePooling2D()(x)  
  74.         x = Dense(1024, activation='relu')(x)  
  75.         predictions = Dense(nb_classes, activation='softmax')(x)  
  76.   
  77.         #訓(xùn)練模型  
  78.         model = Model(inputs=base_model.input, outputs=predictions)  
  79.         sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)  
  80.         model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])  
  81.   
  82.         # 繪圖  
  83.         if is_plot_model:  
  84.             plot_model(model, to_file='vgg19_model.png',show_shapes=True)  
  85.   
  86.         return model  
  87.   
  88.     # InceptionV3模型  
  89.     def InceptionV3_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True,  
  90.                     is_plot_model=False):  
  91.         color = 3 if RGB else 1  
  92.         base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,  
  93.                            input_shape=(img_rows, img_cols, color),  
  94.                            classes=nb_classes)  
  95.   
  96.         # 凍結(jié)base_model所有層,,這樣就可以正確獲得bottleneck特征  
  97.         for layer in base_model.layers:  
  98.             layer.trainable = False  
  99.   
  100.         x = base_model.output  
  101.         # 添加自己的全鏈接分類層  
  102.         x = GlobalAveragePooling2D()(x)  
  103.         x = Dense(1024, activation='relu')(x)  
  104.         predictions = Dense(nb_classes, activation='softmax')(x)  
  105.   
  106.         # 訓(xùn)練模型  
  107.         model = Model(inputs=base_model.input, outputs=predictions)  
  108.         sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)  
  109.         model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])  
  110.   
  111.         # 繪圖  
  112.         if is_plot_model:  
  113.             plot_model(model, to_file='inception_v3_model.png', show_shapes=True)  
  114.   
  115.         return model  
  116.   
  117.     #訓(xùn)練模型  
  118.     def train_model(self, model, epochs, train_generator, steps_per_epoch, validation_generator, validation_steps, model_url, is_load_model=False):  
  119.         # 載入模型  
  120.         if is_load_model and os.path.exists(model_url):  
  121.             model = load_model(model_url)  
  122.   
  123.         history_ft = model.fit_generator(  
  124.             train_generator,  
  125.             steps_per_epoch=steps_per_epoch,  
  126.             epochs=epochs,  
  127.             validation_data=validation_generator,  
  128.             validation_steps=validation_steps)  
  129.         # 模型保存  
  130.         model.save(model_url,overwrite=True)  
  131.         return history_ft  
  132.   
  133.     # 畫圖  
  134.     def plot_training(self, history):  
  135.       acc = history.history['acc']  
  136.       val_acc = history.history['val_acc']  
  137.       loss = history.history['loss']  
  138.       val_loss = history.history['val_loss']  
  139.       epochs = range(len(acc))  
  140.       plt.plot(epochs, acc, 'b-')  
  141.       plt.plot(epochs, val_acc, 'r')  
  142.       plt.title('Training and validation accuracy')  
  143.       plt.figure()  
  144.       plt.plot(epochs, loss, 'b-')  
  145.       plt.plot(epochs, val_loss, 'r-')  
  146.       plt.title('Training and validation loss')  
  147.       plt.show()  
  148.   
  149.   
  150. if __name__ == '__main__':  
  151.     image_size = 197  
  152.     batch_size = 32  
  153.   
  154.     transfer = PowerTransferMode()  
  155.   
  156.     #得到數(shù)據(jù)  
  157.     train_generator = transfer.DataGen('data/cat_dog_Dataset/train', image_size, image_size, batch_size, True)  
  158.     validation_generator = transfer.DataGen('data/cat_dog_Dataset/test', image_size, image_size, batch_size, False)  
  159.   
  160.     #VGG19  
  161.     #model = transfer.VGG19_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)  
  162.     #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'vgg19_model_weights.h5', is_load_model=False)  
  163.   
  164.     #ResNet50  
  165.     model = transfer.ResNet50_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)  
  166.     history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60'resnet50_model_weights.h5', is_load_model=False)  
  167.   
  168.     #InceptionV3  
  169.     #model = transfer.InceptionV3_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=True)  
  170.     #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'inception_v3_model_weights.h5', is_load_model=False)  
  171.   
  172.     # 訓(xùn)練的acc_loss圖  
  173.     transfer.plot_training(history_ft)  



    本站是提供個(gè)人知識(shí)管理的網(wǎng)絡(luò)存儲(chǔ)空間,,所有內(nèi)容均由用戶發(fā)布,不代表本站觀點(diǎn),。請(qǐng)注意甄別內(nèi)容中的聯(lián)系方式,、誘導(dǎo)購買等信息,謹(jǐn)防詐騙,。如發(fā)現(xiàn)有害或侵權(quán)內(nèi)容,,請(qǐng)點(diǎn)擊一鍵舉報(bào)。
    轉(zhuǎn)藏 分享 獻(xiàn)花(0

    0條評(píng)論

    發(fā)表

    請(qǐng)遵守用戶 評(píng)論公約

    類似文章 更多