久久国产成人av_抖音国产毛片_a片网站免费观看_A片无码播放手机在线观看,色五月在线观看,亚洲精品m在线观看,女人自慰的免费网址,悠悠在线观看精品视频,一级日本片免费的,亚洲精品久,国产精品成人久久久久久久

分享

實(shí)例介紹TensorFlow的輸入流水線

 520jefferson 2021-04-20

   作者:葉   虎

編輯:趙一帆

前  言


在訓(xùn)練模型時(shí),,我們首先要處理的就是訓(xùn)練數(shù)據(jù)的加載與預(yù)處理的問題,,這里稱這個(gè)過程為輸入流水線(input pipelines,,或輸入管道,,[參考:https://www./performance/datasets_performance])。在TensorFlow中,,典型的輸入流水線包含三個(gè)流程(ETL流程):

  1. 提?。‥xtract):從存儲(chǔ)介質(zhì)(如硬盤)中讀取數(shù)據(jù),可能是本地讀取,,也可能是遠(yuǎn)程讀?。ū热缭诜植际酱鎯?chǔ)系統(tǒng)HDFS)

  2. 預(yù)處理(Transform):利用CPU處理器解析和預(yù)處理提取的數(shù)據(jù),如圖像解壓縮,,數(shù)據(jù)擴(kuò)增或者變換,,然后會(huì)做random shuffle,并形成batch,。

  3. 加載(load):將預(yù)處理后的數(shù)據(jù)加載到加速設(shè)備中(如GPUs)來執(zhí)行模型的訓(xùn)練,。

輸入流水線對(duì)于加速模型訓(xùn)練還是很重要的,如果你的CPU處理數(shù)據(jù)能力跟不上GPU的處理速度,,此時(shí)CPU預(yù)處理數(shù)據(jù)就成為了訓(xùn)練模型的瓶頸環(huán)節(jié),。除此之外,上述輸入流水線本身也有很多優(yōu)化的地方,。比如,,一個(gè)典型的模型訓(xùn)練過程中,CPU預(yù)處理數(shù)據(jù)時(shí),,GPU是閑置的,,當(dāng)GPU訓(xùn)練模型時(shí),,CPU是閑置的,這個(gè)過程如下所示:

圖片

這樣一個(gè)訓(xùn)練step中所花費(fèi)的時(shí)間是CPU預(yù)處理數(shù)據(jù)和GPU訓(xùn)練模型時(shí)間的總和,。顯然這個(gè)過程中有資源浪費(fèi),,一個(gè)改進(jìn)的方法就是交叉CPU數(shù)據(jù)處理和GPU模型訓(xùn)練這兩個(gè)過程,當(dāng)GPU處于第個(gè)訓(xùn)練階段,,CPU正在準(zhǔn)備第N+1步所需的數(shù)據(jù),,如下圖所示:

圖片

