!11803 allow list as parameter input & store op info using op_name instead of primitive id

From: @simson_wu
Reviewed-by: 
Signed-off-by:
pull/11803/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 5eddd9dbc0

@ -52,7 +52,6 @@ enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
struct OpExecInfo {
std::string op_name;
std::string op_index;
std::string prim_id;
PrimitivePyPtr py_primitive;
AbstractBasePtr abstract;

@ -163,6 +163,25 @@ std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector
return type_indexes;
}
TypeId JudgeMaxType(TypeId max_type, bool has_scalar_float32, bool has_scalar_int64, bool has_tensor_int8) {
if (max_type == TypeId::kNumberTypeBool) {
if (has_scalar_int64) {
max_type = TypeId::kNumberTypeInt64;
}
if (has_scalar_float32) {
max_type = TypeId::kNumberTypeFloat32;
}
}
if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) {
max_type = TypeId::kNumberTypeFloat32;
}
if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) {
max_type = TypeId::kNumberTypeInt16;
}
return max_type;
}
std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) {
std::map<SignatureEnumDType, TypeId> dst_type;
@ -178,14 +197,13 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
bool has_scalar_int64 = false;
bool has_tensor_int8 = false;
for (size_t index : indexes) {
if (!has_scalar_float32 && py::isinstance<py::float_>(py_args[index])) {
auto obj = py_args[index];
if (py::isinstance<py::float_>(obj)) {
has_scalar_float32 = true;
}
if (!has_scalar_int64 && !py::isinstance<py::bool_>(py_args[index]) && py::isinstance<py::int_>(py_args[index])) {
if (!py::isinstance<py::bool_>(obj) && py::isinstance<py::int_>(obj)) {
has_scalar_int64 = true;
}
auto obj = py_args[index];
if (py::isinstance<tensor::Tensor>(obj)) {
auto arg = py::cast<tensor::TensorPtr>(obj);
TypeId arg_type_id = arg->data_type();
@ -202,21 +220,7 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
}
}
}
if (max_type == TypeId::kNumberTypeBool) {
if (has_scalar_int64) {
max_type = TypeId::kNumberTypeInt64;
}
if (has_scalar_float32) {
max_type = TypeId::kNumberTypeFloat32;
}
}
if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) {
max_type = TypeId::kNumberTypeFloat32;
}
if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) {
max_type = TypeId::kNumberTypeInt16;
}
max_type = JudgeMaxType(max_type, has_scalar_float32, has_scalar_int64, has_tensor_int8);
(void)dst_type.emplace(std::make_pair(type, max_type));
}
return dst_type;
@ -274,11 +278,11 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
}
}
// get prim and abstract info
(void)graph_info.append(op_exec_info->prim_id + "_");
(void)graph_info.append(op_exec_info->op_name + "_");
// get attr info
const auto &op_prim = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(op_prim);
const auto &attr_map = op_prim->evaluate_added_attrs();
const auto &attr_map = op_prim->attrs();
(void)std::for_each(attr_map.begin(), attr_map.end(),
[&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });
@ -648,7 +652,6 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
if (!prim->HasPyObj()) {
MS_LOG(EXCEPTION) << "Pyobj is empty";
}
op_exec_info->prim_id = GetId(prim->GetPyObj());
op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
op_exec_info->op_inputs = args[PY_INPUTS];
@ -701,10 +704,10 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
input_node = GetInput(obj, op_mask);
}
// update abstract
if (input_node != nullptr && input_node->abstract() != nullptr) {
if (input_node != nullptr) {
if (input_node->abstract() != nullptr) {
abs = input_node->abstract();
}
if (input_node != nullptr) {
inputs.emplace_back(input_node);
}
}
@ -2169,7 +2172,7 @@ void PynativeExecutor::UpdateCellDynamic(const std::string &cell_id) {
}
}
void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
void PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned, bool is_grad) {
auto update_in_endgraph = need_cloned && !is_grad;
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
@ -2197,7 +2200,12 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt
}
return;
}
}
void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned, bool is_grad) {
auto update_in_endgraph = need_cloned && !is_grad;
UpdateBpropCellGraph(cell, g, cell_id, need_cloned, is_grad);
FuncGraphPtr tmp = g;
if (!IsFirstGradStep(top_cell_id_) && CheckDynamicCell(cell_id) && !CheckRealDynamicCell(cell_id)) {
MS_LOG(DEBUG) << "No need cloned";

@ -241,6 +241,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool CheckCellGraph(const std::string &cell_id, bool is_grad = false);
bool CheckDynamicCell(const std::string &cell_id);
bool CheckRealDynamicCell(const std::string &cell_id);
void UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned,
bool is_grad);
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned = false, bool is_grad = false);
void ClearCnodeRes(const AnfNodePtr &node, std::unordered_set<AnfNodePtr> *node_set);

@ -159,7 +159,7 @@ class Parameter(Tensor_):
Tensor_.__init__(self, mstype.int64, ())
elif isinstance(default_input, float):
Tensor_.__init__(self, mstype.float32, ())
elif isinstance(default_input, np.ndarray):
elif isinstance(default_input, (np.ndarray, list)):
Tensor_.__init__(self, default_input)
else:
raise TypeError(f"Parameter input must be [`Tensor`, `Number`]."

Loading…
Cancel
Save