|
|
|
|
@ -35,8 +35,10 @@ namespace internal {
|
|
|
|
|
const char PARAMETER_MODULE[] = "mindspore.common.parameter";
|
|
|
|
|
const char PARAMETER_CLASS[] = "Parameter";
|
|
|
|
|
const char SET_PARAM[] = "__setattr__";
|
|
|
|
|
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph);
|
|
|
|
|
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res);
|
|
|
|
|
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph,
|
|
|
|
|
const FuncGraphPtr &top_graph);
|
|
|
|
|
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph,
|
|
|
|
|
const MatchResultPtr &res);
|
|
|
|
|
void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input,
|
|
|
|
|
bool requires_grad, bool layerwise_parallel);
|
|
|
|
|
|
|
|
|
|
@ -72,7 +74,8 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res)
|
|
|
|
|
return std::make_shared<ValueNode>(input_tensor);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) {
|
|
|
|
|
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg,
|
|
|
|
|
const FuncGraphPtr &top_graph) {
|
|
|
|
|
auto call_pattern = pattern->cast<CallPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(call_pattern);
|
|
|
|
|
auto prim = call_pattern->prim_value();
|
|
|
|
|
@ -81,20 +84,20 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP
|
|
|
|
|
}
|
|
|
|
|
auto prim_pattern = call_pattern->prim_pattern();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim_pattern);
|
|
|
|
|
return ProcessSinglePattern(prim_pattern, res, fg);
|
|
|
|
|
return ProcessSinglePattern(prim_pattern, res, fg, top_graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
|
|
|
|
|
AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &top_graph) {
|
|
|
|
|
auto new_para_pattern = pattern->cast<NewParameterPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_para_pattern);
|
|
|
|
|
if (!new_para_pattern->built()) {
|
|
|
|
|
static int parameter_id = 0;
|
|
|
|
|
auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name() + std::to_string(parameter_id++);
|
|
|
|
|
auto para_node = std::make_shared<Parameter>(func_graph);
|
|
|
|
|
auto para_node = std::make_shared<Parameter>(top_graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(para_node);
|
|
|
|
|
para_node->set_name(para_name);
|
|
|
|
|
// Set function graph
|
|
|
|
|
para_node->set_func_graph(func_graph);
|
|
|
|
|
para_node->set_func_graph(top_graph);
|
|
|
|
|
// Set Debug Info
|
|
|
|
|
auto debug_info = std::make_shared<NodeDebugInfo>(para_name);
|
|
|
|
|
para_node->set_debug_info(debug_info);
|
|
|
|
|
@ -103,7 +106,7 @@ AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &re
|
|
|
|
|
MS_EXCEPTION_IF_NULL(default_value);
|
|
|
|
|
para_node->set_abstract(default_value->ToAbstract()->Broaden());
|
|
|
|
|
res->add_entry(pattern, para_node);
|
|
|
|
|
func_graph->add_parameter(para_node);
|
|
|
|
|
top_graph->add_parameter(para_node);
|
|
|
|
|
// Reflect back to Cell._params
|
|
|
|
|
internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
|
|
|
|
|
new_para_pattern->layerwise_parallel());
|
|
|
|
|
@ -126,7 +129,8 @@ AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) {
|
|
|
|
|
return std::make_shared<ValueNode>(scalar_value_ptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
|
|
|
|
|
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph,
|
|
|
|
|
const FuncGraphPtr &top_graph) {
|
|
|
|
|
auto target_node = res->get_node(pattern);
|
|
|
|
|
if (target_node != nullptr) {
|
|
|
|
|
// If pattern is NewParameter, check whether it shouldn't last and is not built
|
|
|
|
|
@ -141,9 +145,10 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
|
|
|
|
|
} else if (pattern->isa<NewTensor>()) {
|
|
|
|
|
return BuildNewTensor(pattern, res);
|
|
|
|
|
} else if (pattern->isa<Call>()) {
|
|
|
|
|
return BuildPrimitiveValueNode(pattern, res, func_graph);
|
|
|
|
|
return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph);
|
|
|
|
|
} else if (pattern->isa<NewParameter>()) {
|
|
|
|
|
return BuildNewParameter(pattern, res, func_graph);
|
|
|
|
|
// Add new parameter to top graph instead of current graph
|
|
|
|
|
return BuildNewParameter(pattern, res, top_graph);
|
|
|
|
|
} else if (pattern->isa<Imm>()) {
|
|
|
|
|
return BuildImmNode(pattern, res);
|
|
|
|
|
} else {
|
|
|
|
|
@ -154,17 +159,18 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,
|
|
|
|
|
const FuncGraphPtr &func_graph) {
|
|
|
|
|
const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph) {
|
|
|
|
|
if (pattern->isa<Call>()) {
|
|
|
|
|
return BuildPrimitiveValueNode(pattern, res, func_graph);
|
|
|
|
|
return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph);
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res) {
|
|
|
|
|
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph,
|
|
|
|
|
const MatchResultPtr &res) {
|
|
|
|
|
auto target_inputs = pattern->inputs();
|
|
|
|
|
if (target_inputs.size() == 0) {
|
|
|
|
|
auto new_node = ProcessSinglePattern(pattern, res, func_graph);
|
|
|
|
|
auto new_node = ProcessSinglePattern(pattern, res, func_graph, top_graph);
|
|
|
|
|
if (new_node != nullptr) {
|
|
|
|
|
res->add_entry(pattern, new_node);
|
|
|
|
|
}
|
|
|
|
|
@ -172,14 +178,14 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
|
|
|
|
|
}
|
|
|
|
|
// Build up the AnfNode in a recursive manner
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs;
|
|
|
|
|
auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph);
|
|
|
|
|
auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph, top_graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim_value_node);
|
|
|
|
|
new_inputs.push_back(prim_value_node);
|
|
|
|
|
for (auto &iter : target_inputs) {
|
|
|
|
|
if (iter == pattern) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Circle references. Got pattern: " + pattern->unique_name() + "\n";
|
|
|
|
|
}
|
|
|
|
|
auto input_node = BuildTarget(iter, func_graph, res);
|
|
|
|
|
auto input_node = BuildTarget(iter, func_graph, top_graph, res);
|
|
|
|
|
if (input_node == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failed to build input node for pattern : " + iter->unique_name() + "\n";
|
|
|
|
|
}
|
|
|
|
|
@ -240,11 +246,12 @@ void Reset(PatternPtr pattern) {
|
|
|
|
|
|
|
|
|
|
} // namespace internal
|
|
|
|
|
|
|
|
|
|
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) {
|
|
|
|
|
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, const AnfNodePtr &node,
|
|
|
|
|
const MatchResultPtr &res) {
|
|
|
|
|
auto match_res = src_pattern_->match(node);
|
|
|
|
|
if (match_res != nullptr) {
|
|
|
|
|
res->merge(match_res);
|
|
|
|
|
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res);
|
|
|
|
|
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, top_graph, res);
|
|
|
|
|
internal::Reset(dst_pattern());
|
|
|
|
|
return new_node;
|
|
|
|
|
}
|
|
|
|
|
@ -284,16 +291,19 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
|
|
|
|
|
}
|
|
|
|
|
FuncGraphManagerPtr manager = func_graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
manager->AddFuncGraph(func_graph);
|
|
|
|
|
auto graph_nodes_sorted = TopoSort(func_graph->output());
|
|
|
|
|
auto func_graphs = manager->func_graphs();
|
|
|
|
|
bool changes = false;
|
|
|
|
|
|
|
|
|
|
// Traverse once
|
|
|
|
|
for (auto &node : graph_nodes_sorted) {
|
|
|
|
|
AnfNodePtr new_node = Run(func_graph, node, res);
|
|
|
|
|
if (new_node != nullptr && new_node != node) {
|
|
|
|
|
(void)manager->Replace(node, new_node);
|
|
|
|
|
changes = true;
|
|
|
|
|
for (auto &fg : func_graphs) {
|
|
|
|
|
manager->AddFuncGraph(fg);
|
|
|
|
|
auto graph_nodes_sorted = TopoSort(fg->output());
|
|
|
|
|
// Traverse once
|
|
|
|
|
for (auto &node : graph_nodes_sorted) {
|
|
|
|
|
AnfNodePtr new_node = Run(fg, func_graph, node, res);
|
|
|
|
|
if (new_node != nullptr && new_node != node) {
|
|
|
|
|
MS_LOG(WARNING) << "Matched";
|
|
|
|
|
(void)manager->Replace(node, new_node);
|
|
|
|
|
changes = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return changes;
|
|
|
|
|
|