pull/840/head
chuxing 4 years ago
parent 84cd741be4
commit 68bbf9e41c

@ -95,6 +95,7 @@ Status HybridModelExecutor::InitExecutionContext() {
context_.stream = stream_;
context_.model = model_;
context_.is_eos_ = false;
context_.session_id = ::ge::GetContext().SessionId();
context_.ge_context = &GetThreadLocalContext();
GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id);

@ -82,6 +82,7 @@ struct NodeItem {
bool has_observer = false;
bool has_optional_inputs = false;
bool is_output_shape_static = true;
bool may_trigger_eos_ = false;
UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE;
std::string node_name;
std::string node_type;

@ -21,8 +21,6 @@
#include "common/ge/ge_util.h"
#include "graph/attr_value.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/load/new_model_manager/model_utils.h"
#include "graph/load/new_model_manager/model_manager.h"
#include "hybrid/executor/hybrid_execution_context.h"
@ -60,10 +58,6 @@ Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> d
GELOGE(rt_ret, "rtModelExecute error, ret: hybrid_model_executorOx%X", rt_ret); return FAILED;);
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodertModelExecute] End");
if (need_sync_) {
GELOGD("[%s] model need sync", context.GetNodeName());
GE_CHK_STATUS_RET_NOLOG(context.Synchronize());
}
GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback));
GELOGD("[%s] KnownNodeTask::ExecuteAsync success.", context.GetNodeName());
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeTaskExecuteAsync] End");
@ -177,9 +171,7 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node
GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed.");
bool need_sync = false;
GE_CHK_STATUS_RET_NOLOG(NeedSync(*ge_model, need_sync));
task = MakeShared<KnownNodeTask>(davinci_model, need_sync);
task = MakeShared<KnownNodeTask>(davinci_model);
GE_CHECK_NOTNULL(task);
GELOGI("[%s] KnownNodeExecutor::LoadTask success.", node->GetName().c_str());
return SUCCESS;
@ -194,21 +186,5 @@ Status KnownNodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context,
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorExecuteTask] End");
return SUCCESS;
}
Status KnownNodeExecutor::NeedSync(GeModel &ge_model, bool &need_sync) {
auto compute_graph = GraphUtils::GetComputeGraph(ge_model.GetGraph());
GE_CHECK_NOTNULL(compute_graph);
for (auto &node : compute_graph->GetAllNodes()) {
auto type = NodeUtils::GetNodeType(node);
if (type == GETNEXT) {
GELOGD("Contains GetNext node: %s", node->GetName().c_str());
need_sync = true;
return SUCCESS;
}
}
need_sync = false;
return SUCCESS;
}
} // namespace hybrid
} // namespace ge

@ -27,8 +27,8 @@ class HybridModel;
class KnownNodeTask : public NodeTask {
public:
explicit KnownNodeTask(std::shared_ptr<DavinciModel> davinci_model, bool need_sync)
: davinci_model_(davinci_model), need_sync_(need_sync)
explicit KnownNodeTask(std::shared_ptr<DavinciModel> davinci_model)
: davinci_model_(davinci_model)
{}
~KnownNodeTask() {}
@ -39,7 +39,6 @@ class KnownNodeTask : public NodeTask {
private:
std::shared_ptr<DavinciModel> davinci_model_ = nullptr;
bool load_flag_ = false;
bool need_sync_;
};
class KnownNodeExecutor : public NodeExecutor {
@ -49,7 +48,6 @@ class KnownNodeExecutor : public NodeExecutor {
Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const;
~KnownNodeExecutor() {}
private:
static Status NeedSync(GeModel &ge_model, bool &need_sync);
std::shared_ptr<DavinciModel> davinci_model_ = nullptr;
};
} // namespace hybrid

Loading…
Cancel
Save