|
|
|
@ -85,6 +85,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
|
|
|
|
|
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS,
|
|
|
|
|
core.VarDesc.VarType.LOD_TENSOR_ARRAY
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
for in_name in op.input_names:
|
|
|
|
|
if src_dtype == core.VarDesc.VarType.FP32 and op.type == 'batch_norm':
|
|
|
|
|
if in_name != 'X':
|
|
|
|
@ -94,22 +95,25 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
|
|
|
|
|
if in_var.type not in valid_types:
|
|
|
|
|
continue
|
|
|
|
|
if in_var.dtype == src_dtype:
|
|
|
|
|
out_var = block.create_var(
|
|
|
|
|
name=in_var.name + \
|
|
|
|
|
'.cast_' + _dtype_to_str(dest_dtype),
|
|
|
|
|
dtype=dest_dtype,
|
|
|
|
|
persistable=False,
|
|
|
|
|
stop_gradient=False)
|
|
|
|
|
block._insert_op(
|
|
|
|
|
idx,
|
|
|
|
|
type="cast",
|
|
|
|
|
inputs={"X": in_var},
|
|
|
|
|
outputs={"Out": out_var},
|
|
|
|
|
attrs={
|
|
|
|
|
"in_dtype": in_var.dtype,
|
|
|
|
|
"out_dtype": out_var.dtype
|
|
|
|
|
})
|
|
|
|
|
num_cast_ops += 1
|
|
|
|
|
cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
|
|
|
|
|
out_var = block.vars.get(cast_name)
|
|
|
|
|
if out_var is None or out_var.dtype != dest_dtype:
|
|
|
|
|
out_var = block.create_var(
|
|
|
|
|
name=cast_name,
|
|
|
|
|
dtype=dest_dtype,
|
|
|
|
|
persistable=False,
|
|
|
|
|
stop_gradient=False)
|
|
|
|
|
|
|
|
|
|
block._insert_op(
|
|
|
|
|
idx,
|
|
|
|
|
type="cast",
|
|
|
|
|
inputs={"X": in_var},
|
|
|
|
|
outputs={"Out": out_var},
|
|
|
|
|
attrs={
|
|
|
|
|
"in_dtype": in_var.dtype,
|
|
|
|
|
"out_dtype": out_var.dtype
|
|
|
|
|
})
|
|
|
|
|
num_cast_ops += 1
|
|
|
|
|
_rename_arg(op, in_var.name, out_var.name)
|
|
|
|
|
else:
|
|
|
|
|
if op.has_attr('in_dtype'):
|
|
|
|
@ -155,6 +159,18 @@ def find_true_prev_op(ops, cur_op, var_name):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_in_black_varnames(op, amp_lists):
|
|
|
|
|
for in_name in op.input_arg_names:
|
|
|
|
|
if in_name in amp_lists.black_varnames:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
for out_name in op.output_arg_names:
|
|
|
|
|
if out_name in amp_lists.black_varnames:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rewrite_program(main_prog, amp_lists):
|
|
|
|
|
"""
|
|
|
|
|
Traverse all ops in current block and insert cast op according to
|
|
|
|
@ -180,6 +196,11 @@ def rewrite_program(main_prog, amp_lists):
|
|
|
|
|
white_op_set = set()
|
|
|
|
|
black_op_set = set()
|
|
|
|
|
for op in ops:
|
|
|
|
|
if amp_lists.black_varnames is not None and _is_in_black_varnames(
|
|
|
|
|
op, amp_lists):
|
|
|
|
|
black_op_set.add(op)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if op.type in amp_lists.black_list:
|
|
|
|
|
black_op_set.add(op)
|
|
|
|
|
elif op.type in amp_lists.white_list:
|
|
|
|
|