def my_model(features, labels, mode, params): ''' model_fn指定函數,構建模型,,訓練等 --------------------------------- Args: features: 輸入,,shape = (batch_size, 784) labels: 輸出,,shape = (batch_size, ) mode: str, 階段 params: dict, 超參數 ''' is_training = (mode == tf.estimator.ModeKeys.TRAIN) images = features images = tf.reshape(images, shape=[-1, params['image_size'], params['image_size'], 1]) # reshape (batch_size, img_size, img_size, 1) with tf.variable_scope("model"): embeddings = build_model(is_training, images, params) # 簡歷模型 if mode == tf.estimator.ModeKeys.PREDICT: # 如果是預測階段,,直接返回得到embeddings predictions = {'embeddings': embeddings} return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) '''調用對應的triplet loss''' labels = tf.cast(labels, tf.int64) if params['triplet_strategy'] == 'batch_all': loss, fraction = batch_all_triplet_loss(labels, embeddings, margin=params['margin'], squared=params['squared']) elif params['triplet_strategy'] == 'batch_hard': loss = batch_hard_triplet_loss(labels, embeddings, margin=params['margin'], squared=params['squared']) else: raise ValueError("triplet_strategy 配置不正確: {}".format(params['triplet_strategy'])) embedding_mean_norm = tf.reduce_mean(tf.norm(embeddings, axis=1)) # 這里計算了embeddings的二范數的均值 tf.summary.scalar("embedding_mean_norm", embedding_mean_norm) with tf.variable_scope("metrics"): eval_metric_ops = {'embedding_mean_norm': tf.metrics.mean(embedding_mean_norm)} if params['triplet_strategy'] == 'batch_all': eval_metric_ops['fraction_positive_triplets'] = tf.metrics.mean(fraction) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=eval_metric_ops) tf.summary.scalar('loss', loss) if params['triplet_strategy'] == "batch_all": tf.summary.scalar('fraction_positive_triplets', fraction) tf.summary.image('train_image', images, max_outputs=1) # 1代表1個channel optimizer = tf.train.AdamOptimizer(learning_rate=params['learning_rate']) global_step = tf.train.get_global_step() if params['use_batch_norm']: '''如果使用BN,,需要估計batch上的均值和方差,,tf.get_collection(tf.GraphKeys.UPDATE_OPS)就可以得到 tf.control_dependencies計算完之后再進行里面的操作 ''' with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): train_op = optimizer.minimize(loss, global_step=global_step) else: train_op = optimizer.minimize(loss, global_step=global_step) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
|