我就廢話不多說,看例子吧!
import torch.nn as nnoutputs = model(data)loss= loss_fn(outputs, target)optimizer.zero_grad()loss.backward()nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)optimizer.step()
nn.utils.clip_grad_norm_ 的參數:
parameters – 一個基于變量的迭代器,會進行梯度歸一化
max_norm – 梯度的最大范數
norm_type – 規定范數的類型,默認為L2
以上這篇pytorch梯度剪裁方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持武林站長站。
新聞熱點
疑難解答