@ -100,8 +100,8 @@ size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
} // namespace
template < typename T >
struct SparseAdagradFunctor < platform : : CPU Place , T > {
void operator ( ) ( const platform : : DeviceContext& context ,
struct SparseAdagradFunctor < platform : : CPU DeviceContext , T > {
void operator ( ) ( const platform : : CPU DeviceContext& context ,
const framework : : SelectedRows & grad ,
const framework : : Tensor & learning_rate , T epsilon ,
framework : : Tensor * moment , framework : : Tensor * param ) {
@ -120,7 +120,7 @@ struct SparseAdagradFunctor<platform::CPUPlace, T> {
{ static_cast < int64_t > ( merge_rows . size ( ) ) , grad_width } ) ,
context . GetPlace ( ) ) ;
math : : SetConstant < platform : : CPU Place , T > constant_functor ;
math : : SetConstant < platform : : CPU DeviceContext , T > constant_functor ;
constant_functor ( context , grad_merge - > mutable_value ( ) , 0.0 ) ;
auto * grad_merge_data = grad_merge - > mutable_value ( ) - > data < T > ( ) ;
@ -144,9 +144,9 @@ struct SparseAdagradFunctor<platform::CPUPlace, T> {
auto gs =
framework : : EigenVector < T > : : Flatten ( * ( grad_square - > mutable_value ( ) ) ) ;
auto gm = framework : : EigenVector < T > : : Flatten ( grad_merge - > value ( ) ) ;
gs . device ( * context . GetEigenDevice< platform : : CPUPlace > ( ) ) = gm * gm ;
gs . device ( * context . eigen_device ( ) ) = gm * gm ;
math : : SelectedRowsAddToTensor < platform : : CPU Place , T > functor ;
math : : SelectedRowsAddToTensor < platform : : CPU DeviceContext , T > functor ;
functor ( context , * grad_square , moment ) ;
// 3. update parameter
@ -164,13 +164,13 @@ struct SparseAdagradFunctor<platform::CPUPlace, T> {
}
} ;
template struct SparseAdagradFunctor < platform : : CPU Place , float > ;
template struct SparseAdagradFunctor < platform : : CPU Place , double > ;
template struct SparseAdagradFunctor < platform : : CPU DeviceContext , float > ;
template struct SparseAdagradFunctor < platform : : CPU DeviceContext , double > ;
} // namespace operators
} // namespace paddle
namespace ops = paddle : : operators ;
REGISTER_OP_WITHOUT_GRADIENT ( adagrad , ops : : AdagradOp , ops : : AdagradOpMaker ) ;
REGISTER_OP_CPU_KERNEL (
adagrad , ops : : AdagradOpKernel < paddle : : platform : : CPU Place , float > ,
ops : : AdagradOpKernel < paddle : : platform : : CPU Place , double > ) ;
adagrad , ops : : AdagradOpKernel < paddle : : platform : : CPU DeviceContext , float > ,
ops : : AdagradOpKernel < paddle : : platform : : CPU DeviceContext , double > ) ;