代碼:
import torchclass_num = 10batch_size = 4label = torch.LongTensor(batch_size, 1).random_() % class_numprint(label.size())one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)print(one_hot)
輸出:
torch.Size([4, 1])tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
注意:
label的形狀必須是[n,1]的,也就是必須是二維的,且第二個維度長度為1,如果是一維度的,則需要升維度,代碼如下:
import torchclass_num = 10batch_size = 4label = torch.LongTensor(batch_size).random_() % class_numprint(label.size())label = torch.unsqueeze(label,dim=1)print(label.size())
以上這篇pytorch標簽轉onehot形式實例就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持武林網之家。
新聞熱點
疑難解答