[Sharding] add hybrid-dp feature (#29518)

* Sharding add hybrid-dp feature

* update sharding in distributed_strategy

* update sharding unitest

* revise code format for sharding
revert-31562-mean
JZ-LIANG 4 years ago committed by GitHub
parent 1e72e03217
commit d33d468f02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,6 +26,8 @@ message RecomputeConfig { repeated string checkpoints = 1; }
message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_group_size = 3 [ default = 8 ];
}
message AMPConfig {

@ -71,7 +71,11 @@ class FP16Utils(object):
return inserted_op_num
@staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, nrings):
def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
"""
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
"""
# remove cast
for idx, op in reversed(list(enumerate(block.ops))):
if not FP16Utils.is_fp32_cast_op(block, op):
@ -79,9 +83,9 @@ class FP16Utils(object):
output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError("Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad".format(
input_name))
raise ValueError("Output 'X' of cast_op must be a grad of"
"model param, but {} is not a grad".format(
output_name))
if output_name in reduced_grads_to_param:
continue
if shard.has_param(param_name):
@ -137,10 +141,12 @@ class FP16Utils(object):
type='c_allreduce_max',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_fp32},
attrs={'ring_id': 0,
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Optimize})
comm_op_num = insert_sync_comm_ops(
block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32])
comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32])
block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num,
type='cast',

@ -16,14 +16,19 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
class GradientClipHelper(object):
def __init__(self):
pass
def __init__(self, sharding_ring_id):
self.sharding_ring_id = sharding_ring_id
def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")
def prune_gradient_clip(self, block, shard):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
deperated_vars = set()
deperate_op_idx = set()
for idx, op in enumerate(block.ops):
@ -75,8 +80,10 @@ class GradientClipHelper(object):
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
attrs={
'ring_id': self.sharding_ring_id,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
idx + 1,
type='c_sync_calc_stream',

@ -43,6 +43,7 @@ class ProgramDeps(object):
return None
def _build_deps(self, ):
for var_name in self._start_vars:
self._var_to_use_op[var_name] = []
self._var_to_generate_op[var_name] = []

@ -124,6 +124,14 @@ class Shard(object):
return True
return False
def filter_grads(self, grads):
grads_in_shard = []
for grad in grads:
param = grad.split("@")[0]
if self.has_param(param):
grads_in_shard.append(grad)
return grads_in_shard
class ProgramSegment(object):
def __init__(self, block):

Loading…
Cancel
Save