!281 Change op ApplyMomentumD to ApplyMomentum for GE in Incubator.

Merge pull request !281 from liuxiao93/ApplyMomentumD-ApplyMoment-GE-incabator
pull/4243/head
mindspore-ci-bot 5 years ago committed by Gitee
commit ad13048335

@ -216,7 +216,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameIOU), ADPT_DESC(Iou)},
{string(kNameGreaterEqual), ADPT_DESC(GreaterEqual)},
{string(kNameSlice), ADPT_DESC(SliceD)},
{string(kNameApplyMomentum), ADPT_DESC(ApplyMomentumD)},
{string(kNameApplyMomentum), ADPT_DESC(ApplyMomentum)},
{string(kNameMaxPool), ADPT_DESC(MaxPool)},
{string(kNameAvgPool), ADPT_DESC(AvgPool)},
{string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)},

@ -143,12 +143,12 @@ INPUT_MAP(Constant) = EMPTY_INPUT_MAP;
ATTR_MAP(Constant) = {{"value", ATTR_DESC(value, AnyTraits<AnyValue>())}};
OUTPUT_MAP(Constant) = {{0, OUTPUT_DESC(y)}};
// ApplyMomentumD
INPUT_MAP(ApplyMomentumD) = {
// ApplyMomentum
INPUT_MAP(ApplyMomentum) = {
{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}, {5, INPUT_DESC(momentum)}};
ATTR_MAP(ApplyMomentumD) = {{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())},
ATTR_MAP(ApplyMomentum) = {{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())},
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyMomentumD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
OUTPUT_MAP(ApplyMomentum) = {{0, OUTPUT_DESC(var)}};
// ScalarSummary
INPUT_MAP(Summary) = {{2, INPUT_DESC(x)}};

@ -334,8 +334,8 @@ DECLARE_OP_ADAPTER(Assign)
DECLARE_OP_USE_OUTPUT(Assign)
DECLARE_OP_ADAPTER(Constant)
DECLARE_OP_USE_OUTPUT(Constant)
DECLARE_OP_ADAPTER(ApplyMomentumD)
DECLARE_OP_USE_OUTPUT(ApplyMomentumD)
DECLARE_OP_ADAPTER(ApplyMomentum)
DECLARE_OP_USE_OUTPUT(ApplyMomentum)
// ** Summary Operations **
DECLARE_OP_ADAPTER(Summary)

@ -1615,9 +1615,10 @@ class ApplyMomentum(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
outputs=['output'])
self.is_tbe = context.get_context("device_target") == "Ascend"
self.is_ge = context.get_context("enable_ge")
def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
if self.is_tbe:
if not self.is_ge and self.is_tbe:
return v_shape, v_shape
return v_shape
@ -1629,7 +1630,7 @@ class ApplyMomentum(PrimitiveWithInfer):
validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name)
validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name)
if self.is_tbe:
if not self.is_ge and self.is_tbe:
return g_dtype, g_dtype
return g_dtype

Loading…
Cancel
Save