回調(diào)函數(shù)Callbacks回調(diào)函數(shù)是一組在訓(xùn)練的特定階段被調(diào)用的函數(shù)集,你可以使用回調(diào)函數(shù)來觀察訓(xùn)練過程中網(wǎng)絡(luò)內(nèi)部的狀態(tài)和統(tǒng)計(jì)信息。通過傳遞回調(diào)函數(shù)列表到模型的 【Tips】雖然我們稱之為回調(diào)“函數(shù)”,,但事實(shí)上Keras的回調(diào)函數(shù)是一個(gè)類,,回調(diào)函數(shù)只是習(xí)慣性稱呼 CallbackListkeras.callbacks.CallbackList(callbacks=[], queue_length=10) Callbackkeras.callbacks.Callback() 這是回調(diào)函數(shù)的抽象類,定義新的回調(diào)函數(shù)必須繼承自該類 類屬性
回調(diào)函數(shù)以字典 目前,,模型的
BaseLoggerkeras.callbacks.BaseLogger() 該回調(diào)函數(shù)用來對(duì)每個(gè)epoch累加 該回調(diào)函數(shù)在每個(gè)Keras模型中都會(huì)被自動(dòng)調(diào)用 ProgbarLoggerkeras.callbacks.ProgbarLogger() 該回調(diào)函數(shù)用來將 Historykeras.callbacks.History() 該回調(diào)函數(shù)在Keras模型上會(huì)被自動(dòng)調(diào)用,, ModelCheckpointkeras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1) 該回調(diào)函數(shù)將在每個(gè)epoch后保存模型到
例如, 參數(shù)
EarlyStoppingkeras.callbacks.EarlyStopping(monitor='val_loss', patience=0, verbose=0, mode='auto') 當(dāng)監(jiān)測(cè)值不再改善時(shí),,該回調(diào)函數(shù)將中止訓(xùn)練 參數(shù)
RemoteMonitorkeras.callbacks.RemoteMonitor(root='http://localhost:9000') 該回調(diào)函數(shù)用于向服務(wù)器發(fā)送事件流,該回調(diào)函數(shù)需要 參數(shù)
LearningRateSchedulerkeras.callbacks.LearningRateScheduler(schedule) 該回調(diào)函數(shù)是學(xué)習(xí)率調(diào)度器 參數(shù)
TensorBoardkeras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0) 該回調(diào)函數(shù)是一個(gè)可視化的展示器 TensorBoard是TensorFlow提供的可視化工具,,該回調(diào)函數(shù)將日志信息寫入TensorBorad,,使得你可以動(dòng)態(tài)的觀察訓(xùn)練和測(cè)試指標(biāo)的圖像以及不同層的激活值直方圖。 如果已經(jīng)通過pip安裝了TensorFlow,,我們可通過下面的命令啟動(dòng)TensorBoard: tensorboard --logdir=/full_path_to_your_logs 更多的參考信息,,請(qǐng)點(diǎn)擊這里 參數(shù)
ReduceLROnPlateaukeras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0) 當(dāng)評(píng)價(jià)指標(biāo)不在提升時(shí),減少學(xué)習(xí)率 當(dāng)學(xué)習(xí)停滯時(shí),,減少2倍或10倍的學(xué)習(xí)率常常能獲得較好的效果,。該回調(diào)函數(shù)檢測(cè)指標(biāo)的情況,如果在 參數(shù)
示例:reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)model.fit(X_train, Y_train, callbacks=[reduce_lr]) CSVLoggerkeras.callbacks.CSVLogger(filename, separator=',', append=False) 將epoch的訓(xùn)練結(jié)果保存在csv文件中,,支持所有可被轉(zhuǎn)換為string的值,,包括1D的可迭代數(shù)值如np.ndarray. 參數(shù)
示例csv_logger = CSVLogger('training.log')model.fit(X_train, Y_train, callbacks=[csv_logger]) LambdaCallbackkeras.callbacks.LambdaCallback(on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None, on_batch_end=None, on_train_begin=None, on_train_end=None) 用于創(chuàng)建簡(jiǎn)單的callback的callback類 該callback的匿名函數(shù)將會(huì)在適當(dāng)?shù)臅r(shí)候調(diào)用,,注意,,該回調(diào)函數(shù)假定了一些位置參數(shù) 參數(shù)
示例# Print the batch number at the beginning of every batch.batch_print_callback = LambdaCallback(on_batch_begin=lambda batch, logs: print(batch))# Plot the loss after every epoch.import numpy as npimport matplotlib.pyplot as pltplot_loss_callback = LambdaCallback(on_epoch_end=lambda epoch, logs: plt.plot(np.arange(epoch), logs['loss']))# Terminate some processes after having finished model training.processes = ...cleanup_callback = LambdaCallback(on_train_end=lambda logs: [p.terminate() for p in processes if p.is_alive()])model.fit(..., callbacks=[batch_print_callback, plot_loss_callback, cleanup_callback]) 編寫自己的回調(diào)函數(shù)我們可以通過繼承 這里是一個(gè)簡(jiǎn)單的保存每個(gè)batch的loss的回調(diào)函數(shù): class LossHistory(keras.callbacks.Callback): def on_train_begin(self, logs={}): self.losses = [] def on_batch_end(self, batch, logs={}): self.losses.append(logs.get('loss')) 例子:記錄損失函數(shù)的歷史數(shù)據(jù)class LossHistory(keras.callbacks.Callback): def on_train_begin(self, logs={}): self.losses = [] def on_batch_end(self, batch, logs={}): self.losses.append(logs.get('loss'))model = Sequential()model.add(Dense(10, input_dim=784, init='uniform'))model.add(Activation('softmax'))model.compile(loss='categorical_crossentropy', optimizer='rmsprop')history = LossHistory()model.fit(X_train, Y_train, batch_size=128, nb_epoch=20, verbose=0, callbacks=[history])print history.losses# outputs'''[0.66047596406559383, 0.3547245744908703, ..., 0.25953155204159617, 0.25901699725311789] 例子:模型檢查點(diǎn)from keras.callbacks import ModelCheckpointmodel = Sequential()model.add(Dense(10, input_dim=784, init='uniform'))model.add(Activation('softmax'))model.compile(loss='categorical_crossentropy', optimizer='rmsprop')'''saves the model weights after each epoch if the validation loss decreased'''checkpointer = ModelCheckpoint(filepath='/tmp/weights.hdf5', verbose=1, save_best_only=True)model.fit(X_train, Y_train, batch_size=128, nb_epoch=20, verbose=0, validation_data=(X_test, Y_test), callbacks=[checkpointer]) |
|