@ -23,10 +23,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
template < typename T , int MajorType = Eigen : : RowMajor ,
typename IndexType = Eigen : : DenseIndex >
using EigenVector = framework : : EigenVector < T , MajorType , IndexType > ;
template < typename T >
struct DenseRmspropGradFunctor {
inline explicit DenseRmspropGradFunctor ( const T * grad ) : grad_ ( grad ) { }
@ -169,25 +165,25 @@ class RmspropOpKernel : public framework::OpKernel<T> {
* ctx . template device_context < DeviceContext > ( ) . eigen_device ( ) ;
auto lr_value = lr_tensor . data < T > ( ) [ 0 ] ;
auto p = EigenVector< T > : : Flatten ( p_tensor ) ;
auto ms = EigenVector< T > : : Flatten ( ms_tensor ) ;
auto g = EigenVector< T > : : Flatten ( grad_tensor ) ;
auto mom = EigenVector< T > : : Flatten ( mom_tensor ) ;
auto p = framework: : EigenVector< T > : : Flatten ( p_tensor ) ;
auto ms = framework: : EigenVector< T > : : Flatten ( ms_tensor ) ;
auto g = framework: : EigenVector< T > : : Flatten ( grad_tensor ) ;
auto mom = framework: : EigenVector< T > : : Flatten ( mom_tensor ) ;
auto p_out = EigenVector< T > : : Flatten ( * param_out ) ;
auto mom_out = EigenVector< T > : : Flatten ( * moment_out ) ;
auto ms_out = EigenVector< T > : : Flatten ( * mean_square_out ) ;
auto p_out = framework: : EigenVector< T > : : Flatten ( * param_out ) ;
auto mom_out = framework: : EigenVector< T > : : Flatten ( * moment_out ) ;
auto ms_out = framework: : EigenVector< T > : : Flatten ( * mean_square_out ) ;
ms_out . device ( place ) = rho * ms + ( 1 - rho ) * g * g ;
if ( centered ) {
auto & mg_tensor = * ctx . Input < LoDTensor > ( " MeanGrad " ) ;
auto mg = EigenVector< T > : : Flatten ( mg_tensor ) ;
auto mg = framework: : EigenVector< T > : : Flatten ( mg_tensor ) ;
auto * mean_grad_out = ctx . Output < LoDTensor > ( " MeanGradOut " ) ;
PADDLE_ENFORCE_EQ (
& mg_tensor , mean_grad_out ,
platform : : errors : : InvalidArgument (
" MeanGrad and MeanGradOut must be the same Tensor " ) ) ;
auto mg_out = EigenVector< T > : : Flatten ( * mean_grad_out ) ;
auto mg_out = framework: : EigenVector< T > : : Flatten ( * mean_grad_out ) ;
mg_out . device ( place ) = rho * mg + ( 1 - rho ) * g ;
mom_out . device ( place ) =