Merge pull request #14270 from sneaxiy/fix_rmsprop_enforce_bug

Fix rmsprop_op enforce bug
revert-14324-fix_vlog
Zeng Jinle 7 years ago committed by GitHub
commit fcbe84cb50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -179,7 +179,7 @@ class RmspropOpKernel : public framework::OpKernel<T> {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto mg = EigenVector<T>::Flatten(mg_tensor);
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
@ -198,7 +198,7 @@ class RmspropOpKernel : public framework::OpKernel<T> {
if (centered) {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
@ -243,7 +243,7 @@ class RmspropOpKernel : public framework::OpKernel<T> {
if (centered) {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
PADDLE_ENFORCE_EQ(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),

Loading…
Cancel
Save