diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index b11b9dc9dc..f0d822fa3d 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -18,7 +18,6 @@ from types import FunctionType, MethodType from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, _get_parallel_mode) from mindspore.context import ParallelMode -from ...common.tensor import Tensor from ...common import dtype as mstype from ...common.parameter import Parameter, ParameterTuple from ...ops import composite as C @@ -197,15 +196,16 @@ class ForwardValueAndGrad(Cell): If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False. If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through - the location parameter or key-value pair parameter. If the value is transferred through the key-value pair - parameter, the key must be sens. - sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. + the input parameter. Inputs: - **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. + - **(\*sens)** - A sensitivity (gradient with respect to output) as the input of backpropagation. + If network has single output, the sens is a tensor. + If network has multiple outputs, the sens is the tuple(tensor). Outputs: - - **forward value** (a scalar Tensor with shape :math:`()`) - The result of network forward running. + - **forward value** - The result of network forward running. - **gradients** (tuple(tensor)) - The gradients of network parameters and inputs. Supported Platforms: @@ -219,8 +219,8 @@ class ForwardValueAndGrad(Cell): >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() >>> #1) Using the WithLossCell existing provide >>> loss_net = nn.WithLossCell(net, loss_fn) - >>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True) - >>> loss, grads = forward_value_and_grad(inputs, labels, 1.0) + >>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weights, get_by_list=True) + >>> loss, grads = forward_value_and_grad(inputs, labels) >>> >>> #2) Using user-defined WithLossCell >>> class MyWithLossCell(Cell): @@ -238,40 +238,40 @@ class ForwardValueAndGrad(Cell): ... return self._backbone ... >>> loss_net = MyWithLossCell(net, loss_fn) - >>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True) - >>> loss, grads = forward_value_and_grad(inputs, labels, 1.0) + >>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weights, get_by_list=True) + >>> loss, grads = forward_value_and_grad(inputs, labels) """ - def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False, sens=1.0): + def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False): super(ForwardValueAndGrad, self).__init__(auto_prefix=False) if not isinstance(network, (Cell, FunctionType, MethodType)): raise TypeError(f"The type of training network should be cell, function type or method type, " f"but got '{type(network)}'") + if not isinstance(get_all, bool): + raise TypeError(f"The type of get_all should be bool, but got '{type(get_all)}'") + if not isinstance(get_by_list, bool): + raise TypeError(f"The type of get_by_list should be bool, but got '{type(get_by_list)}'") if get_by_list and not isinstance(weights, ParameterTuple): raise TypeError(f"When get_by_list is set to True, the parameters of training network should be " f"ParameterTuple type, but got '{type(weights)}'") - if get_by_list is not True and weights is not None: - raise TypeError(f"When get_by_list is set to False, the parameters of training network should be " - f"NoneType, but got '{type(weights)}'") self.network = network - self.network.set_grad() + if isinstance(network, Cell): + self.network.set_grad() self.weights = weights self.get_all = get_all self.get_by_list = get_by_list self.sens_param = sens_param - self.sens = sens self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param) def construct(self, *inputs): - weights = self.weights - loss = self.network(*inputs) + grad_inputs = inputs if self.sens_param: - sens = self.sens - if not isinstance(self.sens, Tensor): - sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) - grads = self.grad(self.network, weights)(*inputs, sens) + inputs = inputs[:-1] + loss = self.network(*inputs) + if self.get_by_list: + grads = self.grad(self.network, self.weights)(*grad_inputs) else: - grads = self.grad(self.network, weights)(*inputs) + grads = self.grad(self.network)(*grad_inputs) return loss, grads diff --git a/tests/st/networks/test_gpu_resnet.py b/tests/st/networks/test_gpu_resnet.py index 85f2a2a1f8..06c9785065 100644 --- a/tests/st/networks/test_gpu_resnet.py +++ b/tests/st/networks/test_gpu_resnet.py @@ -414,40 +414,14 @@ def test_trainTensor_with_new_interface(num_classes=10, epoch=8, batch_size=1): weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters())) optimizer = Momentum(weights, 0.1, 0.9) - train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True, - sens=1.0) + train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True) losses = [] for i in range(0, epoch): data = Tensor(np.ones([batch_size, 3, 224, 224] ).astype(np.float32) * 0.01) label = Tensor(np.ones([batch_size]).astype(np.int32)) - loss, grads = train_network(data, label) - grads = F.identity(grads) - optimizer(grads) - losses.append(loss) - assert (losses[-1].asnumpy() < 0.8) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_big_batchSize_with_new_interface(num_classes=10, epoch=8, batch_size=338): - net = resnet50(num_classes) - criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - net_with_criterion = WithLossCell(net, criterion) - net_with_criterion.set_train() - - weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters())) - optimizer = Momentum(weights, 0.1, 0.9) - - train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True, - sens=1.0) - losses = [] - for i in range(0, epoch): - data = Tensor(np.ones([batch_size, 3, 224, 224] - ).astype(np.float32) * 0.01) - label = Tensor(np.ones([batch_size]).astype(np.int32)) - loss, grads = train_network(data, label) + sens = Tensor(np.ones([1]).astype(np.float32)) + loss, grads = train_network(data, label, sens) grads = F.identity(grads) optimizer(grads) losses.append(loss)