[Sharding] add hybrid-dp feature (#29518)

* Sharding add hybrid-dp feature

* update sharding in distributed_strategy

* update sharding unitest

* revise code format for sharding
revert-31562-mean
JZ-LIANG 5 years ago committed by GitHub
parent 1e72e03217
commit d33d468f02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,6 +26,8 @@ message RecomputeConfig { repeated string checkpoints = 1; }
message ShardingConfig { message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_group_size = 3 [ default = 8 ];
} }
message AMPConfig { message AMPConfig {

@ -71,7 +71,11 @@ class FP16Utils(object):
return inserted_op_num return inserted_op_num
@staticmethod @staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, nrings): def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
"""
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
"""
# remove cast # remove cast
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if not FP16Utils.is_fp32_cast_op(block, op): if not FP16Utils.is_fp32_cast_op(block, op):
@ -79,9 +83,9 @@ class FP16Utils(object):
output_name = op.desc.output_arg_names()[0] output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD") param_name = output_name.strip("@GRAD")
if param_name not in shard.global_params: if param_name not in shard.global_params:
raise ValueError("Input 'X' of check_finite_and_unscale must" raise ValueError("Output 'X' of cast_op must be a grad of"
"be grads, but {} is not a grad".format( "model param, but {} is not a grad".format(
input_name)) output_name))
if output_name in reduced_grads_to_param: if output_name in reduced_grads_to_param:
continue continue
if shard.has_param(param_name): if shard.has_param(param_name):
@ -137,10 +141,12 @@ class FP16Utils(object):
type='c_allreduce_max', type='c_allreduce_max',
inputs={'X': inf_var_fp32}, inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_fp32}, outputs={'Out': inf_var_fp32},
attrs={'ring_id': 0, attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize}) OP_ROLE_KEY: OpRole.Optimize})
comm_op_num = insert_sync_comm_ops(
block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32]) comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32])
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num, update_loss_scaling_op_idx + 3 + comm_op_num,
type='cast', type='cast',

@ -16,14 +16,19 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
class GradientClipHelper(object): class GradientClipHelper(object):
def __init__(self): def __init__(self, sharding_ring_id):
pass self.sharding_ring_id = sharding_ring_id
def _is_gradient_clip_op(self, op): def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \ return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip") and op.desc.attr("op_namescope").startswith("/gradient_clip")
def prune_gradient_clip(self, block, shard): def prune_gradient_clip(self, block, shard):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
deperated_vars = set() deperated_vars = set()
deperate_op_idx = set() deperate_op_idx = set()
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
@ -75,8 +80,10 @@ class GradientClipHelper(object):
type='c_allreduce_sum', type='c_allreduce_sum',
inputs={'X': sum_res}, inputs={'X': sum_res},
outputs={'Out': sum_res}, outputs={'Out': sum_res},
attrs={'ring_id': 0, attrs={
OP_ROLE_KEY: OpRole.Optimize}) 'ring_id': self.sharding_ring_id,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync( block._insert_op_without_sync(
idx + 1, idx + 1,
type='c_sync_calc_stream', type='c_sync_calc_stream',

@ -43,6 +43,7 @@ class ProgramDeps(object):
return None return None
def _build_deps(self, ): def _build_deps(self, ):
for var_name in self._start_vars: for var_name in self._start_vars:
self._var_to_use_op[var_name] = [] self._var_to_use_op[var_name] = []
self._var_to_generate_op[var_name] = [] self._var_to_generate_op[var_name] = []

@ -124,6 +124,14 @@ class Shard(object):
return True return True
return False return False
def filter_grads(self, grads):
grads_in_shard = []
for grad in grads:
param = grad.split("@")[0]
if self.has_param(param):
grads_in_shard.append(grad)
return grads_in_shard
class ProgramSegment(object): class ProgramSegment(object):
def __init__(self, block): def __init__(self, block):

Loading…
Cancel
Save