pytorch自定义梯度

pytorch自定义梯度

Scroll Down

pytorch 自定义梯度

最近的工作中需要对梯度进行自定义的修改,例如对于一个线性映射来说,我们只希望他更新其中某些位置的权重。
这时候就需要手动定义梯度反传的方式,将一些位置上的梯度赋值为0,下面就直接用简单的线性映射来举例。
这里主要用到的是torch.autograd.function

1 对function的理解

  • Pytorch是利用Variable与Function来构建计算图的。回顾下Variable,Variable就像是计算图中的节点,保存计算结果(包括前向传播的激活值,反向传播的梯度),而Function就像计算图中的边,实现Variable的计算,并输出新的Variable;
  • Function简单说就是对Variable的运算,如加减乘除,relu,pool等
    但它不仅仅是简单的运算。
  • 总结,Function与Variable构成了pytorch的自动求导机制,它定义的是各个Variable之间的计算关系

2 Function与Module的差异

  • Function一般只定义一个操作,因为其无法保存参数,因此适用于激活函数、pooling等操作;
  • Module是保存了参数,因此适合于定义一层,如线性层,卷积层,也适用于定义一个网络;
  • 二者可以结合在一起,这次的linear层就是对二者进行了结合。

3 线性映射举例

3.1 自定义torch.autograd.function

class LinearFunction(torch.autograd.function.Function):
    @staticmethod
    def forward(ctx, input, weight, location, bias=None):
        ctx.save_for_backward(input, weight, bias, location)
        output = input.mm(weight.t()) #这里实现线性层的前向计算
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
         #拿到所有需要的值以便于后面计算梯度
        input, weight, bias, loc = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        #判断是否需要求导,对应的wx+b的求偏导方式
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
        for item in grad_weight:
            for idx in range(20):
                if idx not in loc:
                    item[idx]=0
        # print(grad_weight)
	#返回之后,就自动定义了计算图中的autograd 方法,loss进行backward的时候就直接自动计算了
        return grad_input, grad_weight, grad_bias, grad_bias

3.2 自定义nn.Module

class LinearCraft(nn.Module):
    def __init__(self, input_features, output_features, location, bias=True):
        super(LinearCraft, self).__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.loc = torch.from_numpy(location)
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
            self.register_parameter('bias', None)

        nn.init.uniform_(self.weight, -0.1, 0.1)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -0.1, 0.1)

    def forward(self, input):
	# 这里注意是直接在forward方法里调用写好的function即可,另外调用方式不是 .forward() 而是 .apply()
        return LinearFunction.apply(input, self.weight, self.loc, self.bias)

    def extra_repr(self):
        return 'input_features={}, output_features={}, bias={}'.format(
            self.input_features, self.output_features, self.bias is not None
        )

这样的定义结束之后,就相当于LinearCraft就和nn.linear相同,但是使用的是自定义的梯度反传方式了。感兴趣的可以自己debug看一下反传的梯度,对应位置已经是0了。