本文將原始的numpy array數據在pytorch下封裝為Dataset類的數據集,為后續深度網絡訓練提供數據。
加載并保存圖像信息
首先導入需要的庫,定義各種路徑。
import osimport matplotlibfrom keras.datasets import mnistimport numpy as npfrom torch.utils.data.dataset import Datasetfrom PIL import Imageimport scipy.miscroot_path = 'E:/coding_ex/pytorch/Alexnet/data/'base_path = 'baseset/'training_path = 'trainingset/'test_path = 'testset/'
這里將數據集分為三類,baseset為所有數據(trainingset+testset),trainingset是訓練集,testset是測試集。直接通過keras.dataset加載mnist數據集,不能自動下載的話可以手動下載.npz并保存至相應目錄下。
def LoadData(root_path, base_path, training_path, test_path): (x_train, y_train), (x_test, y_test) = mnist.load_data() x_baseset = np.concatenate((x_train, x_test)) y_baseset = np.concatenate((y_train, y_test)) train_num = len(x_train) test_num = len(x_test) #baseset file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w') file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w') for i in range(train_num + test_num): file_img.write(root_path + base_path + 'img/' + str(i) + '.png/n') #name file_label.write(str(y_baseset[i])+'/n') #label# scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i]) matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i]) file_img.close() file_label.close() #trainingset file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w') file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w') for i in range(train_num): file_img.write(root_path + training_path + 'img/' + str(i) + '.png/n') #name file_label.write(str(y_train[i])+'/n') #label# scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i]) matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i]) file_img.close() file_label.close() #testset file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w') file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w') for i in range(test_num): file_img.write(root_path + test_path + 'img/' + str(i) + '.png/n') #name file_label.write(str(y_test[i])+'/n') #label# scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i]) matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i]) file_img.close() file_label.close()
使用這段代碼時,需要建立相應的文件夾及.txt文件,./data文件夾結構如下:
/img文件夾
由于mnist數據集其實是灰度圖,這里用matplotlib保存的圖像是偽彩色圖像。
新聞熱點
疑難解答