|
|
|
@ -33,8 +33,6 @@
|
|
|
|
|
#include "include/api/types.h"
|
|
|
|
|
|
|
|
|
|
using mindspore::Context;
|
|
|
|
|
using mindspore::GlobalContext;
|
|
|
|
|
using mindspore::ModelContext;
|
|
|
|
|
using mindspore::Serialization;
|
|
|
|
|
using mindspore::Model;
|
|
|
|
|
using mindspore::Status;
|
|
|
|
@ -63,24 +61,26 @@ int main(int argc, char **argv) {
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310);
|
|
|
|
|
GlobalContext::SetGlobalDeviceID(FLAGS_device_id);
|
|
|
|
|
auto context = std::make_shared<Context>();
|
|
|
|
|
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
|
|
|
|
|
ascend310->SetDeviceID(FLAGS_device_id);
|
|
|
|
|
context->MutableDeviceInfo().push_back(ascend310);
|
|
|
|
|
mindspore::Graph graph;
|
|
|
|
|
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
|
|
|
|
|
|
|
|
|
|
auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR);
|
|
|
|
|
auto model_context = std::make_shared<Context>();
|
|
|
|
|
|
|
|
|
|
if (!FLAGS_precision_mode.empty()) {
|
|
|
|
|
ModelContext::SetPrecisionMode(model_context, FLAGS_precision_mode);
|
|
|
|
|
ascend310->SetPrecisionMode(FLAGS_precision_mode);
|
|
|
|
|
}
|
|
|
|
|
if (!FLAGS_op_select_impl_mode.empty()) {
|
|
|
|
|
ModelContext::SetOpSelectImplMode(model_context, FLAGS_op_select_impl_mode);
|
|
|
|
|
ascend310->SetOpSelectImplMode(FLAGS_op_select_impl_mode);
|
|
|
|
|
}
|
|
|
|
|
if (!FLAGS_aipp_path.empty()) {
|
|
|
|
|
ModelContext::SetInsertOpConfigPath(model_context, FLAGS_aipp_path);
|
|
|
|
|
ascend310->SetInsertOpConfigPath(FLAGS_aipp_path);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Model model(GraphCell(graph), model_context);
|
|
|
|
|
Status ret = model.Build();
|
|
|
|
|
Model model;
|
|
|
|
|
Status ret = model.Build(GraphCell(graph), context);
|
|
|
|
|
if (ret != kSuccess) {
|
|
|
|
|
std::cout << "EEEEEEEERROR Build failed." << std::endl;
|
|
|
|
|
return 1;
|
|
|
|
|