明顯上述設(shè)計(jì)可以充分最大化利用CPU和GPU,從而減少資源的閑置,。另外當(dāng)存在多個(gè)CPU核心時(shí),,這又會(huì)涉及到CPU的并行化技術(shù)(多線程)來加速數(shù)據(jù)預(yù)處理過程,因?yàn)槊總€(gè)訓(xùn)練樣本的預(yù)處理過程往往是互相獨(dú)立的,。關(guān)于輸入流程線的優(yōu)化可以參考TensorFlow官網(wǎng)上的Pipeline Performance Guide(https://www./performance/datasets_performance),,相信你會(huì)受益匪淺。

幸運(yùn)的是,,最新的TensorFlow版本提供了tf.data這一套APIs來幫助我們快速實(shí)現(xiàn)高效又靈活的輸入流水線,。在TensorFlow中最常見的加載訓(xùn)練數(shù)據(jù)的方式是通過Feeding(https://www./api_guides/python/reading_data#Feeding)方式,其主要是定義placeholder,,然后將通過Session.run()的feed_dict參數(shù)送入數(shù)據(jù),,但是這其實(shí)是最低效的加載數(shù)據(jù)方式。后來,,TensorFlow增加了QueueRunner(https://www./api_guides/python/reading_data#_QueueRunner_)機(jī)制,,其主要是基于文件隊(duì)列以及多線程技術(shù),實(shí)現(xiàn)了更高效的輸入流水線,,但是其APIs很是讓人難懂,,所以就有了現(xiàn)在的tf.data來替代它。

這里我們通過mnist實(shí)例來講解如何使用tf.data建立簡(jiǎn)潔而高效的輸入流水線,,在介紹之前,,我們先介紹如何制作TFRecords文件,這是TensorFlow支持的一種標(biāo)準(zhǔn)文件格式

1

制作TFRecords文件

TFRecords文件是TensorFlow中的標(biāo)準(zhǔn)數(shù)據(jù)格式,,它是基于protobuf的二進(jìn)制文件,,每個(gè)TFRecord文件的基本元素是tf.train.Example,其對(duì)應(yīng)的是數(shù)據(jù)集中的一個(gè)樣本數(shù)據(jù),,每個(gè)Example包含F(xiàn)eatures,,存儲(chǔ)該樣本的各個(gè)feature,每個(gè)feature包含一個(gè)鍵值對(duì),,分別對(duì)應(yīng)feature的特征名與實(shí)際值,。下面是一個(gè)Example實(shí)例:

// An Example for a movie recommendation application:
      features {
        feature {
          key: 'age'
          value { float_list {
            value: 29.0
          }}
        }
        feature {
          key: 'movie'
          value { bytes_list {
            value: 'The Shawshank Redemption'
            value: 'Fight Club'
          }}
        }
        feature {
          key: 'movie_ratings'
          value { float_list {
            value: 9.0
            value: 9.7
          }}
        }
        feature {
          key: 'suggestion'
          value { bytes_list {
            value: 'Inception'
          }}
        }
        feature {
          key: 'suggestion_purchased'
          value { float_list {
            value: 1.0
          }}
       }
        feature {
          key: 'purchase_price'
          value { float_list {
            value: 9.99
          }}
        }
     }

上面是一個(gè)電影推薦系統(tǒng)中的一個(gè)樣本,可以看到它共含有6個(gè)特征,,每個(gè)特征都是key-value類型,,key是特征名,,而value是特征值,值得注意的是value其實(shí)存儲(chǔ)的是一個(gè)list,,根據(jù)數(shù)據(jù)類型共分為三種:bytes_list, float_listint64_list,分別存儲(chǔ)字節(jié),、浮點(diǎn)及整數(shù)類型(見這里:https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/core/example/feature.proto),。

作為標(biāo)準(zhǔn)數(shù)據(jù)格式,TensorFlow當(dāng)然提供了創(chuàng)建TFRecords文件的python接口,,下面我們創(chuàng)建mnist數(shù)據(jù)集對(duì)應(yīng)的TFRecords文件,。對(duì)于mnist數(shù)據(jù)集,每個(gè)Example需要存儲(chǔ)兩個(gè)feature,,一個(gè)是圖像的像素值,,這里可以用bytes類型,因?yàn)橐粋€(gè)像素點(diǎn)正好可以用一個(gè)字節(jié)存儲(chǔ),,另外是圖像的標(biāo)簽值,,只能用int64類型存儲(chǔ)了。因此,,我們先定義這兩個(gè)類型的接口函數(shù):

 # int64
   def _int64_feature(value):
       return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))    # bytes
   def _bytes_feature(value):
       return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
創(chuàng)建TFRecord文件,,主要通過TF中的tf.python_io.TFRecordWriter函數(shù)來實(shí)現(xiàn),具體代碼如下:

def convert_to_TFRecords(dataset, name):
   '''Convert mnist dataset to TFRecords'''
   images, labels = dataset.images, dataset.labels
   n_examples = dataset.num_examples

   filename = os.path.join(DIR, name + '.tfrecords')
   print('Writing', filename)
   with tf.python_io.TFRecordWriter(filename) as writer:
           for index in range(n_examples):
           image_bytes = images[index].tostring()
           label = labels[index]
           example = tf.train.Example(features=tf.train.Features(
               feature={'image': _bytes_feature(image_bytes),                         'label': _int64_feature(label)}))
           writer.write(example.SerializeToString())

對(duì)于mnist數(shù)據(jù)集,,主要分為train,、validation和test,利用上面的函數(shù)分別創(chuàng)建三個(gè)不同的TFRecords文件:

mnist_datasets = mnist.read_data_sets('mnist_data', dtype=tf.uint8, reshape=False)
convert_to_TFRecords(mnist_datasets.train, 'train')
convert_to_TFRecords(mnist_datasets.validation, 'validation')
convert_to_TFRecords(mnist_datasets.test, 'test')

好了,,這樣我們就創(chuàng)建3個(gè)TFRecords文件了,。


2

讀取TFRecords文件

