!884 fix checking bug of ApplyCenteredRMSProp

Merge pull request !884 from fary86/fix_checking_bug_of_ApplyCenteredRMSProp
pull/884/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 0b67ca85df

@ -1711,9 +1711,11 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer):
"mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
args = {"learning_rate": learning_rate_dtype, "rho": rho_dtype, 'momentum': momentum_dtype,
"epsilon": epsilon_dtype}
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
valid_types = [mstype.float16, mstype.float32]
args_rho = {"rho": rho_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype}
validator.check_type_same(args_rho, valid_types, self.name)
args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype}
validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True)
return var_dtype

Loading…
Cancel
Save