|
|
@ -179,8 +179,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
|
|
|
|
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
|
|
|
|
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
|
|
|
|
auto mg = EigenVector<T>::Flatten(mg_tensor);
|
|
|
|
auto mg = EigenVector<T>::Flatten(mg_tensor);
|
|
|
|
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
|
|
|
|
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
|
|
|
|
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
|
|
|
|
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
|
|
|
|
"MeanGrad and MeanGradOut must be the same Tensor");
|
|
|
|
"MeanGrad and MeanGradOut must be the same Tensor");
|
|
|
|
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
|
|
|
|
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
|
|
|
|
|
|
|
|
|
|
|
|
mg_out.device(place) = rho * mg + (1 - rho) * g;
|
|
|
|
mg_out.device(place) = rho * mg + (1 - rho) * g;
|
|
|
@ -198,8 +198,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
|
|
|
|
if (centered) {
|
|
|
|
if (centered) {
|
|
|
|
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
|
|
|
|
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
|
|
|
|
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
|
|
|
|
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
|
|
|
|
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
|
|
|
|
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
|
|
|
|
"MeanGrad and MeanGradOut must be the same Tensor");
|
|
|
|
"MeanGrad and MeanGradOut must be the same Tensor");
|
|
|
|
for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
|
|
|
|
for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
|
|
|
|
param_out->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
param_out->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
mean_square_out->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
mean_square_out->mutable_data<T>(ctx.GetPlace()),
|
|
|
@ -243,8 +243,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
|
|
|
|
if (centered) {
|
|
|
|
if (centered) {
|
|
|
|
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
|
|
|
|
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
|
|
|
|
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
|
|
|
|
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
|
|
|
|
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
|
|
|
|
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
|
|
|
|
"MeanGrad and MeanGradOut must be the same Tensor");
|
|
|
|
"MeanGrad and MeanGradOut must be the same Tensor");
|
|
|
|
for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
|
|
|
|
for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
|
|
|
|
param_out->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
param_out->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
mean_square_out->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
mean_square_out->mutable_data<T>(ctx.GetPlace()),
|
|
|
|