You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py

515 lines
23 KiB

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import unique_name, core
import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op
from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper import WeightDecayHelper
from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper
from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
import logging
from functools import reduce
__all__ = ["ShardingOptimizer"]
class ShardingOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(ShardingOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
self.meta_optimizers_white_list = [
"RecomputeOptimizer",
"AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self._main_program = None
self._startup_program = None
self._segments = []
# params and fp16 params is for broadcast
self._params = set([])
self._broadcast_vars = set([])
# reduced grads to param name
self._reduced_grads_to_param = {}
self._shard = Shard()
def _can_apply(self):
if not self.role_maker._is_collective:
return False
if self.role_maker._worker_num() <= 1:
return False
return self.user_defined_strategy.sharding
def _disable_strategy(self, dist_strategy):
dist_strategy.sharding = False
dist_strategy.sharding_configs = {}
def _enable_strategy(self, dist_strategy, context):
dist_strategy.sharding = True
dist_strategy.sharding_configs = {"fuse_broadcast_MB": 32}
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
# TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1
self._nrings_dp = 1
self._fuse_broadcast_MB = self.user_defined_strategy.sharding_configs[
"fuse_broadcast_MB"]
self.hybrid_dp = self.user_defined_strategy.sharding_configs[
"hybrid_dp"]
if self.inner_opt is None:
raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.")
optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set)
if startup_program is None:
startup_program = default_startup_program()
main_block = loss.block
startup_block = startup_program.global_block()
self._main_program = main_block.program
self._startup_program = startup_program
# step1: set_up
self._set_up(params_grads)
# step2: split_program
self._split_program(main_block)
# step3: add broadcast and reduce ops
self._add_broadcast_allreduce(main_block)
main_block._sync_with_cpp()
startup_block._sync_with_cpp()
# step4: insert reduce_sum for grad
insert_scale_loss_grad_ops(
main_block, scale=1.0 / self.role_maker._worker_num())
main_block._sync_with_cpp()
# step5: remove unneeded ops and vars from block
self._prune_main_program(main_block)
self._prune_startup_program(startup_block)
# check op dependecy
check_broadcast(main_block)
check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self._wait()
return optimize_ops, params_grads
def _set_up(self, params_grads):
# step 1: initialize nccl
self.global_word_size = self.role_maker._worker_num()
self.global_rank = self.role_maker._worker_index()
self.endpoints = self.role_maker._get_trainer_endpoints()
self.current_endpoint = self.endpoints[self.global_rank]
self._collective_helper = CollectiveHelper(self.role_maker,
self._nrings_sharding)
# config sharding & dp groups
self._init_comm()
# sharding
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.sharding_group_endpoints, self.sharding_rank,
self.sharding_ring_id, True)
# dp
if self.hybrid_dp:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True)
startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp()
# step 2: split params
self._params = set([x[0].name for x in params_grads])
self._shard.setup(params_grads, self.sharding_rank,
self.sharding_group_size)
# step 3: get broadcast vars
self._broadcast_vars = self._shard.find_broadcast_params(
self._main_program.global_block())
def _wait(self, ):
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()]
if self.role_maker._worker_index() == 0:
self._collective_helper._wait(current_endpoint, endpoints)
def _split_program(self, block):
for op_idx, op in reversed(list(enumerate(block.ops))):
if int(op.attr('op_role')) != int(OpRole.Optimize):
last_backward_op_idx = op_idx + 1
break
segment = ProgramSegment(block)
segment._end_idx = last_backward_op_idx
for op_idx in reversed(range(last_backward_op_idx)):
op = block.ops[op_idx]
assert (int(op.attr('op_role')) != int(OpRole.Optimize))
if segment._param_mem >= self._fuse_broadcast_MB:
segment._start_idx = op_idx + 1
self._segments.insert(0, segment)
segment = ProgramSegment(block)
segment._end_idx = op_idx + 1
# find broadcast vars
for input_name in op.desc.input_arg_names():
if input_name not in self._broadcast_vars:
continue
if input_name in segment._param2broadcast:
# skip broadcast because it reuse the old broadcast var
broadcast_name = segment._param2broadcast[input_name]
if input_name != broadcast_name:
op._rename_input(input_name, broadcast_name)
continue
if self._shard.has_param(input_name):
broadcast_var_name = input_name
else:
broadcast_var_name = unique_name.generate(input_name +
"@BroadCast")
segment._fill_constant_vars.append(broadcast_var_name)
segment._param2broadcast[input_name] = broadcast_var_name
segment._broadcast_vars.append((broadcast_var_name,
self._shard.device(input_name)))
segment._param_mem += get_var_size(
self._main_program.global_block().var(input_name))
# find reduce vars
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) != 0:
assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2):
param, reduced_grad = op_role_var[i], op_role_var[i + 1]
segment._allreduce_vars.append(reduced_grad)
assert (
reduced_grad not in self._reduced_grads_to_param)
self._reduced_grads_to_param[reduced_grad] = param
# find cast op
if FP16Utils.is_fp16_cast_op(block, op, self._params):
fp32_param = op.desc.input_arg_names()[0]
fp16_param = op.desc.output_arg_names()[0]
if self._shard.has_param(fp32_param):
segment._cast_ops[fp16_param] = fp32_param
if segment._param_mem > 0:
segment._start_idx = 0
self._segments.insert(0, segment)
return
def _prune_main_program(self, block):
"""
calculate deps from allredce op to optimize op,
remove ops and vars not needed in this worker
1. prune regularization (weight decay)
2. prune cast_fp32_to_fp16; update amp_infine_checking
3. prune gradient_clip related; update global_norm_sum
4. prune optimizer op + param + gradient
"""
weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, self._shard)
FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param,
self.sharding_ring_id)
gradientclip_helper = GradientClipHelper(self.sharding_ring_id)
gradientclip_helper.prune_gradient_clip(block, self._shard)
# build prog deps
reduced_grads = []
for idx, op in enumerate(block.ops):
input_names = op.desc.input_arg_names()
output_names = op.desc.output_arg_names()
if op.type == "c_allreduce_sum":
assert (len(output_names) == 1)
output_name = output_names[0]
reduced_grads.append(output_name)
# prune optimizer state and param
pruned_opti_vars = []
for var_name in list(block.vars.keys()):
if self._shard.is_opti_var(var_name) and \
not self._shard.has_opt_var(var_name):
pruned_opti_vars.append(var_name)
program_deps = ProgramDeps(block, reduced_grads, pruned_opti_vars)
# Init
for var_name in program_deps._end_vars:
program_deps._should_removed_var.add(var_name)
# Prune
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in [
"c_allreduce_sum", "c_sync_comm_stream",
"c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init"
]:
pass
elif op.type == "conditional_block":
assert (op.desc.has_attr("sub_block"))
subblock_idx = op.desc.attr("sub_block").id
subblock_deps = program_deps.get_sub_block_deps(subblock_idx)
# only prune amp subblock
if subblock_deps is None or not self._is_amp_subblock(op):
continue
# init
reversed_output_vars = []
for output_name in op.desc.output("Out"):
if output_name in program_deps._should_removed_var:
subblock_deps._should_removed_var.add(output_name)
program_deps.crop_output_var_from_op(idx, output_name)
else:
reversed_output_vars.append(output_name)
# prune
for sub_op_idx, _ in reversed(
list(enumerate(subblock_deps._block.ops))):
if subblock_deps.should_remove_op(sub_op_idx):
subblock_deps.remove_op(sub_op_idx)
reversed_input_vars = []
for input_name in op.desc.input('Input'):
if input_name not in subblock_deps._should_removed_var:
reversed_input_vars.append(input_name)
else:
program_deps.crop_input_var_from_op(idx, input_name)
op.desc.set_input('Input', reversed_input_vars)
op.desc.set_output('Out', reversed_output_vars)
else:
# if all outputs of this op are in _should_removed_var
# _should_removed_var: opt state not cur shard
if program_deps.should_remove_op(idx):
program_deps.remove_op(idx)
block._sync_with_cpp()
return
def _add_broadcast_allreduce(self, block):
"""
_add_broadcast_allreduce
"""
if len(self._segments) < 1:
return
# sharding
if self._segments[-1]._allreduce_vars:
shard_allredue_vars = self._shard.filter_grads(self._segments[-1]
._allreduce_vars)
if self.hybrid_dp and len(shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars)
insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars)
for idx, segment in reversed(list(enumerate(self._segments))):
allreduce_vars = self._segments[
idx - 1]._allreduce_vars if idx > 0 else []
broadcast_vars = self._segments[idx +
1]._broadcast_vars if idx < len(
self._segments) - 1 else []
fill_constant_vars = self._segments[
idx + 2]._fill_constant_vars if idx < len(
self._segments) - 2 else []
cast_ops = self._segments[idx + 2]._cast_ops if idx < len(
self._segments) - 2 else {}
for op_idx in reversed(range(segment._start_idx, segment._end_idx)):
op = block.ops[op_idx]
for input_name in op.desc.input_arg_names():
if input_name in segment._param2broadcast and \
input_name != segment._param2broadcast[input_name]:
op._rename_input(input_name,
segment._param2broadcast[input_name])
for param_name, broadcast_name in segment._param2broadcast.items():
if param_name != broadcast_name:
block.create_var(
name=broadcast_name,
shape=self._main_program.global_block().var(
param_name).shape,
dtype=self._main_program.global_block().var(param_name)
.dtype,
persistable=False)
# step1: remove cast ops
block._sync_with_cpp()
segment._end_idx += FP16Utils.remove_cast_op(block, self._params,
segment, 0)
# step2: add Sync ops
shard_allredue_vars = self._shard.filter_grads(allreduce_vars)
if self.hybrid_dp and len(shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, segment._end_idx, self.dp_ring_id,
shard_allredue_vars)
broad_cast_vars = [x[0] for x in broadcast_vars]
if len(broad_cast_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx,
self.sharding_ring_id, broad_cast_vars)
else:
comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars]
if len(comm_dep_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx,
self.sharding_ring_id, comm_dep_vars)
calc_dep_vars = fill_constant_vars + [
k for k, v in cast_ops.items()
] + self._segments[idx]._allreduce_vars
if len(calc_dep_vars) > 0:
insert_sync_calc_op(block, segment._end_idx,
[calc_dep_vars[-1]])
# step3: insert `fill_constant` ops
insert_fill_constant_ops(block, segment._end_idx,
fill_constant_vars)
# step4: add `cast` ops
insert_cast_ops(block, segment._end_idx, cast_ops)
# step5: add broadcast ops
insert_broadcast_ops(block, segment._start_idx,
self.sharding_ring_id, broadcast_vars)
# step6: add all_reduce ops
# dp
if self.hybrid_dp and len(shard_allredue_vars) >= 1:
insert_allreduce_ops(block, segment._start_idx, self.dp_ring_id,
shard_allredue_vars)
insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
# sharding
insert_allreduce_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
block._sync_with_cpp()
if self._segments[0]._broadcast_vars:
broadcast_vars = [x[0] for x in self._segments[0]._broadcast_vars]
insert_sync_comm_ops(block, self._segments[0]._start_idx,
self.sharding_ring_id, broadcast_vars)
insert_broadcast_ops(block, self._segments[0]._start_idx,
self.sharding_ring_id,
self._segments[0]._broadcast_vars)
fill_constant_vars = []
for x in self._segments[:2]:
fill_constant_vars += x._fill_constant_vars
# Join
cast_ops = {}
for x in self._segments[:2]:
for k, v in x._cast_ops.items():
cast_ops[k] = v
calc_deps_vars = fill_constant_vars + [k for k, v in cast_ops.items()]
if fill_constant_vars or cast_ops:
insert_sync_calc_op(block, self._segments[0]._start_idx,
[calc_deps_vars[-1]])
if fill_constant_vars:
insert_fill_constant_ops(block, self._segments[0]._start_idx,
fill_constant_vars)
if cast_ops:
insert_cast_ops(block, self._segments[0]._start_idx, cast_ops)
return
def _prune_startup_program(self, block):
for idx, op in reversed(list(enumerate(block.ops))):
for output_name in op.desc.output_arg_names():
if self._shard.has_var(output_name):
continue
#TODO why do we remove op, when only one var is removed
block._remove_op(idx, sync=False)
break
for var_name in list(block.vars.keys()):
if self._shard.has_var(var_name):
continue
block._remove_var(var_name, sync=False)
block._sync_with_cpp()
def _init_comm(self):
if self.hybrid_dp:
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
"sharding_group_size"]
self.sharding_ring_id = 0
self.sharding_rank = self.global_rank % self.sharding_group_size
self.dp_group_size = self.global_word_size // self.sharding_group_size
self.dp_rank = self.global_rank // self.sharding_group_size
self.dp_ring_id = self.sharding_rank + 1
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // self.sharding_group_size) == self.dp_rank
]
self.dp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size) == self.sharding_rank
]
assert self.global_word_size > self.sharding_group_size, \
"global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size)
assert self.global_word_size % self.sharding_group_size == 0, \
"global_word_size: {} should be divisible to the sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size)
assert self.dp_group_size * self.sharding_group_size == self.global_word_size, \
"global_word_size: {} should be equal to the product of sharding_group_size: {} and dp_group_size: {}".format(
self.global_word_size,
self.sharding_group_size,
self.dp_group_size)
logging.info("Using Sharing&DP mode !")
else:
self.sharding_ring_id = 0
self.sharding_rank = self.global_rank
self.sharding_group_size = self.role_maker._worker_num()
self.sharding_group_endpoints = self.endpoints
self.dp_ring_id = -1
self.dp_rank = -1
self.dp_group_size = None
self.dp_group_endpoints = None
logging.info("Using Sharing alone mode !")
logging.info("global word size: {}".format(self.global_word_size))
logging.info("global rank: {}".format(self.global_rank))
logging.info("sharding group_size: {}".format(self.sharding_group_size))
logging.info("sharding rank: {}".format(self.sharding_rank))
logging.info("dp group size: {}".format(self.dp_group_size))
logging.info("dp rank: {}".format(self.dp_rank))
logging.info("current endpoint: {}".format(self.current_endpoint))
logging.info("sharding group endpoints: {}".format(
self.sharding_group_endpoints))
logging.info("dp group endpoints: {}".format(self.dp_group_endpoints))
logging.info("global word endpoints: {}".format(self.endpoints))
return