|
|
|
@ -179,7 +179,7 @@ class RmspropOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
|
|
|
|
|
auto mg = EigenVector<T>::Flatten(mg_tensor);
|
|
|
|
|
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
|
|
|
|
|
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
|
|
|
|
|
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
|
|
|
|
|
"MeanGrad and MeanGradOut must be the same Tensor");
|
|
|
|
|
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
|
|
|
|
|
|
|
|
|
@ -198,7 +198,7 @@ class RmspropOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (centered) {
|
|
|
|
|
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
|
|
|
|
|
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
|
|
|
|
|
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
|
|
|
|
|
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
|
|
|
|
|
"MeanGrad and MeanGradOut must be the same Tensor");
|
|
|
|
|
for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
|
|
|
|
|
param_out->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
@ -243,7 +243,7 @@ class RmspropOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (centered) {
|
|
|
|
|
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
|
|
|
|
|
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
|
|
|
|
|
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
|
|
|
|
|
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
|
|
|
|
|
"MeanGrad and MeanGradOut must be the same Tensor");
|
|
|
|
|
for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
|
|
|
|
|
param_out->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|