From 7c2c6da9a7d1948db0d10701e12a821290e4e75d Mon Sep 17 00:00:00 2001 From: zhaoting Date: Thu, 7 May 2020 20:22:55 +0800 Subject: [PATCH] adapte FusedMulApplyMomentum --- .../_op_impl/tbe/fused_mul_apply_momentum.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py b/mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py index a8f84427d6..14e4323c6f 100644 --- a/mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +++ b/mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py @@ -31,22 +31,23 @@ fused_mul_apply_momentum_op_info = TBERegOp("FusedMulApplyMomentum") \ .input(4, "momentum", False, "required", "all") \ .input(5, "x2", False, "required", "all") \ .output(0, "var", False, "required", "all") \ + .output(1, "accum", False, "required", "all") \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD, - DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \ + DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0, - DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, - DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ + DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, - DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ + DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0, - DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, - DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ + DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ .get_op_info()