|
|
|
@ -17,6 +17,7 @@
|
|
|
|
|
#include "hybrid_model_executor.h"
|
|
|
|
|
#include "graph/ge_context.h"
|
|
|
|
|
#include "graph/runtime_inference_context.h"
|
|
|
|
|
#include "graph/utils/tensor_utils.h"
|
|
|
|
|
#include "common/dump/dump_manager.h"
|
|
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
@ -50,6 +51,11 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) {
|
|
|
|
|
auto root_graph_item = model_->GetRootGraphItem();
|
|
|
|
|
GE_CHECK_NOTNULL(root_graph_item);
|
|
|
|
|
|
|
|
|
|
if (root_graph_item->IsDynamic()) {
|
|
|
|
|
GE_CHK_STATUS_RET(CheckInputShapeByShapeRange(root_graph_item, args),
|
|
|
|
|
"[%s] check input node shape by shape range failed.",
|
|
|
|
|
root_graph_item->GetName().c_str());
|
|
|
|
|
}
|
|
|
|
|
GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration,
|
|
|
|
|
sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream));
|
|
|
|
|
SubgraphExecutor executor(model_->GetRootGraphItem(), &context_);
|
|
|
|
@ -138,5 +144,55 @@ Status HybridModelExecutor::ResetExecutionContext(GraphExecutionContext &context
|
|
|
|
|
GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::CreateContext(ctx_id), "Failed to Destroy RuntimeInferenceContext");
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status HybridModelExecutor::CheckInputShapeByShapeRange(const GraphItem *graph_item,
|
|
|
|
|
HybridModelExecutor::ExecuteArgs &args) {
|
|
|
|
|
GE_CHECK_NOTNULL(graph_item);
|
|
|
|
|
auto input_nodes = graph_item->GetInputNodes();
|
|
|
|
|
if (args.input_desc.size() < input_nodes.size()) {
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "[%s] Number of inputs [%zu] is not sufficient for graph which needs [%zu] inputs.",
|
|
|
|
|
graph_item->GetName().c_str(), args.input_desc.size(), input_nodes.size());
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "[%s] Number of inputs [%zu] is not sufficient for graph which needs [%zu] inputs.",
|
|
|
|
|
graph_item->GetName().c_str(), args.input_desc.size(), input_nodes.size());
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < input_nodes.size(); ++i) {
|
|
|
|
|
auto &input_node = input_nodes[i];
|
|
|
|
|
if (input_node == nullptr) {
|
|
|
|
|
GELOGD("[%s] Input[%zu] is not needed by graph, skip it.", graph_item->GetName().c_str(), i);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
GeTensorDescPtr model_input_desc = input_node->MutableInputDesc(i);
|
|
|
|
|
GE_CHECK_NOTNULL(model_input_desc);
|
|
|
|
|
std::vector<std::pair<int64_t, int64_t>> shape_range;
|
|
|
|
|
if (model_input_desc->GetShapeRange(shape_range) != SUCCESS) {
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "[%s] Input[%zu] get shape range failed", graph_item->GetName().c_str(), i);
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "[%s] Input[%zu] get shape range failed", graph_item->GetName().c_str(), i);
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (shape_range.empty()) {
|
|
|
|
|
GELOGD("[%s] Input[%zu] shape is not needed to check by shape range, skip it.", graph_item->GetName().c_str(), i);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
ConstGeTensorDescPtr args_tensor_desc = args.input_desc[i];
|
|
|
|
|
GE_CHECK_NOTNULL(args_tensor_desc);
|
|
|
|
|
GeShape shape = args_tensor_desc->GetShape();
|
|
|
|
|
if (shape.IsUnknownShape()) {
|
|
|
|
|
REPORT_INNER_ERROR("E19999", "[%s] Input desc shape [%zu] designed by user must be static.",
|
|
|
|
|
graph_item->GetName().c_str(), i);
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "[%s] Input desc shape [%zu] designed by user must be static.",
|
|
|
|
|
graph_item->GetName().c_str(), i);
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (TensorUtils::CheckShapeByShapeRange(shape, shape_range) != SUCCESS) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "[Check][InputShape] [%s] check input [%zu] shape failed by shape range.",
|
|
|
|
|
graph_item->GetName().c_str(), i);
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
} // namespace hybrid
|
|
|
|
|
} // namespace ge
|
|
|
|
|