fix some cast error. (#26884)

revert-26856-strategy_example2
Zhen Wang 4 years ago committed by GitHub
parent 6a09b8f1cb
commit bcdbac1753
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -74,7 +74,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
continue
for in_var_name in op.input(in_name):
in_var = block.var(in_var_name)
if in_var.type not in valid_types:
if in_var.type not in valid_types or in_var.dtype == dest_dtype:
continue
if in_var.dtype == src_dtype:
cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
@ -84,7 +84,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
name=cast_name,
dtype=dest_dtype,
persistable=False,
stop_gradient=False)
stop_gradient=in_var.stop_gradient)
block._insert_op(
idx,
@ -100,7 +100,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32:
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16:
for out_name in op.output_names:
if op.type == 'batch_norm' and out_name != 'Y':
continue

Loading…
Cancel
Save