實驗環境
win10 + anaconda + jupyter notebook
Pytorch1.1.0
Python3.7
gpu環境(可選)
MNIST數據集介紹
MNIST 包括6萬張28x28的訓練樣本,1萬張測試樣本,可以說是CV里的“Hello Word”。本文使用的CNN網絡將MNIST數據的識別率提高到了99%。下面我們就開始進行實戰。
導入包
import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torchvision import datasets, transformstorch.__version__
定義超參數
BATCH_SIZE=512EPOCHS=20 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
數據集
我們直接使用PyTorch中自帶的dataset,并使用DataLoader對訓練數據和測試數據分別進行讀取。如果下載過數據集這里download可選擇False
train_loader = torch.utils.data.DataLoader( datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=BATCH_SIZE, shuffle=True)test_loader = torch.utils.data.DataLoader( datasets.MNIST('data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=BATCH_SIZE, shuffle=True)
定義網絡
該網絡包括兩個卷積層和兩個線性層,最后輸出10個維度,即代表0-9十個數字。
class ConvNet(nn.Module): def __init__(self): super().__init__() self.conv1=nn.Conv2d(1,10,5) # input:(1,28,28) output:(10,24,24) self.conv2=nn.Conv2d(10,20,3) # input:(10,12,12) output:(20,10,10) self.fc1 = nn.Linear(20*10*10,500) self.fc2 = nn.Linear(500,10) def forward(self,x): in_size = x.size(0) out = self.conv1(x) out = F.relu(out) out = F.max_pool2d(out, 2, 2) out = self.conv2(out) out = F.relu(out) out = out.view(in_size,-1) out = self.fc1(out) out = F.relu(out) out = self.fc2(out) out = F.log_softmax(out,dim=1) return out
實例化網絡
model = ConvNet().to(DEVICE) # 將網絡移動到gpu上optimizer = optim.Adam(model.parameters()) # 使用Adam優化器
定義訓練函數
def train(model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if(batch_idx+1)%30 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]/tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))
新聞熱點
疑難解答