|
|
|
@ -163,6 +163,25 @@ std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector
|
|
|
|
|
return type_indexes;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TypeId JudgeMaxType(TypeId max_type, bool has_scalar_float32, bool has_scalar_int64, bool has_tensor_int8) {
|
|
|
|
|
if (max_type == TypeId::kNumberTypeBool) {
|
|
|
|
|
if (has_scalar_int64) {
|
|
|
|
|
max_type = TypeId::kNumberTypeInt64;
|
|
|
|
|
}
|
|
|
|
|
if (has_scalar_float32) {
|
|
|
|
|
max_type = TypeId::kNumberTypeFloat32;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
|
|
|
|
|
max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) {
|
|
|
|
|
max_type = TypeId::kNumberTypeFloat32;
|
|
|
|
|
}
|
|
|
|
|
if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) {
|
|
|
|
|
max_type = TypeId::kNumberTypeInt16;
|
|
|
|
|
}
|
|
|
|
|
return max_type;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
|
|
|
|
|
const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) {
|
|
|
|
|
std::map<SignatureEnumDType, TypeId> dst_type;
|
|
|
|
@ -178,14 +197,13 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
|
|
|
|
|
bool has_scalar_int64 = false;
|
|
|
|
|
bool has_tensor_int8 = false;
|
|
|
|
|
for (size_t index : indexes) {
|
|
|
|
|
if (!has_scalar_float32 && py::isinstance<py::float_>(py_args[index])) {
|
|
|
|
|
auto obj = py_args[index];
|
|
|
|
|
if (py::isinstance<py::float_>(obj)) {
|
|
|
|
|
has_scalar_float32 = true;
|
|
|
|
|
}
|
|
|
|
|
if (!has_scalar_int64 && !py::isinstance<py::bool_>(py_args[index]) && py::isinstance<py::int_>(py_args[index])) {
|
|
|
|
|
if (!py::isinstance<py::bool_>(obj) && py::isinstance<py::int_>(obj)) {
|
|
|
|
|
has_scalar_int64 = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto obj = py_args[index];
|
|
|
|
|
if (py::isinstance<tensor::Tensor>(obj)) {
|
|
|
|
|
auto arg = py::cast<tensor::TensorPtr>(obj);
|
|
|
|
|
TypeId arg_type_id = arg->data_type();
|
|
|
|
@ -202,21 +220,7 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (max_type == TypeId::kNumberTypeBool) {
|
|
|
|
|
if (has_scalar_int64) {
|
|
|
|
|
max_type = TypeId::kNumberTypeInt64;
|
|
|
|
|
}
|
|
|
|
|
if (has_scalar_float32) {
|
|
|
|
|
max_type = TypeId::kNumberTypeFloat32;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
|
|
|
|
|
max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) {
|
|
|
|
|
max_type = TypeId::kNumberTypeFloat32;
|
|
|
|
|
}
|
|
|
|
|
if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) {
|
|
|
|
|
max_type = TypeId::kNumberTypeInt16;
|
|
|
|
|
}
|
|
|
|
|
max_type = JudgeMaxType(max_type, has_scalar_float32, has_scalar_int64, has_tensor_int8);
|
|
|
|
|
(void)dst_type.emplace(std::make_pair(type, max_type));
|
|
|
|
|
}
|
|
|
|
|
return dst_type;
|
|
|
|
@ -274,11 +278,11 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// get prim and abstract info
|
|
|
|
|
(void)graph_info.append(op_exec_info->prim_id + "_");
|
|
|
|
|
(void)graph_info.append(op_exec_info->op_name + "_");
|
|
|
|
|
// get attr info
|
|
|
|
|
const auto &op_prim = op_exec_info->py_primitive;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_prim);
|
|
|
|
|
const auto &attr_map = op_prim->evaluate_added_attrs();
|
|
|
|
|
const auto &attr_map = op_prim->attrs();
|
|
|
|
|
(void)std::for_each(attr_map.begin(), attr_map.end(),
|
|
|
|
|
[&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });
|
|
|
|
|
|
|
|
|
@ -648,7 +652,6 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
|
|
|
|
|
if (!prim->HasPyObj()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Pyobj is empty";
|
|
|
|
|
}
|
|
|
|
|
op_exec_info->prim_id = GetId(prim->GetPyObj());
|
|
|
|
|
op_exec_info->py_primitive = prim;
|
|
|
|
|
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
|
|
|
|
op_exec_info->op_inputs = args[PY_INPUTS];
|
|
|
|
@ -701,10 +704,10 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|
|
|
|
input_node = GetInput(obj, op_mask);
|
|
|
|
|
}
|
|
|
|
|
// update abstract
|
|
|
|
|
if (input_node != nullptr && input_node->abstract() != nullptr) {
|
|
|
|
|
abs = input_node->abstract();
|
|
|
|
|
}
|
|
|
|
|
if (input_node != nullptr) {
|
|
|
|
|
if (input_node->abstract() != nullptr) {
|
|
|
|
|
abs = input_node->abstract();
|
|
|
|
|
}
|
|
|
|
|
inputs.emplace_back(input_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -2169,8 +2172,8 @@ void PynativeExecutor::UpdateCellDynamic(const std::string &cell_id) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
|
|
|
|
|
bool need_cloned, bool is_grad) {
|
|
|
|
|
void PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
|
|
|
|
|
bool need_cloned, bool is_grad) {
|
|
|
|
|
auto update_in_endgraph = need_cloned && !is_grad;
|
|
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
|
|
|
|
// Bprop just save backward graph
|
|
|
|
@ -2197,7 +2200,12 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
|
|
|
|
|
bool need_cloned, bool is_grad) {
|
|
|
|
|
auto update_in_endgraph = need_cloned && !is_grad;
|
|
|
|
|
UpdateBpropCellGraph(cell, g, cell_id, need_cloned, is_grad);
|
|
|
|
|
FuncGraphPtr tmp = g;
|
|
|
|
|
if (!IsFirstGradStep(top_cell_id_) && CheckDynamicCell(cell_id) && !CheckRealDynamicCell(cell_id)) {
|
|
|
|
|
MS_LOG(DEBUG) << "No need cloned";
|
|
|
|
|