|
|
|
@ -465,14 +465,8 @@ class MomentumOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
const bool multi_precision = ctx.Attr<bool>("multi_precision");
|
|
|
|
|
if (multi_precision) {
|
|
|
|
|
LOG_FIRST_N(INFO, 1) << R"CODE(
|
|
|
|
|
InnerCompute<MPDType>(ctx, multi_precision);
|
|
|
|
|
)CODE";
|
|
|
|
|
InnerCompute<MPDType>(ctx, multi_precision);
|
|
|
|
|
} else {
|
|
|
|
|
LOG_FIRST_N(INFO, 1) << R"CODE(
|
|
|
|
|
InnerCompute<T>(ctx, multi_precision);
|
|
|
|
|
)CODE";
|
|
|
|
|
InnerCompute<T>(ctx, multi_precision);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -504,17 +498,6 @@ class MomentumOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
const framework::Tensor* master_param = nullptr;
|
|
|
|
|
framework::Tensor* master_param_out = nullptr;
|
|
|
|
|
if (multi_precision) {
|
|
|
|
|
LOG_FIRST_N(INFO, 1) << R"CODE(
|
|
|
|
|
bool has_master =
|
|
|
|
|
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
|
|
|
|
|
PADDLE_ENFORCE_EQ(has_master, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The Input(MasterParam) and Output(MasterParamOut) "
|
|
|
|
|
"should not be null when "
|
|
|
|
|
"the attr `multi_precision` is true"));
|
|
|
|
|
master_param = ctx.Input<framework::Tensor>("MasterParam");
|
|
|
|
|
master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
|
|
|
|
|
)CODE";
|
|
|
|
|
bool has_master =
|
|
|
|
|
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
|
|
|
|
|
PADDLE_ENFORCE_EQ(has_master, true,
|
|
|
|
@ -547,14 +530,6 @@ class MomentumOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
static_cast<const DeviceContext&>(ctx.device_context()),
|
|
|
|
|
param->numel());
|
|
|
|
|
if (use_nesterov) {
|
|
|
|
|
LOG_FIRST_N(INFO, 1) << R"CODE(
|
|
|
|
|
DenseMomentumFunctor<T, MT, UseNesterov> functor(
|
|
|
|
|
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
|
|
|
|
|
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
|
|
|
|
|
param->numel(), regularization_flag, regularization_coeff,
|
|
|
|
|
param_out->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
|
|
|
|
|
)CODE";
|
|
|
|
|
DenseMomentumFunctor<T, MT, UseNesterov> functor(
|
|
|
|
|
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
|
|
|
|
|
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
|
|
|
|
@ -564,14 +539,6 @@ class MomentumOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
for_range(functor);
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
LOG_FIRST_N(INFO, 1) << R"CODE(
|
|
|
|
|
DenseMomentumFunctor<T, MT, NoNesterov> functor(
|
|
|
|
|
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
|
|
|
|
|
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
|
|
|
|
|
param->numel(), regularization_flag, regularization_coeff,
|
|
|
|
|
param_out->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data);
|
|
|
|
|
)CODE";
|
|
|
|
|
DenseMomentumFunctor<T, MT, NoNesterov> functor(
|
|
|
|
|
param->data<T>(), grad->data<T>(), velocity->data<MT>(),
|
|
|
|
|
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad,
|
|
|
|
|