Feature: atc support dynamic input

pull/271/head
wxl 4 years ago
parent 5840f16819
commit efd77465b5

@ -48,6 +48,8 @@ constexpr char const *kKeyShapeRange = "shape_range";
constexpr char const *kKeyValue = "value";
constexpr char const *kKeyFormat = "format";
constexpr char const *kFileSuffix = ".om";
constexpr char const *kKeyDynamicInput = "dynamic_input";
constexpr char const *kKeyDynamicOutput = "dynamic_output";
constexpr int kDumpJsonIndent = 2;
constexpr int kShapeRangePairSize = 2;
constexpr int kShapeRangeLow = 0;
@ -124,6 +126,10 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) {
if (tensor_name != j.end()) {
desc.name = tensor_name->get<string>();
}
auto dynamic_input_name = j.find(kKeyDynamicInput);
if (dynamic_input_name != j.end()) {
desc.dynamic_input_name = dynamic_input_name->get<string>();
}
}
void from_json(const Json &j, SingleOpAttr &attr) {
@ -276,6 +282,23 @@ std::unique_ptr<OpDesc> SingleOpParser::CreateOpDesc(const string &op_type) {
return std::unique_ptr<OpDesc>(new(std::nothrow) OpDesc(op_type, op_type));
}
Status SingleOpParser::UpdateDynamicTensorName(std::vector<SingleOpTensorDesc> &desc) {
std::map<std::string, int> dynamic_name_map;
for (auto &tensor : desc) {
if (tensor.dynamic_input_name.empty()) {
continue;
}
if (dynamic_name_map.find(tensor.dynamic_input_name) == dynamic_name_map.end()) {
dynamic_name_map[tensor.dynamic_input_name] = 0;
} else {
dynamic_name_map[tensor.dynamic_input_name]++;
}
tensor.name = tensor.dynamic_input_name + std::to_string(dynamic_name_map[tensor.dynamic_input_name]);
}
GELOGD("Update dynamic tensor name success!");
return SUCCESS;
}
Status SingleOpParser::ConvertToBuildParam(int index,
const SingleOpDesc &single_op_desc,
SingleOpBuildParam &build_param) {
@ -471,6 +494,11 @@ Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector<Si
SingleOpDesc single_op_desc;
GELOGI("Parsing op[%d], jsonStr = %s", index, single_op_json.dump(kDumpJsonIndent).c_str());
single_op_desc = single_op_json;
if (UpdateDynamicTensorName(single_op_desc.input_desc) != SUCCESS) {
GELOGE(FAILED, "Update dynamic tensor name failed!");
return FAILED;
}
if (!Validate(single_op_desc)) {
GELOGE(PARAM_INVALID, "Validate the index[%d] of op failed when read json file[%s].", index, file.c_str());
return PARAM_INVALID;

@ -33,6 +33,7 @@ struct SingleOpTensorDesc {
std::vector<std::vector<int64_t>> dim_ranges;
ge::Format format = ge::FORMAT_RESERVED;
ge::DataType type = ge::DT_UNDEFINED;
std::string dynamic_input_name;
};
struct SingleOpAttr {
@ -70,6 +71,7 @@ class SingleOpParser {
static bool Validate(const SingleOpDesc &op_desc);
static std::unique_ptr<OpDesc> CreateOpDesc(const std::string &op_type);
static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param);
static Status UpdateDynamicTensorName(std::vector<SingleOpTensorDesc> &desc);
static Status VerifyOpInputOutputSizeByIr(const OpDesc &current_op_desc);
static Status SetShapeRange(const std::string &op_name,
const SingleOpTensorDesc &tensor_desc,

Loading…
Cancel
Save