me ge use ApplyAdamD

pull/915/head
zhaozhenlong 5 years ago
parent 77cdb89669
commit b79382ab32

@ -1 +1 @@
Subproject commit 976d1e31b777d65f87333c3a125093946e682a6e Subproject commit 1ab4fa8eb55b4f98e9e5e871a54909a1eaedffd3

@ -391,6 +391,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}}; {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}};
#ifdef ENABLE_GE #ifdef ENABLE_GE
adpt_map[string(kNamePrint)] = ADPT_DESC(Print); adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD);
#endif #endif
return adpt_map; return adpt_map;
} }

@ -468,6 +468,15 @@ ATTR_MAP(ApplyAdam) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())
{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}}; {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}}; OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}};
// ApplyAdamD
INPUT_MAP(ApplyAdamD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)},
{4, INPUT_DESC(beta1_power)}, {5, INPUT_DESC(beta2_power)}, {6, INPUT_DESC(lr)},
{7, INPUT_DESC(beta1)}, {8, INPUT_DESC(beta2)}, {9, INPUT_DESC(epsilon)},
{10, INPUT_DESC(grad)}};
ATTR_MAP(ApplyAdamD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())},
{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAdamD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}};
// Relu6 // Relu6
INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}}; INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Relu6) = EMPTY_ATTR_MAP; ATTR_MAP(Relu6) = EMPTY_ATTR_MAP;

@ -124,6 +124,8 @@ DECLARE_OP_ADAPTER(ResizeNearestNeighborV2Grad)
DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad) DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad)
DECLARE_OP_ADAPTER(ApplyAdam) DECLARE_OP_ADAPTER(ApplyAdam)
DECLARE_OP_USE_OUTPUT(ApplyAdam) DECLARE_OP_USE_OUTPUT(ApplyAdam)
DECLARE_OP_ADAPTER(ApplyAdamD)
DECLARE_OP_USE_OUTPUT(ApplyAdamD)
DECLARE_OP_ADAPTER(Relu6) DECLARE_OP_ADAPTER(Relu6)
DECLARE_OP_USE_OUTPUT(Relu6) DECLARE_OP_USE_OUTPUT(Relu6)
DECLARE_OP_ADAPTER(Relu6Grad) DECLARE_OP_ADAPTER(Relu6Grad)

@ -2323,7 +2323,11 @@ class Adam(PrimitiveWithInfer):
- **gradient** (Tensor) - Gradients. - **gradient** (Tensor) - Gradients.
Outputs: Outputs:
Tuple of 3 Tensor, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`. - **var** (Tensor) - The same shape and data type as `var`.
- **m** (Tensor) - The same shape and data type as `m`.
- **v** (Tensor) - The same shape and data type as `v`.
""" """
@prim_attr_register @prim_attr_register
@ -2336,7 +2340,7 @@ class Adam(PrimitiveWithInfer):
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
return var_shape return var_shape, m_shape, v_shape
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
@ -2346,7 +2350,7 @@ class Adam(PrimitiveWithInfer):
args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True) validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True)
return var_dtype return var_dtype, m_dtype, v_dtype
class BinaryCrossEntropy(PrimitiveWithInfer): class BinaryCrossEntropy(PrimitiveWithInfer):

Loading…
Cancel
Save