|
|
|
@ -36,92 +36,6 @@ def append_cast_op(i, o, prog):
|
|
|
|
|
"out_dtype": o.dtype})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def copy_to_master_param(p, block):
|
|
|
|
|
"""
|
|
|
|
|
New a master parameter for the input parameter, and they two share the same
|
|
|
|
|
attributes except the data type.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
p(Parameter): The input parameter in float16.
|
|
|
|
|
block(Program): The block in which the parameter is.
|
|
|
|
|
"""
|
|
|
|
|
v = block.vars.get(p.name, None)
|
|
|
|
|
if v is None:
|
|
|
|
|
raise ValueError("no param name %s found!" % p.name)
|
|
|
|
|
new_p = framework.Parameter(
|
|
|
|
|
block=block,
|
|
|
|
|
shape=v.shape,
|
|
|
|
|
dtype=core.VarDesc.VarType.FP32,
|
|
|
|
|
type=v.type,
|
|
|
|
|
lod_level=v.lod_level,
|
|
|
|
|
stop_gradient=p.stop_gradient,
|
|
|
|
|
trainable=p.trainable,
|
|
|
|
|
optimize_attr=p.optimize_attr,
|
|
|
|
|
regularizer=p.regularizer,
|
|
|
|
|
gradient_clip_attr=p.gradient_clip_attr,
|
|
|
|
|
error_clip=p.error_clip,
|
|
|
|
|
name=v.name + ".master")
|
|
|
|
|
return new_p
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_master_params_grads(params_grads, main_prog, startup_prog,
|
|
|
|
|
loss_scaling):
|
|
|
|
|
"""
|
|
|
|
|
Create master parameters and gradients in float32 from params and grads
|
|
|
|
|
in float16.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
params_grads (list): A list of tuple (parameter, gradient) in float32.
|
|
|
|
|
main_prog (Program): The main program for training.
|
|
|
|
|
startup_prog (Program): The startup program to initialize all parameters.
|
|
|
|
|
loss_scaling (float): The factor to scale loss and gradients.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A list of master parameters and gradients.
|
|
|
|
|
"""
|
|
|
|
|
master_params_grads = []
|
|
|
|
|
for p, g in params_grads:
|
|
|
|
|
# create master parameters
|
|
|
|
|
with main_prog._optimized_guard([p, g]):
|
|
|
|
|
# create master parameters
|
|
|
|
|
master_param = copy_to_master_param(p, main_prog.global_block())
|
|
|
|
|
startup_master_param = startup_prog.global_block()._clone_variable(
|
|
|
|
|
master_param)
|
|
|
|
|
startup_p = startup_prog.global_block().var(p.name)
|
|
|
|
|
# fp16 -> fp32
|
|
|
|
|
append_cast_op(startup_p, startup_master_param, startup_prog)
|
|
|
|
|
# cast fp16 gradients to fp32 before apply gradients
|
|
|
|
|
if g.name.find("batch_norm") > -1:
|
|
|
|
|
scaled_g = g / loss_scaling
|
|
|
|
|
master_params_grads.append([p, scaled_g])
|
|
|
|
|
continue
|
|
|
|
|
master_grad = layers.cast(x=g, dtype="float32")
|
|
|
|
|
master_grad = master_grad / loss_scaling
|
|
|
|
|
master_params_grads.append([master_param, master_grad])
|
|
|
|
|
|
|
|
|
|
return master_params_grads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def master_param_to_train_param(master_params_grads, params_grads, main_prog):
|
|
|
|
|
"""
|
|
|
|
|
Convert master master parameters and gradients in float32 to parameters and
|
|
|
|
|
gradients in float16 for forward computation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
master_params_grads (list): A list of master parameters and gradients in
|
|
|
|
|
float32.
|
|
|
|
|
params_grads (list): A list of parameters and gradients in float16.
|
|
|
|
|
main_prog (list): The main program for execution.
|
|
|
|
|
"""
|
|
|
|
|
for idx, m_p_g in enumerate(master_params_grads):
|
|
|
|
|
train_p, _ = params_grads[idx]
|
|
|
|
|
if train_p.name.find("batch_norm") > -1:
|
|
|
|
|
continue
|
|
|
|
|
with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]):
|
|
|
|
|
# fp32 -> fp16
|
|
|
|
|
append_cast_op(m_p_g[0], train_p, main_prog)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _rename_arg(op, old_name, new_name):
|
|
|
|
|
"""
|
|
|
|
|
If an op has old_name input and output, rename these input
|
|
|
|
@ -172,6 +86,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
|
|
|
|
|
core.VarDesc.VarType.LOD_TENSOR_ARRAY
|
|
|
|
|
]
|
|
|
|
|
for in_name in op.input_names:
|
|
|
|
|
if src_dtype == core.VarDesc.VarType.FP32 and op.type == 'batch_norm':
|
|
|
|
|
if in_name != 'X':
|
|
|
|
|
continue
|
|
|
|
|
for in_var_name in op.input(in_name):
|
|
|
|
|
in_var = block.var(in_var_name)
|
|
|
|
|
if in_var.type not in valid_types:
|
|
|
|
@ -197,16 +114,18 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
|
|
|
|
|
else:
|
|
|
|
|
if op.has_attr('in_dtype'):
|
|
|
|
|
op._set_attr('in_dtype', dest_dtype)
|
|
|
|
|
if src_dtype == core.VarDesc.VarType.FP16:
|
|
|
|
|
if src_dtype == core.VarDesc.VarType.FP32:
|
|
|
|
|
for out_name in op.output_names:
|
|
|
|
|
if op.type == 'batch_norm' and out_name != 'Y':
|
|
|
|
|
continue
|
|
|
|
|
for out_var_name in op.output(out_name):
|
|
|
|
|
out_var = block.var(out_var_name)
|
|
|
|
|
if out_var.type not in valid_types:
|
|
|
|
|
continue
|
|
|
|
|
if out_var.dtype == core.VarDesc.VarType.FP16:
|
|
|
|
|
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
|
|
|
|
|
if out_var.dtype == core.VarDesc.VarType.FP32:
|
|
|
|
|
out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
|
|
|
|
|
if op.has_attr('out_dtype'):
|
|
|
|
|
op._set_attr('out_dtype', core.VarDesc.VarType.FP32)
|
|
|
|
|
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
|
|
|
|
|
return num_cast_ops
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|