[Sharding] add hybrid-dp feature (#29518)

* Sharding add hybrid-dp feature

* update sharding in distributed_strategy

* update sharding unitest

* revise code format for sharding
revert-31562-mean
JZ-LIANG 5 years ago committed by GitHub
parent 1e72e03217
commit d33d468f02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,6 +26,8 @@ message RecomputeConfig { repeated string checkpoints = 1; }
message ShardingConfig { message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_group_size = 3 [ default = 8 ];
} }
message AMPConfig { message AMPConfig {

@ -71,7 +71,11 @@ class FP16Utils(object):
return inserted_op_num return inserted_op_num
@staticmethod @staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, nrings): def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
"""
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
"""
# remove cast # remove cast
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if not FP16Utils.is_fp32_cast_op(block, op): if not FP16Utils.is_fp32_cast_op(block, op):
@ -79,9 +83,9 @@ class FP16Utils(object):
output_name = op.desc.output_arg_names()[0] output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD") param_name = output_name.strip("@GRAD")
if param_name not in shard.global_params: if param_name not in shard.global_params:
raise ValueError("Input 'X' of check_finite_and_unscale must" raise ValueError("Output 'X' of cast_op must be a grad of"
"be grads, but {} is not a grad".format( "model param, but {} is not a grad".format(
input_name)) output_name))
if output_name in reduced_grads_to_param: if output_name in reduced_grads_to_param:
continue continue
if shard.has_param(param_name): if shard.has_param(param_name):
@ -137,10 +141,12 @@ class FP16Utils(object):
type='c_allreduce_max', type='c_allreduce_max',
inputs={'X': inf_var_fp32}, inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_fp32}, outputs={'Out': inf_var_fp32},
attrs={'ring_id': 0, attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize}) OP_ROLE_KEY: OpRole.Optimize})
comm_op_num = insert_sync_comm_ops(
block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32]) comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32])
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num, update_loss_scaling_op_idx + 3 + comm_op_num,
type='cast', type='cast',

