update interface for yolov4 and unet 310 infer

pull/14054/head
lihongkang 4 years ago
parent d926e8e4f4
commit 65527c39bf

@ -34,8 +34,6 @@
using mindspore::Context;
using mindspore::GlobalContext;
using mindspore::ModelContext;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
@ -57,19 +55,21 @@ int main(int argc, char **argv) {
return 1;
}
GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310);
GlobalContext::SetGlobalDeviceID(FLAGS_device_id);
auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR);
auto model_context = std::make_shared<Context>();
Model model(GraphCell(graph), model_context);
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
Status ret = model.Build();
Model model;
Status ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "EEEEEEEERROR Build failed." << std::endl;
return 1;
}
std::vector<MSTensor> model_inputs = model.GetInputs();
std::vector<MSTensor> model_inputs = model.GetInputs();
auto all_files = GetAllFiles(FLAGS_dataset_path);
if (all_files.empty()) {
std::cout << "ERROR: no input data." << std::endl;

@ -15,7 +15,7 @@
# ============================================================================
if [[ $# -lt 2 || $# -gt 3 ]]; then
echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
@ -71,7 +71,7 @@ function compile_app()
if [ -f "Makefile" ]; then
make clean
fi
sh build.sh &> build.log
bash build.sh &> build.log
}
function infer()

@ -33,8 +33,6 @@
#include "include/api/types.h"
using mindspore::Context;
using mindspore::GlobalContext;
using mindspore::ModelContext;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
@ -63,24 +61,26 @@ int main(int argc, char **argv) {
return 1;
}
GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310);
GlobalContext::SetGlobalDeviceID(FLAGS_device_id);
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR);
auto model_context = std::make_shared<Context>();
if (!FLAGS_precision_mode.empty()) {
ModelContext::SetPrecisionMode(model_context, FLAGS_precision_mode);
ascend310->SetPrecisionMode(FLAGS_precision_mode);
}
if (!FLAGS_op_select_impl_mode.empty()) {
ModelContext::SetOpSelectImplMode(model_context, FLAGS_op_select_impl_mode);
ascend310->SetOpSelectImplMode(FLAGS_op_select_impl_mode);
}
if (!FLAGS_aipp_path.empty()) {
ModelContext::SetInsertOpConfigPath(model_context, FLAGS_aipp_path);
ascend310->SetInsertOpConfigPath(FLAGS_aipp_path);
}
Model model(GraphCell(graph), model_context);
Status ret = model.Build();
Model model;
Status ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "EEEEEEEERROR Build failed." << std::endl;
return 1;

@ -15,7 +15,7 @@
# ============================================================================
if [[ $# -lt 3 || $# -gt 4 ]]; then
echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] [ANN_FILE]
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] [ANN_FILE]
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
@ -64,7 +64,7 @@ function compile_app()
if [ -f "Makefile" ]; then
make clean
fi
sh build.sh &> build.log
bash build.sh &> build.log
}
function infer()

Loading…
Cancel
Save