上面我們創(chuàng)建了TFRecords文件,但是怎么去讀取它們呢,,當(dāng)然TF提供了讀取TFRecords文件的接口函數(shù),,這里首先介紹如何利用TF中操作TFRecord的python接口來讀取TFRecord文件,主要是tf.python_io.tf_record_iterator函數(shù),,它輸入TFRecord文件,,但是得到一個(gè)迭代器,每個(gè)元素是一個(gè)Example,,但是卻是一個(gè)字符串,,這里可以用tf.train.Example來解析它,具體代碼如下:

def read_TFRecords_test(name):
   filename = os.path.join(DIR, name + '.tfrecords')
   record_itr = tf.python_io.tf_record_iterator(path=filename)
   for r in record_itr:
       example = tf.train.Example()
       example.ParseFromString(r)

       label = example.features.feature['label'].int64_list.value[0]
       print('Label', label)
       image_bytes = example.features.feature['image'].bytes_list.value[0]
       img = np.fromstring(image_bytes, dtype=np.uint8).reshape(28, 28)
       print(img)
       plt.imshow(img, cmap='gray')
       plt.show()
       break  # 只讀取一個(gè)Example

上面僅是純python的讀取方式,,這不是TFRecords文件的正確使用方式,。既然是官方標(biāo)準(zhǔn)數(shù)據(jù)格式,TF也提供了使用TFRecords文件建立輸入流水線的方式,。在tf.data出現(xiàn)之前,,使用的是QueueRunner方式,,即文件隊(duì)列機(jī)制,其原理如下圖所示:

圖片

文件隊(duì)列機(jī)制主要分為兩個(gè)階段:第一個(gè)階段將輸入文件打亂,,并在文件隊(duì)列入列,,然后Reader從文件隊(duì)列中讀取一個(gè)文件,同時(shí)文件隊(duì)列出列這個(gè)文件,,Reader同時(shí)對(duì)文件進(jìn)行解碼,,然后生產(chǎn)數(shù)據(jù)樣本,并將樣本在樣本隊(duì)列中入列,,可以定義多個(gè)Reader并發(fā)地從多個(gè)文件同時(shí)讀取數(shù)據(jù),。從樣本隊(duì)列中的出列一定量的樣本數(shù)據(jù)即可以用于一個(gè)訓(xùn)練過程。TF提供了配套的API來完成這個(gè)過程,,注意的是這個(gè)輸入流水線是直接嵌入訓(xùn)練的Graph中,,即是整個(gè)圖模型的一部分。根據(jù)文件的不同,,可以使用不同類型的Reader,,對(duì)于TFRecord文件,可以使用tf.TFRecordReader,,下面是具體的實(shí)現(xiàn)代碼:

def read_example(filename_queue):
   '''Read one example from filename_queue'''
   reader = tf.TFRecordReader()
   key, value = reader.read(filename_queue)
   features = tf.parse_single_example(value, features={'image': tf.FixedLenFeature([], tf.string),                                                            'label': tf.FixedLenFeature([], tf.int64)})
   image = tf.decode_raw(features['image'], tf.uint8)
   image = tf.reshape(image, [28, 28])
   label = tf.cast(features['label'], tf.int32)
   return image, label
   
if __name__ == '__main__':
   queue = tf.train.string_input_producer(['TFRecords/train.tfrecords'], num_epochs=10)
   image, label = read_example(queue)

   img_batch, label_batch = tf.train.shuffle_batch([image, label], batch_size=32, capacity=5000,
                                                       min_after_dequeue=2000, num_threads=4)
   with tf.Session() as sess:
       sess.run(tf.local_variables_initializer())
       sess.run(tf.global_variables_initializer())

       coord = tf.train.Coordinator()
       threads = tf.train.start_queue_runners(sess=sess, coord=coord)
       try:
               while not coord.should_stop():                # Run training steps or whatever
               images, labels = sess.run([img_batch, label_batch])
               print(images.shape, labels.shape)
       except tf.errors.OutOfRangeError:
           print('Done training -- epoch limit reached')

       coord.request_stop()
       coord.join(threads)

對(duì)于隊(duì)列機(jī)制,,估計(jì)大家看的云里霧里的,代碼確實(shí)讓人難懂,,但是其實(shí)只要按照官方提供的標(biāo)準(zhǔn)代碼,,還是很容易在自己的數(shù)據(jù)集上進(jìn)行修改的。不過現(xiàn)在有了tf.data,,可以更加優(yōu)雅地實(shí)現(xiàn)上面的過程,。

3

tf.data簡(jiǎn)介

