前言
本博客默認讀者對神經網絡與Tensorflow有一定了解,對其中的一些術語不再做具體解釋。并且本博客主要以圖片數據為例進行介紹,如有錯誤,敬請斧正。
使用Tensorflow訓練神經網絡時,我們可以用多種方式來讀取自己的數據。如果數據集比較小,而且內存足夠大,可以選擇直接將所有數據讀進內存,然后每次取一個batch的數據出來。如果數據較多,可以每次直接從硬盤中進行讀取,不過這種方式的讀取效率就比較低了。此篇博客就主要講一下Tensorflow官方推薦的一種較為高效的數據讀取方式——tfrecord。
從宏觀來講,tfrecord其實是一種數據存儲形式。使用tfrecord時,實際上是先讀取原生數據,然后轉換成tfrecord格式,再存儲在硬盤上。而使用時,再把數據從相應的tfrecord文件中解碼讀取出來。那么使用tfrecord和直接從硬盤讀取原生數據相比到底有什么優勢呢?其實,Tensorflow有和tfrecord配套的一些函數,可以加快數據的處理。實際讀取tfrecord數據時,先以相應的tfrecord文件為參數,創建一個輸入隊列,這個隊列有一定的容量(視具體硬件限制,用戶可以設置不同的值),在一部分數據出隊列時,tfrecord中的其他數據就可以通過預取進入隊列,并且這個過程和網絡的計算是獨立進行的。也就是說,網絡每一個iteration的訓練不必等待數據隊列準備好再開始,隊列中的數據始終是充足的,而往隊列中填充數據時,也可以使用多線程加速。
下面,本文將從以下4個方面對tfrecord進行介紹:
1. tfrecord格式簡介
這部分主要參考了另一篇博文,Tensorflow 訓練自己的數據集(二)(TFRecord)
tfecord文件中的數據是通過tf.train.Example Protocol Buffer的格式存儲的,下面是tf.train.Example的定義
message Example { Features features = 1;};message Features{ map<string,Feature> featrue = 1;};message Feature{ oneof kind{ BytesList bytes_list = 1; FloatList float_list = 2; Int64List int64_list = 3; }};
從上述代碼可以看出,tf.train.Example 的數據結構很簡單。tf.train.Example中包含了一個從屬性名稱到取值的字典,其中屬性名稱為一個字符串,屬性的取值可以為字符串(BytesList ),浮點數列表(FloatList )或整數列表(Int64List )。例如我們可以將圖片轉換為字符串進行存儲,圖像對應的類別標號作為整數存儲,而用于回歸任務的ground-truth可以作為浮點數存儲。通過后面的代碼我們會對tfrecord的這種字典形式有更直觀的認識。
2. 利用自己的數據生成tfrecord文件
先上一段代碼,然后我再針對代碼進行相關介紹。
新聞熱點
疑難解答