supporting stateless if/while

pull/278/head
chuxing 4 years ago
parent 2e193f8f19
commit 54fab4c9aa

@ -257,7 +257,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
} }
// cond or branch need to be prepared before the execution of IF or CASE // cond or branch need to be prepared before the execution of IF or CASE
if (node_item.node_type == IF || node_item.node_type == CASE) { if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) {
const auto &in_anchor = ge_node->GetInDataAnchor(0); const auto &in_anchor = ge_node->GetInDataAnchor(0);
GE_CHECK_NOTNULL(in_anchor); GE_CHECK_NOTNULL(in_anchor);
const auto &peer_anchor = in_anchor->GetPeerOutAnchor(); const auto &peer_anchor = in_anchor->GetPeerOutAnchor();
@ -917,7 +917,7 @@ Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr
auto parent_node = sub_graph.GetParentNode(); auto parent_node = sub_graph.GetParentNode();
GE_CHECK_NOTNULL(parent_node); GE_CHECK_NOTNULL(parent_node);
auto op_type = parent_node->GetType(); auto op_type = parent_node->GetType();
if (op_type == IF || op_type == CASE || op_type == WHILE) { if (IsControlOp(op_type)) {
GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d", GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d",
sub_graph.GetName().c_str(), sub_graph.GetName().c_str(),
ge_model->GetModelTaskDefPtr()->task_size()); ge_model->GetModelTaskDefPtr()->task_size());

@ -28,6 +28,9 @@ namespace hybrid {
namespace { namespace {
const char * const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; const char * const kAttrNameOriginalFusionGraph = "_original_fusion_graph";
const char * const kNodeTypeRetVal = "_RetVal"; const char * const kNodeTypeRetVal = "_RetVal";
std::set<std::string> kControlOpTypes {
IF, STATELESSIF, CASE, WHILE, STATELESSWHILE
}
Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) {
uint32_t parent_index = 0; uint32_t parent_index = 0;
@ -102,6 +105,11 @@ Status ParseFusedSubgraph(NodeItem &node_item) {
return SUCCESS; return SUCCESS;
} }
} // namespace } // namespace
bool IsControlOp(const std::string &op_type) {
return kControlOpTypes.count(op_type) > 0;
}
NodeItem::NodeItem(NodePtr node): node(std::move(node)) { NodeItem::NodeItem(NodePtr node): node(std::move(node)) {
this->op_desc = this->node->GetOpDesc().get(); this->op_desc = this->node->GetOpDesc().get();
this->node_id = this->op_desc->GetId(); this->node_id = this->op_desc->GetId();
@ -153,8 +161,7 @@ Status NodeItem::Init() {
} }
bool NodeItem::IsControlOp() const { bool NodeItem::IsControlOp() const {
auto op_type = op_desc->GetType(); return ge::hybrid::IsControlOp(op_desc->GetType());
return op_type == IF || op_type == CASE || op_type == WHILE || op_type == FOR;
} }
std::string NodeItem::DebugString() const { std::string NodeItem::DebugString() const {

@ -36,6 +36,8 @@ struct FusedSubgraph {
ComputeGraphPtr graph; ComputeGraphPtr graph;
}; };
bool IsControlOp(const std::string &op_type);
// for caching static information across execution // for caching static information across execution
struct NodeItem { struct NodeItem {
explicit NodeItem(NodePtr node); explicit NodeItem(NodePtr node);

@ -404,11 +404,11 @@ Status ControlOpNodeExecutor::LoadTask(const HybridModel &model,
unique_ptr<ControlOpNodeTask> node_task; unique_ptr<ControlOpNodeTask> node_task;
auto node_type = node->GetType(); auto node_type = node->GetType();
if (node_type == IF) { if (node_type == IF || node_type == STATELESSIF) {
node_task.reset(new(std::nothrow) IfOpNodeTask()); node_task.reset(new(std::nothrow) IfOpNodeTask());
} else if (node_type == CASE) { } else if (node_type == CASE) {
node_task.reset(new(std::nothrow) CaseOpNodeTask()); node_task.reset(new(std::nothrow) CaseOpNodeTask());
} else if (node_type == WHILE) { } else if (node_type == WHILE || node_type == STATELESSWHILE) {
node_task.reset(new(std::nothrow) WhileOpNodeTask()); node_task.reset(new(std::nothrow) WhileOpNodeTask());
} else { } else {
GELOGE(PARAM_INVALID, "[%s] Unsupported type: %s", node->GetName().c_str(), node_type.c_str()); GELOGE(PARAM_INVALID, "[%s] Unsupported type: %s", node->GetName().c_str(), node_type.c_str());

@ -97,7 +97,7 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node
return ExecutorType::GE_LOCAL; return ExecutorType::GE_LOCAL;
} }
if (op_type == IF || op_type == CASE || op_type == WHILE) { if (IsControlOp(op_type)) {
return ExecutorType::CONTROL_OP; return ExecutorType::CONTROL_OP;
} }

Loading…
Cancel
Save