parent
2d24f56a7a
commit
427c5529ea
@ -0,0 +1,126 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import core, unique_name
|
||||
from ..base.private_helper_function import wait_server_ready
|
||||
|
||||
OpRole = core.op_proto_and_checker_maker.OpRole
|
||||
|
||||
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
|
||||
OP_ROLE_VAR_KEY = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
|
||||
|
||||
|
||||
def is_update_op(op):
|
||||
return 'Param' in op.input_names and 'Grad' in op.input_names and \
|
||||
"LearningRate" in op.input_names
|
||||
|
||||
|
||||
def is_loss_grad_op(op):
|
||||
if OP_ROLE_KEY not in op.attr_names:
|
||||
return False
|
||||
op_role = int(op.all_attrs()[OP_ROLE_KEY])
|
||||
return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss)
|
||||
|
||||
|
||||
def is_backward_op(op):
|
||||
return OP_ROLE_KEY in op.attr_names and \
|
||||
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)
|
||||
|
||||
|
||||
def is_optimizer_op(op):
|
||||
return OP_ROLE_KEY in op.attr_names and \
|
||||
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize)
|
||||
|
||||
|
||||
class CollectiveHelper(object):
|
||||
def __init__(self, role_maker, nrings=1, wait_port='6174'):
|
||||
self.nrings = nrings
|
||||
self.wait_port = wait_port
|
||||
self.role_maker = role_maker
|
||||
|
||||
def update_startup_program(self, startup_program=None):
|
||||
self.startup_program = startup_program
|
||||
if startup_program is None:
|
||||
self.startup_program = fluid.default_startup_program()
|
||||
|
||||
endpoints = self.role_maker.get_trainer_endpoints()
|
||||
current_endpoint = endpoints[self.role_maker.worker_index()]
|
||||
for ring_id in range(self.nrings):
|
||||
self._init_communicator(
|
||||
self.startup_program, current_endpoint, endpoints,
|
||||
self.role_maker.worker_index(), ring_id, self.wait_port)
|
||||
self._broadcast_params()
|
||||
|
||||
def _init_communicator(self, program, current_endpoint, endpoints, rank,
|
||||
ring_id, wait_port):
|
||||
nranks = len(endpoints)
|
||||
other_endpoints = endpoints[:]
|
||||
other_endpoints.remove(current_endpoint)
|
||||
if rank == 0 and wait_port:
|
||||
wait_server_ready(other_endpoints)
|
||||
|
||||
block = program.global_block()
|
||||
nccl_id_var = block.create_var(
|
||||
name=unique_name.generate('nccl_id'),
|
||||
persistable=True,
|
||||
type=core.VarDesc.VarType.RAW)
|
||||
block.append_op(
|
||||
type='c_gen_nccl_id',
|
||||
inputs={},
|
||||
outputs={'Out': nccl_id_var},
|
||||
attrs={
|
||||
'rank': rank,
|
||||
'endpoint': current_endpoint,
|
||||
'other_endpoints': other_endpoints,
|
||||
OP_ROLE_KEY: OpRole.Forward
|
||||
})
|
||||
block.append_op(
|
||||
type='c_comm_init',
|
||||
inputs={'X': nccl_id_var},
|
||||
outputs={},
|
||||
attrs={
|
||||
'nranks': nranks,
|
||||
'rank': rank,
|
||||
'ring_id': ring_id,
|
||||
OP_ROLE_KEY: OpRole.Forward
|
||||
})
|
||||
|
||||
def _broadcast_params(self):
|
||||
block = self.startup_program.global_block()
|
||||
ring_id = -1
|
||||
for param in block.iter_parameters():
|
||||
if param.is_distributed:
|
||||
continue
|
||||
|
||||
ring_id = (ring_id + 1) % self.nrings
|
||||
block.append_op(
|
||||
type='c_broadcast',
|
||||
inputs={'X': param},
|
||||
outputs={'Out': param},
|
||||
attrs={
|
||||
'ring_id': ring_id,
|
||||
'root': 0,
|
||||
OP_ROLE_KEY: OpRole.Forward
|
||||
})
|
||||
|
||||
for ring_id in range(self.nrings):
|
||||
block.append_op(
|
||||
type='c_sync_comm_stream',
|
||||
inputs={'X': param},
|
||||
outputs={'Out': param},
|
||||
attrs={'ring_id': ring_id,
|
||||
OP_ROLE_KEY: OpRole.Forward})
|
@ -0,0 +1,193 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle.fluid import program_guard, layers
|
||||
from paddle.fluid.optimizer import Momentum, SGD
|
||||
from .meta_optimizer_base import MetaOptimizerBase
|
||||
from .common import OpRole, OP_ROLE_KEY, CollectiveHelper, is_update_op
|
||||
|
||||
|
||||
class LocalSGDOptimizer(MetaOptimizerBase):
|
||||
def __init__(self, optimizer):
|
||||
super(LocalSGDOptimizer, self).__init__(optimizer)
|
||||
self.inner_opt = optimizer
|
||||
self.meta_optimizers_white_list = []
|
||||
self.snapshot_key = '@SNAPSHOT'
|
||||
|
||||
def _can_apply(self):
|
||||
if not self.user_defined_strategy.localsgd:
|
||||
return False
|
||||
|
||||
if self.role_maker.worker_num() <= 1:
|
||||
return False
|
||||
|
||||
return isinstance(self.inner_opt, Momentum) \
|
||||
or isinstance(self.inner_opt, SGD)
|
||||
|
||||
def _disable_strategy(self, dist_strategy):
|
||||
dist_strategy.localsgd = False
|
||||
dist_strategy.localsgd_configs = {'k_steps': 1}
|
||||
|
||||
def snapshot_name(self, param_name):
|
||||
return param_name + self.snapshot_key
|
||||
|
||||
def minimize_impl(self,
|
||||
loss,
|
||||
startup_program=None,
|
||||
parameter_list=None,
|
||||
no_grad_set=None):
|
||||
minimized = self.inner_opt.minimize(
|
||||
loss, startup_program=startup_program)
|
||||
|
||||
init_k_steps = self.user_defined_strategy.localsgd_configs['k_steps']
|
||||
auto_steps = self.user_defined_strategy.auto
|
||||
|
||||
if startup_program is None:
|
||||
startup_program = default_startup_program()
|
||||
main_block = loss.block
|
||||
|
||||
self.nrings = 2
|
||||
collective_helper = CollectiveHelper(self.role_maker, self.nrings)
|
||||
collective_helper.update_startup_program(startup_program)
|
||||
|
||||
with program_guard(main_block.program):
|
||||
step = layers.autoincreased_step_counter(begin=0)
|
||||
k_steps = layers.create_global_var(
|
||||
name="k_steps",
|
||||
shape=[1],
|
||||
value=init_k_steps,
|
||||
dtype='int64',
|
||||
persistable=True)
|
||||
last_step = layers.create_global_var(
|
||||
name="last_step",
|
||||
shape=[1],
|
||||
value=int(0),
|
||||
dtype='int64',
|
||||
persistable=True)
|
||||
|
||||
if auto_steps:
|
||||
lr_0 = layers.create_global_var(
|
||||
name="lr_0",
|
||||
shape=[1],
|
||||
value=float(0),
|
||||
dtype='float32',
|
||||
persistable=True)
|
||||
loss_0 = layers.create_global_var(
|
||||
name="loss_0",
|
||||
shape=[1],
|
||||
value=float(0),
|
||||
dtype='float32',
|
||||
persistable=True)
|
||||
|
||||
global_lr = self.inner_opt._global_learning_rate()
|
||||
|
||||
def initialize():
|
||||
layers.assign(loss, loss_0)
|
||||
layers.assign(global_lr, lr_0)
|
||||
|
||||
layers.cond(step == 0, initialize)
|
||||
|
||||
def communicate():
|
||||
ordered_param_snapshot = []
|
||||
ring_id = -1
|
||||
for idx, op in reversed(list(enumerate(main_block.ops))):
|
||||
if is_update_op(op):
|
||||
param = main_block.vars[op.input('Param')[0]]
|
||||
if param.is_distributed:
|
||||
continue
|
||||
|
||||
snapshot = main_block.create_var(
|
||||
name=self.snapshot_name(param.name),
|
||||
shape=param.shape,
|
||||
persistable=True,
|
||||
stop_gradient=True,
|
||||
dtype=param.dtype)
|
||||
|
||||
main_block._insert_op(
|
||||
idx + 1,
|
||||
type='elementwise_sub',
|
||||
inputs={'X': [snapshot],
|
||||
'Y': [param]},
|
||||
outputs={'Out': [param]},
|
||||
attrs={OP_ROLE_KEY: OpRole.Optimize})
|
||||
main_block._insert_op(
|
||||
idx + 2,
|
||||
type='c_sync_calc_stream',
|
||||
inputs={'X': param},
|
||||
outputs={'Out': param},
|
||||
attrs={OP_ROLE_KEY: OpRole.Optimize})
|
||||
ring_id = (ring_id + 1) % self.nrings
|
||||
main_block._insert_op(
|
||||
idx + 3,
|
||||
type='c_allreduce_sum',
|
||||
inputs={'X': [param]},
|
||||
outputs={'Out': [param]},
|
||||
attrs={
|
||||
'ring_id': ring_id,
|
||||
OP_ROLE_KEY: OpRole.Optimize
|
||||
})
|
||||
|
||||
ordered_param_snapshot.append((param, snapshot))
|
||||
|
||||
for ring_id in range(self.nrings):
|
||||
main_block.append_op(
|
||||
type='c_sync_comm_stream',
|
||||
inputs={'X': param},
|
||||
outputs={'Out': param},
|
||||
attrs={
|
||||
'ring_id': ring_id,
|
||||
OP_ROLE_KEY: OpRole.Optimize
|
||||
})
|
||||
|
||||
for param_snapshot in reversed(ordered_param_snapshot):
|
||||
param = param_snapshot[0]
|
||||
snapshot = param_snapshot[1]
|
||||
main_block.append_op(
|
||||
type='scale',
|
||||
inputs={'X': [param]},
|
||||
outputs={'Out': [param]},
|
||||
attrs={
|
||||
'scale': 1.0 / self.role_maker.worker_num(),
|
||||
OP_ROLE_KEY: OpRole.Optimize
|
||||
})
|
||||
main_block.append_op(
|
||||
type='elementwise_sub',
|
||||
inputs={'X': [snapshot],
|
||||
'Y': [param]},
|
||||
outputs={'Out': [param]},
|
||||
attrs={OP_ROLE_KEY: OpRole.Optimize})
|
||||
main_block.append_op(
|
||||
type='assign',
|
||||
inputs={'X': [param]},
|
||||
outputs={'Out': [snapshot]},
|
||||
attrs={OP_ROLE_KEY: OpRole.Optimize})
|
||||
|
||||
if auto_steps:
|
||||
next_local_steps = layers.cast(
|
||||
layers.ceil(
|
||||
layers.sqrt(lr_0 * loss / (global_lr * loss_0) *
|
||||
float(init_k_steps))),
|
||||
dtype='int64')
|
||||
max_local_steps = layers.fill_constant(
|
||||
shape=[1], dtype='int64', value=16)
|
||||
next_local_steps = layers.elementwise_min(next_local_steps,
|
||||
max_local_steps)
|
||||
layers.assign(next_local_steps, k_steps)
|
||||
layers.assign(step, last_step)
|
||||
|
||||
layers.cond(step - last_step == k_steps, communicate)
|
||||
|
||||
return minimized
|
@ -0,0 +1,55 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import paddle
|
||||
import os
|
||||
|
||||
import paddle.fleet as fleet
|
||||
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||
|
||||
|
||||
class TestFleetLocalSGDMetaOptimizer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["PADDLE_TRAINER_ID"] = "1"
|
||||
os.environ[
|
||||
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002"
|
||||
|
||||
def test_localsgd_optimizer(self):
|
||||
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||
fleet.init(role)
|
||||
input_x = paddle.fluid.layers.data(
|
||||
name="x", shape=[32], dtype='float32')
|
||||
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
|
||||
|
||||
fc = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
|
||||
prediction = paddle.fluid.layers.fc(input=[fc], size=2, act='softmax')
|
||||
cost = paddle.fluid.layers.cross_entropy(
|
||||
input=prediction, label=input_y)
|
||||
avg_cost = paddle.fluid.layers.mean(x=cost)
|
||||
|
||||
strategy = paddle.fleet.DistributedStrategy()
|
||||
strategy.localsgd = True
|
||||
strategy.auto = True
|
||||
config = strategy.localsgd_configs
|
||||
config['k_steps'] = 1
|
||||
strategy.localsgd_configs = config
|
||||
|
||||
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
|
||||
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
|
||||
optimizer.minimize(avg_cost)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue