|
|
|
@ -442,10 +442,10 @@ void DfGraphConvertor::InitLoopVar(std::vector<ge::Operator> *init_input) {
|
|
|
|
|
|
|
|
|
|
int64_t value = 0;
|
|
|
|
|
auto const_iter_num = std::make_shared<Constant>("const/npu_runconfig/iterations_per_loop");
|
|
|
|
|
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
|
|
|
|
|
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
|
|
|
|
|
value = ConfigManager::GetInstance().iter_num();
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(INFO) << "Run with feed mode, the iterator number will always be 1";
|
|
|
|
|
MS_LOG(INFO) << "Run with non-sink mode, the iterator number will always be 1";
|
|
|
|
|
value = 1;
|
|
|
|
|
ConfigManager::GetInstance().set_iter_num(value);
|
|
|
|
|
}
|
|
|
|
@ -576,7 +576,7 @@ void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std
|
|
|
|
|
|
|
|
|
|
void DfGraphConvertor::MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it) {
|
|
|
|
|
MS_LOG(INFO) << "The " << name << " is the " << input_idx << "(st/nd/th) input";
|
|
|
|
|
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
|
|
|
|
|
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
|
|
|
|
|
auto getnext_idx = static_cast<int64_t>(input_idx);
|
|
|
|
|
DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
|
|
|
|
|
if (!param.input_indexes().empty() && input_idx <= param.input_indexes().size()) {
|
|
|
|
@ -868,7 +868,7 @@ DfGraphConvertor &DfGraphConvertor::ConvertAllNode() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create dataset iterator and iterator_getnext node
|
|
|
|
|
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
|
|
|
|
|
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
|
|
|
|
|
DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
|
|
|
|
|
MS_LOG(INFO) << "Dataset param is " << param.ToString() << ".";
|
|
|
|
|
// GetNext
|
|
|
|
@ -977,7 +977,7 @@ void DfGraphConvertor::TraceOutputFromParameter(const AnfNodePtr &anf_out) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetupDatasetIterGetNextNode(const OperatorPtr &op) {
|
|
|
|
|
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
|
|
|
|
|
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
|
|
|
|
|
DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
|
|
|
|
|
size_t output_num = param.ge_types().size();
|
|
|
|
|
MS_LOG(INFO) << "Set iterator_getnext op's output num = " << output_num << ".";
|
|
|
|
@ -1036,7 +1036,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
|
|
|
|
|
|
|
|
|
// set graph input according to the order from anf graph
|
|
|
|
|
std::vector<Operator> inputs;
|
|
|
|
|
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
|
|
|
|
|
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
|
|
|
|
|
inputs.push_back(*dataset_iter_getnext_);
|
|
|
|
|
} else {
|
|
|
|
|
auto params = anf_graph_->parameters();
|
|
|
|
|