pytorch 自定义梯度
最近的工作中需要对梯度进行自定义的修改,例如对于一个线性映射来说,我们只希望他更新其中某些位置的权重。
这时候就需要手动定义梯度反传的方式,将一些位置上的梯度赋值为0,下面就直接用简单的线性映射来举例。
这里主要用到的是torch.autograd.function
1 对function的理解
- Pytorch是利用Variable与Function来构建计算图的。回顾下Variable,Variable就像是计算图中的节点,保存计算结果(包括前向传播的激活值,反向传播的梯度),而Function就像计算图中的边,实现Variable的计算,并输出新的Variable;
- Function简单说就是对Variable的运算,如加减乘除,relu,pool等
但它不仅仅是简单的运算。 - 总结,Function与Variable构成了pytorch的自动求导机制,它定义的是各个Variable之间的计算关系
2 Function与Module的差异
- Function一般只定义一个操作,因为其无法保存参数,因此适用于激活函数、pooling等操作;
- Module是保存了参数,因此适合于定义一层,如线性层,卷积层,也适用于定义一个网络;
- 二者可以结合在一起,这次的linear层就是对二者进行了结合。
3 线性映射举例
3.1 自定义torch.autograd.function
class LinearFunction(torch.autograd.function.Function):
@staticmethod
def forward(ctx, input, weight, location, bias=None):
ctx.save_for_backward(input, weight, bias, location)
output = input.mm(weight.t()) #这里实现线性层的前向计算
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
#拿到所有需要的值以便于后面计算梯度
input, weight, bias, loc = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
#判断是否需要求导,对应的wx+b的求偏导方式
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
for item in grad_weight:
for idx in range(20):
if idx not in loc:
item[idx]=0
# print(grad_weight)
#返回之后,就自动定义了计算图中的autograd 方法,loss进行backward的时候就直接自动计算了
return grad_input, grad_weight, grad_bias, grad_bias
3.2 自定义nn.Module
class LinearCraft(nn.Module):
def __init__(self, input_features, output_features, location, bias=True):
super(LinearCraft, self).__init__()
self.input_features = input_features
self.output_features = output_features
self.loc = torch.from_numpy(location)
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
self.register_parameter('bias', None)
nn.init.uniform_(self.weight, -0.1, 0.1)
if self.bias is not None:
nn.init.uniform_(self.bias, -0.1, 0.1)
def forward(self, input):
# 这里注意是直接在forward方法里调用写好的function即可,另外调用方式不是 .forward() 而是 .apply()
return LinearFunction.apply(input, self.weight, self.loc, self.bias)
def extra_repr(self):
return 'input_features={}, output_features={}, bias={}'.format(
self.input_features, self.output_features, self.bias is not None
)
这样的定义结束之后,就相当于LinearCraft就和nn.linear相同,但是使用的是自定义的梯度反传方式了。感兴趣的可以自己debug看一下反传的梯度,对应位置已经是0了。