|
|
@ -2323,7 +2323,11 @@ class Adam(PrimitiveWithInfer):
|
|
|
|
- **gradient** (Tensor) - Gradients.
|
|
|
|
- **gradient** (Tensor) - Gradients.
|
|
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
Outputs:
|
|
|
|
|
|
|
|
Tuple of 3 Tensor, the updated parameters.
|
|
|
|
|
|
|
|
|
|
|
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
|
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
|
|
|
|
|
|
- **m** (Tensor) - The same shape and data type as `m`.
|
|
|
|
|
|
|
|
- **v** (Tensor) - The same shape and data type as `v`.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
@prim_attr_register
|
|
|
@ -2336,7 +2340,7 @@ class Adam(PrimitiveWithInfer):
|
|
|
|
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
|
|
|
|
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
|
|
|
|
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
|
|
|
|
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
|
|
|
|
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
|
|
|
|
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
|
|
|
|
return var_shape
|
|
|
|
return var_shape, m_shape, v_shape
|
|
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
|
|
|
|
def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
|
|
|
|
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
|
|
|
|
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
|
|
|
@ -2346,7 +2350,7 @@ class Adam(PrimitiveWithInfer):
|
|
|
|
args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
|
|
|
|
args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
|
|
|
|
"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
|
|
|
|
"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
|
|
|
|
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True)
|
|
|
|
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True)
|
|
|
|
return var_dtype
|
|
|
|
return var_dtype, m_dtype, v_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BinaryCrossEntropy(PrimitiveWithInfer):
|
|
|
|
class BinaryCrossEntropy(PrimitiveWithInfer):
|
|
|
|