|
|
|
@ -110,9 +110,15 @@ py::object GetTupleObj(const py::object &obj) {
|
|
|
|
|
return obj_tuple;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) {
|
|
|
|
|
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) {
|
|
|
|
|
auto &py_args = *out_args;
|
|
|
|
|
py::tuple input_mask(args.size());
|
|
|
|
|
for (size_t i = 0; i < args.size(); ++i) {
|
|
|
|
|
if (py::hasattr(args[i], "__parameter__")) {
|
|
|
|
|
input_mask[i] = true;
|
|
|
|
|
} else {
|
|
|
|
|
input_mask[i] = false;
|
|
|
|
|
}
|
|
|
|
|
py_args[i] = GetTupleObj(args[i]);
|
|
|
|
|
}
|
|
|
|
|
auto signature = prim->signatures();
|
|
|
|
@ -121,7 +127,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *
|
|
|
|
|
[](const Signature &sig) { return sig.dtype; });
|
|
|
|
|
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
|
|
|
|
|
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
|
|
|
|
|
return;
|
|
|
|
|
return input_mask;
|
|
|
|
|
}
|
|
|
|
|
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
|
|
|
|
for (size_t i = 0; i < dtypes.size(); ++i) {
|
|
|
|
@ -160,6 +166,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return input_mask;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
|
|
|
|
@ -167,7 +174,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
|
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
|
|
|
ValuePtr input_value = PyAttrValue(py_args[i]);
|
|
|
|
|
if (input_value->isa<tensor::Tensor>()) {
|
|
|
|
|
if (!py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()) {
|
|
|
|
|
args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
|
|
|
|
|
} else {
|
|
|
|
|
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
|
|
|
|
@ -179,7 +186,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
|
|
|
|
|
|
|
|
|
|
OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
|
|
|
|
if (args.size() != PY_ARGS_NUM) {
|
|
|
|
|
MS_LOG(ERROR) << "Four args are needed by RunOp";
|
|
|
|
|
MS_LOG(ERROR) << "Three args are needed by RunOp";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto op_exec_info = std::make_shared<OpExecInfo>();
|
|
|
|
@ -195,14 +202,13 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
|
|
|
|
size_t input_num = a.size();
|
|
|
|
|
op_exec_info->op_inputs = py::tuple(input_num);
|
|
|
|
|
|
|
|
|
|
ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs);
|
|
|
|
|
op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs);
|
|
|
|
|
// use python infer method
|
|
|
|
|
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
|
|
|
|
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get());
|
|
|
|
|
}
|
|
|
|
|
op_exec_info->py_primitive = prim;
|
|
|
|
|
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
|
|
|
|
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
|
|
|
|
|
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
|
|
|
|
|
return nullptr;
|
|
|
|
@ -488,14 +494,14 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr PynativeExecutor::MakeCNode(const py::args &args, const py::tuple &out) {
|
|
|
|
|
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
|
|
|
|
|
if (!grad_flag_ || graph_info_map_.size() == 0) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
|
|
|
|
auto prim = op_exec_info->py_primitive;
|
|
|
|
|
inputs.push_back(NewValueNode(prim));
|
|
|
|
|
py::tuple op_masks = args[PY_INPUT_MASK];
|
|
|
|
|
py::tuple op_masks = op_exec_info->inputs_mask;
|
|
|
|
|
py::list op_args = args[PY_INPUTS];
|
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
|
|
for (size_t i = 0; i < op_args.size(); i++) {
|
|
|
|
@ -584,7 +590,7 @@ py::tuple RunOp(const py::args &args) {
|
|
|
|
|
return err_ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto node = PynativeExecutor::GetInstance()->MakeCNode(args, result);
|
|
|
|
|
auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
|
|
|
|
|
if (node != nullptr) {
|
|
|
|
|
node->set_abstract(op_exec_info->abstract);
|
|
|
|
|
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString();
|
|
|
|
@ -705,7 +711,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
|
|
|
|
|
}
|
|
|
|
|
cell_graph_map_[cell_id] = curr_g_;
|
|
|
|
|
auto out_id = GetId(out);
|
|
|
|
|
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id)) {
|
|
|
|
|
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) {
|
|
|
|
|
// cell construct return x, y
|
|
|
|
|
if (py::isinstance<py::tuple>(out)) {
|
|
|
|
|
std::vector<AnfNodePtr> args;
|
|
|
|
@ -727,12 +733,26 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto output_node = GetObjNode(out);
|
|
|
|
|
AnfNodePtr output_node;
|
|
|
|
|
if (graph_info_map_[curr_g_].param_map.count(out_id)) {
|
|
|
|
|
output_node = graph_info_map_[curr_g_].param_map[out_id];
|
|
|
|
|
} else {
|
|
|
|
|
output_node = GetObjNode(out);
|
|
|
|
|
}
|
|
|
|
|
curr_g_->set_output(output_node);
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
inputs.push_back(NewValueNode(curr_g_));
|
|
|
|
|
MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
|
|
|
|
|
resource_->manager()->AddFuncGraph(curr_g_);
|
|
|
|
|
// custom bprop debug
|
|
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
|
|
|
|
MS_LOG(DEBUG) << "Use cell custom bprop function.";
|
|
|
|
|
FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
|
|
|
|
|
if (bprop_graph != nullptr) {
|
|
|
|
|
(void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
|
|
|
|
|
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
|
|
|
|
|
if (curr_g_ != top_g_) {
|
|
|
|
|
Popp();
|
|
|
|
|