Fix l1_norm_op and squared_l2_norm_op for debug mode (#5560)

mobile_baidu
emailweixu 8 years ago committed by Abhinav Arora
parent b6c262e12f
commit 85b839f0f1

@ -29,7 +29,7 @@ class L1NormKernel : public framework::OpKernel<T> {
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X); auto x = framework::EigenVector<T>::Flatten(*X);
auto out = framework::EigenVector<T>::Flatten(*Out); auto out = framework::EigenScalar<T>::From(*Out);
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
out.device(place) = x.abs().sum(); out.device(place) = x.abs().sum();

@ -29,7 +29,7 @@ class SquaredL2NormKernel : public framework::OpKernel<T> {
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X); auto x = framework::EigenVector<T>::Flatten(*X);
auto out = framework::EigenVector<T>::Flatten(*Out); auto out = framework::EigenScalar<T>::From(*Out);
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
out.device(place) = x.square().sum(); out.device(place) = x.square().sum();

Loading…
Cancel
Save