使用tf.data可以更方便地創(chuàng)建高效的輸入流水線,但是其相比隊(duì)列機(jī)制API更友好,,這主要是因?yàn)?strong>tf.data提供了高級(jí)抽象,。第一個(gè)抽象是使用tf.data.Dataset來表示一個(gè)數(shù)據(jù)集合,集合里面的每個(gè)元素包含一個(gè)或者多個(gè)Tensor,,一般就是對(duì)應(yīng)一個(gè)訓(xùn)練樣本,。第二個(gè)抽象是使用tf.data.Iterator來從數(shù)據(jù)集中提取數(shù)據(jù),這是一個(gè)迭代器對(duì)象,,可以通過Iterator.get_next()Dataset中產(chǎn)生一個(gè)樣本,。利用這兩個(gè)抽象,Dataset的使用簡(jiǎn)化為三個(gè)步驟:

  1. 創(chuàng)建Dataset實(shí)例對(duì)象,;

  2. 創(chuàng)建遍歷DatasetIterator實(shí)例對(duì)象,;

  3. Iterator中不斷地產(chǎn)生樣本,并送入模型中進(jìn)行訓(xùn)練。

1.創(chuàng)建Dataset

TF提供了很多方式創(chuàng)建Dataset,,下面是幾種方式:

# 從Numpy的arraydataset1 = tf.data.Dataset.from_tensor_slices(np.random.randn((5, 10))
print(dataset1.output_types)  # ==> 'tf.float32'print(dataset1.output_shapes)  # ==> '(10,)'# 從Tensor

dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]),
tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types)  # ==> '(tf.float32, tf.int32)'print(dataset2.output_shapes)  # ==> '((), (100,))'# 從文件
filenames = ['/var/data/file1.tfrecord', '/var/data/file2.tfrecord']
dataset3 = tf.data.TFRecordDataset(filenames)

更重要的是Dataset可以進(jìn)行一系列的變換操作,,并且支持鏈?zhǔn)秸{(diào)用,這對(duì)于數(shù)據(jù)預(yù)處理很重要:

dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)  # 解析數(shù)據(jù)或者對(duì)數(shù)據(jù)預(yù)處理,,如normalize.

dataset = dataset.repeat()  # 重復(fù)數(shù)據(jù)集,,一般設(shè)置num_epochs

dataset = dataset.batch(32) # 形成batch

2.創(chuàng)建Iterator

創(chuàng)建了Dataset之后,我們需要?jiǎng)?chuàng)建Iterator來遍歷數(shù)據(jù)集,,返回的是迭代器對(duì)象,,并從中可以產(chǎn)生數(shù)據(jù),以用于模型訓(xùn)練,。TF共支持4中迭代器類型,分別是one-shot, initializable, reinitializablefeedable,。下面逐個(gè)介紹它們,。

One-shot Iterator

這是最簡(jiǎn)單的Iterator,它僅僅遍歷整個(gè)數(shù)據(jù)集一次,,而且不需要顯示初始化,,下面是個(gè)實(shí)例:

dataset = tf.data.Dataset.from_tensor_slices(np.arange(10))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
   for i in range(10):
       sess.run(next_element) # 0, 1, ..., 9

Initializable Iterator

相比one-shot Iterator,它需要在使用前顯示初始化,,這樣就可以支持參數(shù)化,,每次初始化時(shí)送入不同的參數(shù),就可以支持?jǐn)?shù)據(jù)集的簡(jiǎn)單參數(shù)化,,下面是一個(gè)實(shí)例:

max_value = tf.placeholder(tf.int64, [])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:    # 需要顯示初始化
   sess.run(iterator.initializer, feed_dict={max_value: 10})
   for i in range(10):
       print(sess.run(next_element)) # 0, 1, ..., 9

Reinitializable Iterator

相比initializable Iterator,,它可以支持從不同的Dataset進(jìn)行初始化,有時(shí)候你需要訓(xùn)練集和測(cè)試集,,但是兩者并不同,,此時(shí)就可以定義兩個(gè)不同的Dataset,并配合reinitializable Iterator來定義一個(gè)通用的迭代器,,在使用前只需要送入不同的Dataset進(jìn)行初始化就可以,,下面是一個(gè)實(shí)例:

train_data = np.random.randn(100, 5)
test_data = np.random.randn(20, 5)
train_dataset = tf.data.Dataset.from_tensor_slices(train_data)
test_dataset = tf.data.Dataset.from_tensor_slices(test_data)# 創(chuàng)建一個(gè)reinitializable iterator

