|
|
|
@ -32,7 +32,21 @@ bool CheckValueNodeInputOfMul(const AnfNodePtr &node) {
|
|
|
|
|
std::vector<size_t> mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
|
|
|
|
return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1);
|
|
|
|
|
}
|
|
|
|
|
void AddInputToOutput(const FuncGraphPtr &func_graph, const CNodePtr &old_cnode, const AnfNodePtr &new_node,
|
|
|
|
|
std::vector<AnfNodePtr> *new_outputs) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(old_cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_outputs);
|
|
|
|
|
auto node_to_output = old_cnode->input(kAccumIndex + 1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_to_output);
|
|
|
|
|
AbstractBasePtrList abstract_list{old_cnode->abstract(), node_to_output->abstract()};
|
|
|
|
|
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
|
|
|
|
new_node->set_abstract(abstract_tuple);
|
|
|
|
|
// Create Output
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, new_node, kFusedMulApplyMomentumOutputNum, new_outputs);
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
const BaseRef MomentumLossscaleFusion::DefinePattern() const {
|
|
|
|
|
VarPtr Xs = std::make_shared<SeqVar>();
|
|
|
|
|
VarPtr X0 = std::make_shared<Var>();
|
|
|
|
@ -80,15 +94,10 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph
|
|
|
|
|
input_names_value[3] = "x1";
|
|
|
|
|
input_names_value.emplace_back("x2");
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node);
|
|
|
|
|
auto node_to_output = cnode->input(kAccumIndex + 1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_to_output);
|
|
|
|
|
AbstractBasePtrList abstract_list{node->abstract(), node_to_output->abstract()};
|
|
|
|
|
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
|
|
|
|
new_node->set_abstract(abstract_tuple);
|
|
|
|
|
new_node->set_scope(node->scope());
|
|
|
|
|
// Create Output
|
|
|
|
|
// Create Outputs
|
|
|
|
|
std::vector<AnfNodePtr> new_outputs;
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, new_node, kFusedMulApplyMomentumOutputNum, &new_outputs);
|
|
|
|
|
AddInputToOutput(func_graph, cnode, new_node, &new_outputs);
|
|
|
|
|
if (new_outputs.size() != kFusedMulApplyMomentumOutputNum) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failed to create outputs of " << new_node->DebugString();
|
|
|
|
|
}
|
|
|
|
|