You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py

275 lines
9.4 KiB

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import core
from functools import reduce
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
import re
def check_broadcast(block):
"""
if a var is broadcasted, it should have a sync_comm before
this var is used, if not, raise error.
if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error.
"""
broadcast_vars = {}
for idx, op in enumerate(block.ops):
if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if var_name in broadcast_vars:
raise ValueError("var_name areadly exist: {}"
"the old pos is {}, the new pos is {}".
format(var_name, broadcast_vars[var_name][
"broadcast_pos"], idx))
broadcast_vars[var_name] = {
"fill_constant_pos": -1,
"broadcast_pos": idx,
}
for idx, op in enumerate(block.ops):
if op.type == "fill_constant":
var_name = op.desc.output_arg_names()[0]
if var_name in broadcast_vars:
broadcast_vars[var_name]["fill_constant_pos"] = idx
continue
last_sync_comm_op_idx = -1
last_sync_calc_op_idx = -1
for idx, op in enumerate(block.ops):
if op.type == "c_sync_comm_stream":
last_sync_comm_op_idx = idx
continue
if op.type == "c_sync_calc_stream":
last_sync_calc_op_idx = idx
continue
if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if broadcast_vars[var_name]["fill_constant_pos"] != -1:
assert (last_sync_calc_op_idx != -1)
assert (broadcast_vars[var_name]["fill_constant_pos"] <
last_sync_calc_op_idx)
assert (last_sync_calc_op_idx < idx)
continue
for input_name in op.desc.input_arg_names():
if input_name in broadcast_vars:
assert (broadcast_vars[input_name]["broadcast_pos"] != -1)
assert (broadcast_vars[input_name]["broadcast_pos"] <
last_sync_comm_op_idx)
assert (last_sync_comm_op_idx < idx)
return
def check_allreduce_sum(block):
"""
if a Var is allreduced, the op order should be:
- 0: op that generate Var
- 1: sync_calc
- 2: allreduce_sum op
- 3: sync_comm
- 4: op that use Var
"""
var_status = {}
for op in block.ops:
if op.type == "c_allreduce_sum":
var_name = op.desc.input_arg_names()[0]
var_status[var_name] = -1
for op in block.ops:
if op.type == "c_sync_calc_stream":
for var_name in var_status:
if var_name in var_status and var_status[var_name] == 0:
var_status[var_name] = 1
elif op.type == "c_allreduce_sum":
var_name = op.desc.input_arg_names()[0]
if var_status[var_name] == -1:
raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name))
if var_status[var_name] == 0:
raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name))
assert (var_status[var_name] == 1)
var_status[var_name] = 2
elif op.type == "c_sync_comm_stream":
for var_name in op.desc.input_arg_names():
if var_name in var_status and var_status[var_name] == 2:
var_status[var_name] = 3
else:
for input_name in op.desc.input_arg_names():
if input_name in var_status:
if var_status[input_name] != 3:
raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".format(
var_name))
for output_name in op.desc.output_arg_names():
if output_name in var_status and \
var_status[output_name] == -1:
var_status[output_name] = 0
return
def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
"""
_insert_sync_calc_op
"""
op_role = block.ops[insert_idx].attr('op_role')
block._insert_op_without_sync(
insert_idx,
type='c_sync_calc_stream',
inputs={'X': calc_dep_vars},
outputs={'Out': calc_dep_vars},
attrs={OP_ROLE_KEY: op_role})
return
def insert_sync_comm_ops(block, insert_idx, nrings, comm_dep_vars):
"""
_insert_sync_comm_ops
"""
op_role = block.ops[insert_idx].attr('op_role')
for i in range(nrings):
block._insert_op_without_sync(
insert_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': i,
OP_ROLE_KEY: op_role})
return nrings
def insert_fill_constant_ops(block, insert_idx, fill_constant_vars):
"""
_add_fill_constant_ops
"""
op_role = block.ops[insert_idx].attr('op_role')
for broadcast_name in fill_constant_vars:
broadcast_var = block.var(broadcast_name)
block._insert_op_without_sync(
insert_idx,
type="fill_constant",
outputs={"Out": broadcast_var.name},
attrs={
"shape": broadcast_var.shape,
"dtype": broadcast_var.dtype,
"value": 0.0,
OP_ROLE_KEY: op_role
})
return
def insert_cast_ops(block, insert_idx, cast_ops):
"""
_add_cast_ops
"""
op_role = block.ops[insert_idx].attr('op_role')
for fp16_name, fp32_name in cast_ops.items():
block._insert_op_without_sync(
insert_idx,
type="cast",
inputs={"X": fp32_name},
outputs={"Out": fp16_name},
attrs={
"in_dtype": core.VarDesc.VarType.FP32,
"out_dtype": core.VarDesc.VarType.FP16,
OP_ROLE_KEY: op_role
})
return
def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars):
"""
_add_allreduce_ops
"""
ring_id = -1
for var in allreduce_vars:
ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync(
insert_idx,
type='c_allreduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
return
def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root):
"""
_add_broadcast_ops
"""
ring_id = -1
op_role = block.ops[insert_idx].attr('op_role')
for broadcast_name, root_device in broadcast2root:
ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync(
insert_idx,
type='c_broadcast',
inputs={'X': broadcast_name},
outputs={'Out': broadcast_name},
attrs={
'ring_id': ring_id,
'root': root_device,
OP_ROLE_KEY: op_role
})
return
DtypeToSize = {
core.VarDesc.VarType.FP16: 2,
core.VarDesc.VarType.FP32: 4,
core.VarDesc.VarType.FP64: 8,
core.VarDesc.VarType.INT16: 2,
core.VarDesc.VarType.INT32: 4,
core.VarDesc.VarType.INT64: 8,
core.VarDesc.VarType.BOOL: 1,
core.VarDesc.VarType.UINT8: 1,
}
def get_var_size(param):
"""
input:
- param: var
return:
var size in Bytes
"""
assert -1 not in param.shape
return reduce(lambda x, y: x * y,
param.shape) * DtypeToSize[param.dtype] / 1024.0 / 1024.0
def insert_scale_loss_grad_ops(block, scale=1.0):
'''
In order to keep the learning rate consistent in different numbers of
training workers, we scale the loss grad by the number of workers
'''
for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]]
block._insert_op_without_sync(
idx + 1,
type='scale',
inputs={'X': loss_grad_var},
outputs={'Out': loss_grad_var},
attrs={'scale': scale,
OP_ROLE_KEY: OpRole.Backward})