re_iterator = tf.data.Iterator.from_structure(train_dataset.output_types,
                                             train_dataset.output_shapes)
next_element = re_iterator.get_next()
train_init_op = re_iterator.make_initializer(train_dataset)
test_init_op = re_iterator.make_initializer(test_dataset)
with tf.Session() as sess:    # 訓(xùn)練
   n_epochs = 2
   for i in range(n_epochs):
       sess.run(train_init_op)
       for j in range(100):
           print(sess.run(next_element))
   # 測(cè)試
   sess.run(test_init_op)
   for i in range(20):
       print(sess.run(next_element))

Feedable Iterator

對(duì)于reinitializable iterator,它可以支持送入不同Dataset,,從而完成數(shù)據(jù)集的切換,,但是每次切換時(shí)必須要重新初始化。對(duì)于Feedable Iterator,,其可以認(rèn)為支持送入不同的Iterator,,通過切換迭代器的string handle來完成不同數(shù)據(jù)集的切換,并且在切換時(shí)迭代器的狀態(tài)還會(huì)被保留,,這相比reinitializable iterator更加靈活,,下面是一個(gè)實(shí)例:

train_data = np.random.randn(100, 5)
val_data = np.random.randn(20, 5)
n_epochs = 20train_dataset = tf.data.Dataset.from_tensor_slices(train_data).repeat(n_epochs)
val_dataset = tf.data.Dataset.from_tensor_slices(val_data)# 創(chuàng)建一個(gè)feedable iterator

handle = tf.placeholder(tf.string, [])
feed_iterator = tf.data.Iterator.from_string_handle(handle, train_dataset.output_types,
                                                 train_dataset.output_shapes)
next_element = feed_iterator.get_next()# 創(chuàng)建不同的iterator

train_iterator = train_dataset.make_one_shot_iterator()
val_iterator = val_dataset.make_initializable_iterator()
with tf.Session() as sess:
   # 生成對(duì)應(yīng)的handle
   train_handle = sess.run(train_iterator.string_handle())
   val_handle = sess.run(val_iterator.string_handle())
   # 訓(xùn)練
   for n in range(n_epochs):
       for i in range(100):
           print(i, sess.run(next_element, feed_dict={handle: train_handle}))
               # 驗(yàn)證
       if n % 10 == 0:
           sess.run(val_iterator.initializer)
                   for i in range(20):
               print(sess.run(next_element, feed_dict={handle: val_handle}))

關(guān)于tf.data的基礎(chǔ)知識(shí)就這么多了,更多內(nèi)容可以參考官方文檔,,另外這里要說一點(diǎn)就是,,對(duì)于迭代器對(duì)象,,當(dāng)其元素取盡之后,會(huì)拋出tf.errors.OutOfRangeError錯(cuò)誤,,當(dāng)然一般情況下你是知道自己的迭代器對(duì)象的元素?cái)?shù),,那么也就可以不用通過捕獲錯(cuò)誤來實(shí)現(xiàn)終止條件。下面,,我們將使用tf.data實(shí)現(xiàn)mnist的完整訓(xùn)練過程,。

4

MNIST完整實(shí)例

我們采用feedable Iterator來實(shí)現(xiàn)mnist數(shù)據(jù)集的訓(xùn)練過程,分別創(chuàng)建兩個(gè)Dataset,,一個(gè)為訓(xùn)練集,,一個(gè)為驗(yàn)證集,對(duì)于驗(yàn)證集不需要shuffle操作,。首先我們創(chuàng)建Dataset對(duì)象的輔助函數(shù),,主要是解析TFRecords文件,并對(duì)image做歸一化處理:

def decode(serialized_example):
   '''decode the serialized example'''
   features = tf.parse_single_example(serialized_example,
                           features={'image': tf.FixedLenFeature([], tf.string),                                      'label': tf.FixedLenFeature([], tf.int64)})
   image = tf.decode_raw(features['image'], tf.uint8)
   image = tf.cast(image, tf.float32)
   image = tf.reshape(image, [784])
   label = tf.cast(features['label'], tf.int64)
   return image, label
def normalize(image, label):
   '''normalize the image to [-0.5, 0.5]'''
   image = image / 255.0 - 0.5
   return image, label

然后定義創(chuàng)建Dataset的函數(shù),,對(duì)于訓(xùn)練集和驗(yàn)證集,,兩者的參數(shù)會(huì)不同:

