|
|
|
@ -180,7 +180,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|
|
|
|
network (Cell): The training network. The network only supports single output.
|
|
|
|
|
optimizer (Cell): Optimizer for updating the weights.
|
|
|
|
|
scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value
|
|
|
|
|
is Tensor type, Tensor with shape :math:`()`.
|
|
|
|
|
is Tensor type, Tensor with shape :math:`()` or :math:`(1,)`.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
|
|
|
@ -230,7 +230,10 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|
|
|
|
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
|
|
|
|
|
name="scale_sense")
|
|
|
|
|
elif isinstance(scale_sense, Tensor):
|
|
|
|
|
self.scale_sense = Parameter(scale_sense, name='scale_sense')
|
|
|
|
|
if scale_sense.shape == (1,) or scale_sense.shape == ():
|
|
|
|
|
self.scale_sense = Parameter(scale_sense, name='scale_sense')
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format(scale_sense.shape))
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))
|
|
|
|
|
|
|
|
|
@ -284,4 +287,4 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|
|
|
|
if self.scale_sense and isinstance(sens, Tensor):
|
|
|
|
|
self.scale_sense.set_data(sens)
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError("The input type must be Tensor,but got {}".format(type(sens)))
|
|
|
|
|
raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))
|
|
|
|
|