diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index 1dfc91743c..dce621a765 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -179,7 +179,7 @@ class TrainOneStepWithLossScaleCell(Cell): network (Cell): The training network. The network only supports single output. optimizer (Cell): Optimizer for updating the weights. scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value - is Tensor type, Tensor with shape :math:`()`. Default: None. + is Tensor type, Tensor with shape :math:`()`. Inputs: - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. @@ -189,6 +189,7 @@ class TrainOneStepWithLossScaleCell(Cell): - **loss** (Tensor) - Tensor with shape :math:`()`. - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool. + - **loss scaling value** (Tensor) - Tensor with shape :math:`()` Examples: >>> net_with_loss = Net() @@ -203,7 +204,7 @@ class TrainOneStepWithLossScaleCell(Cell): >>> output = train_network(inputs, label, scaling_sens) """ - def __init__(self, network, optimizer, scale_sense=None): + def __init__(self, network, optimizer, scale_sense): super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network self.network.set_grad() @@ -236,14 +237,15 @@ class TrainOneStepWithLossScaleCell(Cell): self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE - self.scale_sense = None self.loss_scaling_manager = None if isinstance(scale_sense, Cell): self.loss_scaling_manager = scale_sense self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), name="scale_sense") - if isinstance(scale_sense, Tensor): + elif isinstance(scale_sense, Tensor): self.scale_sense = Parameter(scale_sense, name='scale_sense') + else: + raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense))) @C.add_flags(has_effect=True) def construct(self, *inputs): @@ -293,4 +295,6 @@ class TrainOneStepWithLossScaleCell(Cell): """If the user has set the sens in the training process and wants to reassign the value, he can call this function again to make modification, and sens needs to be of type Tensor.""" if self.scale_sense and isinstance(sens, Tensor): - self.self.scale_sense.set_data(sens) + self.scale_sense.set_data(sens) + else: + raise TypeError("The input type must be Tensor,but got {}".format(type(sens)))