任務要求:
自定義一個層主要是定義該層的實現函數,只需要重載Function的forward和backward函數即可,如下:
import torchfrom torch.autograd import Functionfrom torch.autograd import Variable
定義二值化函數
class BinarizedF(Function): def forward(self, input): self.save_for_backward(input) a = torch.ones_like(input) b = -torch.ones_like(input) output = torch.where(input>=0,a,b) return output def backward(self, output_grad): input, = self.saved_tensors input_abs = torch.abs(input) ones = torch.ones_like(input) zeros = torch.zeros_like(input) input_grad = torch.where(input_abs<=1,ones, zeros) return input_grad
定義一個module
class BinarizedModule(nn.Module): def __init__(self): super(BinarizedModule, self).__init__() self.BF = BinarizedF() def forward(self,input): print(input.shape) output =self.BF(input) return output
進行測試
a = Variable(torch.randn(4,480,640), requires_grad=True)output = BinarizedModule()(a)output.backward(torch.ones(a.size()))print(a)print(a.grad)
其中, 二值化函數部分也可以按照方式寫,但是速度慢了0.05s
class BinarizedF(Function): def forward(self, input): self.save_for_backward(input) output = torch.ones_like(input) output[input<0] = -1 return output def backward(self, output_grad): input, = self.saved_tensors input_grad = output_grad.clone() input_abs = torch.abs(input) input_grad[input_abs>1] = 0 return input_grad
以上這篇pytorch自定義二值化網絡層方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持武林網之家。
新聞熱點
疑難解答