|
|
@ -78,52 +78,137 @@ def check_broadcast(block):
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_allreduce_sum(block):
|
|
|
|
def check_allreduce_sum(block, shard, dp_ring_id=-1):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if a Var is allreduced, the op order should be:
|
|
|
|
the op order should be:
|
|
|
|
|
|
|
|
grad:
|
|
|
|
- 0: op that generate Var
|
|
|
|
- 0: op that generate Var
|
|
|
|
- 1: sync_calc
|
|
|
|
- 1: sync_calc
|
|
|
|
- 2: allreduce_sum op
|
|
|
|
- 2: allreduce_sum_sharding
|
|
|
|
- 3: sync_comm
|
|
|
|
- 3: sync_comm
|
|
|
|
- 4: op that use Var
|
|
|
|
- 4: allreuce_sum_dp (dp_grads)
|
|
|
|
|
|
|
|
- 5: sync_comm (dp_grads)
|
|
|
|
|
|
|
|
- 6: op that use Var (dp_grads & sum)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
var_status = {}
|
|
|
|
vars_status = {}
|
|
|
|
for op in block.ops:
|
|
|
|
dp_grads_status = {}
|
|
|
|
|
|
|
|
idx_last_grad_allreduce = -1
|
|
|
|
|
|
|
|
idx_amp_allreduce = -1
|
|
|
|
|
|
|
|
idx_gradient_clip_allreduce = -1
|
|
|
|
|
|
|
|
for idx, op in enumerate(block.ops):
|
|
|
|
if op.type == "c_allreduce_sum":
|
|
|
|
if op.type == "c_allreduce_sum":
|
|
|
|
|
|
|
|
ring_id = op.desc.attr("ring_id")
|
|
|
|
var_name = op.desc.input_arg_names()[0]
|
|
|
|
var_name = op.desc.input_arg_names()[0]
|
|
|
|
var_status[var_name] = -1
|
|
|
|
param = var_name.split("@")[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert 'sum' in var_name or ("@GRAD" in var_name)
|
|
|
|
|
|
|
|
if 'sum' in var_name or (not shard.has_param(param)):
|
|
|
|
|
|
|
|
vars_status[var_name] = -1
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
dp_grads_status[var_name] = -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ring_id != 0:
|
|
|
|
|
|
|
|
assert shard.has_param(param)
|
|
|
|
|
|
|
|
assert ring_id == dp_ring_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "sum" in var_name:
|
|
|
|
|
|
|
|
idx_amp_allreduce = idx
|
|
|
|
|
|
|
|
elif "@GRAD":
|
|
|
|
|
|
|
|
idx_last_grad_allreduce = idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if op.type == "c_allreduce_max":
|
|
|
|
|
|
|
|
idx_gradient_clip_allreduce = idx
|
|
|
|
|
|
|
|
|
|
|
|
for op in block.ops:
|
|
|
|
for op in block.ops:
|
|
|
|
if op.type == "c_sync_calc_stream":
|
|
|
|
if op.type == "c_sync_calc_stream":
|
|
|
|
for var_name in var_status:
|
|
|
|
for var_name in vars_status:
|
|
|
|
if var_name in var_status and var_status[var_name] == 0:
|
|
|
|
if var_name in vars_status and vars_status[var_name] == 0:
|
|
|
|
var_status[var_name] = 1
|
|
|
|
vars_status[var_name] = 1
|
|
|
|
|
|
|
|
for var_name in dp_grads_status:
|
|
|
|
|
|
|
|
if var_name in dp_grads_status and dp_grads_status[
|
|
|
|
|
|
|
|
var_name] == 0:
|
|
|
|
|
|
|
|
dp_grads_status[var_name] = 1
|
|
|
|
|
|
|
|
|
|
|
|
elif op.type == "c_allreduce_sum":
|
|
|
|
elif op.type == "c_allreduce_sum":
|
|
|
|
var_name = op.desc.input_arg_names()[0]
|
|
|
|
var_name = op.desc.input_arg_names()[0]
|
|
|
|
if var_status[var_name] == -1:
|
|
|
|
ring_id = op.desc.attr("ring_id")
|
|
|
|
|
|
|
|
if ring_id == 0:
|
|
|
|
|
|
|
|
if var_name in vars_status:
|
|
|
|
|
|
|
|
_status = vars_status[var_name]
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
_status = dp_grads_status[var_name]
|
|
|
|
|
|
|
|
if _status == -1:
|
|
|
|
raise ValueError("{} is not generated, but you are"
|
|
|
|
raise ValueError("{} is not generated, but you are"
|
|
|
|
"trying to all-reduce it".format(var_name))
|
|
|
|
"trying to all-reduce it".format(var_name))
|
|
|
|
if var_status[var_name] == 0:
|
|
|
|
if _status == 0:
|
|
|
|
raise ValueError("There should be a sync_calc op "
|
|
|
|
raise ValueError("There should be a sync_calc op "
|
|
|
|
"after generate Var: {} and before the"
|
|
|
|
"after generate Var: {} and before the"
|
|
|
|
"c_allreduce_sum op".format(var_name))
|
|
|
|
"c_allreduce_sum op".format(var_name))
|
|
|
|
assert (var_status[var_name] == 1)
|
|
|
|
assert (_status == 1)
|
|
|
|
var_status[var_name] = 2
|
|
|
|
if var_name in vars_status:
|
|
|
|
|
|
|
|
vars_status[var_name] = 2
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
dp_grads_status[var_name] = 2
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
assert ring_id == dp_ring_id
|
|
|
|
|
|
|
|
param = var_name.split("@")[0]
|
|
|
|
|
|
|
|
assert shard.has_param(param)
|
|
|
|
|
|
|
|
assert dp_grads_status[var_name] == 3
|
|
|
|
|
|
|
|
dp_grads_status[var_name] = 4
|
|
|
|
|
|
|
|
|
|
|
|
elif op.type == "c_sync_comm_stream":
|
|
|
|
elif op.type == "c_sync_comm_stream":
|
|
|
|
|
|
|
|
var_name = op.desc.input_arg_names()[0]
|
|
|
|
|
|
|
|
ring_id = op.desc.attr("ring_id")
|
|
|
|
|
|
|
|
if ring_id == 0:
|
|
|
|
for var_name in op.desc.input_arg_names():
|
|
|
|
for var_name in op.desc.input_arg_names():
|
|
|
|
if var_name in var_status and var_status[var_name] == 2:
|
|
|
|
if var_name in vars_status:
|
|
|
|
var_status[var_name] = 3
|
|
|
|
assert vars_status[var_name] == 2
|
|
|
|
|
|
|
|
vars_status[var_name] = 3
|
|
|
|
|
|
|
|
elif var_name in dp_grads_status:
|
|
|
|
|
|
|
|
assert dp_grads_status[var_name] == 2
|
|
|
|
|
|
|
|
dp_grads_status[var_name] = 3
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
for var_name in op.desc.input_arg_names():
|
|
|
|
|
|
|
|
param = var_name.split("@")[0]
|
|
|
|
|
|
|
|
assert ring_id == dp_ring_id
|
|
|
|
|
|
|
|
assert shard.has_param(param)
|
|
|
|
|
|
|
|
assert dp_grads_status[var_name] == 4
|
|
|
|
|
|
|
|
dp_grads_status[var_name] = 5
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
for input_name in op.desc.input_arg_names():
|
|
|
|
for input_name in op.desc.input_arg_names():
|
|
|
|
if input_name in var_status:
|
|
|
|
if input_name in vars_status:
|
|
|
|
if var_status[input_name] != 3:
|
|
|
|
if vars_status[input_name] != 3:
|
|
|
|
raise ValueError("There should be a sync_comm op "
|
|
|
|
raise ValueError("There should be a sync_comm op "
|
|
|
|
"after allreduce the Var: {}".format(
|
|
|
|
"after allreduce the Var: {}".format(
|
|
|
|
var_name))
|
|
|
|
input_name))
|
|
|
|
|
|
|
|
if input_name in dp_grads_status:
|
|
|
|
|
|
|
|
if dp_ring_id == -1:
|
|
|
|
|
|
|
|
if dp_grads_status[input_name] != 3:
|
|
|
|
|
|
|
|
raise ValueError("There should be a sync_comm op "
|
|
|
|
|
|
|
|
"after allreduce the Var: {}".
|
|
|
|
|
|
|
|
format(input_name))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
if dp_grads_status[input_name] != 5:
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
"The grad in shard should be allreduce and sync"
|
|
|
|
|
|
|
|
"twice before usage {}".format(input_name))
|
|
|
|
|
|
|
|
|
|
|
|
for output_name in op.desc.output_arg_names():
|
|
|
|
for output_name in op.desc.output_arg_names():
|
|
|
|
if output_name in var_status and \
|
|
|
|
if output_name in vars_status and \
|
|
|
|
var_status[output_name] == -1:
|
|
|
|
vars_status[output_name] == -1:
|
|
|
|
var_status[output_name] = 0
|
|
|
|
vars_status[output_name] = 0
|
|
|
|
|
|
|
|
if output_name in dp_grads_status and \
|
|
|
|
|
|
|
|
dp_grads_status[output_name] == -1:
|
|
|
|
|
|
|
|
dp_grads_status[output_name] = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# check sharding with amp
|
|
|
|
|
|
|
|
if idx_amp_allreduce != -1:
|
|
|
|
|
|
|
|
assert idx_amp_allreduce > idx_last_grad_allreduce
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# check sharding with gradient_clip_by_global_norm
|
|
|
|
|
|
|
|
if idx_gradient_clip_allreduce != -1:
|
|
|
|
|
|
|
|
assert idx_gradient_clip_allreduce > idx_last_grad_allreduce
|
|
|
|
|
|
|
|
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -155,20 +240,34 @@ def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def insert_sync_comm_ops(block, insert_idx, nrings, comm_dep_vars):
|
|
|
|
def insert_sync_comm_op(block, insert_idx, ring_id, comm_dep_vars):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
_insert_sync_comm_ops
|
|
|
|
insert sync_comm_op for single var
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
op_role = get_valid_op_role(block, insert_idx)
|
|
|
|
op_role = get_valid_op_role(block, insert_idx)
|
|
|
|
for i in range(nrings):
|
|
|
|
|
|
|
|
block._insert_op_without_sync(
|
|
|
|
block._insert_op_without_sync(
|
|
|
|
insert_idx,
|
|
|
|
insert_idx,
|
|
|
|
type='c_sync_comm_stream',
|
|
|
|
type='c_sync_comm_stream',
|
|
|
|
inputs={'X': comm_dep_vars},
|
|
|
|
inputs={'X': comm_dep_vars},
|
|
|
|
outputs={'Out': comm_dep_vars},
|
|
|
|
outputs={'Out': comm_dep_vars},
|
|
|
|
attrs={'ring_id': i,
|
|
|
|
attrs={'ring_id': ring_id,
|
|
|
|
OP_ROLE_KEY: op_role})
|
|
|
|
OP_ROLE_KEY: op_role})
|
|
|
|
return nrings
|
|
|
|
return 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
insert sync_comm_op for vars
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
op_role = get_valid_op_role(block, insert_idx)
|
|
|
|
|
|
|
|
block._insert_op_without_sync(
|
|
|
|
|
|
|
|
insert_idx,
|
|
|
|
|
|
|
|
type='c_sync_comm_stream',
|
|
|
|
|
|
|
|
inputs={'X': comm_dep_vars},
|
|
|
|
|
|
|
|
outputs={'Out': comm_dep_vars},
|
|
|
|
|
|
|
|
attrs={'ring_id': int(ring_id),
|
|
|
|
|
|
|
|
OP_ROLE_KEY: op_role})
|
|
|
|
|
|
|
|
return 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def insert_fill_constant_ops(block, insert_idx, fill_constant_vars):
|
|
|
|
def insert_fill_constant_ops(block, insert_idx, fill_constant_vars):
|
|
|
@ -210,13 +309,11 @@ def insert_cast_ops(block, insert_idx, cast_ops):
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars):
|
|
|
|
def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
_add_allreduce_ops
|
|
|
|
_add_allreduce_ops
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
ring_id = -1
|
|
|
|
|
|
|
|
for var in allreduce_vars:
|
|
|
|
for var in allreduce_vars:
|
|
|
|
ring_id = (ring_id + 1) % nrings
|
|
|
|
|
|
|
|
block._insert_op_without_sync(
|
|
|
|
block._insert_op_without_sync(
|
|
|
|
insert_idx,
|
|
|
|
insert_idx,
|
|
|
|
type='c_allreduce_sum',
|
|
|
|
type='c_allreduce_sum',
|
|
|
@ -224,17 +321,16 @@ def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars):
|
|
|
|
outputs={'Out': var},
|
|
|
|
outputs={'Out': var},
|
|
|
|
attrs={'ring_id': ring_id,
|
|
|
|
attrs={'ring_id': ring_id,
|
|
|
|
OP_ROLE_KEY: OpRole.Backward})
|
|
|
|
OP_ROLE_KEY: OpRole.Backward})
|
|
|
|
|
|
|
|
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root):
|
|
|
|
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
_add_broadcast_ops
|
|
|
|
_add_broadcast_ops
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
ring_id = -1
|
|
|
|
|
|
|
|
op_role = get_valid_op_role(block, insert_idx)
|
|
|
|
op_role = get_valid_op_role(block, insert_idx)
|
|
|
|
for broadcast_name, root_device in broadcast2root:
|
|
|
|
for broadcast_name, root_device in broadcast2root:
|
|
|
|
ring_id = (ring_id + 1) % nrings
|
|
|
|
|
|
|
|
block._insert_op_without_sync(
|
|
|
|
block._insert_op_without_sync(
|
|
|
|
insert_idx,
|
|
|
|
insert_idx,
|
|
|
|
type='c_broadcast',
|
|
|
|
type='c_broadcast',
|
|
|
@ -245,6 +341,7 @@ def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root):
|
|
|
|
'root': root_device,
|
|
|
|
'root': root_device,
|
|
|
|
OP_ROLE_KEY: op_role
|
|
|
|
OP_ROLE_KEY: op_role
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|