|
|
|
@ -13,9 +13,12 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Cell_wrapper."""
|
|
|
|
|
from types import FunctionType, MethodType
|
|
|
|
|
|
|
|
|
|
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
|
|
|
|
_get_parallel_mode)
|
|
|
|
|
from mindspore.context import ParallelMode
|
|
|
|
|
from ...common.tensor import Tensor
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
from ...common.parameter import Parameter, ParameterTuple
|
|
|
|
|
from ...ops import composite as C
|
|
|
|
@ -174,6 +177,107 @@ class WithGradCell(Cell):
|
|
|
|
|
return grads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ForwardValueAndGrad(Cell):
|
|
|
|
|
r"""
|
|
|
|
|
Network training package class.
|
|
|
|
|
|
|
|
|
|
Including the network and a gradient function. The resulting Cell is trained with input '\*inputs'.
|
|
|
|
|
The backward graph will be created in the gradient function to calculating gradient.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
network (Cell): The training network. The network only supports single output.
|
|
|
|
|
weights (ParameterTuple): The parameters of the training network that need to calculate the gradient
|
|
|
|
|
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
|
|
|
|
|
get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
|
|
|
|
|
If get_all and get_by_list are both False, get the gradient with respect to first input.
|
|
|
|
|
If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
|
|
|
|
|
at the same time in the form of ((gradients with respect to inputs),
|
|
|
|
|
(gradients with respect to parameters)). Default: False.
|
|
|
|
|
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
|
|
|
|
|
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
|
|
|
|
|
Default: False.
|
|
|
|
|
If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through
|
|
|
|
|
the location parameter or key-value pair parameter. If the value is transferred through the key-value pair
|
|
|
|
|
parameter, the key must be sens.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
|
|
|
|
- sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
- **forward value** (a scalar Tensor with shape :math:`()`) - The result of network forward running.
|
|
|
|
|
- **gradients** (tuple(tensor)) - The gradients of network parameters and inputs.
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend`` ``GPU````CPU``
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> inputs = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32))
|
|
|
|
|
>>> labels = Tensor(np.ones([32]).astype(np.int32))
|
|
|
|
|
>>> net = Net()
|
|
|
|
|
>>> weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
|
|
|
|
|
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
|
|
|
|
>>> #1) Using the WithLossCell existing provide
|
|
|
|
|
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
|
|
|
|
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True)
|
|
|
|
|
>>> loss, grads = forward_value_and_grad(inputs, labels, 1.0)
|
|
|
|
|
>>>
|
|
|
|
|
>>> #2) Using user-defined WithLossCell
|
|
|
|
|
>>> class MyWithLossCell(Cell):
|
|
|
|
|
... def __init__(self, backbone, loss_fn):
|
|
|
|
|
... super(MyWithLossCell, self).__init__(auto_prefix=False)
|
|
|
|
|
... self._backbone = backbone
|
|
|
|
|
... self._loss_fn = loss_fn
|
|
|
|
|
...
|
|
|
|
|
... def construct(self, x, y, label):
|
|
|
|
|
... out = self._backbone(x, y)
|
|
|
|
|
... return self._loss_fn(out, label)
|
|
|
|
|
...
|
|
|
|
|
... @property
|
|
|
|
|
... def backbone_network(self):
|
|
|
|
|
... return self._backbone
|
|
|
|
|
...
|
|
|
|
|
>>> loss_net = MyWithLossCell(net, loss_fn)
|
|
|
|
|
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True)
|
|
|
|
|
>>> loss, grads = forward_value_and_grad(inputs, labels, 1.0)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False):
|
|
|
|
|
super(ForwardValueAndGrad, self).__init__(auto_prefix=False)
|
|
|
|
|
if not isinstance(network, (Cell, FunctionType, MethodType)):
|
|
|
|
|
raise TypeError(f"The type of training network should be cell, function type or method type, "
|
|
|
|
|
f"but got '{type(network)}'")
|
|
|
|
|
if get_by_list and not isinstance(weights, ParameterTuple):
|
|
|
|
|
raise TypeError(f"When get_by_list is set to True, the parameters of training network should be "
|
|
|
|
|
f"ParameterTuple type, but got '{type(weights)}'")
|
|
|
|
|
if get_by_list is not True and weights is not None:
|
|
|
|
|
raise TypeError(f"When get_by_list is set to False, the parameters of training network should be "
|
|
|
|
|
f"NoneType, but got '{type(weights)}'")
|
|
|
|
|
self.network = network
|
|
|
|
|
self.network.set_grad()
|
|
|
|
|
self.weights = weights
|
|
|
|
|
self.get_all = get_all
|
|
|
|
|
self.get_by_list = get_by_list
|
|
|
|
|
self.sens_param = sens_param
|
|
|
|
|
self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
|
|
|
|
|
|
|
|
|
|
def construct(self, *inputs):
|
|
|
|
|
weights = self.weights
|
|
|
|
|
if self.sens_param:
|
|
|
|
|
sens = inputs[-1]
|
|
|
|
|
inputs = inputs[:-1]
|
|
|
|
|
else:
|
|
|
|
|
sens = None
|
|
|
|
|
loss = self.network(*inputs)
|
|
|
|
|
if self.sens_param:
|
|
|
|
|
if not isinstance(sens, Tensor):
|
|
|
|
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), sens)
|
|
|
|
|
grads = self.grad(self.network, weights)(*inputs, sens)
|
|
|
|
|
else:
|
|
|
|
|
grads = self.grad(self.network, weights)(*inputs)
|
|
|
|
|
return loss, grads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainOneStepCell(Cell):
|
|
|
|
|
r"""
|
|
|
|
|
Network training package class.
|
|
|
|
|