@ -16,14 +16,19 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
class GradientClipHelper(object): class GradientClipHelper(object):
def __init__(self): def __init__(self, sharding_ring_id):
pass self.sharding_ring_id = sharding_ring_id
def _is_gradient_clip_op(self, op): def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \ return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip") and op.desc.attr("op_namescope").startswith("/gradient_clip")
def prune_gradient_clip(self, block, shard): def prune_gradient_clip(self, block, shard):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
deperated_vars = set() deperated_vars = set()
deperate_op_idx = set() deperate_op_idx = set()
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
@ -75,8 +80,10 @@ class GradientClipHelper(object):
type='c_allreduce_sum', type='c_allreduce_sum',
inputs={'X': sum_res}, inputs={'X': sum_res},
outputs={'Out': sum_res}, outputs={'Out': sum_res},
attrs={'ring_id': 0, attrs={
OP_ROLE_KEY: OpRole.Optimize}) 'ring_id': self.sharding_ring_id,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync( block._insert_op_without_sync(
idx + 1, idx + 1,
type='c_sync_calc_stream', type='c_sync_calc_stream',

@ -43,6 +43,7 @@ class ProgramDeps(object):
return None return None
def _build_deps(self, ): def _build_deps(self, ):
for var_name in self._start_vars: for var_name in self._start_vars:
self._var_to_use_op[var_name] = [] self._var_to_use_op[var_name] = []
self._var_to_generate_op[var_name] = [] self._var_to_generate_op[var_name] = []

@ -124,6 +124,14 @@ class Shard(object):
return True return True
return False return False
def filter_grads(self, grads):
grads_in_shard = []
for grad in grads:
param = grad.split("@")[0]
if self.has_param(param):
grads_in_shard.append(grad)
return grads_in_shard
class ProgramSegment(object): class ProgramSegment(object):
def __init__(self, block): def __init__(self, block):

@ -78,52 +78,137 @@ def check_broadcast(block):
return return
def check_allreduce_sum(block): def check_allreduce_sum(block, shard, dp_ring_id=-1):
""" """
if a Var is allreduced, the op order should be: the op order should be:
grad:
- 0: op that generate Var - 0: op that generate Var
- 1: sync_calc - 1: sync_calc
- 2: allreduce_sum op - 2: allreduce_sum_sharding
- 3: sync_comm - 3: sync_comm
- 4: op that use Var - 4: allreuce_sum_dp (dp_grads)
- 5: sync_comm (dp_grads)
- 6: op that use Var (dp_grads & sum)
""" """
var_status = {} vars_status = {}
for op in block.ops: dp_grads_status = {}
idx_last_grad_allreduce = -1
idx_amp_allreduce = -1
idx_gradient_clip_allreduce = -1
for idx, op in enumerate(block.ops):
if op.type == "c_allreduce_sum": if op.type == "c_allreduce_sum":
ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
var_status[var_name] = -1 param = var_name.split("@")[0]
assert 'sum' in var_name or ("@GRAD" in var_name)
if 'sum' in var_name or (not shard.has_param(param)):
vars_status[var_name] = -1
else:
dp_grads_status[var_name] = -1
if ring_id != 0:
assert shard.has_param(param)
assert ring_id == dp_ring_id
if "sum" in var_name:
idx_amp_allreduce = idx
elif "@GRAD":
idx_last_grad_allreduce = idx
if op.type == "c_allreduce_max":
idx_gradient_clip_allreduce = idx
for op in block.ops: for op in block.ops:
if op.type == "c_sync_calc_stream": if op.type == "c_sync_calc_stream":
for var_name in var_status: for var_name in vars_status:
if var_name in var_status and var_status[var_name] == 0: if var_name in vars_status and vars_status[var_name] == 0:
var_status[var_name] = 1 vars_status[var_name] = 1
for var_name in dp_grads_status:
if var_name in dp_grads_status and dp_grads_status[
var_name] == 0:
dp_grads_status[var_name] = 1
elif op.type == "c_allreduce_sum": elif op.type == "c_allreduce_sum":
var_name = op.desc.input_arg_names()[0] var_name = op.desc.input_arg_names()[0]
if var_status[var_name] == -1: ring_id = op.desc.attr("ring_id")
if ring_id == 0:
if var_name in vars_status:
_status = vars_status[var_name]
else:
_status = dp_grads_status[var_name]
if _status == -1:
raise ValueError("{} is not generated, but you are" raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name)) "trying to all-reduce it".format(var_name))
if var_status[var_name] == 0: if _status == 0:
raise ValueError("There should be a sync_calc op " raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the" "after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name)) "c_allreduce_sum op".format(var_name))
assert (var_status[var_name] == 1) assert (_status == 1)
var_status[var_name] = 2 if var_name in vars_status:
vars_status[var_name] = 2
else:
dp_grads_status[var_name] = 2
else:
assert ring_id == dp_ring_id
param = var_name.split("@")[0]
assert shard.has_param(param)
assert dp_grads_status[var_name] == 3
dp_grads_status[var_name] = 4
elif op.type == "c_sync_comm_stream": elif op.type == "c_sync_comm_stream":
var_name = op.desc.input_arg_names()[0]
ring_id = op.desc.attr("ring_id")
if ring_id == 0:
for var_name in op.desc.input_arg_names(): for var_name in op.desc.input_arg_names():
if var_name in var_status and var_status[var_name] == 2: if var_name in vars_status:
var_status[var_name] = 3 assert vars_status[var_name] == 2
vars_status[var_name] = 3
elif var_name in dp_grads_status:
assert dp_grads_status[var_name] == 2
dp_grads_status[var_name] = 3
else:
for var_name in op.desc.input_arg_names():
param = var_name.split("@")[0]
assert ring_id == dp_ring_id
assert shard.has_param(param)
assert dp_grads_status[var_name] == 4
dp_grads_status[var_name] = 5
else: else:
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name in var_status: if input_name in vars_status:
if var_status[input_name] != 3: if vars_status[input_name] != 3:
raise ValueError("There should be a sync_comm op " raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".format( "after allreduce the Var: {}".format(
var_name)) input_name))
if input_name in dp_grads_status:
if dp_ring_id == -1:
if dp_grads_status[input_name] != 3:
raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".
format(input_name))
else:
if dp_grads_status[input_name] != 5:
raise ValueError(
"The grad in shard should be allreduce and sync"
"twice before usage {}".format(input_name))
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
if output_name in var_status and \ if output_name in vars_status and \
var_status[output_name] == -1: vars_status[output_name] == -1:
var_status[output_name] = 0 vars_status[output_name] = 0
if output_name in dp_grads_status and \
dp_grads_status[output_name] == -1:
dp_grads_status[output_name] = 0
# check sharding with amp
if idx_amp_allreduce != -1:
assert idx_amp_allreduce > idx_last_grad_allreduce
# check sharding with gradient_clip_by_global_norm
if idx_gradient_clip_allreduce != -1:
assert idx_gradient_clip_allreduce > idx_last_grad_allreduce
return return
@ -155,20 +240,34 @@ def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
return return
def insert_sync_comm_ops(block, insert_idx, nrings, comm_dep_vars): def insert_sync_comm_op(block, insert_idx, ring_id, comm_dep_vars):
""" """
_insert_sync_comm_ops insert sync_comm_op for single var
""" """
op_role = get_valid_op_role(block, insert_idx) op_role = get_valid_op_role(block, insert_idx)
for i in range(nrings):
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
type='c_sync_comm_stream', type='c_sync_comm_stream',
inputs={'X': comm_dep_vars}, inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars}, outputs={'Out': comm_dep_vars},
attrs={'ring_id': i, attrs={'ring_id': ring_id,
OP_ROLE_KEY: op_role}) OP_ROLE_KEY: op_role})
return nrings return 1
def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
"""
insert sync_comm_op for vars
"""
op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync(
insert_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': int(ring_id),
OP_ROLE_KEY: op_role})
return 1
def insert_fill_constant_ops(block, insert_idx, fill_constant_vars): def insert_fill_constant_ops(block, insert_idx, fill_constant_vars):
@ -210,13 +309,11 @@ def insert_cast_ops(block, insert_idx, cast_ops):
return return
def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars): def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
""" """
_add_allreduce_ops _add_allreduce_ops
""" """
ring_id = -1
for var in allreduce_vars: for var in allreduce_vars:
ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
type='c_allreduce_sum', type='c_allreduce_sum',
@ -224,17 +321,16 @@ def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars):
outputs={'Out': var}, outputs={'Out': var},
attrs={'ring_id': ring_id, attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward}) OP_ROLE_KEY: OpRole.Backward})
return return
def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root): def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
""" """
_add_broadcast_ops _add_broadcast_ops
""" """
ring_id = -1
op_role = get_valid_op_role(block, insert_idx) op_role = get_valid_op_role(block, insert_idx)
for broadcast_name, root_device in broadcast2root: for broadcast_name, root_device in broadcast2root:
ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
type='c_broadcast', type='c_broadcast',
@ -245,6 +341,7 @@ def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root):
'root': root_device, 'root': root_device,
OP_ROLE_KEY: op_role OP_ROLE_KEY: op_role
}) })
return return

Loading…
Cancel
Save