在tensorflow中,有三種方式輸入數據
1. 利用feed_dict送入numpy數組
2. 利用隊列從文件中直接讀取數據
3. 預加載數據
其中第一種方式很常用,在tensorflow的MNIST訓練源碼中可以看到,通過feed_dict={},可以將任意數據送入tensor中。
第二種方式相比于第一種,速度更快,可以利用多線程的優勢把數據送入隊列,再以batch的方式出隊,并且在這個過程中可以很方便地對圖像進行隨機裁剪、翻轉、改變對比度等預處理,同時可以選擇是否對數據隨機打亂,可以說是非常方便。該部分的源碼在tensorflow官方的CIFAR-10訓練源碼中可以看到,但是對于剛學習tensorflow的人來說,比較難以理解,本篇博客就當成我調試完成后寫的一篇總結,以防自己再忘記具體細節。
讀取CIFAR-10數據集
按照第一種方式的話,CIFAR-10的讀取只需要寫一段非常簡單的代碼即可將測試集與訓練集中的圖像分別讀?。?/p>
path = 'E:/Dataset/cifar-10/cifar-10-batches-py'# extract train examplesnum_train_examples = 50000x_train = np.empty((num_train_examples, 32, 32, 3), dtype='uint8')y_train = np.empty((num_train_examples), dtype='uint8')for i in range(1, 6): fpath = os.path.join(path, 'data_batch_' + str(i)) (x_train[(i - 1) * 10000: i * 10000, :, :, :], y_train[(i - 1) * 10000: i * 10000]) = load_and_decode(fpath)# extract test examplesfpath = os.path.join(path, 'test_batch')x_test, y_test = load_and_decode(fpath)return x_train, y_train, x_test, np.array(y_test)
其中load_and_decode函數只需要按照CIFAR-10官網給出的方式decode就行,最終返回的x_train是一個[50000, 32, 32, 3]的ndarray,但對于ndarray來說,進行預處理就要麻煩很多,為了取mini-SGD的batch,還自己寫了一個類,通過調用train_set.next_batch()函數來取,總而言之就是什么都要自己動手,效率確實不高
但對于第二種方式,讀取起來就要麻煩很多,但使用起來,又快又方便
首先,把CIFAR-10的測試集文件讀取出來,生成文件名列表
path = 'E:/Dataset/cifar-10/cifar-10-batches-py'filenames = [os.path.join(path, 'data_batch_%d' % i) for i in range(1, 6)]
有了列表以后,利用tf.train.string_input_producer函數生成一個讀取隊列
filename_queue = tf.train.string_input_producer(filenames)
接下來,我們調用read_cifar10函數,得到一幅一幅的圖像,該函數的代碼如下:
def read_cifar10(filename_queue): label_bytes = 1 IMAGE_SIZE = 32 CHANNELS = 3 image_bytes = IMAGE_SIZE*IMAGE_SIZE*3 record_bytes = label_bytes+image_bytes # define a reader reader = tf.FixedLengthRecordReader(record_bytes) key, value = reader.read(filename_queue) record_bytes = tf.decode_raw(value, tf.uint8) label = tf.strided_slice(record_bytes, [0], [label_bytes]) depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]), [CHANNELS, IMAGE_SIZE, IMAGE_SIZE]) image = tf.transpose(depth_major, [1, 2, 0]) return image, label
新聞熱點
疑難解答