!1440 fix issue of dynamic outputs compile failed

Merge pull request !1440 from wenchunjiang/fix_dynamic_output_bug
pull/1440/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 728afd23db

@ -15,6 +15,12 @@
"""tbe common"""
import json
import os
from attrdict import AttrDict
class ParamType(AttrDict):
Required = "required"
Dynamic = "dynamic"
Optional = "optional"
class TBEException(Exception):
@ -80,7 +86,62 @@ def _check_arg_info(item):
raise ValueError("Json string Errors, key:ori_format not found.")
if 'dtype' not in item or not item['dtype']:
raise ValueError("Json string Errors, key:dtype not found.")
if 'param_type' not in item or not item['param_type']:
raise ValueError("Json string Errors, key:param_type not found.")
def get_input_output(io_info, args):
"""
Parse args.
Args:
io_info (dict): input or output info dict.
args (list): the arguments list.
Raises:
Exception: If specific keyword is not found.
"""
for item in io_info:
arg = []
for info in item:
if 'valid' not in info:
raise ValueError("Json string Errors, key:valid not found.")
if info['valid']:
_check_arg_info(info)
del info['valid']
del info['name']
if len(item) > 1:
arg.append(info)
else:
if info['param_type'] == ParamType.Dynamic:
arg.append(info)
args.append(arg)
else:
args.append(info)
else:
if len(item) > 1:
arg.append(None)
else:
args.append(None)
if len(item) > 1:
args.append(arg)
def get_attr(attr_info, args):
"""
Parse args.
Args:
attr_info (dict): input or output info dict.
args (list): the arguments list.
Raises:
Exception: If specific keyword is not found.
"""
for item in attr_info:
if item["valid"]:
if 'value' not in item:
raise ValueError("Json string Errors, attr key:value not found.")
if item["name"] != "isRef":
args.append(item['value'])
def get_args(op_info, arg_type):
"""
@ -98,35 +159,12 @@ def get_args(op_info, arg_type):
args = []
if not op_info[arg_type]:
return args
if arg_type in ['inputs', 'outputs']:
for item in op_info[arg_type]:
arg = []
for info in item:
if 'valid' not in info:
raise ValueError("Json string Errors, key:valid not found.")
if info['valid']:
_check_arg_info(info)
del info['valid']
del info['name']
if len(item) > 1:
arg.append(info)
else:
args.append(info)
else:
if len(item) > 1:
arg.append(None)
else:
args.append(None)
if len(item) > 1:
args.append(arg)
arg_info = op_info[arg_type]
if arg_type in ['inputs', 'outputs']:
get_input_output(arg_info, args)
elif arg_type == 'attrs':
for item in op_info[arg_type]:
if item["valid"]:
if 'value' not in item:
raise ValueError("Json string Errors, attr key:value not found.")
if item["name"] != "isRef":
args.append(item['value'])
get_attr(arg_info, args)
return args

@ -147,6 +147,7 @@ bool TbeKernelJsonCreator::GenInputDescJson(const shared_ptr<AnfNode> &anf_node,
input_desc_json["format"] = format;
}
input_desc_json["valid"] = value;
input_desc_json["param_type"] = input_ptr->param_type();
input_list->emplace_back(input_desc_json);
}
return true;
@ -356,6 +357,7 @@ void TbeKernelJsonCreator::GenOutputList(const shared_ptr<AnfNode> &anf_node, co
output_obj["ori_format"] = kOpFormat_NCHW;
output_obj["name"] = output_ptr->name();
output_obj["valid"] = true;
output_obj["param_type"] = output_ptr->param_type();
output_list->emplace_back(output_obj);
(*output_idx)++;

Loading…
Cancel
Save