|
|
|
@ -74,7 +74,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
|
|
|
|
|
continue
|
|
|
|
|
for in_var_name in op.input(in_name):
|
|
|
|
|
in_var = block.var(in_var_name)
|
|
|
|
|
if in_var.type not in valid_types:
|
|
|
|
|
if in_var.type not in valid_types or in_var.dtype == dest_dtype:
|
|
|
|
|
continue
|
|
|
|
|
if in_var.dtype == src_dtype:
|
|
|
|
|
cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
|
|
|
|
@ -84,7 +84,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
|
|
|
|
|
name=cast_name,
|
|
|
|
|
dtype=dest_dtype,
|
|
|
|
|
persistable=False,
|
|
|
|
|
stop_gradient=False)
|
|
|
|
|
stop_gradient=in_var.stop_gradient)
|
|
|
|
|
|
|
|
|
|
block._insert_op(
|
|
|
|
|
idx,
|
|
|
|
@ -100,7 +100,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
|
|
|
|
|
else:
|
|
|
|
|
if op.has_attr('in_dtype'):
|
|
|
|
|
op._set_attr('in_dtype', dest_dtype)
|
|
|
|
|
if src_dtype == core.VarDesc.VarType.FP32:
|
|
|
|
|
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16:
|
|
|
|
|
for out_name in op.output_names:
|
|
|
|
|
if op.type == 'batch_norm' and out_name != 'Y':
|
|
|
|
|
continue
|
|
|
|
|