diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 024f93f5d3..3d61ba1aac 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -374,9 +374,13 @@ class Cell: inputs (Function or Cell): inputs of construct method. """ parallel_inputs_run = [] - if len(inputs) > self._construct_inputs_num: - raise ValueError('Len of inputs: {} is bigger than self._construct_inputs_num: {}.'. - format(len(inputs), self._construct_inputs_num)) + # judge if *args exists in input + if self.argspec[1] is not None: + prefix = self.argspec[1] + for i in range(len(inputs)): + key = prefix + str(i) + self._construct_inputs_names = self._construct_inputs_names + (key,) + self._construct_inputs_num = self._construct_inputs_num + 1 for i, tensor in enumerate(inputs): key = self._construct_inputs_names[i] # if input is not used, self.parameter_layout_dict may not contain the key @@ -404,7 +408,7 @@ class Cell: from mindspore._extends.parse.parser import get_parse_method_of_class fn = get_parse_method_of_class(self) - inspect.getfullargspec(fn) + self.argspec = inspect.getfullargspec(fn) self._construct_inputs_num = fn.__code__.co_argcount self._construct_inputs_names = fn.__code__.co_varnames diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index 014cc8f823..7781e52d57 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -41,7 +41,7 @@ def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment, class Momentum(Optimizer): - """ + r""" Implements the Momentum algorithm. Refer to the paper on the importance of initialization and momentum in deep learning for more details. @@ -56,13 +56,13 @@ class Momentum(Optimizer): .. math:: v_{t} = v_{t-1} \ast u + gradients - If use_nesterov is True: - .. math:: - p_{t} = p_{t-1} - (grad \ast lr + v_{t} \ast u \ast lr) + If use_nesterov is True: + .. math:: + p_{t} = p_{t-1} - (grad \ast lr + v_{t} \ast u \ast lr) - If use_nesterov is Flase: - .. math:: - p_{t} = p_{t-1} - lr \ast v_{t} + If use_nesterov is Flase: + .. math:: + p_{t} = p_{t-1} - lr \ast v_{t} Here: where grad, lr, p, v and u denote the gradients, learning_rate, params, moments, and momentum respectively. diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index 1d79cfad42..e684fae22f 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -32,7 +32,7 @@ def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, s class SGD(Optimizer): - """ + r""" Implements stochastic gradient descent (optionally with momentum). Introduction to SGD can be found at https://en.wikipedia.org/wiki/Stochastic_gradient_descent. @@ -47,15 +47,15 @@ class SGD(Optimizer): To improve parameter groups performance, the customized order of parameters can be supported. .. math:: - v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening) + v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening) - If nesterov is True: - .. math:: - p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1}) + If nesterov is True: + .. math:: + p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1}) - If nesterov is Flase: - .. math:: - p_{t+1} = p_{t} - lr \ast v_{t+1} + If nesterov is Flase: + .. math:: + p_{t+1} = p_{t} - lr \ast v_{t+1} To be noticed, for the first step, v_{t+1} = gradient diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index a7fb4adcf0..980585e270 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -82,7 +82,7 @@ class WithGradCell(Cell): Wraps the network with backward cell to compute gradients. A network with a loss function is necessary as argument. If loss function in None, the network must be a wrapper of network and loss function. This - Cell accepts data and label as inputs and returns gradients for each trainable parameter. + Cell accepts *inputs as inputs and returns gradients for each trainable parameter. Note: Run in PyNative mode. @@ -95,8 +95,7 @@ class WithGradCell(Cell): output value. Default: None. Inputs: - - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. - - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. + - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. Outputs: list, a list of Tensors with identical shapes as trainable weights. @@ -126,12 +125,12 @@ class WithGradCell(Cell): self.network_with_loss = WithLossCell(self.network, self.loss_fn) self.network_with_loss.set_train() - def construct(self, data, label): + def construct(self, *inputs): weights = self.weights if self.sens is None: - grads = self.grad(self.network_with_loss, weights)(data, label) + grads = self.grad(self.network_with_loss, weights)(*inputs) else: - grads = self.grad(self.network_with_loss, weights)(data, label, self.sens) + grads = self.grad(self.network_with_loss, weights)(*inputs, self.sens) return grads @@ -139,7 +138,7 @@ class TrainOneStepCell(Cell): r""" Network training package class. - Wraps the network with an optimizer. The resulting Cell be trained with input data and label. + Wraps the network with an optimizer. The resulting Cell be trained with input *inputs. Backward graph will be created in the construct function to do parameter updating. Different parallel modes are available to run the training. @@ -149,8 +148,7 @@ class TrainOneStepCell(Cell): sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. Inputs: - - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. - - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. + - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. Outputs: Tensor, a scalar Tensor with shape :math:`()`. @@ -181,11 +179,11 @@ class TrainOneStepCell(Cell): degree = _get_device_num() self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - def construct(self, data, label): + def construct(self, *inputs): weights = self.weights - loss = self.network(data, label) + loss = self.network(*inputs) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) - grads = self.grad(self.network, weights)(data, label, sens) + grads = self.grad(self.network, weights)(*inputs, sens) if self.reducer_flag: # apply grad reducer on grads grads = self.grad_reducer(grads)