!1260 for second order subgraph switch

Merge pull request !1260 from zongha/master
pull/1260/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit bb73bfdf3a

@ -95,7 +95,7 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature.");
(void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"),
py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"),
py::arg("phase") = py::str("dataset"), "Init and exec dataset.");
py::arg("phase") = py::str("dataset"), py::arg("need_run") = py::bool_(true), "Init and exec dataset.");
(void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode.");
(void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend.");

@ -694,7 +694,7 @@ void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &ph
bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, const std::string &phase) {
const std::vector<int64_t> &input_indexes, const std::string &phase, bool need_run) {
std::string name = MsContext::GetInstance()->backend_policy();
#ifndef NO_DLIB
auto ms_context = MsContext::GetInstance();
@ -704,7 +704,7 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
}
#endif
if (name == kMsConvert || name == kMsVm) {
return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes);
return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run);
}
#if ENABLE_GE
return InitExecDatasetGe(queue_name, iter_num, batch_size, types, shapes, input_indexes, phase);
@ -719,7 +719,7 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes) {
const std::vector<int64_t> &input_indexes, bool need_run) {
MS_LOG(INFO) << "Start InitDataSet Entry";
std::vector<int> int_input_indexes;
(void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes),
@ -772,7 +772,9 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
// launch init dataset runner without inputs and outputs
VectorRef args;
auto fn = runner.run;
if (need_run) {
(void)(*fn)(args);
}
MS_LOG(DEBUG) << "InitDataSetVm End.";
return true;
}

@ -127,12 +127,12 @@ void ExportGraph(const std::string &file_name, const std::string &, const std::s
// init and exec dataset sub graph
bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes, const std::string &phase);
const std::vector<int64_t> &input_indexes, const std::string &phase, bool need_run);
// Build and run dataset subgraph for ms backend
bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
const std::vector<int64_t> &input_indexes);
const std::vector<int64_t> &input_indexes, bool need_run);
} // namespace pipeline
} // namespace mindspore

Loading…
Cancel
Save