Fix DGC algorithm flow to make it the same as paper (#20758)
parent
ba45dce35d
commit
250e72d254
@ -0,0 +1,68 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "paddle/fluid/operators/optimizers/dgc_momentum_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class DGCMomentumOp : public MomentumOp {
|
||||
public:
|
||||
using MomentumOp::MomentumOp;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("current_step"), true,
|
||||
"current_step should be set.");
|
||||
return MomentumOp::InferShape(ctx);
|
||||
}
|
||||
|
||||
framework::OpKernelType GetKernelTypeForVar(
|
||||
const std::string& var_name, const framework::Tensor& tensor,
|
||||
const framework::OpKernelType& expected_kernel_type) const override {
|
||||
if (var_name == "current_step") {
|
||||
VLOG(10) << "var_name:" << var_name << " need not to transform";
|
||||
return expected_kernel_type;
|
||||
}
|
||||
|
||||
return framework::OperatorWithKernel::GetKernelTypeForVar(
|
||||
var_name, tensor, expected_kernel_type);
|
||||
}
|
||||
};
|
||||
|
||||
class DGCMomentumOpMaker : public MomentumOpMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("current_step", "(Tensor) Current step.");
|
||||
AddAttr<float>("rampup_begin_step",
|
||||
"(float, -1.0)"
|
||||
"The period when begin DGC.")
|
||||
.SetDefault(-1.0);
|
||||
|
||||
return MomentumOpMaker::Make();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(dgc_momentum, ops::DGCMomentumOp,
|
||||
ops::DGCMomentumOpMaker);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
dgc_momentum,
|
||||
ops::DGCMomentumKernel<paddle::platform::CPUDeviceContext, float>);
|
@ -0,0 +1,20 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/optimizers/dgc_momentum_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
dgc_momentum,
|
||||
ops::DGCMomentumKernel<paddle::platform::CUDADeviceContext, float>);
|
@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "paddle/fluid/operators/optimizers/momentum_op.h"
|
||||
#include "paddle/fluid/operators/optimizers/sgd_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class DGCMomentumKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
DGCMomentumKernel()
|
||||
: _momentum_op_kernel(new MomentumOpKernel<DeviceContext, T>()),
|
||||
_sgd_op_kernel(new SGDOpKernel<DeviceContext, T>()) {}
|
||||
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto rampup_begin_step = context.Attr<float>("rampup_begin_step");
|
||||
if (static_cast<int>(rampup_begin_step) < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto current_step_tensor = context.Input<framework::Tensor>("current_step");
|
||||
auto* current_step = current_step_tensor->data<T>();
|
||||
|
||||
VLOG(10) << "current_step:" << *current_step
|
||||
<< ", rampup_begin_step:" << rampup_begin_step;
|
||||
|
||||
if (static_cast<int>(*current_step) < static_cast<int>(rampup_begin_step)) {
|
||||
VLOG(10) << " so use momentum optimizer";
|
||||
return _momentum_op_kernel->Compute(context);
|
||||
}
|
||||
|
||||
VLOG(10) << " so use sgd optimizer";
|
||||
return _sgd_op_kernel->Compute(context);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<MomentumOpKernel<DeviceContext, T>> _momentum_op_kernel;
|
||||
std::unique_ptr<SGDOpKernel<DeviceContext, T>> _sgd_op_kernel;
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,134 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle.fluid.core as core
|
||||
from paddle.fluid.op import Operator
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
class TestDGCMomentumOp1(unittest.TestCase):
|
||||
def get_tensor(self, name, value, place=None):
|
||||
tensor = self.scope.var(name).get_tensor()
|
||||
tensor.set(value, self.place if place is None else place)
|
||||
return name, tensor
|
||||
|
||||
def setup(self, place, step=0.0):
|
||||
self.scope = fluid.global_scope()
|
||||
self.place = place
|
||||
print("place:", place)
|
||||
|
||||
self.op_type = "dgc_momentum"
|
||||
self.dtype = np.float32
|
||||
|
||||
param = np.random.random((123, 321)).astype(self.dtype)
|
||||
grad = np.random.random((123, 321)).astype(self.dtype)
|
||||
velocity = np.zeros((123, 321)).astype(self.dtype)
|
||||
learning_rate = np.array([0.001]).astype(self.dtype)
|
||||
current_step = np.full((1), step).astype("float32")
|
||||
mu = 0.0001
|
||||
use_nesterov = False
|
||||
rampup_begin_step = 10.0
|
||||
|
||||
self.param_name, self.param_tensor = self.get_tensor('Param', param)
|
||||
self.grad_name, self.grad_tensor = self.get_tensor('Grad', grad)
|
||||
self.velocity_name, self.velocity_tensor = self.get_tensor('Velocity',
|
||||
velocity)
|
||||
self.learning_rate_name, self.learning_rate_tensor = self.get_tensor(
|
||||
'LearningRate', learning_rate)
|
||||
self.current_step_name, self.current_step_tensor = self.get_tensor(
|
||||
'current_step', current_step, core.CPUPlace())
|
||||
|
||||
self.kwargs = {
|
||||
# inputs
|
||||
'Param': self.param_name,
|
||||
'Grad': self.grad_name,
|
||||
'Velocity': self.velocity_name,
|
||||
'LearningRate': self.learning_rate_name,
|
||||
'current_step': self.current_step_name,
|
||||
|
||||
# attrs
|
||||
'mu': mu,
|
||||
'use_nesterov': use_nesterov,
|
||||
'rampup_begin_step': rampup_begin_step,
|
||||
|
||||
# outputs
|
||||
'ParamOut': self.param_name,
|
||||
'VelocityOut': self.velocity_name
|
||||
}
|
||||
|
||||
velocity_out = mu * velocity + grad
|
||||
if use_nesterov:
|
||||
param_out = param - grad * learning_rate - \
|
||||
velocity_out * mu * learning_rate
|
||||
else:
|
||||
param_out = param - learning_rate * velocity_out
|
||||
|
||||
sgd_out = param - learning_rate * grad
|
||||
|
||||
self.outputs = {
|
||||
'ParamOut': param_out,
|
||||
'VelocityOut': velocity_out,
|
||||
'SGDOut': sgd_out
|
||||
}
|
||||
|
||||
def check(self, actual_t, expect_t, place, out_name, atol=1e-5):
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
actual_t, expect_t, atol=atol),
|
||||
"Output (" + out_name + ") has diff at " + str(place) + "\nExpect "
|
||||
+ str(expect_t) + "\n" + "But Got" + str(actual_t))
|
||||
|
||||
def check_momentum_step(self, place):
|
||||
self.setup(place=place)
|
||||
|
||||
dgc_momentum_op = Operator(self.op_type, **self.kwargs)
|
||||
dgc_momentum_op.run(self.scope, self.place)
|
||||
|
||||
self.check(
|
||||
np.array(self.param_tensor), self.outputs['ParamOut'], self.place,
|
||||
self.param_name)
|
||||
|
||||
self.check(
|
||||
np.array(self.velocity_tensor), self.outputs['VelocityOut'],
|
||||
self.place, self.velocity_name)
|
||||
|
||||
def check_sgd_step(self, place):
|
||||
self.setup(place=place, step=15.0)
|
||||
|
||||
dgc_momentum_op = Operator(self.op_type, **self.kwargs)
|
||||
dgc_momentum_op.run(self.scope, self.place)
|
||||
|
||||
self.check(
|
||||
np.array(self.param_tensor), self.outputs['SGDOut'], self.place,
|
||||
self.param_name)
|
||||
|
||||
def test_cuda_place(self):
|
||||
if not core.is_compiled_with_cuda():
|
||||
return
|
||||
place = core.CUDAPlace(0)
|
||||
self.check_momentum_step(place)
|
||||
self.check_sgd_step(place)
|
||||
|
||||
def test_cpu_place(self):
|
||||
place = core.CPUPlace()
|
||||
self.check_momentum_step(place)
|
||||
self.check_sgd_step(place)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,108 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import paddle.fluid.framework as framework
|
||||
import paddle.fluid.optimizer as optimizer
|
||||
import paddle.compat as cpt
|
||||
from paddle.fluid.backward import append_backward
|
||||
from paddle.fluid.transpiler.details import program_to_code
|
||||
|
||||
|
||||
class TestDGCMomentumOptimizer(unittest.TestCase):
|
||||
class MockDGCMomentum(optimizer.DGCMomentumOptimizer):
|
||||
def get_accumulators(self):
|
||||
return self._accumulators
|
||||
|
||||
def get_velocity_str(self):
|
||||
return self._velocity_acc_str
|
||||
|
||||
def check_dgc_momentum_optimizer(self, dims=[5, 10, 8], name="momentum"):
|
||||
init_program = framework.Program()
|
||||
program = framework.Program()
|
||||
block = program.global_block()
|
||||
mul_x = block.create_parameter(
|
||||
dtype="float32",
|
||||
shape=[dims[0], dims[1]],
|
||||
lod_level=0,
|
||||
name="mul.x",
|
||||
optimize_attr={'learning_rate': 1.1})
|
||||
mul_y = block.create_var(
|
||||
dtype="float32",
|
||||
shape=[dims[1], dims[2]],
|
||||
lod_level=0,
|
||||
name="mul.y")
|
||||
mul_out = block.create_var(
|
||||
dtype="float32",
|
||||
shape=[dims[0], dims[2]],
|
||||
lod_level=0,
|
||||
name="mul.out")
|
||||
block.append_op(
|
||||
type="mul",
|
||||
inputs={"X": mul_x,
|
||||
"Y": mul_y},
|
||||
outputs={"Out": mul_out},
|
||||
attrs={"x_num_col_dims": 1})
|
||||
learning_rate = 0.01
|
||||
dgc_momentum_optimizer = self.MockDGCMomentum(
|
||||
learning_rate=learning_rate, momentum=0.2, rampup_begin_step=0)
|
||||
mean_out = block.create_var(
|
||||
dtype="float32", shape=[1], lod_level=0, name="mean.out")
|
||||
block.append_op(
|
||||
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
|
||||
# params_grads = append_backward(mean_out)
|
||||
params_grads = dgc_momentum_optimizer.backward(mean_out)
|
||||
self.assertEqual(len(params_grads), 1)
|
||||
self.assertEqual(len(dgc_momentum_optimizer.get_accumulators()), 0)
|
||||
with framework.program_guard(program, init_program):
|
||||
opts = dgc_momentum_optimizer.apply_gradients(params_grads)
|
||||
self.assertEqual(len(opts), 2)
|
||||
sgd_op = opts[-1]
|
||||
self.assertEqual([op.type for op in opts], ["scale", name])
|
||||
self.assertFalse(sgd_op.attr('use_nesterov'))
|
||||
|
||||
# Check accumulators
|
||||
accumulators = dgc_momentum_optimizer.get_accumulators()
|
||||
self.assertEqual(len(accumulators), 1)
|
||||
self.assertTrue(
|
||||
dgc_momentum_optimizer.get_velocity_str() in accumulators)
|
||||
velocity_acc = accumulators[dgc_momentum_optimizer.get_velocity_str()]
|
||||
self.assertEqual(len(velocity_acc), 1)
|
||||
self.assertTrue(mul_x.name in velocity_acc)
|
||||
|
||||
# Check init_program
|
||||
init_ops = init_program.global_block().ops
|
||||
self.assertEqual(len(init_ops), 2)
|
||||
self.assertEqual(init_ops[0].type, "fill_constant")
|
||||
self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate)
|
||||
self.assertEqual(init_ops[1].type, "fill_constant")
|
||||
self.assertAlmostEqual(init_ops[1].attr('value'), 0.0)
|
||||
|
||||
with open("test_dgc_optimizer_" + name + ".log", "w") as f:
|
||||
program_to_code(program, fout=f)
|
||||
|
||||
def test_momentum_without_dgc(self):
|
||||
self.check_dgc_momentum_optimizer()
|
||||
|
||||
def test_momentum_with_dgc(self):
|
||||
# 16 * 1024 = 16384, use dgc momentum
|
||||
self.check_dgc_momentum_optimizer(
|
||||
dims=[16, 1024, 8], name="dgc_momentum")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue