|
|
@ -156,26 +156,8 @@ def append_input_output(block, op_proto, np_list, is_input, dtype):
|
|
|
|
return var_dict
|
|
|
|
return var_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def var_cast(block, input):
|
|
|
|
|
|
|
|
if input.dtype == core.VarDesc.VarType.FP32 or input.dtype == core.VarDesc.VarType.FP32:
|
|
|
|
|
|
|
|
return input
|
|
|
|
|
|
|
|
out = block.create_var(dtype="float32", shape=[1])
|
|
|
|
|
|
|
|
op = block.append_op(
|
|
|
|
|
|
|
|
inputs={"X": input},
|
|
|
|
|
|
|
|
outputs={"Out": out},
|
|
|
|
|
|
|
|
type='cast',
|
|
|
|
|
|
|
|
attrs={
|
|
|
|
|
|
|
|
'out_dtype': core.VarDesc.VarType.FP32,
|
|
|
|
|
|
|
|
'in_dtype': input.dtype
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
op.desc.infer_var_type(block.desc)
|
|
|
|
|
|
|
|
op.desc.infer_shape(block.desc)
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def append_loss_ops(block, output_names):
|
|
|
|
def append_loss_ops(block, output_names):
|
|
|
|
mean_inputs = list(map(block.var, output_names))
|
|
|
|
mean_inputs = list(map(block.var, output_names))
|
|
|
|
mean_inputs = [var_cast(block, x) for x in mean_inputs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(mean_inputs) == 1:
|
|
|
|
if len(mean_inputs) == 1:
|
|
|
|
loss = block.create_var(dtype=mean_inputs[0].dtype, shape=[1])
|
|
|
|
loss = block.create_var(dtype=mean_inputs[0].dtype, shape=[1])
|
|
|
|