可以從官網加載預訓練好的模型:
import torchvision.models as models model = models.vgg16(pretrained = True)print(model)
但是經常會出現因為下載速度太慢而出現requests.exceptions.ConnectionError: ('Connection aborted.', TimeoutError(10060, '由于連接方在一段時間后沒有正確答復或連接的主機沒有反應,連接嘗試失敗。', None, 10060, None))這種錯誤,因此需要我們手動去下載 .pth 文件(百度云也很慢,如果你是SVIP,當我沒說;迅雷的速度也還可以),然后從本地加載。
從本地加載只需要把上面的代碼換成如下:
import torchvision.models as models model = models.vgg16(pretrained=False)pre=torch.load(r'./kaggle_dog_vs_cat/pretrain/vgg16-397923af.pth')model.load_state_dict(pre)
如果你模型不是用的vgg16,而是用的vgg11或者vgg13,只需要修改語句 model = models.vgg16(pretrained=False) 為對應模型的函數即可。
以上這篇pytorch實現從本地加載 .pth 格式模型就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持武林站長站。
新聞熱點
疑難解答