Merge pull request #3179 from gangliao/eigen_refine

Refine compute code in operators
cblas_new
gangliao 8 years ago committed by GitHub
commit 28db149187

@ -28,9 +28,13 @@ public:
output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device(context.GetEigenDevice<Place>()) =
framework::EigenVector<T>::Flatten(*input0) +
framework::EigenVector<T>::Flatten(*input1);
auto X = EigenVector<T>::Flatten(*input0);
auto Y = EigenVector<T>::Flatten(*input1);
auto Z = EigenVector<T>::Flatten(*output);
auto place = context.GetEigenDevice<Place>();
Z.device(place) = X + Y;
}
};

@ -27,8 +27,11 @@ public:
output->mutable_data<T>(context.GetPlace());
EigenScalar<T>::From(*output).device(context.GetEigenDevice<Place>()) =
EigenVector<T>::Flatten(*input).mean();
auto X = EigenVector<T>::Flatten(*input);
auto y = EigenScalar<T>::From(*output);
auto place = context.GetEigenDevice<Place>();
y.device(place) = X.mean();
}
};

@ -26,13 +26,18 @@ public:
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto input0 = context.Input<Tensor>("X");
auto input1 = context.Input<Tensor>("Y");
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::From(*output).device(context.GetEigenDevice<Place>()) =
EigenMatrix<T>::From(*context.Input<Tensor>("X"))
.contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")),
dim_pair);
auto X = EigenMatrix<T>::From(*input0);
auto Y = EigenMatrix<T>::From(*input1);
auto Z = EigenMatrix<T>::From(*output);
auto place = context.GetEigenDevice<Place>();
Z.device(place) = X.contract(Y, dim_pair);
}
};
} // namespace operators

@ -29,8 +29,12 @@ public:
param_out->mutable_data<T>(ctx.GetPlace());
EigenVector<T>::Flatten(*param_out).device(ctx.GetEigenDevice<Place>()) =
EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad);
auto p = EigenVector<T>::Flatten(*param);
auto g = EigenVector<T>::Flatten(*grad);
auto o = EigenVector<T>::Flatten(*param_out);
auto place = ctx.GetEigenDevice<Place>();
o.device(place) = p - lr * g;
}
};

@ -27,8 +27,11 @@ public:
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device(context.GetEigenDevice<Place>()) =
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp());
auto X = EigenVector<T>::Flatten(*input);
auto Y = EigenVector<T>::Flatten(*output);
auto place = context.GetEigenDevice<Place>();
Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp());
}
};
} // namespace operators

Loading…
Cancel
Save