!4058 modify parameter input

Merge pull request !4058 from lijiaqi/cell_inputs
pull/4058/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit cfae4096d2

@ -383,9 +383,13 @@ class Cell:
inputs (Function or Cell): inputs of construct method.
"""
parallel_inputs_run = []
if len(inputs) > self._construct_inputs_num:
raise ValueError('Len of inputs: {} is bigger than self._construct_inputs_num: {}.'.
format(len(inputs), self._construct_inputs_num))
# judge if *args exists in input
if self.argspec[1] is not None:
prefix = self.argspec[1]
for i in range(len(inputs)):
key = prefix + str(i)
self._construct_inputs_names = self._construct_inputs_names + (key,)
self._construct_inputs_num = self._construct_inputs_num + 1
for i, tensor in enumerate(inputs):
key = self._construct_inputs_names[i]
# if input is not used, self.parameter_layout_dict may not contain the key
@ -412,7 +416,7 @@ class Cell:
from mindspore._extends.parse.parser import get_parse_method_of_class
fn = get_parse_method_of_class(self)
inspect.getfullargspec(fn)
self.argspec = inspect.getfullargspec(fn)
self._construct_inputs_num = fn.__code__.co_argcount
self._construct_inputs_names = fn.__code__.co_varnames

@ -41,7 +41,7 @@ def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment,
class Momentum(Optimizer):
"""
r"""
Implements the Momentum algorithm.
Refer to the paper on the importance of initialization and momentum in deep learning for more details.

@ -32,7 +32,7 @@ def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, s
class SGD(Optimizer):
"""
r"""
Implements stochastic gradient descent (optionally with momentum).
Introduction to SGD can be found at https://en.wikipedia.org/wiki/Stochastic_gradient_descent.

@ -82,7 +82,7 @@ class WithGradCell(Cell):
Wraps the network with backward cell to compute gradients. A network with a loss function is necessary
as argument. If loss function in None, the network must be a wrapper of network and loss function. This
Cell accepts data and label as inputs and returns gradients for each trainable parameter.
Cell accepts *inputs as inputs and returns gradients for each trainable parameter.
Note:
Run in PyNative mode.
@ -95,8 +95,7 @@ class WithGradCell(Cell):
output value. Default: None.
Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
Outputs:
list, a list of Tensors with identical shapes as trainable weights.
@ -126,12 +125,12 @@ class WithGradCell(Cell):
self.network_with_loss = WithLossCell(self.network, self.loss_fn)
self.network_with_loss.set_train()
def construct(self, data, label):
def construct(self, *inputs):
weights = self.weights
if self.sens is None:
grads = self.grad(self.network_with_loss, weights)(data, label)
grads = self.grad(self.network_with_loss, weights)(*inputs)
else:
grads = self.grad(self.network_with_loss, weights)(data, label, self.sens)
grads = self.grad(self.network_with_loss, weights)(*inputs, self.sens)
return grads
@ -139,7 +138,7 @@ class TrainOneStepCell(Cell):
r"""
Network training package class.
Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
Wraps the network with an optimizer. The resulting Cell be trained with input *inputs.
Backward graph will be created in the construct function to do parameter updating. Different
parallel modes are available to run the training.
@ -149,8 +148,7 @@ class TrainOneStepCell(Cell):
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
Outputs:
Tensor, a scalar Tensor with shape :math:`()`.
@ -181,11 +179,11 @@ class TrainOneStepCell(Cell):
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, data, label):
def construct(self, *inputs):
weights = self.weights
loss = self.network(data, label)
loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(data, label, sens)
grads = self.grad(self.network, weights)(*inputs, sens)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)

Loading…
Cancel
Save