Understanding Jacobian Tensor Gradients In Pytorch
Solution 1:
We will go through the entire process: from computing the Jacobian to applying it to get the resulting gradient for this input. We're looking at the operation f(x) = (x + 1)²
, in the simple scalar setting, we get df/dx = 2(x + 1)
as complete derivative.
In the multi-dimensional setting, we have an input x_ij
, and an output y_mn
, indexed by (i, j)
, and (m, n)
respectively. The function mapping is defined as y_mn = (x_mn + 1)²
.
First, we should look at the Jacobian itself, this corresponds to the tensor J
containing all partial derivatives J_ijmn = dy_mn/dx_ij
. From the expression of y_mn
we can say that for all i
, j
, m
, and n
: dy_mn/dx_ij = d(x_mn + 1)²/dx_ij
which is 0
if m≠i
or n≠j
. Else, i.e.m=i
or n=j
, we have that d(x_mn + 1)²/dx_ij = d(x_ij + 1)²/dx_ij = 2(x_ij + 1)
.
As a result, J_ijmn
can be simply defined as
↱ 2(x_ij + 1) if i=m, j=nJ_ijmn=
↳ 0else
From the rule chain the gradient of the output with respect to the input x
is denoted as dL/dx = dL/dy*dy/dx
. From a PyTorch perspective we have the following relationships:
x.grad = dL/dx
, shaped likex
,dL/dy
is the incoming gradient: thegradient
argument in thebackward
functiondL/dx
is the Jacobian tensor described above.
As explained in the documentation, applying backward
doesn't actually provide the Jacobian. It computes the chain rule product directly and stores the gradient (i.e.dL/dx
inside x.grad
).
In terms of shapes, the Jacobian multiplication dL/dy*dy/dx = gradient*J
reduces itself to a tensor of the same shape as x
.
The operation performed is defined by: [dL/dx]_ij = ∑_mn([dL/dy]_ij * J_ijmn)
.
If we apply this to your example. We have x = 1(i=j)
(where 1(k): (k == True) -> 1
is the indicator function), essentially just the identity matrix.
We compute the Jacobian:
↱ 2(1(i=j) + 1) = if i=m, j=nJ_ijmn=
↳ 0else
which becomes
↱ 2(1 + 1) = 4if i=j=m=nJ_ijmn= → 2(0 + 1) = 2if i=m, j=n, i≠j
↳ 0else
For visualization purposes, we will stick with x = torch.eye(2)
:
>>> f = lambda x: (x+1)**2
>>> J = A.jacobian(f, inp)
tensor([[[[4., 0.],
[0., 0.]],
[[0., 2.],
[0., 0.]]],
[[[0., 0.],
[2., 0.]],
[[0., 0.],
[0., 4.]]]])
Then computing the matrix multiplication using torch.einsum
(I won't go into details, look through this, then this for an in-depth overview of the EinSum summation operator):
>>> torch.einsum('ij,ijmn->mn', torch.ones_like(inp), J)
tensor([[4., 2.],
[2., 4.]])
This matches what you get when back propagating from out
with torch.ones_like(inp)
as incoming gradient:
>>>out = f(inp)>>>out.backward(torch.ones_like(inp))>>>inp.grad
tensor([[4., 2.],
[2., 4.]])
If you backpropagate twice (while retaining the graph of course) you end up computing the same operation which accumulating on the parameter's grad
attribute. So, naturally, after two backward passes you have twice the gradient:
>>>out = f(inp)>>>out.backward(torch.ones_like(inp), retain_graph=True)>>>out.backward(torch.ones_like(inp))>>>inp.grad
tensor([[8., 4.],
[4., 8.]])
Those gradients will accumulate, you can reset them by calling the inplace function zero_
: inp.grad.zero_()
. From there if you backpropagate again you will stand with one accumulate gradient only.
In practice, you would register your parameters on an optimizer, from which you can call zero_grad
enabling you to handle and reset all parameters in that collection in one go.
I have imported torch.autograd.functional
as A
Post a Comment for "Understanding Jacobian Tensor Gradients In Pytorch"