You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
466 lines
17 KiB
466 lines
17 KiB
/**
|
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "pipeline/action.h"
|
|
|
|
#include <memory>
|
|
#include <utility>
|
|
#include <vector>
|
|
#include <string>
|
|
#include <algorithm>
|
|
#include <functional>
|
|
|
|
#include "ir/func_graph_cloner.h"
|
|
#include "ir/param_value_py.h"
|
|
#include "parallel/costmodel_context.h"
|
|
#include "parallel/context.h"
|
|
#include "pipeline/pass.h"
|
|
#include "pipeline/parse/parse_base.h"
|
|
#include "pipeline/parse/data_converter.h"
|
|
#include "pipeline/static_analysis/abstract_value.h"
|
|
#include "pipeline/static_analysis/static_analysis.h"
|
|
#include "pipeline/static_analysis/program_specialize.h"
|
|
#include "pipeline/resource.h"
|
|
#include "utils/context/ms_context.h"
|
|
#include "pipeline/remove_value_node_dup.h"
|
|
#include "optimizer/optimizer.h"
|
|
#include "vm/transform.h"
|
|
|
|
namespace mindspore {
|
|
namespace pipeline {
|
|
using CompileGraphs = compile::CompileGraphs;
|
|
using abstract::AnalysisResult;
|
|
using mindspore::abstract::AnalysisContextPtr;
|
|
|
|
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
|
const abstract::AbstractBasePtrList &args_spec, bool clear) {
|
|
MS_LOG(DEBUG) << "AbstractAnalyze start";
|
|
auto engine = res->engine();
|
|
MS_EXCEPTION_IF_NULL(engine);
|
|
if (clear) {
|
|
auto manager = res->manager();
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
engine->Clear();
|
|
for (auto &node : manager->all_nodes()) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
const AbstractBasePtr &prev_inferred = node->abstract();
|
|
// Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
|
|
if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
|
|
node->set_abstract(nullptr);
|
|
MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr";
|
|
}
|
|
}
|
|
}
|
|
auto ret = engine->Run(func_graph, args_spec);
|
|
MS_LOG(DEBUG) << "AbstractAnalyze end";
|
|
return ret;
|
|
}
|
|
|
|
FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
|
const abstract::AnalysisContextPtr &context) {
|
|
MS_LOG(DEBUG) << "ProgramSpecialize start";
|
|
abstract::ProgramSpecializer spc(res->engine());
|
|
FuncGraphPtr result = spc.Run(func_graph, context);
|
|
auto manager = res->manager();
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
manager->KeepRoots({result});
|
|
MS_LOG(DEBUG) << "ProgramSpecialize end";
|
|
return result;
|
|
}
|
|
|
|
FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
|
const abstract::AbstractBasePtrList &args_spec) {
|
|
MS_LOG(DEBUG) << "Renormalize start";
|
|
#ifdef ENABLE_PROFILE
|
|
double t1 = GetTime();
|
|
#endif
|
|
abstract::AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec, true);
|
|
#ifdef ENABLE_PROFILE
|
|
double t2 = GetTime();
|
|
#endif
|
|
auto ret = ProgramSpecialize(res, func_graph, result.context);
|
|
res->set_func_graph(ret);
|
|
#ifdef ENABLE_PROFILE
|
|
double t3 = GetTime();
|
|
MsProfile::StatTime("renormalize.infer", t2 - t1);
|
|
MsProfile::StatTime("renormalize.specialize", t3 - t2);
|
|
#endif
|
|
MS_LOG(DEBUG) << "Renormalize end";
|
|
return ret;
|
|
}
|
|
|
|
bool ParseAction(const ResourcePtr &res) {
|
|
if (!res->input()) {
|
|
MS_LOG(EXCEPTION) << "Parse error";
|
|
}
|
|
|
|
py::object input = res->input();
|
|
parse::Parser::InitParserEnvironment(input);
|
|
py::module path = py::module::import("os.path");
|
|
std::string dir = path.attr("dirname")(py::globals()["__file__"]).cast<std::string>();
|
|
|
|
parse::python_adapter::set_python_env_flag(true);
|
|
parse::python_adapter::SetPythonPath(dir);
|
|
|
|
FuncGraphPtr fg = parse::ConvertToFuncGraph(input);
|
|
if (fg == nullptr) {
|
|
MS_LOG(EXCEPTION) << "Parse error.";
|
|
}
|
|
res->set_func_graph(fg);
|
|
|
|
FuncGraphManagerPtr manager = res->manager();
|
|
if (manager == nullptr) {
|
|
MS_LOG(EXCEPTION) << "Manager is nullptr.";
|
|
}
|
|
manager->AddFuncGraph(fg);
|
|
return true;
|
|
}
|
|
|
|
// obj_map's graphs have the same construct, these graphs can be optimized to one graph.
|
|
// This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}->
|
|
// graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx}
|
|
// all obj_map's graph shared base_graph
|
|
bool CombineLikeGraphs(const ResourcePtr &res) {
|
|
auto &obj_map = parse::data_converter::GetObjGraphs();
|
|
|
|
for (auto it : obj_map) {
|
|
auto &graphs = it.second;
|
|
MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size();
|
|
auto fg = graphs[0];
|
|
FuncGraphPtrList func_graphs = {fg};
|
|
ClonerPtr cloner = std::make_shared<Cloner>(func_graphs, false, false, true, std::make_shared<TraceCopy>(),
|
|
std::make_shared<TraceCombileLikeGraphs>());
|
|
cloner->Run();
|
|
auto base_graph = cloner->cloned_func_graph()[fg];
|
|
MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString();
|
|
|
|
if (fg->paramter_obj_nodes().size() == 0 || graphs.size() <= 1) {
|
|
continue;
|
|
}
|
|
for (auto &fv : fg->paramter_obj_nodes()) {
|
|
TraceManager::DebugTrace(std::make_shared<TraceCombileLikeGraphs>(fv->debug_info()));
|
|
auto param = base_graph->add_parameter();
|
|
TraceManager::EndTrace();
|
|
auto &node_users = res->manager()->node_users()[fv];
|
|
for (auto &n : node_users) {
|
|
auto repl_n = (*cloner->cloned_node())[n.first]->cast<CNodePtr>();
|
|
repl_n->set_input(n.second, param);
|
|
}
|
|
}
|
|
MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size();
|
|
|
|
for (auto &g : graphs) {
|
|
auto fvs = g->paramter_obj_nodes();
|
|
std::vector<AnfNodePtr> new_node_inputs;
|
|
new_node_inputs.push_back(NewValueNode(base_graph));
|
|
for (auto &p : g->parameters()) {
|
|
AnfNodePtr para_after_cast = parse::GetMixedPrecisionCastHelp(g, p);
|
|
new_node_inputs.push_back(para_after_cast);
|
|
}
|
|
(void)new_node_inputs.insert(new_node_inputs.end(), fvs.begin(), fvs.end());
|
|
AnfNodePtr out = g->NewCNode(new_node_inputs);
|
|
g->set_output(out);
|
|
MS_LOG(DEBUG) << "Combine graph newout:" << out->DebugString(4);
|
|
}
|
|
MS_LOG(DEBUG) << "End combine graph:" << it.first;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool SymbolResolveAction(const ResourcePtr &res) {
|
|
if (res->manager() == nullptr) {
|
|
MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null";
|
|
}
|
|
if (res->func_graph() == nullptr) {
|
|
MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null";
|
|
}
|
|
FuncGraphPtr func_graph = res->func_graph();
|
|
auto succ = parse::ResolveFuncGraph(func_graph, res);
|
|
|
|
// Remove unused nodes in cnode order list.
|
|
func_graph->EraseUnusedNodeInOrder();
|
|
func_graph->ReleaseFullOrderToEffectOrder();
|
|
for (auto fg : func_graph->func_graphs_used_total()) {
|
|
MS_EXCEPTION_IF_NULL(fg);
|
|
fg->EraseUnusedNodeInOrder();
|
|
fg->ReleaseFullOrderToEffectOrder();
|
|
}
|
|
return succ;
|
|
}
|
|
|
|
bool InferenceOptPrepareAction(const ResourcePtr &res) {
|
|
if (res->manager() == nullptr) {
|
|
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null.";
|
|
}
|
|
if (res->func_graph() == nullptr) {
|
|
MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null.";
|
|
}
|
|
return InferenceOptPreparePass(res);
|
|
}
|
|
|
|
bool AbstractSpecializeAction(const ResourcePtr &res) {
|
|
if (res->func_graph() == nullptr) {
|
|
MS_LOG(EXCEPTION) << "AbstractSpecialize error";
|
|
}
|
|
|
|
FuncGraphPtr func_graph = res->func_graph();
|
|
abstract::AbstractBasePtrList args_spec = res->args_spec();
|
|
|
|
parallel::ParallelParameterContextInit(func_graph);
|
|
|
|
// suppose that there is not KeywordArgument for the top graph
|
|
// get the hyper parameter
|
|
for (const auto ¶m : func_graph->parameters()) {
|
|
auto param_node = std::static_pointer_cast<Parameter>(param);
|
|
if (param_node->has_default()) {
|
|
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
|
|
AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true);
|
|
|
|
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
|
|
args_spec.push_back(ptr);
|
|
parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr);
|
|
}
|
|
}
|
|
// Analyze
|
|
AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec);
|
|
// The top graph may be replaced by infer, update the top graph when the infer is done
|
|
parse::Parser::UpdateTopFuncGraph(result.context->func_graph());
|
|
|
|
// Specialize
|
|
FuncGraphPtr new_fg = ProgramSpecialize(res, result.context->func_graph(), result.context);
|
|
res->set_func_graph(new_fg);
|
|
|
|
MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
|
|
return true;
|
|
}
|
|
|
|
bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) {
|
|
size_t counter = 0;
|
|
for (auto &pass : passes) {
|
|
WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() {
|
|
MS_LOG(DEBUG) << "Pass " << pass.first << " start ...";
|
|
auto result = pass.second(res);
|
|
if (!result) {
|
|
MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first;
|
|
}
|
|
if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) {
|
|
auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first;
|
|
auto func_graph = res->func_graph();
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
func_graph->DumpFuncGraph(fg_name);
|
|
DumpIR(fg_name + ".ir", func_graph);
|
|
MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
|
|
}
|
|
counter++;
|
|
MS_LOG(DEBUG) << "Pass " << pass.first << " end.";
|
|
};
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }
|
|
|
|
bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }
|
|
|
|
bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); }
|
|
|
|
static bool IsCtrlSink() {
|
|
auto ms_ctx = MsContext::GetInstance();
|
|
std::string device_target = ms_ctx->device_target();
|
|
if (device_target != kAscendDevice) {
|
|
return false;
|
|
}
|
|
|
|
if (!ms_ctx->enable_task_sink()) {
|
|
return false;
|
|
}
|
|
|
|
const char *enable_ctrl_sink = std::getenv("ENABLE_CTRL_SINK");
|
|
if (enable_ctrl_sink == nullptr) {
|
|
return false;
|
|
}
|
|
std::string enable_ctrl_sink_str(enable_ctrl_sink);
|
|
if (enable_ctrl_sink_str == "0") {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool TaskEmitAction(const ResourcePtr &res) {
|
|
if (res->func_graph() == nullptr) {
|
|
MS_LOG(EXCEPTION) << "TaskEmit args error";
|
|
}
|
|
FuncGraphPtr func_graph = res->func_graph();
|
|
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
|
|
if (IsCtrlSink()) {
|
|
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
|
|
return true;
|
|
}
|
|
std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops;
|
|
if (bc_ptr->name() == kMsConvert) {
|
|
cut_list = compile::GetMsNonlinearOps();
|
|
}
|
|
|
|
std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list);
|
|
auto context_ptr = MsContext::GetInstance();
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
if (compile->ContainMixedTarget(func_graph)) {
|
|
bc_ptr->set_is_multi_graph_sink(false);
|
|
context_ptr->set_loop_sink_flag(false);
|
|
} else if (context_ptr->execution_mode() != kPynativeMode) {
|
|
std::string device_target = context_ptr->device_target();
|
|
if (device_target == kAscendDevice) {
|
|
bc_ptr->set_is_multi_graph_sink(true);
|
|
}
|
|
}
|
|
res->results()[kOutput] = compile->CompileAndLink(func_graph);
|
|
return true;
|
|
}
|
|
|
|
bool ExecuteAction(const ResourcePtr &res) {
|
|
if (res->results().count(kOutput) == 0) {
|
|
MS_LOG(EXCEPTION) << "Execute args error";
|
|
}
|
|
|
|
if (IsCtrlSink()) {
|
|
if (!res->results()[kOutput].is<GraphId>()) {
|
|
MS_LOG(EXCEPTION) << "Execute args error";
|
|
}
|
|
|
|
auto graph_id = res->results()[kOutput].cast<GraphId>();
|
|
std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
|
|
std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr);
|
|
MS_EXCEPTION_IF_NULL(msbc_ptr);
|
|
compile::VmEvalFuncPtr run =
|
|
std::make_shared<compile::VmEvalFunc>([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef {
|
|
MS_LOG(INFO) << "Execute args size " << args.size();
|
|
auto outs = msbc_ptr->RunGraph(graph_id, args);
|
|
MS_LOG(DEBUG) << "out size " << outs.size();
|
|
return outs[0];
|
|
});
|
|
res->results()[kOutput] = run;
|
|
return true;
|
|
}
|
|
|
|
if (!res->results()[kOutput].is<compile::FinalVMPtr>()) {
|
|
MS_LOG(EXCEPTION) << "Execute args error";
|
|
}
|
|
compile::FinalVMPtr vm = res->results()[kOutput].cast<compile::FinalVMPtr>();
|
|
if (vm == nullptr) {
|
|
MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM";
|
|
return true;
|
|
}
|
|
compile::VmEvalFuncPtr run =
|
|
std::make_shared<compile::VmEvalFunc>(std::bind(&compile::FinalVM::Eval, vm, std::placeholders::_1));
|
|
res->results()[kOutput] = run;
|
|
return true;
|
|
}
|
|
|
|
// The parallel primitive related valuenode might be partitioned so that its value changes by device,
|
|
// that will result in a syncronization error due to different executing order.
|
|
// Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive,
|
|
// the final solution will be proposed later as a parallel feature.
|
|
bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &res) {
|
|
auto &node_users = res->manager()->node_users();
|
|
auto &users = node_users[value_node];
|
|
auto used_by_keep_value_prim =
|
|
std::any_of(users.begin(), users.end(), [](const std::pair<AnfNodePtr, int> &user) -> bool {
|
|
MS_EXCEPTION_IF_NULL(user.first);
|
|
auto cnode = user.first->cast<CNodePtr>();
|
|
if (cnode == nullptr) {
|
|
return false;
|
|
}
|
|
auto prim_node = cnode->input(0);
|
|
if (IsValueNode<Primitive>(prim_node)) {
|
|
auto prim = GetValue<PrimitivePtr>(prim_node->cast<ValueNodePtr>()->value());
|
|
// value_node is referenced by some parallel primitive
|
|
return prim->HasAttr("keep_value_node_input");
|
|
}
|
|
return false;
|
|
});
|
|
return used_by_keep_value_prim;
|
|
}
|
|
|
|
bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
|
|
if (res->func_graph() == nullptr) {
|
|
MS_LOG(EXCEPTION) << "Remove value node duplications error.";
|
|
}
|
|
FuncGraphPtr func_graph = res->func_graph();
|
|
auto manager = res->manager();
|
|
// Remove duplicated value nodes, due to replace operation, can't use reference.
|
|
auto value_nodes = func_graph->value_nodes();
|
|
HashCache hash_cache;
|
|
HashValue hashes;
|
|
for (const auto &value_pair : value_nodes) {
|
|
if (KeepValueNodeDuplication(value_pair.first, res)) {
|
|
continue;
|
|
}
|
|
TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); }
|
|
|
|
static std::vector<ActionItem> CommonPipeline() {
|
|
std::vector<ActionItem> actions;
|
|
|
|
// Parse the python ast to ANF graph
|
|
actions.emplace_back(std::make_pair("parse", ParseAction));
|
|
|
|
// Resolve the python func
|
|
actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction));
|
|
auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs();
|
|
if (!multi_graphs) {
|
|
actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
|
|
}
|
|
actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
|
|
// Evaluate type and shape, and specialize
|
|
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
|
|
|
|
return actions;
|
|
}
|
|
|
|
std::vector<ActionItem> GePipeline() {
|
|
auto actions = CommonPipeline();
|
|
// optimize
|
|
actions.emplace_back(std::make_pair("optimize", GeOptimizeAction));
|
|
actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
|
|
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
|
return actions;
|
|
}
|
|
|
|
std::vector<ActionItem> VmPipeline() {
|
|
auto actions = CommonPipeline();
|
|
|
|
// optimize
|
|
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
|
|
|
actions.emplace_back(std::make_pair("validate", ValidateAction));
|
|
|
|
// compile the ANF graph
|
|
actions.emplace_back(std::make_pair("task_emit", TaskEmitAction));
|
|
|
|
// to execute the graph
|
|
actions.emplace_back(std::make_pair("execute", ExecuteAction));
|
|
|
|
return actions;
|
|
}
|
|
} // namespace pipeline
|
|
} // namespace mindspore
|