|
|
|
@ -299,6 +299,18 @@ REG_OP(ApplyMomentumCCE)
|
|
|
|
|
.ATTR(use_locking, Bool, false)
|
|
|
|
|
.OP_END_FACTORY_REG(ApplyMomentumCCE)
|
|
|
|
|
|
|
|
|
|
REG_OP(ApplyMomentumD)
|
|
|
|
|
.INPUT(var, TensorType::NumberType())
|
|
|
|
|
.INPUT(accum, TensorType::NumberType())
|
|
|
|
|
.INPUT(lr, TensorType::NumberType())
|
|
|
|
|
.INPUT(grad, TensorType::NumberType())
|
|
|
|
|
.INPUT(momentum, TensorType::NumberType())
|
|
|
|
|
.OUTPUT(var, TensorType::NumberType())
|
|
|
|
|
.OUTPUT(accum, TensorType::NumberType())
|
|
|
|
|
.ATTR(use_nesterov, Bool, false)
|
|
|
|
|
.ATTR(use_locking, Bool, false)
|
|
|
|
|
.OP_END_FACTORY_REG(ApplyMomentumD)
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
*@brief Updates "var" according to the AddSign update.\n
|
|
|
|
|
* t-1 mean previous period.
|
|
|
|
|