以讀取VOC2012語義分割數據集為例,具體見代碼注釋:
VocDataset.py
from PIL import Imageimport torchimport torch.utils.data as dataimport numpy as npimport osimport torchvisionimport torchvision.transforms as transformsimport time#VOC數據集分類對應顏色標簽VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]#顏色標簽空間轉到序號標簽空間,就他媽這里浪費巨量的時間,這里還他媽的有問題def voc_label_indices(colormap, colormap2label): """Assign label indices for Pascal VOC2012 Dataset.""" idx = ((colormap[:, :, 2] * 256 + colormap[ :, :,1]) * 256+ colormap[:, :,0]) #out = np.empty(idx.shape, dtype = np.int64) out = colormap2label[idx] out=out.astype(np.int64)#數據類型轉換 end = time.time() return outclass MyDataset(data.Dataset):#創建自定義的數據讀取類 def __init__(self, root, is_train, crop_size=(320,480)): self.rgb_mean =(0.485, 0.456, 0.406) self.rgb_std = (0.229, 0.224, 0.225) self.root=root self.crop_size=crop_size images = []#創建空列表存文件名稱 txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt') with open(txt_fname, 'r') as f: self.images = f.read().split() #數據名稱整理 self.files = [] for name in self.images: img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name) label_file = os.path.join(self.root, "SegmentationClass/%s.png" % name) self.files.append({ "img": img_file, "label": label_file, "name": name }) self.colormap2label = np.zeros(256**3) #整個循環的意思就是將顏色標簽映射為單通道的數組索引 for i, cm in enumerate(VOC_COLORMAP): self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i #按照索引讀取每個元素的具體內容 def __getitem__(self, index): datafiles = self.files[index] name = datafiles["name"] image = Image.open(datafiles["img"]) label = Image.open(datafiles["label"]).convert('RGB')#打開的是PNG格式的圖片要轉到rgb的格式下,不然結果會比較要命 #以圖像中心為中心截取固定大小圖像,小于固定大小的圖像則自動填0 imgCenterCrop = transforms.Compose([ transforms.CenterCrop(self.crop_size), transforms.ToTensor(), transforms.Normalize(self.rgb_mean, self.rgb_std),#圖像數據正則化 ]) labelCenterCrop = transforms.CenterCrop(self.crop_size) cropImage=imgCenterCrop(image) croplabel=labelCenterCrop(label) croplabel=torch.from_numpy(np.array(croplabel)).long()#把標簽數據類型轉為torch #將顏色標簽圖轉為序號標簽圖 mylabel=voc_label_indices(croplabel, self.colormap2label) return cropImage,mylabel #返回圖像數據長度 def __len__(self): return len(self.files)
新聞熱點
疑難解答