|
|
|
@ -1156,14 +1156,13 @@ class OutScaleForTrainingPass(object):
|
|
|
|
|
assert isinstance(graph,
|
|
|
|
|
IrGraph), 'graph must be the instance of IrGraph.'
|
|
|
|
|
self._is_test = graph.is_test()
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
name = op_node.name()
|
|
|
|
|
if name in self._teller_set:
|
|
|
|
|
if len(op_node.output_arg_names()) != 1:
|
|
|
|
|
continue
|
|
|
|
|
in_node = graph._find_node_by_name(
|
|
|
|
|
op_node.outputs, op_node.output_arg_names()[0])
|
|
|
|
|
target_ops = []
|
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
|
if op.name() in self._teller_set:
|
|
|
|
|
target_ops.append(op)
|
|
|
|
|
for op in target_ops:
|
|
|
|
|
for output_var_name in _get_op_output_var_names(op):
|
|
|
|
|
in_node = graph._find_node_by_name(op.outputs, output_var_name)
|
|
|
|
|
out_node = graph.create_var_node_from_desc(in_node.var())
|
|
|
|
|
scale_node = graph.create_persistable_node(
|
|
|
|
|
name=self._scale_name(in_node.name()),
|
|
|
|
@ -1263,13 +1262,13 @@ class OutScaleForInferencePass(object):
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(graph,
|
|
|
|
|
IrGraph), 'graph must be the instance of IrGraph.'
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
name = op_node.name()
|
|
|
|
|
if name in self._teller_set:
|
|
|
|
|
if len(op_node.output_arg_names()) != 1:
|
|
|
|
|
continue
|
|
|
|
|
scale_name = self._scale_name(op_node.output_arg_names()[0])
|
|
|
|
|
op_nodes = graph.all_op_nodes()
|
|
|
|
|
for op_node in op_nodes:
|
|
|
|
|
if op_node.name() in self._teller_set:
|
|
|
|
|
output_var_name = _get_op_output_var_names(op_node)
|
|
|
|
|
assert len(output_var_name) == 1, "Only support collecting " \
|
|
|
|
|
"output for op that only has an activation output for now."
|
|
|
|
|
scale_name = self._scale_name(output_var_name[0])
|
|
|
|
|
scale_v = np.array(
|
|
|
|
|
self._scope.find_var(scale_name).get_tensor())[0]
|
|
|
|
|
op_node.op()._set_attr("out_threshold", float(scale_v))
|
|
|
|
|