|
|
|
@ -269,13 +269,41 @@ bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePa
|
|
|
|
|
|
|
|
|
|
bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }
|
|
|
|
|
|
|
|
|
|
static bool IsCtrlSink() {
|
|
|
|
|
auto ms_ctx = MsContext::GetInstance();
|
|
|
|
|
std::string device_target = ms_ctx->device_target();
|
|
|
|
|
if (device_target != kAscendDevice) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!ms_ctx->enable_task_sink()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
char *enable_ctrl_sink = std::getenv("ENABLE_CTRL_SINK");
|
|
|
|
|
if (enable_ctrl_sink == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
std::string enable_ctrl_sink_str(enable_ctrl_sink);
|
|
|
|
|
if (enable_ctrl_sink_str == "0") {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool TaskEmitAction(const ResourcePtr &res) {
|
|
|
|
|
if (res->func_graph() == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "TaskEmit args error";
|
|
|
|
|
}
|
|
|
|
|
FuncGraphPtr func_graph = res->func_graph();
|
|
|
|
|
|
|
|
|
|
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
|
|
|
|
|
|
|
|
|
|
if (IsCtrlSink()) {
|
|
|
|
|
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops;
|
|
|
|
|
if (bc_ptr->name() == kMsConvert) {
|
|
|
|
|
cut_list = compile::GetMsNonlinearOps();
|
|
|
|
@ -286,10 +314,31 @@ bool TaskEmitAction(const ResourcePtr &res) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ExecuteAction(const ResourcePtr &res) {
|
|
|
|
|
if (res->results().count(kOutput) == 0 || !res->results()[kOutput].is<compile::FinalVMPtr>()) {
|
|
|
|
|
if (res->results().count(kOutput) == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Execute args error";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (IsCtrlSink()) {
|
|
|
|
|
if (!res->results()[kOutput].is<GraphId>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Execute args error";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto graph_id = res->results()[kOutput].cast<GraphId>();
|
|
|
|
|
auto bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::MsBackend>>();
|
|
|
|
|
compile::VmEvalFuncPtr run =
|
|
|
|
|
std::make_shared<compile::VmEvalFunc>([&bc_ptr, graph_id](const VectorRef &args) -> BaseRef {
|
|
|
|
|
MS_LOG(INFO) << "Execute args size" << args.size();
|
|
|
|
|
auto outs = bc_ptr->RunGraph(graph_id, args);
|
|
|
|
|
MS_LOG(DEBUG) << "out size" << outs.size();
|
|
|
|
|
return outs[0];
|
|
|
|
|
});
|
|
|
|
|
res->results()[kOutput] = run;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!res->results()[kOutput].is<compile::FinalVMPtr>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Execute args error";
|
|
|
|
|
}
|
|
|
|
|
compile::FinalVMPtr vm = res->results()[kOutput].cast<compile::FinalVMPtr>();
|
|
|
|
|
if (vm == nullptr) {
|
|
|
|
|
MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM";
|
|
|
|
|