|
|
|
@ -131,21 +131,21 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class RmspropOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
using Tensor = framework::LoDTensor;
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
auto *grad_var = ctx.InputVar("Grad");
|
|
|
|
|
auto *param_out = ctx.Output<Tensor>("ParamOut");
|
|
|
|
|
auto *moment_out = ctx.Output<Tensor>("MomentOut");
|
|
|
|
|
auto *mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
|
|
|
|
|
auto *param_out = ctx.Output<LoDTensor>("ParamOut");
|
|
|
|
|
auto *moment_out = ctx.Output<LoDTensor>("MomentOut");
|
|
|
|
|
auto *mean_square_out = ctx.Output<LoDTensor>("MeanSquareOut");
|
|
|
|
|
|
|
|
|
|
auto epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
|
|
|
|
|
auto rho = static_cast<T>(ctx.Attr<float>("decay"));
|
|
|
|
|
auto momentum = static_cast<T>(ctx.Attr<float>("momentum"));
|
|
|
|
|
bool centered = ctx.Attr<bool>("centered");
|
|
|
|
|
|
|
|
|
|
auto &p_tensor = *ctx.Input<Tensor>("Param");
|
|
|
|
|
auto &ms_tensor = *ctx.Input<Tensor>("MeanSquare");
|
|
|
|
|
auto &lr_tensor = *ctx.Input<Tensor>("LearningRate");
|
|
|
|
|
auto &mom_tensor = *ctx.Input<Tensor>("Moment");
|
|
|
|
|
auto &p_tensor = *ctx.Input<LoDTensor>("Param");
|
|
|
|
|
auto &ms_tensor = *ctx.Input<LoDTensor>("MeanSquare");
|
|
|
|
|
auto &lr_tensor = *ctx.Input<LoDTensor>("LearningRate");
|
|
|
|
|
auto &mom_tensor = *ctx.Input<LoDTensor>("Moment");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(&p_tensor, param_out,
|
|
|
|
|
"Param and ParamOut must be the same Tensor");
|
|
|
|
@ -157,8 +157,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
size_t limit = static_cast<size_t>(ms_tensor.numel());
|
|
|
|
|
|
|
|
|
|
if (grad_var->IsType<Tensor>()) {
|
|
|
|
|
auto &grad_tensor = grad_var->Get<Tensor>();
|
|
|
|
|
if (grad_var->IsType<LoDTensor>()) {
|
|
|
|
|
auto &grad_tensor = grad_var->Get<LoDTensor>();
|
|
|
|
|
|
|
|
|
|
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value) {
|
|
|
|
|
auto &place =
|
|
|
|
@ -176,9 +176,9 @@ class RmspropOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
ms_out.device(place) = rho * ms + (1 - rho) * g * g;
|
|
|
|
|
if (centered) {
|
|
|
|
|
auto &mg_tensor = *ctx.Input<Tensor>("MeanGrad");
|
|
|
|
|
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
|
|
|
|
|
auto mg = EigenVector<T>::Flatten(mg_tensor);
|
|
|
|
|
auto *mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
|
|
|
|
|
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
|
|
|
|
|
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
|
|
|
|
|
"MeanGrad and MeanGradOut must be the same Tensor");
|
|
|
|
|
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
|
|
|
|
@ -196,8 +196,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
DenseRmspropGradFunctor<T> grad_func(grad_tensor.data<T>());
|
|
|
|
|
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
|
|
|
|
|
if (centered) {
|
|
|
|
|
auto &mg_tensor = *ctx.Input<Tensor>("MeanGrad");
|
|
|
|
|
auto *mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
|
|
|
|
|
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
|
|
|
|
|
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
|
|
|
|
|
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
|
|
|
|
|
"MeanGrad and MeanGradOut must be the same Tensor");
|
|
|
|
|
for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
|
|
|
|
@ -241,8 +241,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
row_numel, row_count);
|
|
|
|
|
|
|
|
|
|
if (centered) {
|
|
|
|
|
auto &mg_tensor = *ctx.Input<Tensor>("MeanGrad");
|
|
|
|
|
auto *mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
|
|
|
|
|
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
|
|
|
|
|
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
|
|
|
|
|
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
|
|
|
|
|
"MeanGrad and MeanGradOut must be the same Tensor");
|
|
|
|
|
for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
|
|
|
|
|