|
|
|
@ -92,18 +92,196 @@ def core(fn=None, **flags):
|
|
|
|
|
|
|
|
|
|
class GradOperation(GradOperation_):
|
|
|
|
|
"""
|
|
|
|
|
An metafuncgraph object which is used to get the gradient of output of a network(function).
|
|
|
|
|
An higher-order function which is used to generate the gradient function for the input function.
|
|
|
|
|
|
|
|
|
|
The GradOperation will convert the network(function) into a back propagation graph.
|
|
|
|
|
The gradient function generated by `GradOperation` higher-order function can be customized by construction args.
|
|
|
|
|
|
|
|
|
|
Given an input function `net = Net()` that take `x` and `y` as inputs, and has a parameter `z`,
|
|
|
|
|
see `Net` in Examples.
|
|
|
|
|
|
|
|
|
|
To generate a gradient function that returns gradients with respect to the first input
|
|
|
|
|
(see `GradNetWrtX` in Examples).
|
|
|
|
|
|
|
|
|
|
1. Construct a `GradOperation` higher-order function with default arguments:
|
|
|
|
|
`grad_op = GradOperation()`.
|
|
|
|
|
|
|
|
|
|
2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
|
|
|
|
|
|
|
|
|
|
3. Call the gradient function with input function's inputs to get the gradients with respect to the first input:
|
|
|
|
|
`grad_op(net)(x, y)`.
|
|
|
|
|
|
|
|
|
|
To generate a gradient function that returns gradients with respect to all inputs (see `GradNetWrtXY` in Examples).
|
|
|
|
|
|
|
|
|
|
1. Construct a `GradOperation` higher-order function with `get_all=True` which
|
|
|
|
|
indicates getting gradients with respect to all inputs, they are `x` and `y` in example function `Net()`:
|
|
|
|
|
`grad_op = GradOperation(get_all=True)`.
|
|
|
|
|
|
|
|
|
|
2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
|
|
|
|
|
|
|
|
|
|
3. Call the gradient function with input function's inputs to get the gradients with respect to all inputs:
|
|
|
|
|
`gradient_function(x, y)`.
|
|
|
|
|
|
|
|
|
|
To generate a gradient function that returns gradients with respect to given parameters
|
|
|
|
|
(see `GradNetWithWrtParams` in Examples).
|
|
|
|
|
|
|
|
|
|
1. Construct a `GradOperation` higher-order function with `get_by_list=True`:
|
|
|
|
|
`grad_op = GradOperation(get_by_list=True)`.
|
|
|
|
|
|
|
|
|
|
2. Construct a `ParameterTuple` that will be passed along input function when constructing
|
|
|
|
|
`GradOperation` higher-order function, it will be used as a parameter filter that determine
|
|
|
|
|
which gradient to return: `params = ParameterTuple(net.trainable_params())`.
|
|
|
|
|
|
|
|
|
|
3. Call it with input function and `params` as arguments to get the gradient function:
|
|
|
|
|
`gradient_function = grad_op(net, params)`.
|
|
|
|
|
|
|
|
|
|
4. Call the gradient function with input function's inputs to get the gradients with
|
|
|
|
|
respect to given parameters: `gradient_function(x, y)`.
|
|
|
|
|
|
|
|
|
|
To generate a gradient function that returns gradients with respect to all inputs and given parameters
|
|
|
|
|
in the format of ((dx, dy), (dz))(see `GradNetWrtInputsAndParams` in Examples).
|
|
|
|
|
|
|
|
|
|
1. Construct a `GradOperation` higher-order function with `get_all=True` and `get_by_list=True`:
|
|
|
|
|
`grad_op = GradOperation(get_all=True, get_by_list=True)`.
|
|
|
|
|
|
|
|
|
|
2. Construct a `ParameterTuple` that will be passed along input function when constructing
|
|
|
|
|
`GradOperation` higher-order function: `params = ParameterTuple(net.trainable_params())`.
|
|
|
|
|
|
|
|
|
|
3. Call it with input function and `params` as arguments to get the gradient function:
|
|
|
|
|
`gradient_function = grad_op(net, params)`.
|
|
|
|
|
|
|
|
|
|
4. Call the gradient function with input function's inputs
|
|
|
|
|
to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`.
|
|
|
|
|
|
|
|
|
|
We can configure the sensitiviy(gradient with respect to output) by setting `sens_param=True` and
|
|
|
|
|
passing in an extra sensitiviy input to the gradient function, the sensitiviy input should be
|
|
|
|
|
with same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples).
|
|
|
|
|
|
|
|
|
|
1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`:
|
|
|
|
|
`grad_op = GradOperation(get_all=True, sens_param=True)`.
|
|
|
|
|
|
|
|
|
|
2. Define grad_wrt_output as sens_param which works as the gradient with respect to output:
|
|
|
|
|
`grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`.
|
|
|
|
|
|
|
|
|
|
3. Call it with input function as argument to get the gradient function:
|
|
|
|
|
`gradient_function = grad_op(net)`.
|
|
|
|
|
|
|
|
|
|
4. Call the gradient function with input function's inputs and sens_param to
|
|
|
|
|
get the gradients with respect to all inputs:
|
|
|
|
|
`gradient_function(x, y, grad_wrt_output)`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
get_all (bool): If True, get all the gradients w.r.t inputs. Default: False.
|
|
|
|
|
get_by_list (bool): If True, get all the gradients w.r.t Parameter variables.
|
|
|
|
|
If get_all and get_by_list are both False, get the gradient w.r.t first input.
|
|
|
|
|
If get_all and get_by_list are both True, get the gradients w.r.t inputs and Parameter variables
|
|
|
|
|
at the same time in the form of ((grads w.r.t inputs), (grads w.r.t parameters)). Default: False.
|
|
|
|
|
sens_param (bool): Whether append sensitivity as input. If sens_param is False,
|
|
|
|
|
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
|
|
|
|
|
get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
|
|
|
|
|
If get_all and get_by_list are both False, get the gradient with respect to first input.
|
|
|
|
|
If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
|
|
|
|
|
at the same time in the form of ((gradients with respect to inputs),
|
|
|
|
|
(gradients with respect to parameters)). Default: False.
|
|
|
|
|
sens_param (bool): Whether append sensitivity(gradient with respect to output) as input. If sens_param is False,
|
|
|
|
|
a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The higher-order function which takes a function as argument and returns gradient function for it.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> class Net(nn.Cell):
|
|
|
|
|
>>> def __init__(self):
|
|
|
|
|
>>> super(Net, self).__init__()
|
|
|
|
|
>>> self.matmul = P.MatMul()
|
|
|
|
|
>>> self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
|
|
|
|
|
>>> def construct(self, x, y):
|
|
|
|
|
>>> x = x * self.z
|
|
|
|
|
>>> out = self.matmul(x, y)
|
|
|
|
|
>>> return out
|
|
|
|
|
>>>
|
|
|
|
|
>>> class GradNetWrtX(nn.Cell):
|
|
|
|
|
>>> def __init__(self, net):
|
|
|
|
|
>>> super(GradNetWrtX, self).__init__()
|
|
|
|
|
>>> self.net = net
|
|
|
|
|
>>> self.grad_op = GradOperation()
|
|
|
|
|
>>> def construct(self, x, y):
|
|
|
|
|
>>> gradient_function = self.grad_op(self.net)
|
|
|
|
|
>>> return gradient_function(x, y)
|
|
|
|
|
>>>
|
|
|
|
|
>>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
|
|
|
|
>>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
|
|
|
|
>>> GradNetWrtX(Net())(x, y)
|
|
|
|
|
Tensor(shape=[2, 3], dtype=Float32,
|
|
|
|
|
[[1.4100001 1.5999999 6.6 ]
|
|
|
|
|
[1.4100001 1.5999999 6.6 ]])
|
|
|
|
|
>>>
|
|
|
|
|
>>> class GradNetWrtXY(nn.Cell):
|
|
|
|
|
>>> def __init__(self, net):
|
|
|
|
|
>>> super(GradNetWrtXY, self).__init__()
|
|
|
|
|
>>> self.net = net
|
|
|
|
|
>>> self.grad_op = GradOperation(get_all=True)
|
|
|
|
|
>>> def construct(self, x, y):
|
|
|
|
|
>>> gradient_function = self.grad_op(self.net)
|
|
|
|
|
>>> return gradient_function(x, y)
|
|
|
|
|
>>>
|
|
|
|
|
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
|
|
|
|
|
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
|
|
|
|
|
>>> GradNetWrtXY(Net())(x, y)
|
|
|
|
|
(Tensor(shape=[2, 3], dtype=Float32,
|
|
|
|
|
[[4.5099998 2.7 3.6000001]
|
|
|
|
|
[4.5099998 2.7 3.6000001]]), Tensor(shape=[3, 3], dtype=Float32,
|
|
|
|
|
[[2.6 2.6 2.6 ]
|
|
|
|
|
[1.9 1.9 1.9 ]
|
|
|
|
|
[1.3000001 1.3000001 1.3000001]]))
|
|
|
|
|
>>>
|
|
|
|
|
>>> class GradNetWrtXYWithSensParam(nn.Cell):
|
|
|
|
|
>>> def __init__(self, net):
|
|
|
|
|
>>> super(GradNetWrtXYWithSensParam, self).__init__()
|
|
|
|
|
>>> self.net = net
|
|
|
|
|
>>> self.grad_op = GradOperation(get_all=True, sens_param=True)
|
|
|
|
|
>>> self.grad_wrt_output = Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=mstype.float32)
|
|
|
|
|
>>> def construct(self, x, y):
|
|
|
|
|
>>> gradient_function = self.grad_op(self.net)
|
|
|
|
|
>>> return gradient_function(x, y, self.grad_wrt_output)
|
|
|
|
|
>>>
|
|
|
|
|
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
|
|
|
|
|
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
|
|
|
|
|
>>> GradNetWrtXYWithSensParam(Net())(x, y)
|
|
|
|
|
(Tensor(shape=[2, 3], dtype=Float32,
|
|
|
|
|
[[2.211 0.51 1.4900001]
|
|
|
|
|
[5.588 2.68 4.07 ]]), Tensor(shape=[3, 3], dtype=Float32,
|
|
|
|
|
[[1.52 2.82 2.14 ]
|
|
|
|
|
[1.1 2.05 1.55 ]
|
|
|
|
|
[0.90000004 1.55 1.25 ]]))
|
|
|
|
|
>>>
|
|
|
|
|
>>> class GradNetWithWrtParams(nn.Cell):
|
|
|
|
|
>>> def __init__(self, net):
|
|
|
|
|
>>> super(GradNetWithWrtParams, self).__init__()
|
|
|
|
|
>>> self.net = net
|
|
|
|
|
>>> self.params = ParameterTuple(net.trainable_params())
|
|
|
|
|
>>> self.grad_op = GradOperation(get_by_list=True)
|
|
|
|
|
>>> def construct(self, x, y):
|
|
|
|
|
>>> gradient_function = self.grad_op(self.net, self.params)
|
|
|
|
|
>>> return gradient_function(x, y)
|
|
|
|
|
>>>
|
|
|
|
|
>>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
|
|
|
|
|
>>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
|
|
|
|
|
>>> GradNetWithWrtParams(Net())(x, y)
|
|
|
|
|
(Tensor(shape=[1], dtype=Float32, [21.536]),)
|
|
|
|
|
>>>
|
|
|
|
|
>>> class GradNetWrtInputsAndParams(nn.Cell):
|
|
|
|
|
>>> def __init__(self, net):
|
|
|
|
|
>>> super(GradNetWrtInputsAndParams, self).__init__()
|
|
|
|
|
>>> self.net = net
|
|
|
|
|
>>> self.params = ParameterTuple(net.trainable_params())
|
|
|
|
|
>>> self.grad_op = GradOperation(get_all=True, get_by_list=True)
|
|
|
|
|
>>> def construct(self, x, y):
|
|
|
|
|
>>> gradient_function = self.grad_op(self.net, self.params)
|
|
|
|
|
>>> return gradient_function(x, y)
|
|
|
|
|
>>>
|
|
|
|
|
>>> x = Tensor([[0.1, 0.6, 1.2], [0.5, 1.3, 0.1]], dtype=mstype.float32)
|
|
|
|
|
>>> y = Tensor([[0.12, 2.3, 1.1], [1.3, 0.2, 2.4], [0.1, 2.2, 0.3]], dtype=mstype.float32)
|
|
|
|
|
>>> GradNetWrtInputsAndParams(Net())(x, y)
|
|
|
|
|
((Tensor(shape=[2, 3], dtype=Float32,
|
|
|
|
|
[[3.52 3.9 2.6 ]
|
|
|
|
|
[3.52 3.9 2.6 ]]), Tensor(shape=[3, 3], dtype=Float32,
|
|
|
|
|
[[0.6 0.6 0.6 ]
|
|
|
|
|
[1.9 1.9 1.9 ]
|
|
|
|
|
[1.3000001 1.3000001 1.3000001]])), (Tensor(shape=[1], dtype=Float32, [12.902]),))
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, get_all=False, get_by_list=False, sens_param=False):
|
|
|
|
|