supporting stateless if/while

pull/278/head
chuxing 4 years ago
parent 2e193f8f19
commit 54fab4c9aa

@ -257,7 +257,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
}
// cond or branch need to be prepared before the execution of IF or CASE
if (node_item.node_type == IF || node_item.node_type == CASE) {
if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) {
const auto &in_anchor = ge_node->GetInDataAnchor(0);
GE_CHECK_NOTNULL(in_anchor);
const auto &peer_anchor = in_anchor->GetPeerOutAnchor();
@ -917,7 +917,7 @@ Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr
auto parent_node = sub_graph.GetParentNode();
GE_CHECK_NOTNULL(parent_node);
auto op_type = parent_node->GetType();
if (op_type == IF || op_type == CASE || op_type == WHILE) {
if (IsControlOp(op_type)) {
GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d",
sub_graph.GetName().c_str(),
ge_model->GetModelTaskDefPtr()->task_size());

@ -28,6 +28,9 @@ namespace hybrid {
namespace {
const char * const kAttrNameOriginalFusionGraph = "_original_fusion_graph";
const char * const kNodeTypeRetVal = "_RetVal";
std::set<std::string> kControlOpTypes {
IF, STATELESSIF, CASE, WHILE, STATELESSWHILE
}
Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) {
uint32_t parent_index = 0;
@ -102,6 +105,11 @@ Status ParseFusedSubgraph(NodeItem &node_item) {
return SUCCESS;
}
} // namespace
bool IsControlOp(const std::string &op_type) {
return kControlOpTypes.count(op_type) > 0;
}
NodeItem::NodeItem(NodePtr node): node(std::move(node)) {
this->op_desc = this->node->GetOpDesc().get();
this->node_id = this->op_desc->GetId();
@ -153,8 +161,7 @@ Status NodeItem::Init() {
}
bool NodeItem::IsControlOp() const {
auto op_type = op_desc->GetType();
return op_type == IF || op_type == CASE || op_type == WHILE || op_type == FOR;
return ge::hybrid::IsControlOp(op_desc->GetType());
}
std::string NodeItem::DebugString() const {

@ -36,6 +36,8 @@ struct FusedSubgraph {
ComputeGraphPtr graph;
};
bool IsControlOp(const std::string &op_type);
// for caching static information across execution
struct NodeItem {
explicit NodeItem(NodePtr node);

@ -404,11 +404,11 @@ Status ControlOpNodeExecutor::LoadTask(const HybridModel &model,
unique_ptr<ControlOpNodeTask> node_task;
auto node_type = node->GetType();
if (node_type == IF) {
if (node_type == IF || node_type == STATELESSIF) {
node_task.reset(new(std::nothrow) IfOpNodeTask());
} else if (node_type == CASE) {
node_task.reset(new(std::nothrow) CaseOpNodeTask());
} else if (node_type == WHILE) {
} else if (node_type == WHILE || node_type == STATELESSWHILE) {
node_task.reset(new(std::nothrow) WhileOpNodeTask());
} else {
GELOGE(PARAM_INVALID, "[%s] Unsupported type: %s", node->GetName().c_str(), node_type.c_str());

@ -97,7 +97,7 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node
return ExecutorType::GE_LOCAL;
}
if (op_type == IF || op_type == CASE || op_type == WHILE) {
if (IsControlOp(op_type)) {
return ExecutorType::CONTROL_OP;
}

Loading…
Cancel
Save