From 16c38f2897a430015399e465fdc670d49d94493c Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Fri, 17 Jul 2020 11:48:18 +0800 Subject: [PATCH] change ApplyMomentumD->ApplyMoment for GE. --- mindspore/ccsrc/transform/graph_ir/convert.cc | 2 +- mindspore/ccsrc/transform/graph_ir/op_declare.cc | 10 +++++----- mindspore/ccsrc/transform/graph_ir/op_declare.h | 4 ++-- mindspore/ops/operations/nn_ops.py | 5 +++-- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc index 7419dd2cc9..1978181fec 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.cc +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -216,7 +216,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameIOU), ADPT_DESC(Iou)}, {string(kNameGreaterEqual), ADPT_DESC(GreaterEqual)}, {string(kNameSlice), ADPT_DESC(SliceD)}, - {string(kNameApplyMomentum), ADPT_DESC(ApplyMomentumD)}, + {string(kNameApplyMomentum), ADPT_DESC(ApplyMomentum)}, {string(kNameMaxPool), ADPT_DESC(MaxPool)}, {string(kNameAvgPool), ADPT_DESC(AvgPool)}, {string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)}, diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare.cc index e3751e0c92..b42519900e 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.cc @@ -143,12 +143,12 @@ INPUT_MAP(Constant) = EMPTY_INPUT_MAP; ATTR_MAP(Constant) = {{"value", ATTR_DESC(value, AnyTraits())}}; OUTPUT_MAP(Constant) = {{0, OUTPUT_DESC(y)}}; -// ApplyMomentumD -INPUT_MAP(ApplyMomentumD) = { +// ApplyMomentum +INPUT_MAP(ApplyMomentum) = { {1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}, {5, INPUT_DESC(momentum)}}; -ATTR_MAP(ApplyMomentumD) = {{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}, - {"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ApplyMomentumD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}}; +ATTR_MAP(ApplyMomentum) = {{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}, + {"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyMomentum) = {{0, OUTPUT_DESC(var)}}; // ScalarSummary INPUT_MAP(Summary) = {{2, INPUT_DESC(x)}}; diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare.h index e493ea0e52..b849461d56 100755 --- a/mindspore/ccsrc/transform/graph_ir/op_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.h @@ -334,8 +334,8 @@ DECLARE_OP_ADAPTER(Assign) DECLARE_OP_USE_OUTPUT(Assign) DECLARE_OP_ADAPTER(Constant) DECLARE_OP_USE_OUTPUT(Constant) -DECLARE_OP_ADAPTER(ApplyMomentumD) -DECLARE_OP_USE_OUTPUT(ApplyMomentumD) +DECLARE_OP_ADAPTER(ApplyMomentum) +DECLARE_OP_USE_OUTPUT(ApplyMomentum) // ** Summary Operations ** DECLARE_OP_ADAPTER(Summary) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index a9bdf07d28..83e6454a00 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1615,9 +1615,10 @@ class ApplyMomentum(PrimitiveWithInfer): self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'], outputs=['output']) self.is_tbe = context.get_context("device_target") == "Ascend" + self.is_ge = context.get_context("enable_ge") def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape): - if self.is_tbe: + if not self.is_ge and self.is_tbe: return v_shape, v_shape return v_shape @@ -1629,7 +1630,7 @@ class ApplyMomentum(PrimitiveWithInfer): validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name) validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name) - if self.is_tbe: + if not self.is_ge and self.is_tbe: return g_dtype, g_dtype return g_dtype