diff --git a/ge/offline/single_op_parser.cc b/ge/offline/single_op_parser.cc index bc1b39f8..72f742e9 100644 --- a/ge/offline/single_op_parser.cc +++ b/ge/offline/single_op_parser.cc @@ -48,6 +48,8 @@ constexpr char const *kKeyShapeRange = "shape_range"; constexpr char const *kKeyValue = "value"; constexpr char const *kKeyFormat = "format"; constexpr char const *kFileSuffix = ".om"; +constexpr char const *kKeyDynamicInput = "dynamic_input"; +constexpr char const *kKeyDynamicOutput = "dynamic_output"; constexpr int kDumpJsonIndent = 2; constexpr int kShapeRangePairSize = 2; constexpr int kShapeRangeLow = 0; @@ -124,6 +126,10 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) { if (tensor_name != j.end()) { desc.name = tensor_name->get(); } + auto dynamic_input_name = j.find(kKeyDynamicInput); + if (dynamic_input_name != j.end()) { + desc.dynamic_input_name = dynamic_input_name->get(); + } } void from_json(const Json &j, SingleOpAttr &attr) { @@ -276,6 +282,23 @@ std::unique_ptr SingleOpParser::CreateOpDesc(const string &op_type) { return std::unique_ptr(new(std::nothrow) OpDesc(op_type, op_type)); } +Status SingleOpParser::UpdateDynamicTensorName(std::vector &desc) { + std::map dynamic_name_map; + for (auto &tensor : desc) { + if (tensor.dynamic_input_name.empty()) { + continue; + } + if (dynamic_name_map.find(tensor.dynamic_input_name) == dynamic_name_map.end()) { + dynamic_name_map[tensor.dynamic_input_name] = 0; + } else { + dynamic_name_map[tensor.dynamic_input_name]++; + } + tensor.name = tensor.dynamic_input_name + std::to_string(dynamic_name_map[tensor.dynamic_input_name]); + } + GELOGD("Update dynamic tensor name success!"); + return SUCCESS; +} + Status SingleOpParser::ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param) { @@ -471,6 +494,11 @@ Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector> dim_ranges; ge::Format format = ge::FORMAT_RESERVED; ge::DataType type = ge::DT_UNDEFINED; + std::string dynamic_input_name; }; struct SingleOpAttr { @@ -70,6 +71,7 @@ class SingleOpParser { static bool Validate(const SingleOpDesc &op_desc); static std::unique_ptr CreateOpDesc(const std::string &op_type); static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param); + static Status UpdateDynamicTensorName(std::vector &desc); static Status VerifyOpInputOutputSizeByIr(const OpDesc ¤t_op_desc); static Status SetShapeRange(const std::string &op_name, const SingleOpTensorDesc &tensor_desc,