diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index d3fccc11fd..e77d348630 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -201,6 +201,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) { if (AnfAlgo::GetCNodeName(kernel) == "ApplyMomentum") { auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0); AnfAlgo::SetOutputAddr(device_address, 0, kernel.get()); + AnfAlgo::SetOutputAddr(device_address, 1, kernel.get()); return; } diff --git a/mindspore/ops/_op_impl/tbe/apply_momentum.py b/mindspore/ops/_op_impl/tbe/apply_momentum.py index 920a48e48d..deb8f0d387 100644 --- a/mindspore/ops/_op_impl/tbe/apply_momentum.py +++ b/mindspore/ops/_op_impl/tbe/apply_momentum.py @@ -29,23 +29,24 @@ apply_momentum_op_info = TBERegOp("ApplyMomentum") \ .input(2, "lr", False, "required", "all") \ .input(3, "grad", False, "required", "all") \ .input(4, "momentum", False, "required", "all") \ - .output(0, "out", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .output(1, "accum", False, "required", "all") \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_Default) \ + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD, - DataType.F16_Default, DataType.F16_5HD) \ + DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0, - DataType.F16_Default, DataType.F16_C1HWNCoC0) \ + DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, - DataType.F16_Default, DataType.F16_FracZ) \ + DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default) \ + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, - DataType.F32_Default, DataType.F32_5HD) \ + DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0, - DataType.F32_Default, DataType.F32_C1HWNCoC0) \ + DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, - DataType.F32_Default, DataType.F32_FracZ) \ + DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ .get_op_info() diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index fb8e908924..2a2dbe08a8 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1427,8 +1427,11 @@ class ApplyMomentum(PrimitiveWithInfer): def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0): self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'], outputs=['output']) + self.is_tbe = context.get_context("device_target") == "Ascend" def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape): + if self.is_tbe: + return v_shape, v_shape return v_shape def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): @@ -1439,6 +1442,8 @@ class ApplyMomentum(PrimitiveWithInfer): validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name) + if self.is_tbe: + return g_dtype, g_dtype return g_dtype