基于Tensorflow下的批量數據的輸入處理:
1.Tensor TFrecords格式
2.h5py的庫的數組方法
在tensorflow的框架下寫CNN代碼,我在書寫過程中,感覺不是框架內容難寫, 更多的是我在對圖像的預處理和輸入這部分花了很多精神。
使用了兩種方法:
方法一:
Tensor 以Tfrecords的格式存儲數據,如果對數據進行標簽,可以同時做到數據打標簽。
①創建TFrecords文件
orig_image = '/home/images/train_image/'gen_image = '/home/images/image_train.tfrecords'def create_record(): writer = tf.python_io.TFRecordWriter(gen_image) class_path = orig_image for img_name in os.listdir(class_path): #讀取每一幅圖像 img_path = class_path + img_name img = Image.open(img_path) #讀取圖像 #img = img.resize((256, 256)) #設置圖片大小, 在這里可以對圖像進行處理 img_raw = img.tobytes() #將圖片轉化為原聲bytes example = tf.train.Example( features=tf.train.Features(feature={ 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[0])), #打標簽 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))#存儲數據 })) writer.write(example.SerializeToString()) writer.close()
②讀取TFrecords文件
def read_and_decode(filename): #創建文件隊列,不限讀取的數據 filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string)}) label = features['label'] img = features['img_raw'] img = tf.decode_raw(img, tf.uint8) #tf.float32 img = tf.image.convert_image_dtype(img, dtype=tf.float32) img = tf.reshape(img, [256, 256, 1]) label = tf.cast(label, tf.int32) return img, label
③批量讀取數據,使用tf.train.batch
min_after_dequeue = 10000capacity = min_after_dequeue + 3 * batch_sizenum_samples= len(os.listdir(orig_image))create_record()img, label = read_and_decode(gen_image)total_batch = int(num_samples/batch_size)image_batch, label_batch = tf.train.batch([img, label], batch_size=batch_size, num_threads=32, capacity=capacity) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())with tf.Session() as sess: sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(total_batch): cur_image_batch, cur_label_batch = sess.run([image_batch, label_batch]) coord.request_stop() coord.join(threads)
方法二:
使用h5py就是使用數組的格式來存儲數據
這個方法比較好,在CNN的過程中,會使用到多個數據類存儲,比較好用, 比如一個數據進行了兩種以上的變化,并且分類存儲,我認為這個方法會比較好用。
新聞熱點
疑難解答