分類器平均準確率計算:
correct = torch.zeros(1).squeeze().cuda()total = torch.zeros(1).squeeze().cuda()for i, (images, labels) in enumerate(train_loader): images = Variable(images.cuda()) labels = Variable(labels.cuda()) output = model(images) prediction = torch.argmax(output, 1) correct += (prediction == labels).sum().float() total += len(labels)acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())
分類器各個子類準確率計算:
correct = list(0. for i in range(args.class_num))total = list(0. for i in range(args.class_num))for i, (images, labels) in enumerate(train_loader): images = Variable(images.cuda()) labels = Variable(labels.cuda()) output = model(images) prediction = torch.argmax(output, 1) res = prediction == labels for label_idx in range(len(labels)): label_single = label[label_idx] correct[label_single] += res[label_idx].item() total[label_single] += 1 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total)) for acc_idx in range(len(train_class_correct)): try: acc = correct[acc_idx]/total[acc_idx] except: acc = 0 finally: acc_str += '/tclassID:%d/tacc:%f/t'%(acc_idx+1, acc)
以上這篇Pytorch 實現計算分類器準確率(總分類及子分類)就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持武林站長站。
新聞熱點
疑難解答