def create_dataset(filename, batch_size=64, is_shuffle=False, n_repeats=0):
   '''create dataset for train and validation dataset'''
   dataset = tf.data.TFRecordDataset(filename)
   if n_repeats > 0:
       dataset = dataset.repeat(n_repeats) # for train

   dataset = dataset.map(decode).map(normalize) # decode and normalize

   if is_shuffle:
       dataset = dataset.shuffle(1000 + 3 * batch_size) # shuffle
   dataset = dataset.batch(batch_size)
   return dataset

我們使用一個(gè)簡(jiǎn)單的全連接層網(wǎng)絡(luò)來實(shí)現(xiàn)mnist的分類模型:

def model(inputs, hidden_sizes=(500, 500)):
   h1, h2 = hidden_sizes
   net = tf.layers.dense(inputs, h1, activation=tf.nn.relu)
   net = tf.layers.dense(net, h2, activation=tf.nn.relu)
   net = tf.layers.dense(net, 10, activation=None)
   return net

然后是訓(xùn)練的主體代碼:

n_train_examples = 55000n_val_examples = 5000n_epochs = 50batch_size = 64train_dataset = create_dataset('TFRecords/train.tfrecords', batch_size=batch_size, is_shuffle=True,
                              n_repeats=n_epochs)
val_dataset = create_dataset('TFRecords/validation.tfrecords', batch_size=batch_size)
# 創(chuàng)建一個(gè)feedable iterator

handle = tf.placeholder(tf.string, [])
feed_iterator = tf.data.Iterator.from_string_handle(handle, train_dataset.output_types,
                                                 train_dataset.output_shapes)
images, labels = feed_iterator.get_next()
# 創(chuàng)建不同的iterator

train_iterator = train_dataset.make_one_shot_iterator()
val_iterator = val_dataset.make_initializable_iterator()
# 創(chuàng)建模型

logits = model(images, [500, 500])
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
loss = tf.reduce_mean(loss)
train_op = tf.train.AdamOptimizer(learning_rate=1e-04).minimize(loss)
predictions = tf.argmax(logits, axis=1)
accuracy = tf.reduce_mean(tf.cast(tf.equal(predictions, labels), tf.float32))

init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
   sess.run(init_op)    # 生成對(duì)應(yīng)的handle
   train_handle = sess.run(train_iterator.string_handle())
   val_handle = sess.run(val_iterator.string_handle())
   # 訓(xùn)練
   for n in range(n_epochs):
       ls = []
       for i in range(n_train_examples // batch_size):
           _, l = sess.run([train_op, loss], feed_dict={handle: train_handle})
           ls.append(l)
       print('Epoch %d, train loss: %f' % (n, np.mean(ls)))
       if (n + 1) % 10 == 0:
           sess.run(val_iterator.initializer)
           accs = []
           for i in range(n_val_examples // batch_size):
               acc = sess.run(accuracy, feed_dict={handle: val_handle})
               accs.append(acc)
           print('\t validation accuracy: %f' % (np.mean(accs)))

大約可以在驗(yàn)證集上的accuracy達(dá)到98%。

小結(jié)

看起來最新的tf.data還是比較好用的,,如果你是TensorFlow用戶,,可以嘗試著使用它,當(dāng)然上面的例子并不能包含關(guān)于tf.data的所有內(nèi)容,,想繼續(xù)深入的話可以移步TF的官網(wǎng),。

參考資料
  1. [Programmers guide: import data](https://www./programmers_guide/datasets).

  2. [How to use Dataset in TensorFlow](https:///how-to-use-dataset-in-tensorflow-c758ef9e4428).

  3. [Reading data](https://www./api_guides/python/reading_data).

  4. [Performance: datasets performance](https://www./performance/datasets_performance).

  5. [Introduction to Artificial Neural Networks and Deep Learning: A Practical Guide with Applications in Python](https://github.com/rasbt/deep-learning-book/).

end

    本站是提供個(gè)人知識(shí)管理的網(wǎng)絡(luò)存儲(chǔ)空間,所有內(nèi)容均由用戶發(fā)布,,不代表本站觀點(diǎn),。請(qǐng)注意甄別內(nèi)容中的聯(lián)系方式、誘導(dǎo)購買等信息,,謹(jǐn)防詐騙,。如發(fā)現(xiàn)有害或侵權(quán)內(nèi)容,請(qǐng)點(diǎn)擊一鍵舉報(bào),。
    轉(zhuǎn)藏 分享 獻(xiàn)花(0

    0條評(píng)論

    發(fā)表

    請(qǐng)遵守用戶 評(píng)論公約

    類似文章 更多