|
|
@ -23,6 +23,7 @@
|
|
|
|
namespace mindspore {
|
|
|
|
namespace mindspore {
|
|
|
|
namespace opt {
|
|
|
|
namespace opt {
|
|
|
|
namespace {
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
constexpr size_t kAccumIndex = 1;
|
|
|
|
bool CheckValueNodeInputOfMul(const AnfNodePtr &node) {
|
|
|
|
bool CheckValueNodeInputOfMul(const AnfNodePtr &node) {
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
if (!node->isa<ValueNode>()) {
|
|
|
|
if (!node->isa<ValueNode>()) {
|
|
|
@ -79,9 +80,19 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph
|
|
|
|
input_names_value[3] = "x1";
|
|
|
|
input_names_value[3] = "x1";
|
|
|
|
input_names_value.emplace_back("x2");
|
|
|
|
input_names_value.emplace_back("x2");
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node);
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node);
|
|
|
|
new_node->set_abstract(node->abstract());
|
|
|
|
auto node_to_output = cnode->input(kAccumIndex + 1);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_to_output);
|
|
|
|
|
|
|
|
AbstractBasePtrList abstract_list{node->abstract(), node_to_output->abstract()};
|
|
|
|
|
|
|
|
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
|
|
|
|
|
|
|
new_node->set_abstract(abstract_tuple);
|
|
|
|
new_node->set_scope(node->scope());
|
|
|
|
new_node->set_scope(node->scope());
|
|
|
|
return new_node;
|
|
|
|
// Create Output
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> new_outputs;
|
|
|
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, new_node, kFusedMulApplyMomentumOutputNum, &new_outputs);
|
|
|
|
|
|
|
|
if (new_outputs.size() != kFusedMulApplyMomentumOutputNum) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failed to create outputs of " << new_node->DebugString();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return new_outputs[0];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} // namespace opt
|
|
|
|
} // namespace opt
|
|
|
|
} // namespace mindspore
|
|
|
|
} // namespace mindspore
|
|
|
|