large scale kv speedup (#26510)
* rename communicator meet->BatchesCounter * fix parame recv for sparse * geo sparse init from pserver * optimize init from pserver * add large scale optimizer fuse(SGD/ADAM) * rectification init_worker and exe.run startup programrevert-27520-disable_pr
parent
d7b7dcd10e
commit
bc5f0246a8
@ -0,0 +1,153 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.h"
|
||||
|
||||
#include <string>
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class LargeScaleFuseAdamOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
||||
"Input(Grad) of LargeScaleFuseAdamOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasInput("LearningRate"),
|
||||
"Input(LearningRate) of LargeScaleFuseAdamOp should not be null.");
|
||||
|
||||
auto lr_dims = ctx->GetInputDim("LearningRate");
|
||||
|
||||
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
|
||||
"Maybe the Input variable LearningRate has not "
|
||||
"been initialized. You may need to confirm "
|
||||
"if you put exe.run(startup_program) "
|
||||
"after optimizer.minimize function.");
|
||||
|
||||
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
|
||||
"Learning rate should have 1 element");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Grad");
|
||||
return framework::OpKernelType(data_type, ctx.device_context());
|
||||
}
|
||||
|
||||
framework::OpKernelType GetKernelTypeForVar(
|
||||
const std::string &var_name, const framework::Tensor &tensor,
|
||||
const framework::OpKernelType &expected_kernel_type) const {
|
||||
if (var_name == "LearningRate") {
|
||||
return framework::OpKernelType(tensor.type(), tensor.place(),
|
||||
tensor.layout());
|
||||
}
|
||||
return framework::OpKernelType(expected_kernel_type.data_type_,
|
||||
tensor.place(), tensor.layout());
|
||||
}
|
||||
};
|
||||
|
||||
class LargeScaleFuseAdamOpInferVarType : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(framework::InferVarTypeContext *ctx) const override {
|
||||
auto in_var_type = ctx->GetInputType("Grad");
|
||||
PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
|
||||
in_var_type == framework::proto::VarType::LOD_TENSOR,
|
||||
true, platform::errors::InvalidArgument(
|
||||
"The input Var's type should be LoDtensor or "
|
||||
"SelectedRows, but the received type is %s",
|
||||
in_var_type));
|
||||
}
|
||||
};
|
||||
|
||||
class LargeScaleFuseAdamOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Grad",
|
||||
"(SelectedRows) Ids's type should be SelectedRows"
|
||||
"THe ids to be looked up in W.");
|
||||
|
||||
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator");
|
||||
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator");
|
||||
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
|
||||
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator");
|
||||
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator");
|
||||
|
||||
AddAttr<float>("beta1",
|
||||
"(float, default 0.9) "
|
||||
"Exponential decay rate for the "
|
||||
"first moment estimates.")
|
||||
.SetDefault(0.9f);
|
||||
|
||||
AddAttr<float>("beta2",
|
||||
"(float, default 0.999) "
|
||||
"exponential decay rate for the "
|
||||
"second moment estimates.")
|
||||
.SetDefault(0.999f);
|
||||
|
||||
AddAttr<float>("epsilon",
|
||||
"(float, default 1.0e-8) "
|
||||
"Constant for numerical stability")
|
||||
.SetDefault(1.0e-8f);
|
||||
|
||||
AddAttr<bool>("is_entry",
|
||||
"(bool)"
|
||||
"sparse table need entry");
|
||||
|
||||
AddAttr<std::string>("tablename",
|
||||
"(string)"
|
||||
"sparse table name");
|
||||
|
||||
AddAttr<std::vector<std::string>>("value_names",
|
||||
"(strings)"
|
||||
"sparse table name");
|
||||
|
||||
AddComment(R"DOC(
|
||||
Adam Optimizer.
|
||||
|
||||
This implements the Adam optimizer from Section 2 of the Adam
|
||||
paper : https://arxiv.org/abs/1412.6980.
|
||||
Adam is a first-order gradient-based optimization method based on
|
||||
adaptive estimates of lower-order moments.
|
||||
|
||||
Adam updates:
|
||||
|
||||
$$
|
||||
moment\_1\_out = \beta_1 * moment\_1 + (1 - \beta_1) * grad \\
|
||||
moment\_2_\out = \beta_2 * moment\_2 + (1 - \beta_2) * grad * grad \\
|
||||
learning\_rate = learning\_rate *
|
||||
\frac{\sqrt{1 - \beta_{2\_pow}}}{1 - \beta_{1\_pow}} \\
|
||||
param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon}
|
||||
$$
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(
|
||||
lookup_sparse_table_fuse_adam, ops::LargeScaleFuseAdamOp,
|
||||
ops::LargeScaleFuseAdamOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
|
||||
ops::LargeScaleFuseAdamOpInferVarType);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
lookup_sparse_table_fuse_adam,
|
||||
ops::LargeScaleFuseAdamOpKernel<paddle::platform::CPUDeviceContext, float>);
|
@ -0,0 +1,142 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <math.h> // for sqrt in CPU and CUDA
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class LargeScaleFuseAdamOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
|
||||
: public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
using paddle::framework::LoDTensor;
|
||||
|
||||
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
|
||||
const auto *grad_var = ctx.InputVar("Grad");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
grad_var->IsType<framework::SelectedRows>(),
|
||||
platform::errors::InvalidArgument(
|
||||
"in large scale optimize, gradient should only be SelectedRows"));
|
||||
|
||||
const auto &grad = grad_var->Get<framework::SelectedRows>();
|
||||
|
||||
// for distributed training, a sparse var may be empty,
|
||||
// just skip updating.
|
||||
if (grad.rows().size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
framework::SelectedRows tmp_grad_merge;
|
||||
const framework::SelectedRows *grad_merge_ptr;
|
||||
math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
|
||||
merge_func(ctx.template device_context<platform::CPUDeviceContext>(), grad,
|
||||
&tmp_grad_merge, true);
|
||||
grad_merge_ptr = &tmp_grad_merge;
|
||||
|
||||
std::vector<int64_t> in_rows;
|
||||
in_rows.reserve(grad_merge_ptr->rows().size());
|
||||
std::copy(grad_merge_ptr->rows().begin(), grad_merge_ptr->rows().end(),
|
||||
std::back_inserter(in_rows));
|
||||
|
||||
const auto *lr = learning_rate->data<T>();
|
||||
auto grad_v = grad_merge_ptr->value();
|
||||
auto grad_width = grad_v.dims()[1];
|
||||
|
||||
// auto is_entry = context.Attr<bool>("is_entry");
|
||||
auto tablename = ctx.Attr<std::string>("tablename");
|
||||
auto value_names = ctx.Attr<std::vector<std::string>>("value_names");
|
||||
|
||||
auto *beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
|
||||
auto *beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");
|
||||
auto *beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
|
||||
auto *beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
|
||||
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
|
||||
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
|
||||
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
|
||||
|
||||
PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"beta1 pow output size should be 1, but received "
|
||||
"value is:%d.",
|
||||
beta1_pow_out->numel()));
|
||||
|
||||
PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"beta2 pow output size should be 1, but received "
|
||||
"value is:%d.",
|
||||
beta2_pow_out->numel()));
|
||||
|
||||
// update beta1 and beta2
|
||||
beta1_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
|
||||
beta1 * beta1_pow->data<T>()[0];
|
||||
beta2_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
|
||||
beta2 * beta2_pow->data<T>()[0];
|
||||
|
||||
std::vector<std::vector<std::vector<float> *>> values;
|
||||
std::vector<int64_t> dims;
|
||||
|
||||
auto *ins = distributed::LargeScaleKV::GetInstance();
|
||||
auto *table = ins->Get(tablename);
|
||||
table->Get(in_rows, value_names, &values);
|
||||
table->Dims({"Param"}, &dims);
|
||||
|
||||
PADDLE_ENFORCE_EQ(dims[0], grad_width,
|
||||
platform::errors::InvalidArgument(
|
||||
"param_row should have the same size with grad_row"));
|
||||
|
||||
T lr_ = lr[0];
|
||||
T beta1_pow_ = beta1_pow->data<T>()[0];
|
||||
T beta2_pow_ = beta2_pow->data<T>()[0];
|
||||
|
||||
lr_ *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_);
|
||||
|
||||
for (size_t i = 0; i < in_rows.size(); i++) {
|
||||
auto ¶ms = values[i][0];
|
||||
auto &moment_1 = values[i][1];
|
||||
auto &moment_2 = values[i][2];
|
||||
|
||||
auto *p_data = params->data();
|
||||
auto *m1_data = moment_1->data();
|
||||
auto *m2_data = moment_2->data();
|
||||
|
||||
for (int x = 0; x < grad_width; ++x) {
|
||||
auto g = grad_v.data<T>()[grad_width * i + x];
|
||||
m1_data[x] = beta1 * m1_data[x] + (1 - beta1) * g;
|
||||
m2_data[x] = beta2 * m2_data[x] + (1 - beta2) * g * g;
|
||||
p_data[x] -= lr_ * (m1_data[x] / (sqrt(m2_data[x]) + epsilon));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,120 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.h"
|
||||
|
||||
#include <string>
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class LargeScaleFuseSGDOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
||||
"Input(Grad) of LargeScaleFuseSGDOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasInput("LearningRate"),
|
||||
"Input(LearningRate) of LargeScaleFuseSGDOp should not be null.");
|
||||
|
||||
auto lr_dims = ctx->GetInputDim("LearningRate");
|
||||
|
||||
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
|
||||
"Maybe the Input variable LearningRate has not "
|
||||
"been initialized. You may need to confirm "
|
||||
"if you put exe.run(startup_program) "
|
||||
"after optimizer.minimize function.");
|
||||
|
||||
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
|
||||
"Learning rate should have 1 element");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Grad");
|
||||
return framework::OpKernelType(data_type, ctx.device_context());
|
||||
}
|
||||
|
||||
framework::OpKernelType GetKernelTypeForVar(
|
||||
const std::string &var_name, const framework::Tensor &tensor,
|
||||
const framework::OpKernelType &expected_kernel_type) const {
|
||||
if (var_name == "LearningRate") {
|
||||
return framework::OpKernelType(tensor.type(), tensor.place(),
|
||||
tensor.layout());
|
||||
}
|
||||
return framework::OpKernelType(expected_kernel_type.data_type_,
|
||||
tensor.place(), tensor.layout());
|
||||
}
|
||||
};
|
||||
|
||||
class LargeScaleFuseSGDOpInferVarType : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(framework::InferVarTypeContext *ctx) const override {
|
||||
auto in_var_type = ctx->GetInputType("Grad");
|
||||
PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
|
||||
in_var_type == framework::proto::VarType::LOD_TENSOR,
|
||||
true, platform::errors::InvalidArgument(
|
||||
"The input Var's type should be LoDtensor or "
|
||||
"SelectedRows, but the received type is %s",
|
||||
in_var_type));
|
||||
}
|
||||
};
|
||||
|
||||
class LargeScaleFuseSGDOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Grad",
|
||||
"(SelectedRows) Ids's type should be SelectedRows"
|
||||
"THe ids to be looked up in W.");
|
||||
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
|
||||
AddAttr<bool>("is_entry",
|
||||
"(bool)"
|
||||
"sparse table need entry");
|
||||
|
||||
AddAttr<std::string>("tablename",
|
||||
"(string)"
|
||||
"sparse table name");
|
||||
|
||||
AddAttr<std::vector<std::string>>("value_names",
|
||||
"(strings)"
|
||||
"sparse table name");
|
||||
|
||||
AddComment(R"DOC(
|
||||
|
||||
LargeScaleFuseSGD operator
|
||||
|
||||
This operator implements one step of the stochastic gradient descent algorithm.
|
||||
|
||||
$$param\_out = param - learning\_rate * grad$$
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(
|
||||
lookup_sparse_table_fuse_sgd, ops::LargeScaleFuseSGDOp,
|
||||
ops::LargeScaleFuseSGDOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
|
||||
ops::LargeScaleFuseSGDOpInferVarType);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
lookup_sparse_table_fuse_sgd,
|
||||
ops::LargeScaleFuseSGDOpKernel<paddle::platform::CPUDeviceContext, float>);
|
@ -0,0 +1,105 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class LargeScaleFuseSGDOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LargeScaleFuseSGDOpKernel<platform::CPUDeviceContext, T>
|
||||
: public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
|
||||
|
||||
const auto *grad_var = ctx.InputVar("Grad");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
grad_var->IsType<framework::SelectedRows>(),
|
||||
platform::errors::InvalidArgument(
|
||||
"in large scale optimize, gradient should only be SelectedRows"));
|
||||
|
||||
const auto &grad = grad_var->Get<framework::SelectedRows>();
|
||||
|
||||
// for distributed training, a sparse var may be empty,
|
||||
// just skip updating.
|
||||
if (grad.rows().size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
framework::SelectedRows tmp_grad_merge;
|
||||
const framework::SelectedRows *grad_merge_ptr;
|
||||
math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
|
||||
merge_func(ctx.template device_context<platform::CPUDeviceContext>(), grad,
|
||||
&tmp_grad_merge, true);
|
||||
grad_merge_ptr = &tmp_grad_merge;
|
||||
|
||||
std::vector<int64_t> in_rows;
|
||||
in_rows.reserve(grad_merge_ptr->rows().size());
|
||||
std::copy(grad_merge_ptr->rows().begin(), grad_merge_ptr->rows().end(),
|
||||
std::back_inserter(in_rows));
|
||||
|
||||
const auto *lr = learning_rate->data<T>();
|
||||
auto grad_v = grad_merge_ptr->value();
|
||||
auto grad_width = grad_v.dims()[1];
|
||||
|
||||
// auto is_entry = context.Attr<bool>("is_entry");
|
||||
auto tablename = ctx.Attr<std::string>("tablename");
|
||||
auto value_names = ctx.Attr<std::vector<std::string>>("value_names");
|
||||
|
||||
std::vector<std::vector<std::vector<float> *>> values;
|
||||
std::vector<int64_t> dims;
|
||||
|
||||
auto *ins = distributed::LargeScaleKV::GetInstance();
|
||||
auto *table = ins->Get(tablename);
|
||||
table->Get(in_rows, value_names, &values);
|
||||
table->Dims({"Param"}, &dims);
|
||||
|
||||
PADDLE_ENFORCE_EQ(dims[0], grad_width,
|
||||
platform::errors::InvalidArgument(
|
||||
"param_row should have the same size with grad_row"));
|
||||
|
||||
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
|
||||
|
||||
std::vector<T> grads;
|
||||
framework::TensorToVector(grad_v, ctx.device_context(), &grads);
|
||||
|
||||
blas.SCAL(grads.size(), lr[0], grads.data());
|
||||
|
||||
for (int x = 0; x < static_cast<int>(in_rows.size()); ++x) {
|
||||
auto ¶ms = values[x][0];
|
||||
blas.VSUB(grad_width, params->data(), grads.data() + grad_width * x,
|
||||
params->data());
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue