|
|
@ -111,7 +111,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
|
|
|
|
def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
|
|
|
|
def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
|
|
|
|
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
|
|
|
|
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
|
|
|
|
"""Check the type of inputs."""
|
|
|
|
"""Check the type of inputs."""
|
|
|
|
_ = warmup_steps
|
|
|
|
|
|
|
|
validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name)
|
|
|
|
validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name)
|
|
|
|
validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name)
|
|
|
|
validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name)
|
|
|
|
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
|
|
|
|
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
|
|
|
@ -119,7 +118,7 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
|
|
|
|
validator.check_float_positive('power', power, prim_name)
|
|
|
|
validator.check_float_positive('power', power, prim_name)
|
|
|
|
validator.check_float_legal_value('power', power, prim_name)
|
|
|
|
validator.check_float_legal_value('power', power, prim_name)
|
|
|
|
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
|
|
|
|
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
|
|
|
|
validator.check_integer('warmup_steps', decay_steps, 0, Rel.GT, prim_name)
|
|
|
|
validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name)
|
|
|
|
validator.check_value_type("beta1", beta1, [float], prim_name)
|
|
|
|
validator.check_value_type("beta1", beta1, [float], prim_name)
|
|
|
|
validator.check_value_type("beta2", beta2, [float], prim_name)
|
|
|
|
validator.check_value_type("beta2", beta2, [float], prim_name)
|
|
|
|
validator.check_value_type("eps", eps, [float], prim_name)
|
|
|
|
validator.check_value_type("eps", eps, [float], prim_name)
|
|
|
|