|
|
@ -71,7 +71,11 @@ class FP16Utils(object):
|
|
|
|
return inserted_op_num
|
|
|
|
return inserted_op_num
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def prune_fp16(block, shard, reduced_grads_to_param, nrings):
|
|
|
|
def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard
|
|
|
|
|
|
|
|
2. revise amp inifine grad checking for sharding
|
|
|
|
|
|
|
|
"""
|
|
|
|
# remove cast
|
|
|
|
# remove cast
|
|
|
|
for idx, op in reversed(list(enumerate(block.ops))):
|
|
|
|
for idx, op in reversed(list(enumerate(block.ops))):
|
|
|
|
if not FP16Utils.is_fp32_cast_op(block, op):
|
|
|
|
if not FP16Utils.is_fp32_cast_op(block, op):
|
|
|
@ -79,9 +83,9 @@ class FP16Utils(object):
|
|
|
|
output_name = op.desc.output_arg_names()[0]
|
|
|
|
output_name = op.desc.output_arg_names()[0]
|
|
|
|
param_name = output_name.strip("@GRAD")
|
|
|
|
param_name = output_name.strip("@GRAD")
|
|
|
|
if param_name not in shard.global_params:
|
|
|
|
if param_name not in shard.global_params:
|
|
|
|
raise ValueError("Input 'X' of check_finite_and_unscale must"
|
|
|
|
raise ValueError("Output 'X' of cast_op must be a grad of"
|
|
|
|
"be grads, but {} is not a grad".format(
|
|
|
|
"model param, but {} is not a grad".format(
|
|
|
|
input_name))
|
|
|
|
output_name))
|
|
|
|
if output_name in reduced_grads_to_param:
|
|
|
|
if output_name in reduced_grads_to_param:
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
if shard.has_param(param_name):
|
|
|
|
if shard.has_param(param_name):
|
|
|
@ -137,10 +141,12 @@ class FP16Utils(object):
|
|
|
|
type='c_allreduce_max',
|
|
|
|
type='c_allreduce_max',
|
|
|
|
inputs={'X': inf_var_fp32},
|
|
|
|
inputs={'X': inf_var_fp32},
|
|
|
|
outputs={'Out': inf_var_fp32},
|
|
|
|
outputs={'Out': inf_var_fp32},
|
|
|
|
attrs={'ring_id': 0,
|
|
|
|
attrs={'ring_id': ring_id,
|
|
|
|
OP_ROLE_KEY: OpRole.Optimize})
|
|
|
|
OP_ROLE_KEY: OpRole.Optimize})
|
|
|
|
comm_op_num = insert_sync_comm_ops(
|
|
|
|
|
|
|
|
block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32])
|
|
|
|
comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
|
|
|
|
|
|
|
|
ring_id, [inf_var_fp32])
|
|
|
|
|
|
|
|
|
|
|
|
block._insert_op_without_sync(
|
|
|
|
block._insert_op_without_sync(
|
|
|
|
update_loss_scaling_op_idx + 3 + comm_op_num,
|
|
|
|
update_loss_scaling_op_idx + 3 + comm_op_num,
|
|
|
|
type='cast',
|
|
|
|
type='cast',
|
|
|
|