From fa03a66433201edf4ab26ec482dcba758346f72f Mon Sep 17 00:00:00 2001 From: zhaoting Date: Mon, 13 Apr 2020 10:43:12 +0800 Subject: [PATCH] change adam output numbers adapter to tbe --- mindspore/ops/_op_impl/tbe/apply_adam.py | 46 ++++++++++++++++++++---- mindspore/ops/operations/nn_ops.py | 4 +-- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/mindspore/ops/_op_impl/tbe/apply_adam.py b/mindspore/ops/_op_impl/tbe/apply_adam.py index ae6b7d782e..1d5c383515 100644 --- a/mindspore/ops/_op_impl/tbe/apply_adam.py +++ b/mindspore/ops/_op_impl/tbe/apply_adam.py @@ -88,7 +88,8 @@ from mindspore.ops.op_info_register import op_info_register "float16","float16","float16","float16","float","float","float", "float" ], "format": [ - "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" + "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", + "DefaultFormat", "DefaultFormat" ], "name": "beta1_power", "need_compile": false, @@ -101,7 +102,8 @@ from mindspore.ops.op_info_register import op_info_register "float16","float16","float16","float16","float","float","float","float" ], "format": [ - "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" + "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", + "DefaultFormat", "DefaultFormat" ], "name": "beta2_power", "need_compile": false, @@ -114,7 +116,8 @@ from mindspore.ops.op_info_register import op_info_register "float16","float16","float16","float16","float","float","float", "float" ], "format": [ - "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" + "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", + "DefaultFormat", "DefaultFormat" ], "name": "lr", "need_compile": false, @@ -127,7 +130,8 @@ from mindspore.ops.op_info_register import op_info_register "float16","float16","float16","float16","float","float","float", "float" ], "format": [ - "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" + "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", + "DefaultFormat", "DefaultFormat" ], "name": "beta1", "need_compile": false, @@ -140,7 +144,8 @@ from mindspore.ops.op_info_register import op_info_register "float16","float16","float16","float16","float","float","float", "float" ], "format": [ - "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" + "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", + "DefaultFormat", "DefaultFormat" ], "name": "beta2", "need_compile": false, @@ -153,7 +158,8 @@ from mindspore.ops.op_info_register import op_info_register "float16","float16","float16","float16","float","float","float", "float" ], "format": [ - "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" + "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", + "DefaultFormat", "DefaultFormat" ], "name": "epsilon", "need_compile": false, @@ -161,7 +167,7 @@ from mindspore.ops.op_info_register import op_info_register "shape": "all" }, { - "index": 8, + "index": 9, "dtype": [ "float16","float16","float16","float16","float","float","float", "float" ], @@ -187,6 +193,32 @@ from mindspore.ops.op_info_register import op_info_register "need_compile": false, "param_type": "required", "shape": "all" + }, + { + "index": 1, + "dtype": [ + "float16","float16","float16","float16","float","float","float","float" + ], + "format": [ + "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" + ], + "name": "m", + "need_compile": false, + "param_type": "required", + "shape": "all" + }, + { + "index": 2, + "dtype": [ + "float16","float16","float16","float16","float","float","float","float" + ], + "format": [ + "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ", "NC1HWC0", "C1HWNCoC0", "DefaultFormat", "FracZ" + ], + "name": "v", + "need_compile": false, + "param_type": "required", + "shape": "all" } ] }""") diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 83f76455e0..538d7f3826 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2149,7 +2149,7 @@ class Adam(PrimitiveWithInfer): validator.check_param_equal("var_shape", var_shape, "m_shape", m_shape) validator.check_param_equal("var_shape", var_shape, "v_shape", v_shape) validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) - return var_shape + return var_shape, m_shape, v_shape def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): @@ -2159,7 +2159,7 @@ class Adam(PrimitiveWithInfer): args = {"beta1_power_dtype": beta1_power_dtype, "beta2_power_dtype": beta2_power_dtype, 'lr_dtype': lr_dtype, "beta1_dtype": beta1_dtype, "beta2_dtype": beta2_dtype, "epsilon_dtype": epsilon_dtype} validator.check_type_same(args, [mstype.float16, mstype.float32]) - return var_dtype + return var_dtype, m_dtype, v_dtype class BinaryCrossEntropy(PrimitiveWithInfer):