在項目中遇到需要處理超級大量的數據集,無法載入內存的問題就不用說了,單線程分批讀取和處理(雖然這個處理也只是特別簡單的首尾相連的操作)也會使瓶頸出現在CPU性能上,所以研究了一下多線程和多進程的數據讀取和預處理,都是通過調用dataset api實現
1. 多線程數據讀取
第一種方法是可以直接從csv里讀取數據,但返回值是tensor,需要在sess里run一下才能返回真實值,無法實現真正的并行處理,但如果直接用csv文件或其他什么文件存了特征值,可以直接讀取后進行訓練,可使用這種方法.
import tensorflow as tf#這里是返回的數據類型,具體內容無所謂,類型對應就好了,比如我這個,就是一個四維的向量,前三維是字符串類型 最后一維是int類型record_defaults = [[""], [""], [""], [0]]def decode_csv(line): parsed_line = tf.decode_csv(line, record_defaults) label = parsed_line[-1] # label del parsed_line[-1] # delete the last element from the list features = tf.stack(parsed_line) # Stack features so that you can later vectorize forward prop., etc. #label = tf.stack(label) #NOT needed. Only if more than 1 column makes the label... batch_to_return = features, label return batch_to_returnfilenames = tf.placeholder(tf.string, shape=[None])dataset5 = tf.data.Dataset.from_tensor_slices(filenames)#在這里設置線程數目dataset5 = dataset5.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1).map(decode_csv,num_parallel_calls=15)) dataset5 = dataset5.shuffle(buffer_size=1000)dataset5 = dataset5.batch(32) #batch_sizeiterator5 = dataset5.make_initializable_iterator()next_element5 = iterator5.get_next()#這里是需要加載的文件名training_filenames = ["train.csv"]validation_filenames = ["vali.csv"]with tf.Session() as sess: for _ in range(2): #通過文件名初始化迭代器 sess.run(iterator5.initializer, feed_dict={filenames: training_filenames}) while True: try: #這里獲得真實值 features, labels = sess.run(next_element5) # Train... # print("(train) features: ") # print(features) # print("(train) labels: ") # print(labels) except tf.errors.OutOfRangeError: print("Out of range error triggered (looped through training set 1 time)") break # Validate (cost, accuracy) on train set print("/nDone with the first iterator/n") sess.run(iterator5.initializer, feed_dict={filenames: validation_filenames}) while True: try: features, labels = sess.run(next_element5) # Validate (cost, accuracy) on dev set # print("(dev) features: ") # print(features) # print("(dev) labels: ") # print(labels) except tf.errors.OutOfRangeError: print("Out of range error triggered (looped through dev set 1 time only)") break
第二種方法,基于生成器,可以進行預處理操作了,sess里run出來的結果可以直接進行輸入訓練,但需要自己寫一個生成器,我使用的測試代碼如下:
新聞熱點
疑難解答