diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 3b2a89c909..49dd9b721d 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -95,7 +95,7 @@ PYBIND11_MODULE(_c_expression, m) { (void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature."); (void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"), py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"), - py::arg("phase") = py::str("dataset"), "Init and exec dataset."); + py::arg("phase") = py::str("dataset"), py::arg("use_run") = py::bool_(true), "Init and exec dataset."); (void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode."); (void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend."); diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index 2d1dafbb5f..36e6fdcc23 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -688,7 +688,7 @@ void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &ph bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, const std::string &phase) { + const std::vector &input_indexes, const std::string &phase, bool use_run) { std::string name = MsContext::GetInstance()->backend_policy(); #ifndef NO_DLIB auto ms_context = MsContext::GetInstance(); @@ -698,7 +698,7 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba } #endif if (name == kMsConvert || name == kMsVm) { - return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes); + return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, use_run); } #if ENABLE_GE return InitExecDatasetGe(queue_name, iter_num, batch_size, types, shapes, input_indexes, phase); @@ -713,7 +713,7 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes) { + const std::vector &input_indexes, bool use_run) { MS_LOG(INFO) << "Start InitDataSet Entry"; std::vector int_input_indexes; (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes), @@ -766,7 +766,9 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc // launch init dataset runner without inputs and outputs VectorRef args; auto fn = runner.run; - (void)(*fn)(args); + if (use_run){ + (void)(*fn)(args); + } MS_LOG(DEBUG) << "InitDataSetVm End."; return true; } diff --git a/mindspore/ccsrc/pipeline/pipeline.h b/mindspore/ccsrc/pipeline/pipeline.h index 6a99d4dbcd..6ed0c8ddea 100644 --- a/mindspore/ccsrc/pipeline/pipeline.h +++ b/mindspore/ccsrc/pipeline/pipeline.h @@ -127,12 +127,12 @@ void ExportGraph(const std::string &file_name, const std::string &, const std::s // init and exec dataset sub graph bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, const std::string &phase); + const std::vector &input_indexes, const std::string &phase, bool use_run); // Build and run dataset subgraph for ms backend bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes); + const std::vector &input_indexes, bool use_run); } // namespace pipeline } // namespace mindspore