圖片顯示
pytorch 載入的數據集是元組tuple 形式,里面包括了數據及標簽(train_data,label),其中的train_data數據可以轉換為torch.Tensor形式,方便后面計算使用。
同樣給一些剛入門的同學在使用載入的數據顯示圖片的時候帶來一些難以理解的地方,這里主要是將Tensor與numpy轉換的過程,理解了這些就可以就行轉換了
CIAFA10數據集
首先載入數據集,這里做了一些數據處理,包括圖片尺寸、數據歸一化等
import torchfrom torch.autograd import Variable import matplotlib.pyplot as plt import torchvision.datasets as dsetimport torchvision.transforms as transformsfrom autoencoder import AutoEncoderimport torch.nn as nnimport torchvisionimport numpy as npdataset = dset.CIFAR10(root='../train/data', download=True, transform=transforms.Compose([ transforms.Scale(200), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), transforms.Gray() ]))
在這里 dataset 是一個CIFAR10對象,(大家可以查看一下他的源代碼)
方式一
dataset[1] = ([torch.FloatTensor of size 1x200x200],9)
載入的第二個數據是個tensor格式,包含一個標簽 9
這里我們做的就是將torch.FloatTensor 轉換為numpy,然后顯示
b = dataset[1][0].numpy()#取數據,不取標簽
因為這里的b仍然是1*200*200的大小,所以要重新reshape一下,適合輸出圖像
plt.imshow(b.reshape(200,200),cmap = 'gray')plt.show()
然后可以顯示圖像了
方式二
利用torch的接口
img = torchvision.utils.make_grid(dataset[1][0]).numpy()plt.imshow(np.transpose(img,(1,2,0)))plt.show()
這用np.transpose 是因為plt.imshow在顯示 時候輸入的是(imgsize,imgsieze,channels),而這里得到的img是(3,200,200)的格式,所以進行了轉換,才能顯示
以上這篇pytorch 數據集圖片顯示方法就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持武林站長站。
新聞熱點
疑難解答