!310 Fix aclInferShapeAndType core dump.

From: @dong-duo
Reviewed-by: @xchu42,@wqtshg
Signed-off-by:
pull/310/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit d1a9994f38

@ -35,6 +35,7 @@
#include "ir_build/atc_ir_common.h" #include "ir_build/atc_ir_common.h"
#include "model/ge_model.h" #include "model/ge_model.h"
#include "graph/shape_refiner.h" #include "graph/shape_refiner.h"
#include "graph/opsproto_manager.h"
using std::string; using std::string;
using namespace std; using namespace std;
@ -109,6 +110,37 @@ static graphStatus CheckGlobalOptions(std::map<std::string, std::string> &global
return GRAPH_SUCCESS; return GRAPH_SUCCESS;
} }
static void GetOpsProtoPath(string &opsproto_path) {
GELOGI("Start to get ops proto path schedule.");
const char *path_env = std::getenv("ASCEND_OPP_PATH");
if (path_env != nullptr) {
string path = path_env;
string file_path = RealPath(path.c_str());
if (file_path.empty()) {
GELOGE(FAILED, "File path %s is invalid.", path.c_str());
return;
}
opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/");
GELOGI("Get opsproto so path from env : %s", path.c_str());
return;
}
string path_base = PluginManager::GetPath();
GELOGI("path_base is %s", path_base.c_str());
path_base = path_base.substr(0, path_base.rfind('/'));
path_base = path_base.substr(0, path_base.rfind('/') + 1);
opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
}
static void LoadOpsProto() {
string opsproto_path;
GetOpsProtoPath(opsproto_path);
GELOGI("Get opsproto path is %s", opsproto_path.c_str());
OpsProtoManager *manager = OpsProtoManager::Instance();
map<string, string> option_tmp;
option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
(void)manager->Initialize(option_tmp);
}
graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options) { graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options) {
GELOGD("Enter aclgrphInitialize start!"); GELOGD("Enter aclgrphInitialize start!");
// check global options // check global options
@ -116,9 +148,12 @@ graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_opt
GELOGE(GRAPH_PARAM_INVALID, "Check global options falied!"); GELOGE(GRAPH_PARAM_INVALID, "Check global options falied!");
return GRAPH_PARAM_INVALID; return GRAPH_PARAM_INVALID;
} }
// print global option map // print global option map
ge::PrintOptionMap(global_options, "global option"); ge::PrintOptionMap(global_options, "global option");
LoadOpsProto();
std::shared_ptr<ge::GELib> instance_ptr = ge::GELib::GetInstance(); std::shared_ptr<ge::GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGI("aclgrphInitialize start!"); GELOGI("aclgrphInitialize start!");
@ -172,6 +207,7 @@ class Impl {
bool is_dynamic_input); bool is_dynamic_input);
void SetRtSocVersion(); void SetRtSocVersion();
void UpdateThreadContext(); void UpdateThreadContext();
void LoadOpsProto();
public: public:
ge::GeGenerator generator_; ge::GeGenerator generator_;
std::map<std::string, std::string> options_; std::map<std::string, std::string> options_;
@ -442,6 +478,12 @@ graphStatus aclgrphInferShapeAndType(ge::Graph &graph) {
auto compute_graph = GraphUtils::GetComputeGraph(graph); auto compute_graph = GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph); GE_CHECK_NOTNULL(compute_graph);
auto root_graph = compute_graph->GetParentGraph();
if (root_graph != nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "Input param should not be subgraph");
return GRAPH_PARAM_INVALID;
}
auto ret = compute_graph->InferOriginFormat(); auto ret = compute_graph->InferOriginFormat();
if (ret != GRAPH_SUCCESS) { if (ret != GRAPH_SUCCESS) {
GELOGE(ret, "Acl InferOriginFormat failed."); GELOGE(ret, "Acl InferOriginFormat failed.");

Loading…
Cancel
Save