From b79382ab324a56d47a81c01bb3b9767bd1617715 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Wed, 29 Apr 2020 12:42:06 +0800 Subject: [PATCH] me ge use ApplyAdamD --- graphengine | 2 +- mindspore/ccsrc/transform/convert.cc | 1 + mindspore/ccsrc/transform/op_declare.cc | 9 +++++++++ mindspore/ccsrc/transform/op_declare.h | 2 ++ mindspore/ops/operations/nn_ops.py | 8 ++++++-- 5 files changed, 19 insertions(+), 3 deletions(-) diff --git a/graphengine b/graphengine index 976d1e31b7..1ab4fa8eb5 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit 976d1e31b777d65f87333c3a125093946e682a6e +Subproject commit 1ab4fa8eb55b4f98e9e5e871a54909a1eaedffd3 diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index da70f402e6..d15f4ae43a 100755 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -391,6 +391,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}}; #ifdef ENABLE_GE adpt_map[string(kNamePrint)] = ADPT_DESC(Print); + adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD); #endif return adpt_map; } diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 8407e46519..d6da49f85d 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -468,6 +468,15 @@ ATTR_MAP(ApplyAdam) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits()) {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}}; OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}}; +// ApplyAdamD +INPUT_MAP(ApplyAdamD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, + {4, INPUT_DESC(beta1_power)}, {5, INPUT_DESC(beta2_power)}, {6, INPUT_DESC(lr)}, + {7, INPUT_DESC(beta1)}, {8, INPUT_DESC(beta2)}, {9, INPUT_DESC(epsilon)}, + {10, INPUT_DESC(grad)}}; +ATTR_MAP(ApplyAdamD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}, + {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}}; +OUTPUT_MAP(ApplyAdamD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}}; + // Relu6 INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}}; ATTR_MAP(Relu6) = EMPTY_ATTR_MAP; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index cfec43a43c..ccc6578a61 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -124,6 +124,8 @@ DECLARE_OP_ADAPTER(ResizeNearestNeighborV2Grad) DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad) DECLARE_OP_ADAPTER(ApplyAdam) DECLARE_OP_USE_OUTPUT(ApplyAdam) +DECLARE_OP_ADAPTER(ApplyAdamD) +DECLARE_OP_USE_OUTPUT(ApplyAdamD) DECLARE_OP_ADAPTER(Relu6) DECLARE_OP_USE_OUTPUT(Relu6) DECLARE_OP_ADAPTER(Relu6Grad) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 6371fc4654..0687806bb2 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2323,7 +2323,11 @@ class Adam(PrimitiveWithInfer): - **gradient** (Tensor) - Gradients. Outputs: + Tuple of 3 Tensor, the updated parameters. + - **var** (Tensor) - The same shape and data type as `var`. + - **m** (Tensor) - The same shape and data type as `m`. + - **v** (Tensor) - The same shape and data type as `v`. """ @prim_attr_register @@ -2336,7 +2340,7 @@ class Adam(PrimitiveWithInfer): validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) - return var_shape + return var_shape, m_shape, v_shape def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): @@ -2346,7 +2350,7 @@ class Adam(PrimitiveWithInfer): args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True) - return var_dtype + return var_dtype, m_dtype, v_dtype class BinaryCrossEntropy(PrimitiveWithInfer):