|
|
|
@ -46,8 +46,9 @@ using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiv
|
|
|
|
|
std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
|
|
|
|
|
prim::kPrimMakeTuple, prim::kPrimBpropCut};
|
|
|
|
|
const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
|
|
|
|
|
static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
|
|
|
|
|
prim::kPrimBpropCut};
|
|
|
|
|
static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial,
|
|
|
|
|
prim::kPrimSwitch, prim::kPrimMakeTuple,
|
|
|
|
|
prim::kPrimBpropCut, prim::kPrimSwitchLayer};
|
|
|
|
|
return ms_nonlinear_ops;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -187,6 +188,30 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
|
|
|
|
|
std::reverse(result.begin(), result.end());
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsSubGraph(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (node->isa<CNode>()) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
auto &inputs = cnode->inputs();
|
|
|
|
|
if (inputs.empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr fn = inputs[0];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fn);
|
|
|
|
|
if (!IsValueNode<Primitive>(fn)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto node_prim = GetValueNode<PrimitivePtr>(fn);
|
|
|
|
|
if (node_prim->name() == prim::kPrimPartial->name()) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
} else if (IsValueNode<FuncGraph>(node)) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list)
|
|
|
|
@ -235,6 +260,15 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
|
ms_context->set_enable_pynative_hook(true);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) {
|
|
|
|
|
if (inputs.size() < 2) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto ret = IsSubGraph(inputs[1]);
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -466,6 +500,8 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node)
|
|
|
|
|
} else if (IsPrimitive(fn, prim::kPrimSwitch)) {
|
|
|
|
|
AddSwitch(node);
|
|
|
|
|
AddSinkSwitch(node);
|
|
|
|
|
} else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) {
|
|
|
|
|
AddSwitchLayer(node);
|
|
|
|
|
} else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
|
|
|
|
|
AddMakeTuple(node);
|
|
|
|
|
} else {
|
|
|
|
@ -622,6 +658,17 @@ void CompileGraph::AddSwitch(const CNodePtr &node) {
|
|
|
|
|
AddInst(Instruction::kSwitch, args);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CompileGraph::AddSwitchLayer(const CNodePtr &node) {
|
|
|
|
|
auto inputs = node->inputs();
|
|
|
|
|
if (inputs.size() != 3) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Switch layer must have index and branches.";
|
|
|
|
|
}
|
|
|
|
|
VectorRef args;
|
|
|
|
|
args.emplace_back(Ref(inputs[1]));
|
|
|
|
|
args.emplace_back(Ref(inputs[2]));
|
|
|
|
|
AddInst(Instruction::kSwitchLayer, args);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CompileGraph::AddReturn(const CNodePtr &node) {
|
|
|
|
|
VectorRef args;
|
|
|
|
|
if (backend_->simu_flag()) {
|
|
|
|
|