!14054 update interface for yolov4 and unet 310 infer

From: @lihongkang1
Reviewed-by: @c_34,@wuxuejian
Signed-off-by: @c_34
pull/14054/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 6fc62f4ee0

@ -34,8 +34,6 @@
using mindspore::Context; using mindspore::Context;
using mindspore::GlobalContext;
using mindspore::ModelContext;
using mindspore::Serialization; using mindspore::Serialization;
using mindspore::Model; using mindspore::Model;
using mindspore::Status; using mindspore::Status;
@ -57,19 +55,21 @@ int main(int argc, char **argv) {
return 1; return 1;
} }
GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310); auto context = std::make_shared<Context>();
GlobalContext::SetGlobalDeviceID(FLAGS_device_id); auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR); ascend310->SetDeviceID(FLAGS_device_id);
auto model_context = std::make_shared<Context>(); context->MutableDeviceInfo().push_back(ascend310);
Model model(GraphCell(graph), model_context); mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
Status ret = model.Build(); Model model;
Status ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) { if (ret != kSuccess) {
std::cout << "EEEEEEEERROR Build failed." << std::endl; std::cout << "EEEEEEEERROR Build failed." << std::endl;
return 1; return 1;
} }
std::vector<MSTensor> model_inputs = model.GetInputs();
std::vector<MSTensor> model_inputs = model.GetInputs();
auto all_files = GetAllFiles(FLAGS_dataset_path); auto all_files = GetAllFiles(FLAGS_dataset_path);
if (all_files.empty()) { if (all_files.empty()) {
std::cout << "ERROR: no input data." << std::endl; std::cout << "ERROR: no input data." << std::endl;

@ -15,7 +15,7 @@
# ============================================================================ # ============================================================================
if [[ $# -lt 2 || $# -gt 3 ]]; then if [[ $# -lt 2 || $# -gt 3 ]]; then
echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1 exit 1
fi fi
@ -71,7 +71,7 @@ function compile_app()
if [ -f "Makefile" ]; then if [ -f "Makefile" ]; then
make clean make clean
fi fi
sh build.sh &> build.log bash build.sh &> build.log
} }
function infer() function infer()

@ -33,8 +33,6 @@
#include "include/api/types.h" #include "include/api/types.h"
using mindspore::Context; using mindspore::Context;
using mindspore::GlobalContext;
using mindspore::ModelContext;
using mindspore::Serialization; using mindspore::Serialization;
using mindspore::Model; using mindspore::Model;
using mindspore::Status; using mindspore::Status;
@ -63,24 +61,26 @@ int main(int argc, char **argv) {
return 1; return 1;
} }
GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310); auto context = std::make_shared<Context>();
GlobalContext::SetGlobalDeviceID(FLAGS_device_id); auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR);
auto model_context = std::make_shared<Context>();
if (!FLAGS_precision_mode.empty()) { if (!FLAGS_precision_mode.empty()) {
ModelContext::SetPrecisionMode(model_context, FLAGS_precision_mode); ascend310->SetPrecisionMode(FLAGS_precision_mode);
} }
if (!FLAGS_op_select_impl_mode.empty()) { if (!FLAGS_op_select_impl_mode.empty()) {
ModelContext::SetOpSelectImplMode(model_context, FLAGS_op_select_impl_mode); ascend310->SetOpSelectImplMode(FLAGS_op_select_impl_mode);
} }
if (!FLAGS_aipp_path.empty()) { if (!FLAGS_aipp_path.empty()) {
ModelContext::SetInsertOpConfigPath(model_context, FLAGS_aipp_path); ascend310->SetInsertOpConfigPath(FLAGS_aipp_path);
} }
Model model(GraphCell(graph), model_context); Model model;
Status ret = model.Build(); Status ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) { if (ret != kSuccess) {
std::cout << "EEEEEEEERROR Build failed." << std::endl; std::cout << "EEEEEEEERROR Build failed." << std::endl;
return 1; return 1;

@ -15,7 +15,7 @@
# ============================================================================ # ============================================================================
if [[ $# -lt 3 || $# -gt 4 ]]; then if [[ $# -lt 3 || $# -gt 4 ]]; then
echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] [ANN_FILE] echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] [ANN_FILE]
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1 exit 1
fi fi
@ -64,7 +64,7 @@ function compile_app()
if [ -f "Makefile" ]; then if [ -f "Makefile" ]; then
make clean make clean
fi fi
sh build.sh &> build.log bash build.sh &> build.log
} }
function infer() function infer()

Loading…
Cancel
Save